快捷键

动态形状 ¶

代码:symbolic_shapes.py

参见:动态形状手册

动机 ¶

深度学习编译器通常只适用于静态形状,也就是说,它们生成的编译程序只适用于单个特定的输入形状配置,并且如果任何输入形状发生变化,则必须重新编译。这种假设对于今天大多数常用的深度学习模型来说效果很好,但也有一些情况是不够的:

  • 一些维度,例如批量大小或序列长度,可能会变化。例如,一个执行自适应批处理的推理服务将根据其批处理窗口内接收到的请求数量执行具有不同批量大小的推理请求。我们还可以考虑只将可变大小序列填充到批次中的最大序列长度,这个长度可能会从批次到批次而变化。

  • 一些模型表现出数据依赖的输出形状,也就是说,它们的输出和中间结果的尺寸可能取决于实际输入数据,这些数据在运行中可能有所不同。例如,检测模型可能会首先生成可变数量的潜在边界框,然后再运行更昂贵的图像识别模型以确定主题是否在边界框内。边界框的数量是数据依赖的。

  • 在处理稀疏表示时,数据依赖形状的一个特别重要的例子是处理稀疏张量、锯齿形张量和图神经网络。在这些所有情况下,要处理的数据量取决于问题的稀疏结构,这通常以数据依赖的方式变化。

在支持动态形状时,我们选择不支持动态秩程序,例如输入张量在维度上变化的程序,因为在现实世界的深度学习程序中这种模式很少出现,并且它避免了需要对形状的符号列表进行归纳推理的需要。

简化版公共 API ¶

PyTorch 2.1 的默认动态行为是:

  • PT2 默认假设一切都是静态的

  • 如果我们因为大小变化而重新编译,我们将尝试将那个大小作为动态的重新编译(大小变化的可能性很大,未来可能会变化)。这种泛化可能会失败(例如,因为用户代码在相关大小上做了条件分支或 PT2 缺少动态形状支持)。如果您想了解为什么 PT2 对某些代码进行了过度专业化,请使用 TORCH_LOGS=dynamic 运行并查找表示何时添加守卫以及原因的“eval”条目。

  • 如果您事先知道某物将是动态的,您可以使用 torch._dynamo.mark_dynamic(tensor, dim) 跳过第一次重新编译。如果您事先知道该维度可以取的 minmax 值,您可以指定 torch._dynamo.mark_dynamic(tensor, dim, min=min, max=max)

  • 如果你输入 torch.compile(dynamic=False) ,我们将关闭重新编译时的自动动态形状,并且总是为每个不同的尺寸重新编译。相反,如果你输入 torch.compile(dynamic=True) ,我们将尽可能使一切尽可能动态。这主要适用于小型操作符;如果你在一个大型模型上尝试它,它可能会(1)导致 PT2 崩溃,并且(2)没有明显原因地运行缓慢。

守护模型 ¶

在考虑如何为 TorchDynamo 和 TorchInductor 添加动态形状支持时,我们做出了一个重要的设计决策:为了重用针对 PyTorch API 编写的 Python/C++分解和其他现有代码,我们必须能够追踪动态形状。与可能捕获条件分支的两种情况的完全符号系统不同,我们总是选择一个分支,并在假设我们将来只会使用这个追踪来选择该分支的情况下进行追踪。为此,我们为每个符号尺寸维护一个“提示”,说明它在编译时的具体值(因为 TorchDynamo 是一个即时编译器,它总是知道实际的输入尺寸。)当我们对张量执行条件时,我们只需查阅提示以确定要选择哪个分支。

这极大地简化了我们产生的符号形状公式,但意味着我们有一个更复杂的系统来管理守卫。例如,考虑以下程序:

def f(x, y):
    z = torch.cat([x, y])
    if z.size(0) > 2:
        return z.mul(2)
    else:
        return z.add(2)

我们将使用 TorchInductor 编译的最终 IR 将是 torch.cat([x, y]).add(2)torch.cat([x, y]).mul(2) (条件被展开),但为了确定我们处于哪个分支,我们需要知道中间变量 z 的大小。因为 TorchDynamo 必须事先知道编译的跟踪是否有效(我们不支持像一些 JIT 编译器那样的退出,),我们必须能够将 z.size(0) 作为一个关于输入 x.size(0) + y.size(0) 的表达式来减少。这是通过为 PyTorch 中的所有操作符编写元函数来完成的,这些元函数可以将大小信息传播到张量的输出,而无需在节点上实际执行计算。

总体架构

符号形状工作流程

  1. 当我们在 Dynamo 中开始编译一个帧时,我们会分配一个 ShapeEnv(附加到 FakeTensorMode),它负责跟踪符号形状的状态。

  2. 我们在输入时为张量分配符号大小(静态或动态是一个策略决策,有一些可调节的旋钮)。

  3. 我们通过算子传播符号大小,同时维护(1)FX IR,以便我们可以忠实地导出符号计算,以及(2)表示大小变量的 Sympy 表达式,这样我们就可以对它们进行推理。

  4. 当我们在 Dynamo 跟踪或 Inductor 优化中对符号大小进行条件判断时,我们会根据条件添加守卫。这些守卫可以从 Python 和 C++中诱导出来。

  5. 这些守卫可以对符号变量进行进一步简化。例如,如果您断言 s0 == 4 ,我们现在可以替换所有 s0 的出现为 4

  6. 当我们完成跟踪和优化后,我们将所有这些守卫安装到编译后的代码中;只有当所有守卫评估为真时,编译后的代码才是可重用的。

重要文件:

  • C++ SymInt API: c10/core/SymInt.hSymFloat.hSymBool.h

  • Python SymInt API: torch/__init__.py (查找 SymInt/SymFloat/SymBool

  • C++管道: c10/core/SymNodeImpl.htorch/csrc/utils/python_symnode.htorch/csrc/jit/python/init.cpp

  • Python 基础设施: torch/fx/experimental/symbolic_shapes.py

  • 其他重要文件: torch/_subclasses/fake_tensor.pytorch/_meta_registrations.py ,decomps,PrimTorch 引用

简化版内部 API

理解 Python 类层次结构:

  • SymInt/SymFloat/SymBool:这些是用户可见的类,它们模拟其 int/float/bool 对应物。如果你将两个 SymInt 相加,我们会给你一个新的 SymInt,该 SymInt 可以符号化跟踪整数加法操作。

  • SymNode:这是内部结构(例如,通过 symint.node 访问),它包含实际的符号跟踪信息。SymNode 是类型擦除的;这使得表示混合类型操作更加方便。请注意,技术上你不必从 SymInt 调用 Python SymNode;例如,XLA 的 C++ SymNodeImpl 将取代 SymNode。

  • ShapeEnv:每次编译的上下文状态,用于跟踪我们迄今为止积累的所有自由符号和守卫。每个 SymNode 都记录其 ShapeEnv(但反之则不然;只有参与守卫的 SymNodes 才会被使用)。

C++相当相似:

  • c10::SymInt/SymFloat/SymBool:模拟 int/float/bool 的用户可见类。

  • c10::SymNode/SymNodeImpl:与 SymNode 类似。

  • C++中没有 ShapeEnv;为了便于调试,整个符号推理装置都在 Python 中。

当您编写可以用 make_fx 跟踪的代码时,它必须能够处理 SymInt/SymFloat/SymBool 流经它。动态形状手册提供了一些如何做到这一点的指导。

DimDynamic 策略¶

符号推理:

  • 值范围

  • Sympy 使用说明

  • 约束

  • DimDynamic/约束

无后盾的 SymInts ¶

为了解决控制流,我们需要检查符号整数的提示,即实际值,以确定要进入哪个分支。然而,在某些情况下,我们可能没有提示:当大小变量从像 .nonzero().item() 这样的数据相关操作中出现时,就会出现所谓的无后盾符号整数。对这些符号整数执行控制流是不合法的,因此我们必须在这些操作上执行图断点。

草率实现的话,这太过限制了:如果你尝试对无后盾符号整数进行任何操作,大多数 PyTorch 程序将立即失败。以下是实现这一功能的最重要增强:

  • 在张量创建时,PyTorch 预先计算了大量关于张量的数据;例如,如果你使用 empty_strided 来创建张量,我们将急切地对步长进行排序,并确定张量是否非重叠且密集。排序会产生许多保护措施。然而,更常见的是使用像 empty 这样的高级 API 直接生成张量,这保证了生成的张量是非重叠且密集的。我们对 PyTorch 进行了修改,以避免不必要地重新计算这些属性。

  • 即使需要非平凡的计算,有时某个属性实际上根本不会被查询。将这些预计算的属性设置为延迟计算,可以让我们在不需要的情况下避免对未支持符号整数的保护。

  • 整数张量中的数据通常不知道是非负的。然而,我们提供了一个 API constrain_range ,用户可以通过它指定一个大小被已知的上下限所限制。

在 PT2 的未来版本(超越 PT2.1)中,我们将扩展我们的推理系统,根据使用情况推断未支持符号整数具有大小属性。例如,如果您将 .item() 调用的结果传递给工厂函数 torch.empty ,我们将自动推断该结果是大小(因为如果不是,它将失败。)这个假设将在运行时得到验证,如果未满足,将引发错误。


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

查看 PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源