torch.export 编程模型 ¶
本文档旨在解释 torch.export.export()
的行为和能力。目的是帮助您建立对 torch.export.export()
处理代码的直觉。
跟踪基础 ¶
通过在“example”输入上跟踪模型执行并记录沿跟踪路径观察到的 PyTorch 操作和条件来捕获表示您的模型的图。只要输入满足相同的条件,就可以在不同的输入上运行此图。
torch.export.export()
的基本输出是一个包含相关元数据的 PyTorch 操作的单个图。此输出的确切格式在 torch.export IR 规范中有详细说明。
严格跟踪与非严格跟踪
torch.export.export()
提供两种跟踪模式。
在非严格模式下,我们使用正常的 Python 解释器跟踪程序。您的代码将像在急切模式下一样执行;唯一的区别是所有 Tensor 都被替换为没有数据、但有形状和其他形式元数据的虚拟 Tensor,这些虚拟 Tensor 被包装在记录所有操作到图中的代理对象中。我们还捕获了 Tensor 形状上的条件,以保护生成代码的正确性。
在严格模式下,我们首先使用 TorchDynamo 跟踪程序,TorchDynamo 是一个 Python 字节码分析引擎。TorchDynamo 实际上并不执行您的 Python 代码。相反,它符号性地分析它并基于结果构建一个图。一方面,这种分析允许 torch.export.export()
提供额外的 Python 级别安全性保证(超越在非严格模式下捕获 Tensor 形状条件)。另一方面,并非所有 Python 特性都支持这种分析。
尽管目前跟踪的默认模式是严格的,但我们强烈建议使用非严格模式,该模式很快将成为默认模式。对于大多数模型,Tensor 形状的条件足以保证正确性,而额外的 Python 级别安全性保证没有影响;同时,在 TorchDynamo 中遇到不受支持的 Python 特性的可能性是一个不必要的风险。
在本文件的其余部分,我们假设我们在非严格模式下进行跟踪;特别是,我们假设所有 Python 特性都受支持。
值:静态与动态
理解 torch.export.export()
的行为的关键概念是静态值和动态值之间的区别。
静态值
静态值是指在导出时固定不变,在导出的程序执行过程中不能改变的值。当在跟踪过程中遇到该值时,我们将其视为常量,并将其硬编码到图中。
当执行一个操作(例如 x + y
)且所有输入都是静态值时,该操作的输出将直接硬编码到图中,并且该操作不会显示(即它被“常量折叠”)。
当一个值被硬编码到图中时,我们说该图已经针对该值进行了特殊化。例如:
import torch
class MyMod(torch.nn.Module):
def forward(self, x, y):
z = y + 7
return x + z
m = torch.export.export(MyMod(), (torch.randn(1), 3))
print(m.graph_module.code)
"""
def forward(self, arg0_1, arg1_1):
add = torch.ops.aten.add.Tensor(arg0_1, 10); arg0_1 = None
return (add,)
"""
这里,我们提供 3
作为 y
的追踪值;它被视为静态值,并添加到 7
中,在图中烧录静态值 10
。
动态值 ¶
动态值是指可以从一次运行改变到另一次运行的值。它的行为就像一个“正常”的函数参数:你可以传递不同的输入,并期望你的函数能正确地执行。
哪些值是静态的,哪些是动态的? ¶
一个值的静态或动态取决于其类型:
对于张量:
张量数据被视为动态的。
系统可以将张量的形状视为静态或动态。
默认情况下,所有输入张量的形状都被视为静态。用户可以通过指定动态形状来覆盖这种行为,适用于任何输入张量。
作为模块状态一部分的张量,即参数和缓冲区,始终具有静态形状。
其他形式的张量元数据(例如
device
,dtype
)是静态的。
Python 基本类型(
int
,float
,bool
,str
,None
)是静态的。一些基本类型(
SymInt
、SymFloat
、SymBool
)存在动态变体。通常用户无需处理它们。
对于 Python 标准容器(
list
、tuple
、dict
、namedtuple
):结构(即,
list
和tuple
的长度,以及dict
和namedtuple
的键序列)是静态的。包含的元素会递归地应用这些规则(基本上是 PyTree 方案),叶子节点为 Tensor 或基本类型。
其他类(包括数据类)可以使用 PyTree 进行注册(见下文),并遵循标准容器的相同规则。
输入类型 ¶
输入将被视为静态或动态,这取决于它们的类型(如上所述)。
静态输入将被硬编码到图中,如果在运行时传递不同的值将导致错误。回想一下,这些主要是原始类型的数据。
动态输入的行为类似于“正常”函数输入。回想一下,这些大多是张量类型的值。
默认情况下,您可以在程序中使用的输入类型有:
张量
Python 基本类型(
int
,float
,bool
,str
,None
)Python 标准容器(
list
,tuple
,dict
,namedtuple
)
自定义输入类型 ¶
此外,您还可以定义自己的(自定义)类并将其用作输入类型,但您需要将此类注册为 PyTree。
以下是一个使用实用工具注册用作输入类型的数据类的示例。
@dataclass
class Input:
f: torch.Tensor
p: torch.Tensor
torch.export.register_dataclass(Input)
class M(torch.nn.Module):
def forward(self, x: Input):
return x.f + 1
torch.export.export(M(), (Input(f=torch.ones(10, 4), p=torch.zeros(10, 4)),))
可选输入类型 ¶
对于程序中可选输入,如果没有传入, torch.export.export()
将专门化为它们的默认值。因此,导出的程序将需要用户显式传入所有参数,并失去默认行为。例如:
class M(torch.nn.Module):
def forward(self, x, y=None):
if y is not None:
return y * x
return x + x
# Optional input is passed in
ep = torch.export.export(M(), (torch.randn(3, 3), torch.randn(3, 3)))
print(ep)
"""
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 3]", y: "f32[3, 3]"):
# File: /data/users/angelayi/pytorch/moo.py:15 in forward, code: return y * x
mul: "f32[3, 3]" = torch.ops.aten.mul.Tensor(y, x); y = x = None
return (mul,)
"""
# Optional input is not passed in
ep = torch.export.export(M(), (torch.randn(3, 3),))
print(ep)
"""
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 3]", y):
# File: /data/users/angelayi/pytorch/moo.py:16 in forward, code: return x + x
add: "f32[3, 3]" = torch.ops.aten.add.Tensor(x, x); x = None
return (add,)
"""
控制流:静态与动态 ¶
控制流由 torch.export.export()
支持。控制流的行为取决于你分支的值是静态的还是动态的。
静态控制流 ¶
Python 对静态值的控制流支持是透明的。(请记住,静态值包括静态形状,因此对静态形状的控制流也包含在本例中。)
如上所述,我们将静态值“烧录”进去,因此导出的图将永远不会看到任何对静态值的控制流。
在 if
语句的情况下,我们将继续追踪导出时的分支。在 for
或 while
语句的情况下,我们将通过展开循环继续追踪。
动态控制流:形状相关与数据相关 ¶
当控制流中涉及到的值是动态的,它可能依赖于动态形状或动态数据。由于编译器跟踪的是形状信息而不是数据信息,因此在这些情况下对编程模型的影响是不同的。
动态形状相关控制流
当控制流中涉及的值是动态形状时,在大多数情况下,我们也会在跟踪过程中知道动态形状的具体值:请参阅以下章节,以了解更多关于编译器如何跟踪这些信息的内容。
在这种情况下,我们称控制流为形状相关。我们使用动态形状的具体值来评估条件是否为 True
或 False
,并继续跟踪(如上所述),此外还会发出与刚刚评估的条件相对应的守卫。
否则,控制流程被认为是数据依赖的。我们无法评估条件为 True
或 False
,因此无法继续跟踪,必须在导出时引发错误。请参阅下一节。
动态数据依赖控制流程
支持对动态值的数据依赖控制流程,但您必须使用 PyTorch 的显式运算符来继续跟踪。不允许使用 Python 控制流程语句来处理动态值,因为编译器无法评估继续跟踪所需的条件,因此必须在导出时引发错误。
我们提供了运算符来表示对动态值的通用条件和循环,例如 torch.cond、torch.map。请注意,只有当您确实需要数据依赖控制流程时才需要使用这些运算符。
这里是一个关于数据依赖条件的 if
语句的例子, x.sum() > 0
,其中 x
是一个输入张量,使用 torch.cond 重写。现在不需要决定要跟踪哪个分支,现在两个分支都会被跟踪。
class M_old(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x.sin()
else:
return x.cos()
class M_new(torch.nn.Module):
def forward(self, x):
return torch.cond(
pred=x.sum() > 0,
true_fn=lambda x: x.sin(),
false_fn=lambda x: x.cos(),
operands=(x,),
)
数据依赖控制流的一个特殊情况是涉及数据依赖的动态形状:通常是一些中间张量的形状依赖于输入数据而不是输入形状(因此不是形状依赖的)。在这种情况下,你可以提供一个断言来决定条件是 True
还是 False
。给定这样的断言,我们可以继续跟踪,生成一个如上所述的守卫。
我们提供了操作符来表示对动态形状的断言,例如 torch._check。请注意,你只有在数据依赖的动态形状上有控制流时才需要使用这个操作符。
这里是一个关于涉及数据依赖动态形状的条件的 if
语句的例子, nz.shape[0] > 0
,其中 nz
是调用 torch.nonzero()
操作符的结果,该操作符的输出形状依赖于输入数据。你可以通过添加使用 torch._check 的断言来有效地决定要跟踪哪个分支,而不是重写它。
class M_old(torch.nn.Module):
def forward(self, x):
nz = x.nonzero()
if nz.shape[0] > 0:
return x.sin()
else:
return x.cos()
class M_new(torch.nn.Module):
def forward(self, x):
nz = x.nonzero()
torch._check(nz.shape[0] > 0)
if nz.shape[0] > 0:
return x.sin()
else:
return x.cos()
符号形状基础 ¶
在追踪过程中,动态张量的形状及其条件被编码为“符号表达式”。(相比之下,静态张量的形状及其条件只是简单的 int
和 bool
值。)
符号就像变量一样;它描述动态张量的形状。
随着追踪的进行,中间张量的形状可能由更一般的表达式描述,通常涉及整数算术运算符。这是因为对于大多数 PyTorch 运算符,输出张量的形状可以描述为输入张量形状的函数。例如, torch.cat()
的输出形状是其输入形状之和。
此外,当我们遇到程序中的控制流时,我们会创建布尔表达式,通常涉及关系运算符,描述沿追踪路径的条件。这些表达式被评估以决定通过程序追踪哪条路径,并记录在形状环境中以保护追踪路径的正确性,并评估随后创建的表达式。
接下来,我们简要介绍这些子系统。
PyTorch 算子伪造实现
回想一下,在追踪过程中,我们使用的是没有数据的伪造张量执行程序。通常情况下,我们不能使用伪造张量调用 PyTorch 算子的实际实现。因此,每个算子都需要一个额外的伪造(也称为“元”)实现,该实现输入和输出伪造张量,以匹配实际实现的行为,包括形状和其他形式的元数据。
例如,注意 torch.index_select()
的虚假实现是如何使用输入的形状来计算输出形状的(同时忽略输入数据并返回空输出数据)。
def meta_index_select(self, dim, index):
result_size = list(self.size())
if self.dim() > 0:
result_size[dim] = index.numel()
return self.new_empty(result_size)
形状传播:支持与不支持动态形状 ¶
形状是通过 PyTorch 运算符的虚假实现进行传播的。
理解动态形状传播的一个关键概念是支持和不支持动态形状之间的区别:我们知道前者的具体值,但不知道后者的具体值。
形状的传播,包括跟踪带背和未带背的动态形状,过程如下:
表示输入的张量形状可以是静态的或动态的。当动态时,它们由符号描述;此外,这些符号是带背的,因为我们还知道它们在导出时的具体值。
运算符的输出形状由其伪实现计算得出,可以是静态的或动态的。当动态时,通常由一个符号表达式描述。此外:
如果输出形状仅取决于输入形状,则当所有输入形状都是静态的或带背动态的时,它要么是静态的,要么是带背动态的。
另一方面,如果输出形状取决于输入数据,它必然是动态的,而且由于我们无法知道其具体值,它是不受支持的。
控制流程:守卫和断言
当遇到形状的条件时,它要么只涉及静态形状,在这种情况下,它是一个 bool
,要么涉及动态形状,在这种情况下,它是一个符号布尔表达式。对于后者:
当条件只涉及受支持的动态形状时,我们可以使用这些动态形状的具体值来评估条件到
True
或False
。然后我们可以在形状环境中添加一个守卫,声明相应的符号布尔表达式是True
或False
,并继续跟踪。否则,条件涉及未受支持的动态形状。通常,没有额外信息我们无法评估此类条件;因此,我们无法继续跟踪,必须在导出时引发错误。用户应使用显式的 PyTorch 操作符进行跟踪以继续。此信息作为形状环境中的保护措施添加,也可能有助于评估随后遇到的
True
或False
条件。
模型导出后,对后端动态形状的任何守卫都可以理解为对输入动态形状的条件。这些条件将与导出时必须提供的动态形状规范进行验证,该规范描述了动态形状的条件,不仅包括示例输入,还包括所有未来输入都必须满足的条件,以便生成的代码正确。更确切地说,动态形状规范必须逻辑上蕴含生成的守卫,否则在导出时将引发错误(以及针对动态形状规范的修复建议)。另一方面,当后端动态形状没有生成守卫时(特别是当所有形状都是静态的),不需要提供动态形状规范即可导出。一般来说,动态形状规范被转换为生成代码输入的运行时断言。
最后,对未后端动态形状的任何守卫都转换为“内联”运行时断言。这些断言被添加到生成代码中,位置通常是在创建这些未后端动态形状的地方:通常是在数据相关运算符调用之后。
允许的 PyTorch 运算符
所有 PyTorch 运算符都是允许的。
自定义运算符 ¶
此外,您还可以定义和使用自定义运算符。定义自定义运算符包括为其定义一个虚构实现,就像任何其他 PyTorch 运算符一样(参见上一节)。
以下是一个自定义 sin
运算符的示例,它包装了 NumPy,以及其注册的(平凡的)虚构实现。
@torch.library.custom_op("mylib::sin", mutates_args=())
def sin(x: Tensor) -> Tensor:
x_np = x.numpy()
y_np = np.sin(x_np)
return torch.from_numpy(y_np)
@torch.library.register_fake("mylib::sin")
def _(x: Tensor) -> Tensor:
return torch.empty_like(x)
有时您的自定义运算符的虚构实现将涉及数据相关的形状。以下是一个自定义 nonzero
的虚构实现的示例。
...
@torch.library.register_fake("mylib::custom_nonzero")
def _(x):
nnz = torch.library.get_ctx().new_dynamic_size()
shape = [nnz, x.dim()]
return x.new_empty(shape, dtype=torch.int64)
模块状态:读取与更新
模块状态包括参数、缓冲区和常规属性。
常规属性可以是任何类型。
另一方面,参数和缓冲区始终是张量。
模块状态可以是动态的或静态的,这取决于它们如上所述的类型。例如, self.training
是一个 bool
,这意味着它是静态的;另一方面,任何参数或缓冲区都是动态的。
模块状态中包含的任何张量的形状都不能是动态的,即这些形状在导出时是固定的,并且在导出的程序执行之间不能改变。
访问规则
所有模块状态都必须初始化。访问尚未初始化的模块状态会在导出时引发错误。
读取模块状态始终允许。
更新模块状态是可能的,但必须遵循以下规则:
静态常规属性(例如,原始类型)可以更新。读取和更新可以自由交织,并且如预期的那样,任何读取都将始终看到最新更新的值。因为这些属性是静态的,所以我们也会将其值烧录进去,因此生成的代码将不会有实际的“获取”或“设置”此类属性的指令。
动态常规属性(例如,张量类型)不能更新。要更新它,必须在模块初始化期间将其注册为缓冲区。
缓冲区可以更新,更新可以是就地的(例如,
self.buffer[:] = ...
)或不是(例如,self.buffer = ...
)。参数不能更新。通常参数仅在训练期间更新,不在推理期间更新。我们建议使用
torch.no_grad()
导出,以避免在导出时更新参数。
功能化效应
任何被读取和/或更新的动态模块状态将被“提升”为生成的代码的输入和/或输出。
导出的程序存储了生成的代码,以及参数和缓冲区的初始值,以及其他 Tensor 属性的常量值。