由 Evgeni Burovski、Ralf Gommers 和 Mario Lezcano 撰写

Quansight 工程师已在 PyTorch 2.1 中实现了通过 torch.compile 追踪 NumPy 代码的支持。此功能利用 PyTorch 的编译器生成高效的融合向量化代码,无需修改您的原始 NumPy 代码。更重要的是,它还允许通过在 torch.device("cuda") 下运行 torch.compile 来在 CUDA 上执行 NumPy 代码!

在本文中,我们将介绍如何使用此功能,并提供一些技巧和窍门,帮助您充分利用它。

将 NumPy 代码编译成并行 C++

我们以 K-Means 算法中的一步作为我们的运行示例。这段代码是从这本 NumPy 书中借用的

import numpy as np

def kmeans(X, means):
    return np.argmin(np.linalg.norm(X - means[:, None], axis=2), axis=0)

我们创建了一个包含 2000 万个随机 2-D 点的合成数据集。我们可以看到,如果选择合适的均值,该函数将为所有这些点返回正确的簇

npts = 10_000_000
X = np.repeat([[5, 5], [10, 10]], [npts, npts], axis=0)
X = X + np.random.randn(*X.shape)  # 2 distinct "blobs"
means = np.array([[5, 5], [10, 10]])
np_pred = kmeans(X, means)

在 AMD 3970X CPU 上基准测试此函数,我们得到 1.26 秒的基线。

现在编译此函数就像用 torch.compile 包装它并用示例输入执行它一样简单

import torch

compiled_fn = torch.compile(kmeans)
compiled_pred = compiled_fn(X, means)
assert np.allclose(np_pred, compiled_pred)

编译后的函数在单核上运行时速度提升 9 倍。更棒的是,与 NumPy 相比,我们生成的代码确实利用了处理器的所有核心。因此,当我们将其在 32 核上运行时,可以获得 57 倍的速度提升。请注意,PyTorch 默认会使用所有可用核心,除非明确限制,所以使用 torch.compile 时这是默认行为。

我们可以通过运行带有环境变量 TORCH_LOGS=output_code 的脚本来检查生成的 C++代码。这样做时,我们可以看到 torch.compile 能够将广播和两个缩减操作编译成一个 for 循环,并使用 OpenMP 进行并行化。

extern "C" void kernel(const double* in_ptr0, const long* in_ptr1, long* out_ptr0) {
    #pragma omp parallel num_threads(32)
    #pragma omp for
    for(long i0=0L; i0<20000000L; i0+=1L) {
        auto tmp0 = in_ptr0[2L*i0];
        auto tmp1 = in_ptr1[0L];
        auto tmp5 = in_ptr0[1L + (2L*i0)];
        auto tmp6 = in_ptr1[1L];
        // Rest of the kernel omitted for brevity

将 NumPy 代码编译成 CUDA

将我们的代码编译成在 CUDA 上运行,只需将默认设备设置为 CUDA 即可。

with torch.device("cuda"):
    cuda_pred = compiled_fn(X, means)
assert np.allclose(np_pred, cuda_pred)

通过检查通过 TORCH_LOGS=output_code 生成的代码,我们发现, torch.compile 并不是直接生成 CUDA 代码,而是生成相对易读的 triton 代码

def triton_(in_ptr0, in_ptr1, out_ptr0, XBLOCK : tl.constexpr):
    xnumel = 20000000
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (2*x0), xmask)
    tmp1 = tl.load(in_ptr1 + (0))
    // Rest of the kernel omitted for brevity

在 RTX 2060 上运行这个小片段,比原始 NumPy 代码快 8 倍。这已经很不错了,但考虑到我们在 CPU 上看到的加速效果,这并不特别令人印象深刻。让我们看看如何通过一些小的改动来最大限度地发挥我们的 GPU 性能。

float64float32 。许多 GPU,尤其是消费级 GPU,在运行 float64 上的操作时相当缓慢。因此,将数据生成改为 float32 ,原始 NumPy 代码只是稍微快一点,大约快 9%,但我们的 CUDA 代码快了 40%,比纯 NumPy 代码快 11 倍。

torch.compile 默认情况下,遵循 NumPy 语义,因此它使用 np.float64 作为其所有创建操作的默认 dtype。正如所讨论的,这可能会影响性能,因此可以通过设置来更改此默认值

from torch._dynamo import config
config.numpy_default_float = "float32"

CPU <> CUDA 复制。11 倍的速度提升是不错的,但与 CPU 的数字相比还远远不够。这是由于 torch.compile 在幕后进行的一个小转换造成的。上面的代码接收 NumPy 数组并返回 NumPy 数组。所有这些数组都在 CPU 上,但计算是在 GPU 上进行的。这意味着每次调用该函数时, torch.compile 都必须将这些数组从 CPU 复制到 GPU,然后再将结果复制回 CPU 以保持原始语义。在 NumPy 中没有针对此问题的原生解决方案,因为 NumPy 没有 device 的概念。话虽如此,我们可以通过创建一个包装器来绕过这个问题,使其接受 PyTorch 张量并返回 PyTorch 张量。

@torch.compile
def tensor_fn(X, means):
    X, means = X.numpy(), means.numpy()
    ret = kmeans(X, means)
    return torch.from_numpy(ret)

def cuda_fn(X, means):
    with torch.device("cuda"):
        return tensor_fn(X, means)

此函数现在接收 CUDA 内存中的张量,并返回 CUDA 内存中的张量,但该函数本身是用 NumPy 编写的! torch.compile 使用 numpy()from_numpy() 调用作为提示,并优化它们,并在内部仅使用 PyTorch 张量而不移动内存。当我们保持张量在 CUDA 中并在 float32 中执行计算时,我们在 float32 数组上比初始 NumPy 实现快 200 倍。

混合 NumPy 和 PyTorch。在这个例子中,我们不得不编写一个小型适配器来将张量转换为 ndarray,然后再将 ndarray 转换回张量。在混合 PyTorch 和 NumPy 的程序中,将张量转换为 ndarray 通常实现为 x.detach().cpu().numpy() ,或者简单地实现为 x.numpy(force=True) 。由于在 torch.compile 下运行时,我们可以运行 CUDA 中的 NumPy 代码,因此我们可以将这种转换模式实现为对 x.numpy() 的调用,就像我们上面做的那样。这样做并在 device("cuda") 下运行生成的代码将生成从原始 NumPy 调用中直接生成高效的 CUDA 代码,而无需将数据从 CUDA 复制到 CPU。请注意,生成的代码在没有 torch.compile 的情况下无法运行。为了在急切模式下运行,需要回滚到 x.numpy(force=True)

进一步的加速技巧

一般建议。我们展示的 CUDA 代码已经相当高效,但确实,运行示例相当简短。当处理更大的程序时,我们可能需要调整其部分内容以提高效率。一个好的起点是 torch.compile 的多个教程和常见问题解答。这展示了检查跟踪过程和识别可能导致性能下降的问题代码的多种方法。

编译 NumPy 代码时的建议。尽管 NumPy 与 PyTorch 相当相似,但它们的使用方式往往大相径庭。在 NumPy 中进行计算后,根据数组中的值执行 if/else 判断,或者通过布尔掩码进行原地操作,这种情况相当常见。这些构造虽然由 torch.compile 支持,但会降低其性能。将代码以无分支方式编写以避免图断开,或避免原地操作等更改可以大有裨益。

要编写快速的 NumPy 代码,最好避免使用循环,但有时它们是不可避免的。在遍历循环时, torch.compile 会尝试完全展开它。这有时是可取的,但有时可能甚至不可能,例如在 while 循环中,我们有一个动态的停止条件。在这些情况下,最好只是编译循环体,也许一次编译几个迭代(循环展开)。

调试 NumPy 代码。当涉及到编译器时,调试相当棘手。为了确定你遇到的错误是 torch.compile 错误,还是程序错误,你可以执行你的 NumPy 程序,不使用 torch.compile ,通过将 NumPy 导入替换为 import torch._numpy as np 。这仅应用于调试目的,绝对不能替代 PyTorch API,因为它速度较慢,并且作为私有 API,可能会随时更改。有关其他技巧,请参阅此常见问题解答。

NumPy 与 torch.compile NumPy 之间的区别

NumPy 标量。在几乎任何 PyTorch 会返回 0-D 张量(例如从 np.sum )的情况下,NumPy 都会返回 NumPy 标量。在大多数情况下,这没问题。唯一它们行为不同的情况是当 NumPy 标量被隐式用作 Python 标量时。例如,

>>> np.asarray(2) * [1, 2, 3]  # 0-D array is an array-like
array([2, 4, 6])
>>> u = np.int32(2)
>>> u * [1, 2, 3]              # scalar decays into a Python int
[1, 2, 3, 1, 2, 3]
>>> torch.compile(lambda: u * [1, 2, 3])()
array([2, 4, 6])               # acts as a 0-D array, not as a scalar ?!?!

如果我们编译前两行,我们会看到 torch.compileu 视为 0-D 数组。为了恢复急切语义,我们只需要显式进行类型转换

>>> torch.compile(lambda: int(u) * [1, 2, 3])()
[1, 2, 3, 1, 2, 3]

类型提升和版本控制。NumPy 的类型提升规则有时可能会让人有些惊讶

>>> np.zeros(1, dtype=np.int8) + 127
array([127], dtype=int8)
>>> np.zeros(1, dtype=np.int8) + 128
array([128], dtype=int16)

NumPy 2.0 正在改变这些规则,使其更接近 PyTorch 的规则。相关技术文档是 NEP 50。 torch.compile 已经实施了 NEP 50,而不是即将被弃用的规则。

通常情况下,在 torch.compile 中的 NumPy 遵循 NumPy 2.0 的预发布版。

除了 NumPy 之外:SciPy 和 scikit-learn

与努力让 torch.compile 理解 NumPy 代码并行,其他 Quansight 工程师已经设计并提出了一个在 scikit-learn 和 SciPy 中支持 PyTorch 张量的方法。这一举措得到了这些库维护者的热情响应,因为已经证明使用 PyTorch 作为后端通常会带来显著的加速。这两个项目现在已经在多个 API 和子模块中合并了对 PyTorch 张量的初始支持。

这为未来 PyTorch 张量可以在 Python 数据生态系统的其他库中使用奠定了基础。更重要的是,这将使这些库在 GPU 上运行成为可能,甚至可以编译混合这些库和 PyTorch 的代码,就像我们在本文中讨论的那样。

如果你想了解更多关于这项工作的信息,了解如何使用它,或了解如何帮助推动它的发展,请参阅这篇其他博客文章。

结论

PyTorch 自诞生以来就致力于成为与 Python 生态系统兼容的框架。使 NumPy 程序可编译,以及为其他知名库建立必要的工具是这一方向上的两个更多步骤。Quansight 和 Meta 继续携手合作,提高 PyTorch 与生态系统其他部分的兼容性。

我们想感谢来自 Quansight 的 Mengwei、Voz 和 Ed 在将我们的工作与 torch.compile 集成方面提供的宝贵帮助。我们还要感谢 Meta 为该项目以及之前改善 PyTorch 中 NumPy 兼容性的工作,以及支持 scikit-learn 和 SciPy 中 PyTorch 的项目提供资金。这些都是巩固 PyTorch 作为开源 Python 数据生态系统中首选框架的巨大飞跃。