快捷键

torch.export IR 规范¶

Export IR 是一种中间表示(IR),类似于 MLIR 和 TorchScript,专门用于表达 PyTorch 程序的语义。Export IR 主要表示为一系列精简的操作列表,对动态性(如控制流)的支持有限。

要创建 Export IR 图,可以使用前端,通过跟踪专门化机制可靠地捕获 PyTorch 程序。然后,后端可以对生成的 Export IR 进行优化和执行。这可以通过 torch.export.export() 实现。

本文档将涵盖的关键概念包括:

  • ExportedProgram:包含导出 IR 程序的数结构

  • 图:由节点列表组成。

  • 节点:表示存储在此节点上的操作、控制流和元数据。

  • 值由节点产生和消费。

  • 类型与值和节点相关联。

  • 值的大小和内存布局也进行了定义。

假设 §

本文假设读者对 PyTorch 有足够的了解,特别是对 torch.fx 及其相关工具。因此,它将停止描述 torch.fx 文档和论文中存在的内容。

什么是导出 IR?

导出 IR 是 PyTorch 程序的基于图的中间表示 IR。导出 IR 建立在 torch.fx.Graph 之上。换句话说,所有导出 IR 图也都是有效的 FX 图,如果使用标准的 FX 语义进行解释,导出 IR 可以正确地被解释。一个含义是,导出的图可以通过标准的 FX 代码生成转换为有效的 Python 程序。

本文档将主要关注突出导出 IR 与 FX 在严格性方面的差异,而跳过与 FX 相似的部分。

导出程序

顶级导出 IR 构造是一个 torch.export.ExportedProgram 类。它将 PyTorch 模型的计算图(通常是 torch.nn.Module )及其消耗的参数或权重捆绑在一起。

torch.export.ExportedProgram 类的一些显著属性包括:

  • graph_module ( torch.fx.GraphModule ):包含 PyTorch 模型展开的计算图的数据结构。可以通过 ExportedProgram.graph 直接访问该图。

  • 图签名,用于指定图内使用的参数和缓冲区名称,以及被修改的参数和缓冲区。与将参数和缓冲区作为图的属性存储不同,它们被提升为图的输入。图签名用于跟踪这些参数和缓冲区的附加信息。

  • 包含参数和缓冲区的数据结构。

  • 对于具有数据依赖行为的导出程序,每个节点的元数据将包含符号形状(类似于 s0i0 )。此属性将符号形状映射到它们的下限/上限范围。

导出 IR 图是表示为 DAG(有向无环图)的 PyTorch 程序。图中每个节点代表特定的计算或操作,图的边由节点之间的引用组成。

我们可以查看具有此模式的图:

class Graph:
  nodes: List[Node]

在实践中,导出 IR 的图被实现为 torch.fx.Graph Python 类。

导出 IR 图包含以下节点(节点将在下一节中详细介绍):

  • 0 个或多个操作类型为 placeholder 的节点

  • 0 个或多个操作类型为 call_function 的节点

  • 精确 1 个操作类型为 output 的节点

推论:最小的有效图将只有一个节点。即节点列表永远不会为空。

定义:图模块的图的 placeholder 节点集合表示图的输入。图的输出节点表示图模块的输出。

示例:

import torch
from torch import nn

class MyModule(nn.Module):

    def forward(self, x, y):
      return x + y

example_args = (torch.randn(1), torch.randn(1))
mod = torch.export.export(MyModule(), example_args)
print(mod.graph)
graph():
  %x : [num_users=1] = placeholder[target=x]
  %y : [num_users=1] = placeholder[target=y]
  %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %y), kwargs = {})
  return (add,)

上文是图的文本表示,每行代表一个节点。

节点

节点表示特定的计算或操作,在 Python 中使用 torch.fx.Node 类表示。节点之间的边通过节点类的 args 属性表示对其他节点的直接引用。使用相同的 FX 机制,我们可以表示计算图通常需要的以下操作,如算子调用、占位符(即输入)、条件语句和循环。

节点具有以下模式:

class Node:
  name: str # name of node
  op_name: str  # type of operation

  # interpretation of the fields below depends on op_name
  target: [str|Callable]
  args: List[object]
  kwargs: Dict[str, object]
  meta: Dict[str, object]

FX 文本格式

如上例所示,注意每一行都遵循以下格式:

%<name>:[...] = <op_name>[target=<target>](args = (%arg1, %arg2, arg3, arg4, …)), kwargs = {"keyword": arg5})

此格式以紧凑格式捕获 Node 类中除 meta 之外的所有内容。

具体来说:

  • 是节点在 node.name 中的显示名称。

  • node.op 字段,它必须是以下之一:、、 或

  • 是节点的目标,作为 node.target 。此字段的含义取决于 op_name

  • args1, … args 4… 是在 node.args 元组中列出的。如果列表中的值是 torch.fx.Node ,则将特别用前缀 % 标识。

例如,对加法运算符的调用将显示为:

%add1 = call_function[target = torch.op.aten.add.Tensor](args = (%x, %y), kwargs = {})

其中 %x%y 是具有名称 x 和 y 的其他两个节点。值得注意的是,字符串 torch.op.aten.add.Tensor 代表实际存储在目标字段中的可调用对象,而不仅仅是它的字符串名称。

此文本格式的最后一行是:

return [add]

这是一个带有 op_name = output 的节点,表示我们正在返回这个一个元素。

调用函数 ¶

一个 call_function 节点代表对一个操作符的调用。

定义

  • 功能性:如果一个可调用对象满足以下所有要求,我们称其为“功能性”:

    • 非修改性:操作符不会修改其输入值(对于张量,这包括元数据和数据)。

    • 无副作用:操作符不会修改外部可见的状态,例如改变模块参数的值。

  • 操作符:是一个具有预定义模式的函数式可调用对象。此类操作符的例子包括函数式 ATen 操作符。

FX 中的表示

%name = call_function[target = operator](args = (%x, %y, …), kwargs = {})

与 vanilla FX call_function 的区别

  1. 在 FX 图中,call_function 可以指代任何可调用对象,在 Export IR 中,我们将其限制为仅 ATen 运算符、自定义运算符和控制流运算符的选定子集。

  2. 在 Export IR 中,常数参数将被嵌入到图中。

  3. 在 FX 图中,get_attr 节点可以表示读取图模块中存储的任何属性。然而,在 Export IR 中,这仅限于读取子模块,因为所有参数/缓冲区都将作为输入传递给图模块。

元数据 ¶

Node.meta 是每个 FX 节点附加的字典。然而,FX 规范没有指定可以或将会包含哪些元数据。Export IR 提供了一个更强的合约,具体来说,所有 call_function 节点都将保证具有并且仅具有以下元数据字段:

  • node.meta["stack_trace"] 是一个字符串,包含引用原始 Python 源代码的 Python 堆栈跟踪。一个堆栈跟踪的例子如下:

    File "my_module.py", line 19, in forward
    return x + dummy_helper(y)
    File "helper_utility.py", line 89, in dummy_helper
    return y + 1
    
  • node.meta["val"] 描述运行操作的结果。它可以是 、、 List[Union[FakeTensor, SymInt]]None 之一。

  • node.meta["nn_module_stack"] 描述从其中节点来的 “堆栈跟踪” ,如果它是从 torch.nn.Module 调用来的。例如,如果包含 addmm op 的节点从一个 torch.nn.Linear 模块中的 torch.nn.Sequential 模块调用,那么 nn_module_stack 将看起来像:

    {'self_linear': ('self.linear', <class 'torch.nn.Linear'>), 'self_sequential': ('self.sequential', <class 'torch.nn.Sequential'>)}
    
  • node.meta["source_fn_stack"] 包含在分解之前调用此节点的 torch 函数或叶 torch.nn.Module 类。例如,从一个 addmm 模块调用中包含 torch.nn.Linear op 的节点将包含 torch.nn.Linear 在它们的 source_fn 中,而包含从 addmm 模块调用中 torch.nn.functional.Linear op 的节点将包含 torch.nn.functional.Linear 在它们的 source_fn 中。

占位符 ¶

占位符表示图的一个输入。其语义与 FX 中完全相同。占位符节点必须是图节点列表中的前 N 个节点。N 可以为零。

在 FX 中的表示

%name = placeholder[target = name](args = ())

目标字段是一个字符串,表示输入的名称。

args ,如果非空,则应为大小为 1 的值,表示此输入的默认值。

元数据

占位符节点也具有 meta[‘val’] ,就像 call_function 节点一样。在这种情况下, val 字段表示图期望为此输入参数接收的输入形状/数据类型。

输出

输出调用表示函数中的返回语句;因此,它终止当前图。只有一个输出节点,并且它总是图中的最后一个节点。

在 FX 中的表示

output[](args = (%something, …))

这里的语义与 torch.fx 相同。 args 表示要返回的节点。

元数据

输出节点具有与 call_function 节点相同的元数据。

get_attr

get_attr 节点表示从封装的 torch.fx.GraphModule 中读取子模块。与从 torch.fx.symbolic_trace() 中使用的 get_attr 节点读取顶级 torch.fx.GraphModule 中的参数和缓冲区等属性的不同 FX 图相比,参数和缓冲区作为输入传递给图模块,并存储在顶级 torch.export.ExportedProgram 中。

外部效果表示

%name = get_attr[target = name](args = ())

示例

考虑以下模型:

from functorch.experimental.control_flow import cond

def true_fn(x):
    return x.sin()

def false_fn(x):
    return x.cos()

def f(x, y):
    return cond(y, true_fn, false_fn, [x])

图:

graph():
    %x_1 : [num_users=1] = placeholder[target=x_1]
    %y_1 : [num_users=1] = placeholder[target=y_1]
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%y_1, %true_graph_0, %false_graph_0, [%x_1]), kwargs = {})
    return conditional

%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0] 读取子模块 true_graph_0 ,其中包含 sin 运算符。

参考文献

SymInt

SymInt 是一个对象,它可以是一个字面整数值,也可以是一个表示整数的符号(在 Python 中由 sympy.Symbol class 表示)。当 SymInt 是符号时,它描述了一个在编译时对图来说是未知的整数类型的变量,也就是说,它的值仅在运行时才知道。

FakeTensor

FakeTensor 是一个包含张量元数据的对象。它可以被视为具有以下元数据。

class FakeTensor:
  size: List[SymInt]
  dtype: torch.dtype
  device: torch.device
  dim_order: List[int]  # This doesn't exist yet

FakeTensor 的大小字段是一个整数或 SymInts 的列表。如果存在 SymInts,这意味着这个张量具有动态形状。如果存在整数,则假定该张量将具有该确切的静态形状。TensorMeta 的秩永远不会是动态的。dtype 字段表示该节点的输出 dtype。在 Edge IR 中没有隐式类型提升。FakeTensor 中没有步长。

换句话说:

  • 如果节点.target 中的运算符返回一个 Tensor,那么 node.meta['val'] 是描述该张量的 FakeTensor。

  • 如果节点.target 中的运算符返回一个 Tensor 的 n 元组,那么 node.meta['val'] 是描述每个张量的 n 元组 FakeTensor。

  • 如果节点.target 中的操作符返回的是在编译时已知的 int/float/scalar,则 node.meta['val'] 为 None。

  • 如果节点.target 中的操作符返回的是在编译时未知的 int/float/scalar,则 node.meta['val'] 为 SymInt 类型。

例如:

  • aten::add 返回一个 Tensor;因此其 spec 将是一个 FakeTensor,其 dtype 和 size 将与此操作符返回的 tensor 相同。

  • aten::sym_size 返回一个整数;因此其 val 将是一个 SymInt,因为其值仅在运行时可用。

  • 返回一个(Tensor,Tensor)的元组;因此规范也将是一个包含 FakeTensor 对象的 2 元组,第一个 TensorMeta 描述返回值的第一个元素等。

Python 代码:

def add_one(x):
  return torch.ops.aten(x, 1)

图:

graph():
  %ph_0 : [#users=1] = placeholder[target=ph_0]
  %add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, 1), kwargs = {})
  return [add_tensor]

FakeTensor:

FakeTensor(dtype=torch.int, size=[2,], device=CPU)

可浸入的树形类型 ¶

我们定义一个类型“Pytree-able”,如果它是一个叶子类型,或者是一个包含其他 Pytree-able 类型的容器类型。

注意:

pytree 的概念与 JAX 中记录的概念相同:

以下类型被定义为叶类型:

类型

定义

张量

torch.Tensor

标量

Python 中的任何数值类型,包括整数类型、浮点数类型和零维张量。

int

Python 中的 int(在 C++中绑定为 int64_t)

float

Python 中的 float(在 C++中绑定为 double)

bool

Python 布尔值

str

Python 字符串

标量类型

torch.dtype

布局

torch.layout

内存格式

torch.memory_format

设备

torch.device

以下类型被定义为容器类型:

类型

定义

元组

Python 元组

列表

Python 列表

字典

Python 字典(标量键)

命名元组

Python 命名元组

数据类

必须通过 register_dataclass 进行注册

自定义类

使用_register_pytree_node 定义的任何自定义类


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

PyTorch 开发者文档全面访问

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源