torch.export¶
警告
此功能处于积极开发中的原型,未来将会有破坏性更改。
概述 ¶
torch.export.export()
接受一个 torch.nn.Module
并以即时编译(AOT)的方式生成一个仅表示函数中张量计算的跟踪图,该图可以随后用不同的输出或序列化来执行。
import torch
from torch.export import export
class Mod(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
a = torch.sin(x)
b = torch.cos(y)
return a + b
example_args = (torch.randn(10, 10), torch.randn(10, 10))
exported_program: torch.export.ExportedProgram = export(
Mod(), args=example_args
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[10, 10]", y: "f32[10, 10]"):
# code: a = torch.sin(x)
sin: "f32[10, 10]" = torch.ops.aten.sin.default(x)
# code: b = torch.cos(y)
cos: "f32[10, 10]" = torch.ops.aten.cos.default(y)
# code: return a + b
add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos)
return (add,)
Graph signature:
ExportGraphSignature(
input_specs=[
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x'),
target=None,
persistent=None
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='y'),
target=None,
persistent=None
)
],
output_specs=[
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='add'),
target=None
)
]
)
Range constraints: {}
生成一个干净的中间表示(IR),具有以下不变性。有关 IR 的更多规范请见此处。
可靠性:保证是原始程序的可靠表示,并保持原始程序的调用约定。
规范化:图中没有 Python 语义。原始程序中的子模块被内联,形成一个完整的平坦计算图。
图属性:该图是纯函数性的,意味着它不包含具有副作用的操作,如突变或别名。它不会修改任何中间值、参数或缓冲区。
元数据:图中包含在跟踪过程中捕获的元数据,例如用户代码的堆栈跟踪。
在底层, torch.export
利用以下最新技术:
TorchDynamo(torch._dynamo)是一个内部 API,它使用 CPython 的一个功能——帧评估 API 来安全地跟踪 PyTorch 图。这提供了大幅改进的图捕获体验,在完全跟踪 PyTorch 代码时需要的重写更少。
AOT Autograd 提供了一个函数化的 PyTorch 图,并确保图被分解/降低到 ATen 算子集。
火炬 FX(torch.fx)是图的底层表示,允许基于 Python 的灵活转换。
现有框架
torch.compile()
也使用了与 torch.export
相同的 PT2 堆栈,但略有不同:
JIT 与 AOT:
torch.compile()
是一个 JIT 编译器,而它不打算用于在部署之外生成编译工件。部分图与完整图捕获:当
torch.compile()
遇到模型中的不可追踪部分时,会发生“图断裂”并回退到在急切 Python 运行时中运行程序。相比之下,torch.export
旨在获取 PyTorch 模型的完整图表示,因此当遇到不可追踪的部分时会报错。由于torch.export
生成的图与任何 Python 功能或运行时都完全分离,因此该图可以保存、加载并在不同的环境和语言中运行。可用性权衡:由于
torch.compile()
能够在遇到不可追踪的部分时回退到 Python 运行时,因此它具有很高的灵活性。而torch.export
将要求用户提供更多信息或重写代码以使其可追踪。
与 torch.fx.symbolic_trace()
相比, torch.export
使用 TorchDynamo 进行跟踪,该工具在 Python 字节码级别运行,使其能够跟踪任意 Python 构造,不受 Python 运算符重载支持的限制。此外, torch.export
能够精细跟踪张量元数据,因此对张量形状等条件的条件判断不会导致跟踪失败。一般来说, torch.export
预计将在更多用户程序上运行,并生成更底层的图(在 torch.ops.aten
运算符级别)。请注意,用户仍然可以在 torch.fx.symbolic_trace()
之前将其用作预处理步骤。
与 torch.jit.script()
相比, torch.export
不捕获 Python 控制流或数据结构,但它支持比 TorchScript 更多的 Python 语言特性(因为它更容易对 Python 字节码进行全面覆盖)。生成的图更简单,只有直线控制流(除了显式的控制流运算符)。
与 torch.jit.trace()
相比, torch.export
是可靠的:它能够跟踪执行整数计算于大小并记录所有必要条件以证明特定跟踪对其他输入有效的代码。
导出 PyTorch 模型
一个示例
主入口是通过 torch.export.export()
,它接受一个可调用对象( torch.nn.Module
,函数或方法)和样本输入,并将计算图捕获到 torch.export.ExportedProgram
中。例如:
import torch
from torch.export import export
# Simple module for demonstration
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=3, out_channels=16, kernel_size=3, padding=1
)
self.relu = torch.nn.ReLU()
self.maxpool = torch.nn.MaxPool2d(kernel_size=3)
def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor:
a = self.conv(x)
a.add_(constant)
return self.maxpool(self.relu(a))
example_args = (torch.randn(1, 3, 256, 256),)
example_kwargs = {"constant": torch.ones(1, 16, 256, 256)}
exported_program: torch.export.ExportedProgram = export(
M(), args=example_args, kwargs=example_kwargs
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 256, 256]", constant: "f32[1, 16, 256, 256]"):
# code: a = self.conv(x)
conv2d: "f32[1, 16, 256, 256]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1])
# code: a.add_(constant)
add_: "f32[1, 16, 256, 256]" = torch.ops.aten.add_.Tensor(conv2d, constant)
# code: return self.maxpool(self.relu(a))
relu: "f32[1, 16, 256, 256]" = torch.ops.aten.relu.default(add_)
max_pool2d: "f32[1, 16, 85, 85]" = torch.ops.aten.max_pool2d.default(relu, [3, 3], [3, 3])
return (max_pool2d,)
Graph signature:
ExportGraphSignature(
input_specs=[
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_weight'),
target='conv.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_bias'),
target='conv.bias',
persistent=None
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x'),
target=None,
persistent=None
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='constant'),
target=None,
persistent=None
)
],
output_specs=[
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='max_pool2d'),
target=None
)
]
)
Range constraints: {}
检查 ExportedProgram
,我们可以注意到以下内容:
torch.fx.Graph
包含原始程序的计算图,以及原始代码的记录,便于调试。图中只包含此处找到的
torch.ops.aten
运算符和自定义运算符,功能完全,没有任何torch.add_
内联运算符。参数(权重和偏置到卷积)被提升为图的输入,导致图中没有
get_attr
节点,这些节点之前存在于torch.fx.symbolic_trace()
的结果中。torch.export.ExportGraphSignature
模型输入和输出签名,并指定哪些输入是参数。记录图中每个节点产生的张量的结果形状和 dtype。例如,
convolution
节点将产生 dtype 为torch.float32
、形状为(1, 16, 256, 256)的张量。
非严格导出
在 PyTorch 2.3 版本中,我们引入了一种新的追踪模式,称为非严格模式。该模式仍在强化中,如果您遇到任何问题,请使用“oncall: export”标签在 GitHub 上提交。
在非严格模式下,我们使用 Python 解释器遍历程序。您的代码将像在急切模式中一样执行;唯一的区别是所有 Tensor 对象都将被替换为 ProxyTensors,这些对象将记录它们的所有操作到图中。
在严格模式中,目前是默认模式,我们首先使用 TorchDynamo 字节码分析引擎遍历程序。TorchDynamo 实际上并不执行您的 Python 代码。相反,它符号性地分析它并根据结果构建一个图。这种分析使 torch.export 能够提供更强大的安全性保证,但并非所有 Python 代码都受支持。
在某些情况下,人们可能会想要使用非严格模式,例如遇到可能难以解决的、不受支持的 TorchDynamo 功能,并且知道 Python 代码对于计算并非必需。例如:
import contextlib
import torch
class ContextManager():
def __init__(self):
self.count = 0
def __enter__(self):
self.count += 1
def __exit__(self, exc_type, exc_value, traceback):
self.count -= 1
class M(torch.nn.Module):
def forward(self, x):
with ContextManager():
return x.sin() + x.cos()
export(M(), (torch.ones(3, 3),), strict=False) # Non-strict traces successfully
export(M(), (torch.ones(3, 3),)) # Strict mode fails with torch._dynamo.exc.Unsupported: ContextManager
在这个例子中,使用非严格模式(通过 strict=False
标志)进行的第一次调用成功追踪,而使用严格模式(默认)进行的第二次调用则失败,因为 TorchDynamo 无法支持上下文管理器。一个选择是重写代码(参见 torch.export 的限制),但由于上下文管理器不影响模型中的张量计算,我们可以选择非严格模式的结果。
导出用于训练和推理
在 PyTorch 2.5 中,我们引入了一个名为 export_for_training()
的新 API。它仍在进行强化,所以如果您遇到任何问题,请将它们以“oncall: export”标签提交到 GitHub。
在此 API 中,我们生成最通用的 IR,其中包含所有 ATen 运算符(包括功能和非功能),可用于在急切 PyTorch Autograd 中进行训练。此 API 旨在用于急切训练用例,如 PT2 量化,并将很快成为 torch.export.export 的默认 IR。有关此更改背后的动机,请参阅 https://dev-discuss.pytorch.org/t/why-pytorch-does-not-need-a-new-standardized-operator-set/2206
当此 API 与 run_decompositions()
结合使用时,您应该能够获得具有任何所需分解行为的推理 IR。
以下是一些示例:
class ConvBatchnorm(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(1, 3, 1, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return (x,)
mod = ConvBatchnorm()
inp = torch.randn(1, 1, 3, 3)
ep_for_training = torch.export.export_for_training(mod, (inp,))
print(ep_for_training)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias)
add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1)
batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True)
return (batch_norm,)
从上述输出中,您可以看到 export_for_training()
产生的 ExportedProgram 与 export()
几乎相同,除了图中的运算符。您可以看到我们以最通用的形式捕获了 batch_norm。此运算符是非功能的,在推理时将被降低到不同的运算符。
您也可以通过 run_decompositions()
从这个索引推理(IR)转换到推理 IR,并进行任意自定义。
# Lower to core aten inference IR, but keep conv2d
decomp_table = torch.export.default_decompositions()
del decomp_table[torch.ops.aten.conv2d.default]
ep_for_inference = ep_for_training.run_decompositions(decomp_table)
print(ep_for_inference)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias)
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1)
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05)
getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]
return (getitem_3, getitem_4, add, getitem)
在这里,我们在分解其余部分的同时保留了 conv2d
操作符。现在,IR 是一个功能 IR,包含核心 aten 操作符,除了 conv2d
之外。
您还可以通过直接注册您选择的分解行为来进行更多自定义。
您还可以通过直接注册自定义分解行为来进行更多自定义。
# Lower to core aten inference IR, but customize conv2d
decomp_table = torch.export.default_decompositions()
def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1):
return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups)
decomp_table[torch.ops.aten.conv2d.default] = my_awesome_conv2d_function
ep_for_inference = ep_for_training.run_decompositions(decomp_table)
print(ep_for_inference)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1)
mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2)
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1)
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05)
getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4];
return (getitem_3, getitem_4, add, getitem)
表达动态性
默认情况下, torch.export
将根据所有输入形状都是静态的来跟踪程序,并将导出的程序专门化到这些维度。然而,某些维度,例如批处理维度,可能是动态的,并且可能随运行而变化。必须使用 torch.export.Dim()
API 来创建这些维度,并通过 torch.export.export()
通过 dynamic_shapes
参数将它们传递进去。例如:
import torch
from torch.export import Dim, export
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.branch1 = torch.nn.Sequential(
torch.nn.Linear(64, 32), torch.nn.ReLU()
)
self.branch2 = torch.nn.Sequential(
torch.nn.Linear(128, 64), torch.nn.ReLU()
)
self.buffer = torch.ones(32)
def forward(self, x1, x2):
out1 = self.branch1(x1)
out2 = self.branch2(x2)
return (out1 + self.buffer, out2)
example_args = (torch.randn(32, 64), torch.randn(32, 128))
# Create a dynamic batch size
batch = Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
exported_program: torch.export.ExportedProgram = export(
M(), args=example_args, dynamic_shapes=dynamic_shapes
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[s0, 64]", x2: "f32[s0, 128]"):
# code: out1 = self.branch1(x1)
linear: "f32[s0, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias)
relu: "f32[s0, 32]" = torch.ops.aten.relu.default(linear)
# code: out2 = self.branch2(x2)
linear_1: "f32[s0, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias)
relu_1: "f32[s0, 64]" = torch.ops.aten.relu.default(linear_1)
# code: return (out1 + self.buffer, out2)
add: "f32[s0, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer)
return (add, relu_1)
Range constraints: {s0: VR[0, int_oo]}
需要注意的一些其他事项:
通过
torch.export.Dim()
API 和dynamic_shapes
参数,我们指定了每个输入的第一个维度是动态的。查看输入x1
和x2
,它们的符号形状为 (s0, 64) 和 (s0, 128),而不是我们作为示例输入传递的形状为 (32, 64) 和 (32, 128) 的张量。s0
是一个符号,表示这个维度可以是一系列值。描述了图中每个符号出现的范围。在这种情况下,我们看到
s0
的范围是[0, int_oo]。由于这里难以解释的技术原因,它们被假定为非 0 或 1。这并不是一个错误,也不一定意味着导出的程序在维度 0 或 1 时不会工作。关于这个话题的深入讨论,请参阅“0/1 特殊化问题”。
我们还可以指定更丰富的输入形状之间的关系,例如一对形状可能相差一个,一个形状可能是另一个的两倍,或者一个形状是偶数。例如:
class M(torch.nn.Module):
def forward(self, x, y):
return x + y[1:]
x, y = torch.randn(5), torch.randn(6)
dimx = torch.export.Dim("dimx", min=3, max=6)
dimy = dimx + 1
exported_program = torch.export.export(
M(), (x, y), dynamic_shapes=({0: dimx}, {0: dimy}),
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s0]", y: "f32[s0 + 1]"):
# code: return x + y[1:]
slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(y, 0, 1, 9223372036854775807)
add: "f32[s0]" = torch.ops.aten.add.Tensor(x, slice_1)
return (add,)
Range constraints: {s0: VR[3, 6], s0 + 1: VR[4, 7]}
需要注意的一些事项:
通过指定
{0: dimx}
作为第一个输入,我们看到第一个输入的结果形状现在是动态的,变为[s0]
。现在通过指定{0: dimy}
作为第二个输入,我们看到第二个输入的结果形状也是动态的。然而,因为我们表达了dimy = dimx + 1
,而不是y
的形状包含一个新符号,所以我们看到它现在使用与x
、s0
相同的符号来表示。我们可以看到dimy = dimx + 1
的关系是通过s0 + 1
来展示的。观察范围约束,我们看到
s0
的范围是[3, 6],这是最初指定的,我们可以看到s0 + 1
的解的范围是[4, 7]。
序列化 ¶
为了保存 ExportedProgram
,用户可以使用 torch.export.save()
和 torch.export.load()
API。一个惯例是使用 ExportedProgram
的 .pt2
文件扩展名来保存。
例子:
import torch
import io
class MyModule(torch.nn.Module):
def forward(self, x):
return x + 10
exported_program = torch.export.export(MyModule(), torch.randn(5))
torch.export.save(exported_program, 'exported_program.pt2')
saved_exported_program = torch.export.load('exported_program.pt2')
专业化 §
理解 torch.export
行为的关键概念之一是静态值和动态值之间的区别。
动态值是指可能在不同运行之间发生变化的值。这些值的行为类似于 Python 函数的正常参数——你可以为参数传递不同的值,并期望函数能够正确执行。张量数据被视为动态值。
静态值是指在导出时固定不变的值,在导出的程序执行之间无法改变。当在跟踪过程中遇到该值时,导出器将其视为常量并将其硬编码到图中。
当执行操作(例如 x + y
)且所有输入都是静态值时,该操作的输出将直接硬编码到图中,并且该操作不会显示(即,它将被常数折叠)。
当一个值被硬编码到图中时,我们说该图已经针对该值进行了专业化。
以下值是静态的:
输入张量形状
默认情况下, torch.export
将追踪针对输入张量形状进行专业化的程序,除非通过 torch.export
的 dynamic_shapes
参数指定维度为动态。这意味着如果存在与形状相关的控制流, torch.export
将针对给定样本输入所采取的分支进行专业化。例如:
import torch
from torch.export import export
class Mod(torch.nn.Module):
def forward(self, x):
if x.shape[0] > 5:
return x + 1
else:
return x - 1
example_inputs = (torch.rand(10, 2),)
exported_program = export(Mod(), example_inputs)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[10, 2]"):
# code: return x + 1
add: "f32[10, 2]" = torch.ops.aten.add.Tensor(x, 1)
return (add,)
( x.shape[0] > 5
) 的条件在 ExportedProgram
中没有出现,因为示例输入具有静态形状 (10, 2)。由于 torch.export
专门针对输入的静态形状,else 分支 ( x - 1
) 将永远不会被执行。为了保留基于追踪图中张量形状的动态分支行为,需要使用 torch.export.Dim()
来指定输入张量的维度 ( x.shape[0]
) 为动态,并且需要重写源代码。
注意,模块状态中的张量(例如参数和缓冲区)始终具有静态形状。
Python 原始类型 ¶
torch.export
也专门针对 Python 原始类型,例如 int
, float
, bool
,和 str
。然而,它们也有动态变体,如 SymInt
, SymFloat
,和 SymBool
。
例如:
import torch
from torch.export import export
class Mod(torch.nn.Module):
def forward(self, x: torch.Tensor, const: int, times: int):
for i in range(times):
x = x + const
return x
example_inputs = (torch.rand(2, 2), 1, 3)
exported_program = export(Mod(), example_inputs)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[2, 2]", const, times):
# code: x = x + const
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(x, 1)
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 1)
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 1)
return (add_2,)
因为整数是专门的,所以 torch.ops.aten.add.Tensor
操作都使用硬编码的常量 1
进行计算,而不是 const
。如果用户在运行时传递了与导出时不同的值 const
,比如 2,而不是 1,这将导致错误。此外,在 for
循环中使用的 times
迭代器也通过 3 次重复的 torch.ops.aten.add.Tensor
调用“内联”在图中,而输入 times
从未使用。
Python 容器
Python 容器( List
, Dict
, NamedTuple
等)被认为具有静态结构。
torch.export 的限制
图断点
由于 torch.export
是一个从 PyTorch 程序捕获计算图的单次过程,它最终可能会遇到无法追踪的程序部分,因为支持所有 PyTorch 和 Python 特性的追踪几乎是不可能的。在 torch.compile
的情况下,不支持的运算将导致“图断点”,并且不支持的运算将以默认的 Python 评估方式运行。相比之下, torch.export
将需要用户提供额外的信息或重写代码的部分以使其可追踪。由于追踪基于 TorchDynamo,它在 Python 字节码级别进行评估,因此与之前的追踪框架相比,所需的重写将显著减少。
当遇到图断点时,ExportDB 是一个了解支持和不支持程序类型以及如何重写程序以使其可追踪的绝佳资源。
要绕过处理这种图断点的需求,可以使用非严格导出选项。
数据/形状相关控制流 ¶
在数据相关控制流( if
x.shape[0] > 2
)中也可能遇到图断裂,当形状没有被专门化时,追踪编译器无法处理,否则会生成大量路径的代码。在这种情况下,用户需要使用特殊的控制流运算符重写代码。目前,我们支持 torch.cond 来表示 if-else 类型的控制流(更多功能即将推出!)。
缺少操作符的伪/Fake/Meta/抽象内核 ¶
在追踪过程中,所有操作符都需要一个伪 Tensor 内核(也称为元内核、抽象实现)。这用于推理此操作符的输入/输出形状。
请参阅 torch.library.register_fake()
获取更多详情。
如果不幸地,您的模型使用了尚未实现 FakeTensor 内核的 ATen 运算符,请提交一个问题。
阅读更多 ¶
导出用户附加链接
PyTorch 开发者深度学习
API 参考¶
- torch.export.export(mod, args, kwargs=None, *, dynamic_shapes=None, strict=True, preserve_module_call_signature=())[source][source]¶
export()
接受任何 nn.Module 及其示例输入,并以 AOT(提前编译)方式生成仅表示该函数 Tensor 计算的追踪图,该图随后可以用不同的输入或序列化来执行。追踪图(1)生成功能 ATen 操作符集中的标准化操作符(以及任何用户指定的自定义操作符),(2)消除了所有 Python 控制流和数据结构(存在某些例外),(3)记录了所需的形状约束集,以证明这种标准化和控制流消除对未来的输入是合理的。稳定性保证
在追踪过程中,
export()
会记录用户程序和底层 PyTorch 运算符内核所做的形状相关假设。只有当这些假设成立时,输出ExportedProgram
才被认为是有效的。追踪对输入张量的形状(而非值)做出假设。这些假设必须在图捕获时进行验证,以确保
export()
成功。具体来说:对输入张量静态形状的假设会自动验证,无需额外努力。
输入张量动态形状的假设需要通过使用
Dim()
API 显式指定来构造动态维度,并通过dynamic_shapes
参数将它们与示例输入关联。
如果任何假设无法验证,将引发致命错误。当发生这种情况时,错误消息将包括需要验证假设的指定建议修复。例如,
export()
可能会为以下修复动态维度定义,例如出现在与输入x
关联的形状中的dim0_x
,之前定义为Dim("dim0_x")
:dim = Dim("dim0_x", max=5)
这个例子意味着生成的代码要求输入
x
的第 0 维小于或等于 5 才能有效。您可以检查动态维度定义的建议修复,然后直接将它们复制到您的代码中,而无需更改export()
调用中的dynamic_shapes
参数。- 参数:
模块(Module)- 我们将跟踪此模块的前向方法。
args (tuple[Any, ...]) – 示例位置参数输入。
kwargs (Optional[dict[str, Any]]) – 可选示例关键字输入。
dynamic_shapes (Optional[Union[dict[str, Any], tuple[Any], list[Any]]]) –
一个可选参数,其类型应为:1)一个字典,从参数名称
f
到它们的动态形状规格,2)一个元组,指定每个输入的动态形状规格,顺序与原始顺序相同。如果您在关键字参数上指定动态性,则需要按照原始函数签名中定义的顺序传递它们。张量参数的动态形状可以指定为以下两种形式之一:(1)一个从动态维度索引到
Dim()
类型的字典,其中不需要在此字典中包含静态维度索引,但如果包含,则应映射到 None;或者(2)一个Dim()
类型或 None 的元组/列表,其中Dim()
类型对应于动态维度,静态维度用 None 表示。通过使用映射或包含的指定序列,递归地指定字典或张量元组/列表类型的参数。严格(布尔值)- 当启用(默认)时,导出功能将通过 TorchDynamo 跟踪程序,以确保生成的图的正确性。否则,导出的程序将不会验证图内嵌入的隐含假设,可能会导致原始模型和导出模型之间的行为差异。这在用户需要绕过跟踪器中的错误或只想逐步在模型中启用安全性时很有用。请注意,这不会影响生成的 IR 规范不同,并且模型将以相同的方式序列化,无论传递的值是什么。警告:此选项是实验性的,请自行承担使用风险。
- 返回:
包含已跟踪的可调用对象的
ExportedProgram
。- 返回类型:
可接受的输入/输出类型
可接受的输入类型(对于
args
和kwargs
)以及输出包括:基本类型,即
torch.Tensor
、int
、float
、bool
和str
。数据类,但它们必须先通过调用
register_dataclass()
进行注册。由
dict
、list
、tuple
、namedtuple
和OrderedDict
组成的数据结构(嵌套),包含所有上述类型。
- torch.export.save(ep, f, *, extra_files=None, opset_version=None, pickle_protocol=2)[source][source]¶
警告
正在积极开发中,保存的文件可能在 PyTorch 的新版本中不可用。
将
ExportedProgram
保存到文件对象中。然后可以使用 Python APItorch.export.load
进行加载。- 参数:
ep (导出程序) – 要保存的导出程序。
f (str | os.PathLike[str] | IO[bytes]) – 实现写入和刷新) 或包含文件名的字符串。
extra_files (Optional[Dict[str, Any]]) – 从文件名到内容的映射,这些内容将作为 f 的一部分存储。
opset_version (Optional[Dict[str, int]]) – opset 名称到该 opset 版本的映射
pickle_protocol (int) – 可以指定以覆盖默认协议
示例:
import torch import io class MyModule(torch.nn.Module): def forward(self, x): return x + 10 ep = torch.export.export(MyModule(), (torch.randn(5),)) # Save to file torch.export.save(ep, 'exported_program.pt2') # Save to io.BytesIO buffer buffer = io.BytesIO() torch.export.save(ep, buffer) # Save with extra files extra_files = {'foo.txt': b'bar'.decode('utf-8')} torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files)
- torch.export.load(f, *, extra_files=None, expected_opset_version=None)[source][source]¶
警告
正在积极开发中,保存的文件可能在 PyTorch 的新版本中不可用。
加载之前使用
torch.export.save
保存的ExportedProgram
。- 参数:
f (str | os.PathLike[str] | IO[bytes]) – 一个文件对象(必须实现 write 和 flush)或包含文件名的字符串。
extra_files (Optional[Dict[str, Any]]) – 在此映射中给出的额外文件名将被加载,其内容将存储在提供的映射中。
expected_opset_version (Optional[Dict[str, int]]) – 操作集名称到预期操作集版本的映射
- 返回:
一个
ExportedProgram
对象- 返回类型:
示例:
import torch import io # Load ExportedProgram from file ep = torch.export.load('exported_program.pt2') # Load ExportedProgram from io.BytesIO object with open('exported_program.pt2', 'rb') as f: buffer = io.BytesIO(f.read()) buffer.seek(0) ep = torch.export.load(buffer) # Load with extra files. extra_files = {'foo.txt': ''} # values will be replaced with data ep = torch.export.load('exported_program.pt2', extra_files=extra_files) print(extra_files['foo.txt']) print(ep(torch.randn(5)))
- torch.export.register_dataclass(cls, *, serialized_type_name=None)[source][source]¶
注册数据类为有效的输入/输出类型
torch.export.export()
。- 参数:
cls (类型:Any) – 要注册的数据类类型
serialized_type_name (可选[str]) – 数据类的序列化名称。这是
这(如果想要序列化包含的 pytree TreeSpec)是必需的——
数据类。
示例:
import torch from dataclasses import dataclass @dataclass class InputDataClass: feature: torch.Tensor bias: int @dataclass class OutputDataClass: res: torch.Tensor torch.export.register_dataclass(InputDataClass) torch.export.register_dataclass(OutputDataClass) class Mod(torch.nn.Module): def forward(self, x: InputDataClass) -> OutputDataClass: res = x.feature + x.bias return OutputDataClass(res=res) ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1), )) print(ep)
- torch.export.dynamic_shapes.Dim(name, *, min=None, max=None)[source][source]¶
Dim()
构造一个类似于命名符号整数的类型,具有范围。它可以用来描述动态张量维度的多个可能值。请注意,同一张量的不同动态维度或不同张量的不同动态维度可以由相同的类型描述。- 参数:
调试时的人类可读名称。
min(可选[int])- 给定符号可能的最小值(包含)。
max(可选[int])- 给定符号可能的最大值(包含)。
- 返回:
可用于张量动态形状规范的类型。
- torch.export.exported_program.default_decompositions()[source][source]¶
这是默认分解表,其中包含所有 ATEN 算子的分解到核心 aten 算子集。请与
run_decompositions()
API 一起使用。- 返回类型:
- torch.export.dims(*names, min=None, max=None)[source][source]¶
创建多种
Dim()
类型的实用工具。- 返回:
Dim()
类型的元组。- 返回类型:
tuple[torch.export.dynamic_shapes._Dim, …]
- class torch.export.dynamic_shapes.ShapesCollection[source][source]
动态形状构建器。用于将动态形状规范分配给出现在输入中的张量。
尤其当
args()
是一个嵌套输入结构时,这很有用,并且比在dynamic_shapes()
规范中复制args()
的结构更容易索引输入张量。示例:
args = ({"x": tensor_x, "others": [tensor_y, tensor_z]}) dim = torch.export.Dim(...) dynamic_shapes = torch.export.ShapesCollection() dynamic_shapes[tensor_x] = (dim, dim + 1, 8) dynamic_shapes[tensor_y] = {0: dim * 2} # This is equivalent to the following (now auto-generated): # dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [{0: dim * 2}, None]} torch.export(..., args, dynamic_shapes=dynamic_shapes)
- torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes(msg, dynamic_shapes)[source][source]¶
当使用
dynamic_shapes()
导出时,如果规格与从模型跟踪中推断出的约束不匹配,则可能会因 ConstraintViolation 错误而导出失败。错误消息可能会提供建议的修复方案 - 可以对dynamic_shapes()
进行更改以成功导出。示例 ConstraintViolation 错误消息:
Suggested fixes: dim = Dim('dim', min=3, max=6) # this just refines the dim's range dim = 4 # this specializes to a constant dy = dx + 1 # dy was specified as an independent dim, but is actually tied to dx with this relation
这是一个辅助函数,它接受 ConstraintViolation 错误消息和原始
dynamic_shapes()
规范,并返回一个新的dynamic_shapes()
规范,该规范包含了建议的修复方案。演示用法:
try: ep = export(mod, args, dynamic_shapes=dynamic_shapes) except torch._dynamo.exc.UserError as exc: new_shapes = refine_dynamic_shapes_from_suggested_fixes( exc.msg, dynamic_shapes ) ep = export(mod, args, dynamic_shapes=new_shapes)
- 返回类型:
Union[dict[str, Any], tuple[Any], list[Any]]
- torch.export.Constraint¶
Union
的别名_Constraint
[_DerivedConstraint
,_RelaxedConstraint
]
- class torch.export.ExportedProgram(root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs=None, constants=None, *, verifiers=None)[source][source]¶
包含由
export()
导出的程序。它包含一个torch.fx.Graph
代表张量计算,一个包含所有提升参数和缓冲区张量值的 state_dict,以及各种元数据。可以像由
export()
追踪的原始可调用对象一样调用 ExportedProgram,具有相同的调用约定。要对图进行转换,请使用
.module
属性访问torch.fx.GraphModule
。然后可以使用 FX 转换来重写图。之后,只需再次使用export()
即可构建正确的 ExportedProgram。- 模块()[源][源] ¶
返回一个包含所有参数/缓冲区的自包含 GraphModule。
- 返回类型:
- 缓冲区()[源][源] ¶
返回原始模块缓冲区的迭代器。
警告
此 API 为实验性,且不向后兼容。
- 返回类型:
迭代器[Tensor]
- named_buffers()[源][源] ¶
返回一个遍历原始模块缓冲区的迭代器,同时返回缓冲区的名称及其本身。
警告
此 API 为实验性,且不向后兼容。
- 返回类型:
Iterator[tuple[str, torch.Tensor]]
- named_parameters()[源][源] ¶
返回一个遍历原始模块参数的迭代器,同时返回参数的名称及其本身。
警告
此 API 为实验性,且不向后兼容。
- 返回类型:
Iterator[tuple[str, torch.nn.parameter.Parameter]]
- run_decompositions(decomp_table=None, decompose_custom_triton_ops=False)[source][source]¶
在导出的程序上运行一系列分解,并返回一个新的导出程序。默认情况下,我们将运行 Core ATen 分解以获取 Core ATen 操作集中的算子。
目前,我们不分解关节图。
- 参数:
decomp_table (Optional[dict[torch._ops.OperatorBase, Callable]]) – 可选参数,用于指定 Aten 操作的分解行为(1)如果为 None,则分解为 core aten 分解(2)如果为空,则不分解任何操作
- 返回类型:
一些示例:
如果您不想分解任何内容
ep = torch.export.export(model, ...) ep = ep.run_decompositions(decomp_table={})
如果您想获取除某些操作符之外的核心 aten 操作符集,可以执行以下操作:
ep = torch.export.export(model, ...) decomp_table = torch.export.default_decompositions() decomp_table[your_op] = your_custom_decomp ep = ep.run_decompositions(decomp_table=decomp_table)
- 类 torch.export.ExportBackwardSignature(gradients_to_parametersdict[strstr], gradients_to_user_inputsdict[strstr], loss_outputstr)[source][source] ¶
- 类 torch.export.ExportGraphSignature(input_specs, output_specs)[source][source] ¶
模型导出图的输入/输出签名,导出图是一个具有更强不变性保证的 fx.Graph。
导出图是功能性的,不通过
getattr
节点访问图内的“状态”,如参数或缓冲区。相反,export()
保证参数、缓冲区和常数张量被提升为图的输入。同样,对缓冲区的任何修改也不会包含在图中,而是将修改后的缓冲区值作为导出图额外的输出。所有输入和输出的顺序如下:
Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] Outputs = [*mutated_inputs, *flattened_user_outputs]
例如,如果以下模块被导出:
class CustomModule(nn.Module): def __init__(self) -> None: super(CustomModule, self).__init__() # Define a parameter self.my_parameter = nn.Parameter(torch.tensor(2.0)) # Define two buffers self.register_buffer('my_buffer1', torch.tensor(3.0)) self.register_buffer('my_buffer2', torch.tensor(4.0)) def forward(self, x1, x2): # Use the parameter, buffers, and both inputs in the forward method output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 # Mutate one of the buffers (e.g., increment it by 1) self.my_buffer2.add_(1.0) # In-place addition return output
结果图将是:
graph(): %arg0_1 := placeholder[target=arg0_1] %arg1_1 := placeholder[target=arg1_1] %arg2_1 := placeholder[target=arg2_1] %arg3_1 := placeholder[target=arg3_1] %arg4_1 := placeholder[target=arg4_1] %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) return (add_tensor_2, add_tensor_1)
结果的 ExportGraphSignature 将是:
ExportGraphSignature( input_specs=[ InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None) ], output_specs=[ OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None) ] )
- 类 torch.export.ModuleCallSignature(inputslist[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], outputslist[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], in_spectorch.utils._pytree.TreeSpec, out_spectorch.utils._pytree.TreeSpec, forward_arg_namesOptional[list[str]]=None)[source][source] ¶
- 类 torch.export.ModuleCallEntry(fqnstr, signatureOptional[torch.export.exported_program.ModuleCallSignature]=None)[source][source] ¶
- class torch.export.decomp_utils.CustomDecompTable[source][source]¶
这是一个专门用于处理导出中的 decomp_table 的自定义字典。我们需要这个字典的原因是,在新世界中,您只能从 decomp 表删除操作以保留它。这对于自定义操作来说是有问题的,因为我们不知道自定义操作何时会被实际加载到调度器中。因此,我们需要记录自定义操作的运算,直到我们真正需要将其实例化(即在运行分解传递时。)
- 我们持有的不变量是:
所有 aten 分解在初始化时加载
当用户从表中读取时,我们实例化所有操作,以提高调度器选择自定义操作的可能性。
如果是写操作,我们不一定实例化。
我们在导出时加载最终时间,在调用 run_decompositions()之前。
- copy()[source][source]
- 返回类型:
- items()[来源][来源] ¶
- keys()[来源][来源] ¶
- class torch.export.graph_signature.InputKind(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source][source]¶
- class torch.export.graph_signature.InputSpec(kind=torch.export.graph_signature.InputKind, argUnion[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], targetOptional[str], persistentOptional[bool]=None)[source][source] ¶
- class torch.export.graph_signature.OutputKind(value, names=<未提供>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source][source] ¶
- class torch.export.graph_signature.OutputSpec(kind: torch.export.graph_signature.OutputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], target: Optional[str])[source][source]¶
- class torch.export.graph_signature.ExportGraphSignature(input_specs, output_specs)[source][source]¶
ExportGraphSignature
模型化 Export Graph 的输入/输出签名,Export Graph 是一个具有更强不变性保证的 fx.Graph。Export Graph 是功能性的,不通过
getattr
节点访问“状态”如参数或图内的缓冲区。相反,export()
保证参数、缓冲区和常数张量被提升为输入。同样,对缓冲区的任何修改也不会包含在图中,而是将修改后的缓冲区值作为 Export Graph 的额外输出进行建模。所有输入和输出的顺序为:
Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] Outputs = [*mutated_inputs, *flattened_user_outputs]
例如,如果以下模块被导出:
class CustomModule(nn.Module): def __init__(self) -> None: super(CustomModule, self).__init__() # Define a parameter self.my_parameter = nn.Parameter(torch.tensor(2.0)) # Define two buffers self.register_buffer('my_buffer1', torch.tensor(3.0)) self.register_buffer('my_buffer2', torch.tensor(4.0)) def forward(self, x1, x2): # Use the parameter, buffers, and both inputs in the forward method output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 # Mutate one of the buffers (e.g., increment it by 1) self.my_buffer2.add_(1.0) # In-place addition return output
结果图将是:
graph(): %arg0_1 := placeholder[target=arg0_1] %arg1_1 := placeholder[target=arg1_1] %arg2_1 := placeholder[target=arg2_1] %arg3_1 := placeholder[target=arg3_1] %arg4_1 := placeholder[target=arg4_1] %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) return (add_tensor_2, add_tensor_1)
结果的 ExportGraphSignature 将是:
ExportGraphSignature( input_specs=[ InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None) ], output_specs=[ OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None) ] )
- 替换所有对旧名称的使用为新的名称在签名中[source][source] ¶
将签名中的旧名称替换为新名称。
- class torch.export.graph_signature.CustomObjArgument(name: str, class_fqn: str, fake_val: Optional[torch._library.fake_class_registry.FakeScriptObject] = None)[source][source]¶
- class torch.export.unflatten.FlatArgsAdapter[source][source]¶
将输入参数适配为
input_spec
以与target_spec
对齐。
- class torch.export.unflatten.InterpreterModule(graph, ty=None)[source][source]¶
使用 torch.fx.Interpreter 执行,而不是 GraphModule 通常使用的代码生成。这提供了更好的堆栈跟踪信息,并使得调试执行更容易。
- class torch.export.unflatten.InterpreterModuleDispatcher(attrs, call_modules)[source][source]¶
携带对应模块调用序列的 InterpreterModules 序列的模块。每次调用该模块都会调度到下一个 InterpreterModule,并在最后一个之后循环回。
- torch.export.unflatten.unflatten(module, flat_args_adapter=None)[source][source]
将导出的 ExportedProgram 展平,生成与原始 eager 模块具有相同模块层次结构的模块。如果您尝试使用
torch.export
与另一个期望模块层次结构而不是torch.export
通常产生的平坦图的其他系统,这可能很有用。注意
展平的模块的 args/kwargs 不一定与 eager 模块匹配,因此进行模块交换(例如self.submod = new_mod
)不一定有效。如果您需要替换模块,需要设置torch.export.export()
的preserve_module_call_signature
参数。- 参数:
模块(ExportedProgram)- 用于反扁平化的 ExportedProgram。
flat_args_adapter(可选[FlatArgsAdapter])- 如果输入的 TreeSpec 与导出模块不匹配,则适配扁平参数。
- 返回:
与原始急切模块具有相同模块层次结构的
UnflattenedModule
实例。- 返回类型:
反扁平化模块