torch.fx¶
概述 ¶
FX 是开发者用来转换 nn.Module
实例的工具包。FX 包含三个主要组件:符号追踪器、中间表示和 Python 代码生成。以下是这些组件在行动中的演示:
import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
module = MyModule()
from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
%x : [num_users=1] = placeholder[target=x]
%param : [num_users=1] = get_attr[target=param]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
%linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
%clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
return clamp
"""
# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
param = self.param
add = x + param; x = param = None
linear = self.linear(add); add = None
clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
return clamp
"""
符号追踪器执行 Python 代码的“符号执行”。它通过代码传递称为代理的假值。记录这些代理的操作。有关符号追踪的更多信息,请参阅 symbolic_trace()
和 Tracer
文档。
中间表示是记录在符号追踪期间的操作的容器。它由表示函数输入、调用点(到函数、方法或 torch.nn.Module
实例)和返回值的节点列表组成。有关 IR 的更多信息,请参阅 Graph
的文档。IR 是应用转换的格式。
Python 代码生成是 FX 成为 Python 到 Python(或模块到模块)转换工具包的原因。对于每个图 IR,我们可以创建与图语义匹配的有效 Python 代码。此功能封装在 GraphModule
中,它是一个 torch.nn.Module
实例,包含一个 Graph
以及从图生成的 forward
方法。
综合来看,这个组件管道(符号跟踪 -> 中间表示 -> 转换 -> Python 代码生成)构成了 FX 的 Python 到 Python 转换管道。此外,这些组件也可以单独使用。例如,符号跟踪可以单独使用来捕获代码的一种形式以供分析(而非转换)目的。代码生成可以用于程序化生成模型,例如从配置文件中生成。FX 有很多用途!
几个示例转换可以在示例仓库中找到。
编写转换
什么是 FX 转换?本质上,它是一个看起来像这样的函数。
import torch
import torch.fx
def transform(m: nn.Module,
tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
# Step 1: Acquire a Graph representing the code in `m`
# NOTE: torch.fx.symbolic_trace is a wrapper around a call to
# fx.Tracer.trace and constructing a GraphModule. We'll
# split that out in our transform to allow the caller to
# customize tracing behavior.
graph : torch.fx.Graph = tracer_class().trace(m)
# Step 2: Modify this Graph or create a new one
graph = ...
# Step 3: Construct a Module to return
return torch.fx.GraphModule(m, graph)
您的转换将接受一个 torch.nn.Module
,从中获取一个 Graph
,进行一些修改,然后返回一个新的 torch.nn.Module
。您应该将您的 FX 转换返回的 torch.nn.Module
视为与常规的 torch.nn.Module
相同 – 您可以将其传递给另一个 FX 转换,可以传递给 TorchScript,或者可以运行它。确保您的 FX 转换的输入和输出是 torch.nn.Module
将允许进行组合。
注意
也可以修改现有的 GraphModule
而不是创建一个新的,如下所示:
import torch
import torch.fx
def transform(m : nn.Module) -> nn.Module:
gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)
# Modify gm.graph
# <...>
# Recompile the forward() method of `gm` from its Graph
gm.recompile()
return gm
注意,您必须调用 GraphModule.recompile()
来使生成的 forward()
方法与修改后的 GraphModule
同步。
既然您已经传递了一个已经被追踪到 Graph
的 torch.nn.Module
,现在有两种主要方法可以构建一个新的 Graph
。
图论快速入门指南
图的语义的全面介绍可以在 Graph
文档中找到,但在这里我们将介绍基础知识。 Graph
是一种数据结构,它表示在 GraphModule
上的方法。这需要的信息包括:
该方法有哪些输入?
方法内部运行的操作有哪些?
该方法输出的(即返回的)值是什么?
这三个概念都用 Node
实例表示。让我们用一个简短的例子来看看我们是什么意思:
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return torch.topk(torch.sum(
self.linear(x + self.linear.weight).relu(), dim=-1), 3)
m = MyModule()
gm = torch.fx.symbolic_trace(m)
gm.graph.print_tabular()
在这里,我们定义一个用于演示的模块 MyModule
,实例化它,进行符号跟踪,然后调用 Graph.print_tabular()
方法打印出显示此 Graph
节点的表格:
指令码
名称
目标
参数
关键字参数
占位符
x
x
()
{}
获取属性
线性权重
线性.weight
()
{}
调用函数
加 1
<内置函数加>
(x, 线性权重)
{}
调用模块
线性_1
线性
(add_1,)
{}
调用方法
relu_1
relu
(线性_1,)
{}
调用函数
求和 1
<内置方法 sum …>
(relu_1,)
{‘维度’:-1}
调用函数
topk_1
<内置方法 topk …>
(和_1, 3)
{}
输出
输出
输出
(topk_1,)
{}
我们可以使用这些信息来回答我们上面提出的问题。
该方法有哪些输入?在 FX 中,方法输入通过特殊的
placeholder
节点指定。在这种情况下,我们有一个单独的placeholder
节点,其target
为x
,这意味着我们有一个名为 x 的单个(非自身)参数。方法中包含哪些操作?
get_attr
、call_function
、call_module
和call_method
节点代表方法中的操作。所有这些操作的语义的全面解释可以在Node
文档中找到。该方法返回值是什么?在
Graph
中,返回值由一个特殊的output
节点指定。
既然我们已经了解了 FX 中代码表示的基础知识,现在我们可以探索如何编辑 Graph
。
图形操作 ¶
直接图形操作 ¶
建立这种新 Graph
的一种方法就是直接操作你的旧版本。为了帮助实现这一点,我们可以简单地从符号跟踪中获取的 Graph
进行修改。例如,假设我们希望用 torch.mul()
调用替换 torch.add()
调用。
import torch
import torch.fx
# Sample module
class M(torch.nn.Module):
def forward(self, x, y):
return torch.add(x, y)
def transform(m: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph : fx.Graph = tracer_class().trace(m)
# FX represents its Graph as an ordered list of
# nodes, so we can iterate through them.
for node in graph.nodes:
# Checks if we're calling a function (i.e:
# torch.add)
if node.op == 'call_function':
# The target attribute is the function
# that call_function calls.
if node.target == torch.add:
node.target = torch.mul
graph.lint() # Does some checks to make sure the
# Graph is well-formed.
return fx.GraphModule(m, graph)
我们还可以进行更复杂的 Graph
重写,例如删除或添加节点。为了帮助这些转换,FX 提供了用于转换图的实用函数,这些函数可以在 Graph
文档中找到。以下是一个使用这些 API 添加 torch.relu()
调用的示例。
# Specifies the insertion point. Any nodes added to the
# Graph within this scope will be inserted after `node`
with traced.graph.inserting_after(node):
# Insert a new `call_function` node calling `torch.relu`
new_node = traced.graph.call_function(
torch.relu, args=(node,))
# We want all places that used the value of `node` to
# now use that value after the `relu` call we've added.
# We use the `replace_all_uses_with` API to do this.
node.replace_all_uses_with(new_node)
对于只包含替换的简单转换,你还可以使用子图重写器。
使用 replace_pattern() 进行子图重写
FX 还在直接图操作之上提供另一层自动化。 replace_pattern()
API 实质上是一个用于编辑 Graph
的“查找/替换”工具。它允许您指定一个 pattern
和 replacement
函数,然后它会追踪这些函数,在 pattern
图中找到操作组的实例,并将这些实例替换为 replacement
图的副本。这可以帮助极大地自动化繁琐的图操作代码,随着变换变得更加复杂,这些代码可能会变得难以控制。
图操作示例 ¶
代理/重绘
另一种操作 Graph
的方法是重用符号跟踪中使用的 Proxy
机制。例如,让我们想象我们想要编写一个将 PyTorch 函数分解成更小操作的转换。它将把每个 F.relu(x)
调用转换成 (x > 0) * x
。一种可能性是在 F.relu
之后插入比较和乘法所需的图重写,然后清理原始的 F.relu
。然而,我们可以通过使用 Proxy
对象来自动记录操作到 Graph
来自动化此过程。
使用此方法,我们将要插入的操作以常规 PyTorch 代码的形式编写,并使用 Proxy
对象作为参数调用该代码。这些 Proxy
对象将捕获对它们的操作并将它们追加到 Graph
。
# Note that this decomposition rule can be read as regular Python
def relu_decomposition(x):
return (x > 0) * x
decomposition_rules = {}
decomposition_rules[F.relu] = relu_decomposition
def decompose(model: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
"""
Decompose `model` into smaller constituent operations.
Currently,this only supports decomposing ReLU into its
mathematical definition: (x > 0) * x
"""
graph : fx.Graph = tracer_class().trace(model)
new_graph = fx.Graph()
env = {}
tracer = torch.fx.proxy.GraphAppendingTracer(new_graph)
for node in graph.nodes:
if node.op == 'call_function' and node.target in decomposition_rules:
# By wrapping the arguments with proxies,
# we can dispatch to the appropriate
# decomposition rule and implicitly add it
# to the Graph by symbolically tracing it.
proxy_args = [
fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args]
output_proxy = decomposition_rules[node.target](*proxy_args)
# Operations on `Proxy` always yield new `Proxy`s, and the
# return value of our decomposition rule is no exception.
# We need to extract the underlying `Node` from the `Proxy`
# to use it in subsequent iterations of this transform.
new_node = output_proxy.node
env[node.name] = new_node
else:
# Default case: we don't have a decomposition rule for this
# node, so just copy the node over into the new graph.
new_node = new_graph.node_copy(node, lambda x: env[x.name])
env[node.name] = new_node
return fx.GraphModule(model, new_graph)
除了避免显式图操作外,使用 Proxy
还可以让您将重写规则指定为原生 Python 代码。对于需要大量重写规则(如 vmap 或 grad)的转换,这通常可以提高规则的可读性和可维护性。请注意,在调用 Proxy
时,我们还传递了一个指向底层变量图的跟踪器。这样做是为了如果图中的操作是 n 元(例如,add 是一个二元运算符)时,调用 Proxy
不会创建多个图跟踪器实例,这可能导致意外的运行时错误。我们建议在底层运算符不能安全假设为一元时,特别使用此方法使用 Proxy
。
使用 Proxy
进行 Graph
操作的示例可以在此处找到。
解释器模式 ¶
在 FX 中,一种有用的代码组织模式是遍历一个模块中的所有 Node
,并执行它们。这可以用于多种用途,包括对通过图流动的值的运行时分析或通过 Proxy
进行回溯来转换代码。例如,假设我们想要运行一个 GraphModule
并记录节点在运行时看到的 torch.Tensor
形状和 dtype 属性。这可能看起来像:
import torch
import torch.fx
from torch.fx.node import Node
from typing import Dict
class ShapeProp:
"""
Shape propagation. This class takes a `GraphModule`.
Then, its `propagate` method executes the `GraphModule`
node-by-node with the given arguments. As each operation
executes, the ShapeProp class stores away the shape and
element type for the output values of each operation on
the `shape` and `dtype` attributes of the operation's
`Node`.
"""
def __init__(self, mod):
self.mod = mod
self.graph = mod.graph
self.modules = dict(self.mod.named_modules())
def propagate(self, *args):
args_iter = iter(args)
env : Dict[str, Node] = {}
def load_arg(a):
return torch.fx.graph.map_arg(a, lambda n: env[n.name])
def fetch_attr(target : str):
target_atoms = target.split('.')
attr_itr = self.mod
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr
for node in self.graph.nodes:
if node.op == 'placeholder':
result = next(args_iter)
elif node.op == 'get_attr':
result = fetch_attr(node.target)
elif node.op == 'call_function':
result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
elif node.op == 'call_method':
self_obj, *args = load_arg(node.args)
kwargs = load_arg(node.kwargs)
result = getattr(self_obj, node.target)(*args, **kwargs)
elif node.op == 'call_module':
result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
# This is the only code specific to shape propagation.
# you can delete this `if` branch and this becomes
# a generic GraphModule interpreter.
if isinstance(result, torch.Tensor):
node.shape = result.shape
node.dtype = result.dtype
env[node.name] = result
return load_arg(self.graph.result)
如您所见,FX 的完整解释器并不复杂,但它非常有用。为了简化使用这种模式,我们提供了一个 Interpreter
类,它以这种方式封装了上述逻辑,使得解释器执行的一些方面可以通过方法重写来覆盖。
除了执行操作外,我们还可以通过将 Proxy
值通过解释器传递来生成一个新的图。同样,我们提供了一个 Transformer
类来封装这种模式。 Transformer
的行为与 Interpreter
类似,但您不是调用 run
方法从 Module 获取具体的输出值,而是调用 Transformer.transform()
方法返回一个新的 GraphModule
,该新图已应用您安装的任何转换规则作为重写方法。
解释器模式的示例
调试
简介
在编写变换的过程中,我们的代码可能并不完全正确。在这种情况下,我们可能需要进行一些调试。关键是要逆向工作:首先,检查调用生成的模块的结果以证明或反驳正确性。然后,检查和调试生成的代码。然后,调试导致生成代码的变换过程。
如果您不熟悉调试器,请参阅辅助部分“可用的调试器”。
检查模块的正确性 ¶
由于大多数深度学习模块的输出由浮点 torch.Tensor
实例组成,检查两个 torch.nn.Module
的结果之间的等价性并不像进行简单的相等性检查那样简单。为了说明这一点,让我们用一个例子来说明:
import torch
import torch.fx
import torchvision.models as models
def transform(m : torch.nn.Module) -> torch.nn.Module:
gm = torch.fx.symbolic_trace(m)
# Imagine we're doing some transforms here
# <...>
gm.recompile()
return gm
resnet18 = models.resnet18()
transformed_resnet18 = transform(resnet18)
input_image = torch.randn(5, 3, 224, 224)
assert resnet18(input_image) == transformed_resnet18(input_image)
"""
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
"""
在这里,我们尝试使用 ==
等价运算符检查两个深度学习模型的值是否相等。然而,这并不明确,一方面是因为该运算符返回一个张量而不是布尔值,另一方面是因为浮点数的比较应该使用误差范围(或 epsilon)来考虑浮点运算的非交换性(更多详情请见此处)。我们可以使用 torch.allclose()
,它将给出一个近似比较,考虑到相对和绝对容差阈值:
assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))
这是我们的工具箱中第一个检查转换后的模块是否按预期与参考实现相比表现良好的工具。
生成代码的调试
因为 FX 在 GraphModule
上生成 forward()
函数,所以使用传统的调试技术,如 print
语句或 pdb
,并不那么直接。幸运的是,我们有几种可以用来调试生成代码的技术。
使用 pdb
¶
使用 pdb
进入正在运行的程序。虽然代表 Graph
的代码不在任何源文件中,但在调用前向传递时,我们仍然可以使用 pdb
手动进入。
import torch
import torch.fx
import torchvision.models as models
def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph = tracer_class().trace(inp)
# Transformation logic here
# <...>
# Return new Module
return fx.GraphModule(inp, graph)
my_module = models.resnet18()
my_module_transformed = my_pass(my_module)
input_value = torch.randn(5, 3, 224, 224)
# When this line is executed at runtime, we will be dropped into an
# interactive `pdb` prompt. We can use the `step` or `s` command to
# step into the execution of the next line
import pdb; pdb.set_trace()
my_module_transformed(input_value)
打印生成的代码 ¶
如果你想多次运行相同的代码,那么使用 pdb
步进到正确的代码可能会有些繁琐。在这种情况下,一种方法是将生成的 forward
传递复制粘贴到你的代码中,并从那里进行检查。
# Assume that `traced` is a GraphModule that has undergone some
# number of transforms
# Copy this code for later
print(traced)
# Print the code generated from symbolic tracing. This outputs:
"""
def forward(self, y):
x = self.x
add_1 = x + y; x = y = None
return add_1
"""
# Subclass the original Module
class SubclassM(M):
def __init__(self):
super().__init__()
# Paste the generated `forward` function (the one we printed and
# copied above) here
def forward(self, y):
x = self.x
add_1 = x + y; x = y = None
return add_1
# Create an instance of the original, untraced Module. Then, create an
# instance of the Module with the copied `forward` function. We can
# now compare the output of both the original and the traced version.
pre_trace = M()
post_trace = SubclassM()
使用 to_folder
函数从 GraphModule
¶
GraphModule.to_folder()
是 GraphModule
中的一个方法,允许您将生成的 FX 代码导出到文件夹。尽管将前向传递复制到代码中通常就足够了,如在“打印生成的代码”中,但使用 to_folder
可能更容易检查模块和参数。
m = symbolic_trace(M())
m.to_folder("foo", "Bar")
from foo import Bar
y = Bar()
在运行上述示例之后,我们就可以查看 foo/module.py
中的代码并根据需要对其进行修改(例如添加 print
语句或使用 pdb
)以调试生成的代码。
调试转换 ¶
现在我们已经确定一个转换正在生成错误的代码,现在是时候调试这个转换本身了。首先,我们将检查文档中的“符号跟踪的限制”部分。一旦我们验证跟踪是否按预期工作,目标就变成了找出我们的 GraphModule
转换中出了什么问题。在《编写转换》中可能有一个快速的答案,但如果没有,我们有几种方法可以检查我们的跟踪模块:
# Sample Module
class M(torch.nn.Module):
def forward(self, x, y):
return x + y
# Create an instance of `M`
m = M()
# Symbolically trace an instance of `M` (returns a GraphModule). In
# this example, we'll only be discussing how to inspect a
# GraphModule, so we aren't showing any sample transforms for the
# sake of brevity.
traced = symbolic_trace(m)
# Print the code produced by tracing the module.
print(traced)
# The generated `forward` function is:
"""
def forward(self, x, y):
add = x + y; x = y = None
return add
"""
# Print the internal Graph.
print(traced.graph)
# This print-out returns:
"""
graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=1] = placeholder[target=y]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})
return add
"""
# Print a tabular representation of the internal Graph.
traced.graph.print_tabular()
# This gives us:
"""
opcode name target args kwargs
------------- ------ ----------------------- ------ --------
placeholder x x () {}
placeholder y y () {}
call_function add <built-in function add> (x, y) {}
output output output (add,) {}
"""
使用上面的实用函数,我们可以比较我们在应用转换前后跟踪的模块。有时,简单的视觉比较就足以追踪到错误。如果仍然不清楚出了什么问题,使用 pdb
这样的调试器可以是一个好的下一步。
根据上面的例子,考虑以下代码:
# Sample user-defined function
def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
# Get the Graph from our traced Module
g = tracer_class().trace(module)
"""
Transformations on `g` go here
"""
return fx.GraphModule(module, g)
# Transform the Graph
transformed = transform_graph(traced)
# Print the new code after our transforms. Check to see if it was
# what we expected
print(transformed)
使用上面的例子,假设 print(traced)
的调用显示我们的转换中存在错误。我们想通过调试器找出错误所在。我们启动一个 pdb
会话。我们可以通过在 transform_graph(traced)
处中断,然后按 s
“进入” transform_graph(traced)
的调用来查看转换过程中的情况。
我们也可以通过编辑 print_tabular
方法来打印图中节点不同的属性。(例如,我们可能想看到节点的 input_nodes
和 users
。)
可用调试器 §
最常用的 Python 调试器是 pdb。您可以通过在命令行中输入 python -m pdb FILENAME.py
以“调试模式”启动您的程序,其中 FILENAME
是要调试的文件名。之后,您可以使用 pdb
调试器命令逐步执行您的程序。通常,在开始 pdb
时设置一个断点( b LINE-NUMBER
),然后调用 c
运行程序直到该点。这可以防止您必须逐行执行(使用 s
或 n
)才能到达想要检查的代码部分。或者,您可以在想要中断的行之前写入 import pdb; pdb.set_trace()
。如果添加 pdb.set_trace()
,则您的程序将在运行时自动进入调试模式。(换句话说,您可以直接在命令行中输入 python FILENAME.py
而不是 python -m pdb FILENAME.py
。)一旦您以调试模式运行文件,您就可以使用某些命令逐步执行代码并检查程序的内部状态。网上有许多关于 pdb
的优秀教程,包括 RealPython 的“使用 pdb 进行 Python 调试”。
PyCharm 或 VSCode 等集成开发环境通常内置了调试器。在您的 IDE 中,您可以选择以下方式之一:a) 通过在 IDE 中打开终端窗口(例如 VSCode 中的“视图→终端”)来使用 pdb
;b) 使用内置的调试器(通常是一个围绕 pdb
的图形包装器)。
符号跟踪的限制
FX 使用符号跟踪(也称为符号执行)系统来以可转换/可分析的形式捕获程序的语义。该系统是跟踪的,因为它执行程序(实际上是 torch.nn.Module
或函数)以记录操作。它是符号的,因为在执行过程中通过程序的数据不是真实数据,而是符号(在 FX 术语中为 Proxy
)。
虽然符号跟踪对于大多数神经网络代码都有效,但它有一些局限性。
动态控制流
符号跟踪的主要局限性是它目前不支持动态控制流。也就是说,循环或 if
语句的条件可能依赖于程序输入的值。
例如,让我们分析以下程序:
def func_to_trace(x):
if x.sum() > 0:
return torch.relu(x)
else:
return torch.neg(x)
traced = torch.fx.symbolic_trace(func_to_trace)
"""
<...>
File "dyn.py", line 6, in func_to_trace
if x.sum() > 0:
File "pytorch/torch/fx/proxy.py", line 155, in __bool__
return self.tracer.to_bool(self)
File "pytorch/torch/fx/proxy.py", line 85, in to_bool
raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""
if
语句的条件依赖于 x.sum()
的值,而 x.sum()
的值又依赖于 x
的值,这是一个函数输入。由于 x
可以改变(即如果你向跟踪函数传递新的输入张量),这就是动态控制流。回溯会沿着你的代码向上走,以显示这种情况发生的位置。
静态控制流 ¶
另一方面,所谓的静态控制流得到了支持。静态控制流是指值在调用之间不能改变的循环或 if
语句。通常,在 PyTorch 程序中,这种控制流出现在根据超参数做出决策的代码中。作为一个具体的例子:
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self, do_activation : bool = False):
super().__init__()
self.do_activation = do_activation
self.linear = torch.nn.Linear(512, 512)
def forward(self, x):
x = self.linear(x)
# This if-statement is so-called static control flow.
# Its condition does not depend on any input values
if self.do_activation:
x = torch.relu(x)
return x
without_activation = MyModule(do_activation=False)
with_activation = MyModule(do_activation=True)
traced_without_activation = torch.fx.symbolic_trace(without_activation)
print(traced_without_activation.code)
"""
def forward(self, x):
linear_1 = self.linear(x); x = None
return linear_1
"""
traced_with_activation = torch.fx.symbolic_trace(with_activation)
print(traced_with_activation.code)
"""
import torch
def forward(self, x):
linear_1 = self.linear(x); x = None
relu_1 = torch.relu(linear_1); linear_1 = None
return relu_1
"""
if-语句 if self.do_activation
不依赖于任何函数输入,因此它是静态的。 do_activation
可以被视为一个超参数,而 MyModule
的不同实例的痕迹,该参数有不同的值,有不同的代码。这是一个有效的模式,也是符号跟踪所支持的。
许多动态控制流的实例在语义上是静态控制流。通过移除对输入值的依赖,例如通过将值移动到 Module
属性或绑定具体值到符号跟踪期间的参数,可以将这些实例转换为支持符号跟踪:
def f(x, flag):
if flag: return x
else: return x*2
fx.symbolic_trace(f) # Fails!
fx.symbolic_trace(f, concrete_args={'flag': True})
对于真正的动态控制流,包含此代码的程序部分可以被视为对方法(请参阅使用 Tracer 类自定义跟踪)或函数(请参阅 wrap()
)的调用,而不是通过它们进行跟踪。
非函数 torch
FX 使用 __torch_function__
作为拦截调用的机制(有关此内容的更多信息,请参阅技术概述)。一些函数,例如内置的 Python 函数或 math
模块中的函数,不受 __torch_function__
的保护,但我们仍然希望将它们捕获在符号跟踪中。例如:
import torch
import torch.fx
from math import sqrt
def normalize(x):
"""
Normalize `x` by the size of the batch dimension
"""
return x / sqrt(len(x))
# It's valid Python code
normalize(torch.rand(3, 4))
traced = torch.fx.symbolic_trace(normalize)
"""
<...>
File "sqrt.py", line 9, in normalize
return x / sqrt(len(x))
File "pytorch/torch/fx/proxy.py", line 161, in __len__
raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""
错误提示我们内置函数 len
不受支持。我们可以使此类函数在跟踪中以直接调用方式记录,使用 wrap()
API:
torch.fx.wrap('len')
torch.fx.wrap('sqrt')
traced = torch.fx.symbolic_trace(normalize)
print(traced.code)
"""
import math
def forward(self, x):
len_1 = len(x)
sqrt_1 = math.sqrt(len_1); len_1 = None
truediv = x / sqrt_1; x = sqrt_1 = None
return truediv
"""
使用 Tracer
类定制跟踪 ¶
Tracer
类是实现 symbolic_trace
的底层类。可以通过子类化 Tracer 来自定义跟踪的行为,如下所示:
class MyCustomTracer(torch.fx.Tracer):
# Inside here you can override various methods
# to customize tracing. See the `Tracer` API
# reference
pass
# Let's use this custom tracer to trace through this module
class MyModule(torch.nn.Module):
def forward(self, x):
return torch.relu(x) + torch.ones(3, 4)
mod = MyModule()
traced_graph = MyCustomTracer().trace(mod)
# trace() returns a Graph. Let's wrap it up in a
# GraphModule to make it runnable
traced = torch.fx.GraphModule(mod, traced_graph)
叶模块
叶模块是出现在符号跟踪中的调用模块,而不是通过跟踪实现的模块。默认的叶模块集是标准 torch.nn
模块实例的集合。例如:
class MySpecialSubmodule(torch.nn.Module):
def forward(self, x):
return torch.neg(x)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 4)
self.submod = MySpecialSubmodule()
def forward(self, x):
return self.submod(self.linear(x))
traced = torch.fx.symbolic_trace(MyModule())
print(traced.code)
# `linear` is preserved as a call, yet `submod` is traced though.
# This is because the default set of "Leaf Modules" includes all
# standard `torch.nn` modules.
"""
import torch
def forward(self, x):
linear_1 = self.linear(x); x = None
neg_1 = torch.neg(linear_1); linear_1 = None
return neg_1
"""
可以通过覆盖 Tracer.is_leaf_module()
来自定义叶模块集。
杂项 ¶
索引构造函数(例如
torch.zeros
,torch.ones
,torch.rand
,torch.randn
,torch.sparse_coo_tensor
)目前无法追踪。确定性构造函数(
zeros
,ones
)可以使用,它们产生的值将被嵌入到追踪中作为常量。这只有在这些构造函数的参数引用动态输入大小时才会成为问题。在这种情况下,ones_like
或zeros_like
可能是可行的替代品。非确定性构造函数(
rand
,randn
)将在追踪中嵌入单个随机值。这很可能不是预期的行为。一种解决方案是将torch.randn
包裹在torch.fx.wrap
函数中,并调用该函数。
@torch.fx.wrap def torch_randn(x, shape): return torch.randn(shape) def f(x): return x + torch_randn(x, 5) fx.symbolic_trace(f)
该行为可能在未来的版本中得到修复。
类型注解
支持 Python 3 风格的类型注解(例如
func(x : torch.Tensor, y : int) -> torch.Tensor
),并且符号跟踪将保留这些注解。目前不支持 Python 2 风格的注释类型注解
# type: (torch.Tensor, int) -> torch.Tensor
。函数内部对局部名称的注释目前不支持。
关于
training
标志和子模块的注意事项。当使用
torch.nn.functional.dropout
等函数时,训练参数通常会被传递为self.training
。在 FX 追踪期间,这很可能会被固化为一个常量值。
import torch import torch.fx class DropoutRepro(torch.nn.Module): def forward(self, x): return torch.nn.functional.dropout(x, training=self.training) traced = torch.fx.symbolic_trace(DropoutRepro()) print(traced.code) """ def forward(self, x): dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False); x = None return dropout """ traced.eval() x = torch.randn(5, 3) torch.testing.assert_close(traced(x), x) """ AssertionError: Tensor-likes are not close! Mismatched elements: 15 / 15 (100.0%) Greatest absolute difference: 1.6207983493804932 at index (0, 2) (up to 1e-05 allowed) Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed) """
然而,当使用标准
nn.Dropout()
子模块时,训练标志被封装,并且由于保留了nn.Module
对象模型,可以更改。
class DropoutRepro2(torch.nn.Module): def __init__(self): super().__init__() self.drop = torch.nn.Dropout() def forward(self, x): return self.drop(x) traced = torch.fx.symbolic_trace(DropoutRepro2()) print(traced.code) """ def forward(self, x): drop = self.drop(x); x = None return drop """ traced.eval() x = torch.randn(5, 3) torch.testing.assert_close(traced(x), x)
由于这种差异,请考虑将动态交互
training
标志的模块标记为叶模块。
API 参考指南
- torch.fx.symbolic_trace(root, concrete_args=None)[source][source] 参考指南
符号追踪 API
给定一个
nn.Module
或函数实例root
,此函数将返回一个通过记录在root
中看到的操作所构建的GraphModule
。concrete_args
允许您部分专门化您的函数,无论是要移除控制流还是数据结构。例如:
def f(a, b): if b == True: return a else: return a * 2
由于存在控制流,FX 通常无法追踪此操作。但是,我们可以使用 concrete_args 来专门化 b 的值以追踪此操作:
f = fx.symbolic_trace(f, concrete_args={"b": False}) assert f(3, False) == 6
注意,尽管您仍然可以传入不同的 b 值,但它们将被忽略。
我们还可以使用 concrete_args 来消除函数中的数据结构处理。这将使用 pytrees 来简化您的输入。为了避免过度专业化,对于不应专业化的值,请传入 fx.PH。例如:
def f(x): out = 0 for v in x.values(): out += v return out f = fx.symbolic_trace(f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}}) assert f({"a": 1, "b": 2, "c": 4}) == 7
- 参数:
root (Union[torch.nn.Module, Callable]) – 要追踪和转换为图表示的模块或函数。
concrete_args (Optional[Dict[str, any]]) – 部分专业化的输入
- 返回:
由
root
记录的操作创建的模块- 返回类型:
注意
本 API 向后兼容性得到保证。
- torch.fx.wrap(fn_or_name)[source][source]¶
此函数可以在模块级作用域中调用,将 fn_or_name 注册为“叶子函数”。一个“叶子函数”将被保留为 FX 跟踪中的 CallFunction 节点,而不是被跟踪:
# foo/bar/baz.py def my_custom_function(x, y): return x * x + y * y torch.fx.wrap("my_custom_function") def fn_to_be_traced(x, y): # When symbolic tracing, the below call to my_custom_function will be inserted into # the graph rather than tracing it. return my_custom_function(x, y)
此函数也可以等效地用作装饰器:
# foo/bar/baz.py @torch.fx.wrap def my_custom_function(x, y): return x * x + y * y
被包装的函数可以被视为“叶子函数”,类似于“叶子模块”的概念,即它们是保留在 FX 跟踪中的调用,而不是被跟踪的函数。
- 参数:
fn_or_name (Union[str, Callable]) – 当调用时,将函数或全局函数的名称插入到图中的函数
注意
本 API 保证向后兼容性。
- class torch.fx.GraphModule(*args, **kwargs)[source][source]¶
GraphModule 是由 fx.Graph 生成的 nn.Module。Graphmodule 具有从
graph
属性生成的code
以及forward
和graph
属性。警告
当
graph
被重新赋值时,code
和forward
将会自动重新生成。然而,如果您编辑了graph
的内容而没有重新赋值graph
属性本身,您必须调用recompile()
来更新生成的代码。注意
此 API 保证向后兼容。
- __init__(root, graph, class_name='GraphModule')[source][source]¶
构建一个 GraphModule。
- 参数:
root (Union[torch.nn.Module, Dict[str, Any]) –
root
可以是 nn.Module 实例或映射字符串到任何属性类型的 Dict。如果root
是 Module,则 Graph 的 Nodes 的target
字段中通过限定名称引用的基于 Module 的对象将被从root
的 Module 层级复制到 GraphModule 的模块层级。如果root
是 dict,则 Node 的target
中找到的限定名称将直接在 dict 的键中查找。映射到 Dict 的对象将被复制到 GraphModule 的模块层级中相应的位置。图(图)-
graph
包含此 GraphModule 应使用的节点以进行代码生成class_name(字符串)-
name
表示此 GraphModule 的调试名称。如果未设置,所有错误消息都将报告为来自GraphModule
。将其设置为root
的原始名称或适合您转换上下文的名称可能很有帮助。
注意
此 API 向后兼容性得到保证。
- add_submodule(target, m)[源代码][源代码] ¶
将给定的子模块添加到
self
。如果不存在,则安装空的模块,这些模块是
target
的子路径。- 参数:
target (str) – 新子模块的完全限定字符串名称(参见
nn.Module.get_submodule
中的示例,了解如何指定完全限定字符串。)m (Module) – 子模块本身;我们想要安装到当前模块的实际对象
- 返回:
- 子模块是否可以插入。
此方法返回 True,链中的每个对象(由target
表示)必须满足以下条件之一:a)尚不存在,或 b)引用nn.Module
(不是参数或其他属性)
- 返回类型:
注意
此 API 的后向兼容性得到保证。
- 属性 codestr ¶
返回由
Graph
生成的 Python 代码。
- delete_all_unused_submodules()[source][source]¶
从
self
中删除所有未使用的子模块。模块被认为“被使用”,如果以下任何一个条件成立:1. 它有被使用的子模块 2. 它的前向操作通过一个
call_module
节点直接调用 3. 它有一个非模块属性,该属性从一个get_attr
节点被使用可以调用此方法来清理
nn.Module
,而无需手动对每个未使用的子模块调用delete_submodule
。注意
此 API 向后兼容性得到保证。
- delete_submodule(target)[source][source]¶
从
self
中删除指定的子模块。模块在
target
不是一个有效目标时不会被删除。- 参数:
目标(字符串)- 新子模块的完全限定字符串名称(参见
nn.Module.get_submodule
中的示例,了解如何指定完全限定字符串。)- 返回:
- 无论目标字符串是否引用了
子模块我们要删除。返回值False
表示target
不是一个有效的子模块引用。
- 返回类型:
注意
此 API 保证向后兼容。
- 属性图 Graph ¶
返回此
Graph
的底层GraphModule
- 打印可读输出(print_output=True,include_stride=False,include_device=False,colored=False)[source][source] ¶
返回当前 GraphModule 及其子 GraphModule 生成的 Python 代码
警告
此 API 为实验性,且不向后兼容。
- class torch.fx.Graph(拥有模块=None, 跟踪器类=None, 跟踪器额外参数=None)[source][source] ¶
Graph
是 FX 中间表示法中使用的最主要的数据结构。它由一系列Node
组成,每个Node
代表一个调用点(或其他语法结构)。Node
的列表,合在一起,构成一个有效的 Python 函数。例如,以下代码
import torch import torch.fx class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x): return torch.topk( torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3 ) m = MyModule() gm = torch.fx.symbolic_trace(m)
将产生以下 Graph:
print(gm.graph)
graph(x): %linear_weight : [num_users=1] = self.linear.weight %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {}) %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1}) %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {}) return topk_1
关于在
Graph
中表示的操作的语义,请参阅Node
。注意
本 API 保证向下兼容。
- __init__(owning_module=None, tracer_cls=None, tracer_extras=None)[source][source]¶
构建一个空图。
注意
此 API 保证向后兼容。
- call_function(the_function, args=None, kwargs=None, type_expr=None)[source][source]¶
在
Graph
中插入call_function
Node
。一个call_function
节点表示对 Python 可调用对象的调用,该对象由the_function
指定。- 参数:
the_function (Callable[..., Any]) – 要调用的函数。可以是任何 PyTorch 操作符、Python 函数或
builtins
或operator
命名空间中的成员。args (可选[Tuple[Argument, ...]]) – 要传递给被调用函数的位置参数。
kwargs (可选[Dict[str, Argument]]) – 要传递给被调用函数的关键字参数
type_expr (可选[Any]) – 表示此节点输出将具有的 Python 类型的可选类型注解。
- 返回:
新创建并插入的
call_function
节点。- 返回类型:
注意
此方法的插入点类型表达式规则与
Graph.create_node()
相同。注意
此 API 向后兼容性得到保证。
- call_method(method_name, args=None, kwargs=None, type_expr=None)[source][source]¶
在
Graph
中插入call_method
Node
。一个call_method
节点表示对args
的第 0 个元素调用给定方法。- 参数:
method_name (str) – 要应用到 self 参数上的方法名称。例如,如果 args[0]是一个表示
Tensor
的Node
,那么要调用Tensor
上的relu()
,需要将relu
传递给method_name
。args (Optional[Tuple[Argument, ...]]) – 要传递给被调用方法的定位参数。注意,这应该包括一个
self
参数。kwargs (Optional[Dict[str, Argument]]) – 要传递给被调用方法的键值参数
type_expr (Optional[Any]) – 表示此节点输出将具有的 Python 类型的可选类型注解。
- 返回:
新创建并插入的
call_method
节点。- 返回类型:
注意
此方法的插入点和类型表达式规则与
Graph.create_node()
相同。注意
此 API 保证向后兼容。
- call_module(module_name, args=None, kwargs=None, type_expr=None)[source][source]¶
在
Graph
中插入call_module
Node
。一个call_module
节点代表对Module
层次结构中Module
的 forward()函数的调用。- 参数:
module_name (str) – 要调用的
Module
在Module
层次结构中的限定名称。例如,如果被跟踪的Module
有一个名为foo
的子模块,该子模块有一个名为bar
的子模块,则应将限定名称foo.bar
作为module_name
传递以调用该模块。args(可选[Tuple[Argument, ...]])- 要传递给调用方法的定位参数。请注意,此参数不应包括
self
参数。kwargs(可选[Dict[str, Argument]])- 要传递给调用方法的键值参数
type_expr(可选[Any])- 表示此节点输出将具有的 Python 类型的可选类型注解。
- 返回:
新创建并插入的
call_module
节点。- 返回类型:
注意
此方法的插入点类型表达式规则与
Graph.create_node()
相同。注意
此 API 向后兼容性得到保证。
- create_node(op, target, args=None, kwargs=None, name=None, type_expr=None)[source][source]¶
创建一个
Node
并将其添加到当前插入点处的Graph
。请注意,当前插入点可以通过Graph.inserting_before()
和Graph.inserting_after()
设置。- 参数:
op (str) – 此节点的操作码。可以是 'call_function'、'call_method'、'get_attr'、'call_module'、'placeholder' 或 'output' 之一。这些操作码的语义在
Graph
文档字符串中描述。args (Optional[Tuple[Argument, ...]]) – 是一个包含此节点参数的元组。
kwargs (Optional[Dict[str, Argument]]) – 此节点的 kwargs
可选的字符串名称(Optional[str]) - 为
Node
. 分配的值提供可选的字符串名称。这将影响 Python 生成的代码中分配的值名称。类型表达式(Optional[Any]) - 表示此节点输出将具有的 Python 类型的可选类型注解。
- 返回:
新创建并插入的节点。
- 返回类型:
注意
此 API 保证向后兼容。
- eliminate_dead_code(is_impure_node=None)[source][source]¶
根据每个节点的用户数量以及节点是否有副作用,从图中删除所有死代码。调用之前必须对图进行拓扑排序。
- 参数:
is_impure_node (Optional[Callable[[Node], bool]]) – 一个返回
(节点是否不纯。如果是) –
(那么默认行为是) –
(使用)Node.is_impure. –
- 返回:
(该图是否因为该遍历而改变。)
- 返回类型:
示例:
在消除死代码之前,一个如下的 a = x + 1 没有用户,因此可以从图中删除而不会产生影响。
def forward(self, x): a = x + 1 return x + self.attr_1
在消除死代码之后,a = x + 1 已被删除,其余的前向操作仍然保留。
def forward(self, x): return x + self.attr_1
警告
死代码消除有一些启发式方法来避免删除有副作用的节点(见 Node.is_impure),但总体覆盖率非常差,因此你应该假设除非你知道你的 FX 图完全由函数操作组成,或者你提供了自己的自定义函数来检测有副作用的节点,否则这种方法是不可靠的。
注意
本 API 的向后兼容性得到保证。
- erase_node(to_erase)[源代码][源代码]
从
Graph
中删除Node
。如果该节点在Graph
中仍有用户,则抛出异常。- 参数:
to_erase (节点) – 从
Graph
中删除的Node
。
注意
本 API 向后兼容性有保证。
- find_nodes(*, op, target=None, sort=True)[source][source]¶
允许快速查询节点
- 参数:
op (str) – 操作名称
target (Optional[Target]) – 节点的目标。对于 call_function,目标为必需。对于其他操作,目标为可选。
是否按在图上出现的顺序返回节点。
- 返回:
具有请求操作和目标的节点可迭代。
警告
此 API 为实验性,且不向后兼容。
- get_attr(qualified_name, type_expr=None)[source][source]¶
在图中插入一个
get_attr
节点。Aget_attr
Node
表示从Module
层次中获取属性。- 参数:
完整名称(str)- 要检索的属性的完整名称。例如,如果跟踪的模块有一个名为
foo
的子模块,该子模块有一个名为bar
的子模块,该子模块有一个名为baz
的属性,则应将合格名称foo.bar.baz
传递为qualified_name
。type_expr(可选[Any])- 表示此节点输出将具有的 Python 类型的可选类型注解。
- 返回:
新创建并插入的
get_attr
节点。- 返回类型:
注意
此方法与
Graph.create_node
的插入点及类型表达式规则相同。注意
此 API 向后兼容性得到保证。
- graph_copy(g, val_map, return_output_node=False)[source][source]¶
将给定图中的所有节点复制到
self
。- 参数:
g(图)- 从中复制节点的源图。
val_map(Dict[Node, Node])- 一个字典,将填充从
g
到self
的节点映射。注意,如果val_map
已经包含值,则可以传入以覆盖某些值的复制。
- 返回:
如果
g
有一个output
节点,则self
中的值现在等同于g
的输出值。否则为None
。- 返回类型:
Optional[Union[ tuple[Union[ForwardRef('Argument'), ...], collections.abc.Sequence[ForwardRef('Argument')], collections.abc.Mapping[ str, ForwardRef('Argument')], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType], …], Sequence[Optional[Union[ tuple[ForwardRef('Argument'), …], Sequence[Argument], Mapping[ str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]], Mapping[ str, Optional[Union[ tuple[ForwardRef('Argument'), …], Sequence[Argument], Mapping[ str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]
注意
此 API 向后兼容性得到保证。
- inserting_after(n=None)[source][source]¶
- 设置 create_node 和伴随方法将插入到图中的点。
当在“with”语句中使用时,这将临时设置插入点,并在退出“with”语句时恢复:with g.inserting_after(n): ... # inserting after node n ... # insert point restored to what it was previously g.inserting_after(n) # set the insert point permanently
参数:
- n (可选[节点]): 要插入其前的节点。如果为 None,则将在整个图的开始之后插入。
的位置。
- 返回:
一个将恢复插入点在__exit__
的资源管理器。
注意
本 API 向后兼容性得到保证。
- inserting_before(n=None)[source][source]¶
- 设置 create_node 和伴随方法将插入到图中的点。
当在“with”语句中使用时,这将临时设置插入点,并在退出“with”语句时恢复:with g.inserting_before(n): ... # inserting before node n ... # insert point restored to what it was previously g.inserting_before(n) # set the insert point permanently
参数:
- n (可选[节点]): 要插入其前的节点。如果为 None,则将在整个图的开始处插入。
的位置。
- 返回:
一个将恢复__exit__
.插入点的资源管理器。
注意
本 API 向后兼容性得到保证。
- lint()[源码][源码] ¶
对此图进行各种检查以确保其格式正确。特别是:- 检查节点拥有正确的所有权(由此图拥有)- 检查节点按拓扑顺序出现- 如果此图有一个拥有 GraphModule,检查该 GraphModule 中存在目标
注意
此 API 的向后兼容性得到保证。
- node_copy(node, arg_transform=>)[源码][源码] ¶
将一个节点从一张图复制到另一张图。
arg_transform
需要将节点所在图的参数转换为自身图的参数。示例:# Copying all the nodes in `g` into `new_graph` g: torch.fx.Graph = ... new_graph = torch.fx.graph() value_remap = {} for node in g.nodes: value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n])
- 参数:
节点(Node)- 要复制的节点
self
。arg_transform (Callable[[Node], Argument]) - 一个函数,用于将节点
args
和kwargs
中的参数转换为self
中的等效参数。在简单情况下,这应该从将原始图中的节点映射到self
的表中检索值。
- 返回类型:
注意
本 API 向后兼容性得到保证。
- 属性节点_node_list _
获取构成此图的节点列表。
注意,此
Node
列表表示形式是双向链表。迭代期间的变异(例如删除节点、添加节点)是安全的。- 返回:
双向链表的节点。注意,可以通过
reversed
来切换此列表的迭代顺序。
- on_generate_code(make_transformer)[source][source]
在生成 Python 代码时注册转换器函数
- 参数:
- make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]):
一个返回要注册的代码转换器的函数。此函数由 on_generate_code 调用以获取代码转换器。
此函数还接收当前已注册的代码转换器(如果没有注册则为 None),以防不希望覆盖它。这对于将代码转换器链接在一起很有用。
- 返回:
一个上下文管理器,当在 with 语句中使用时,会自动恢复之前注册的代码转换器。
示例:
gm: fx.GraphModule = ... # This is a code transformer we want to register. This code # transformer prepends a pdb import and trace statement at the very # beginning of the generated torch.fx code to allow for manual # debugging with the PDB library. def insert_pdb(body): return ["import pdb; pdb.set_trace()\n", *body] # Registers `insert_pdb`, and overwrites the current registered # code transformer (given by `_` to the lambda): gm.graph.on_generate_code(lambda _: insert_pdb) # Or alternatively, registers a code transformer which first # runs `body` through existing registered transformer, then # through `insert_pdb`: gm.graph.on_generate_code( lambda current_trans: ( lambda body: insert_pdb(current_trans(body) if current_trans else body) ) ) gm.recompile() gm(*inputs) # drops into pdb
此功能也可以用作上下文管理器,具有自动恢复先前注册的代码转换器的优势:
# ... continue from previous example with gm.graph.on_generate_code(lambda _: insert_pdb): # do more stuff with `gm`... gm.recompile() gm(*inputs) # drops into pdb # now previous code transformer is restored (but `gm`'s code with pdb # remains - that means you can run `gm` with pdb here too, until you # run next `recompile()`).
警告
此 API 为实验性,且不向后兼容。
- output(result, type_expr=None)[source][source]¶
在
Graph
中插入output
Node
。一个output
节点代表 Python 代码中的一个return
语句。result
是应该返回的值。- 参数:
result (参数) – 要返回的值。
type_expr (Optional[Any]) – 表示此节点输出将具有的 Python 类型的可选类型注解。
注意
此方法与
Graph.create_node
适用于相同的插入点和类型表达式规则。注意
本 API 保证向后兼容性。
- placeholder(名称, 类型表达式=None, 默认值)[源代码][源代码] ¶
在图中插入一个
placeholder
节点。placeholder
代表函数输入。- 参数:
名称 (str) – 输入值的名称。这对应于该
Graph
代表的函数的位置参数名称。type_expr (Optional[Any]) – 表示此节点输出将具有的 Python 类型的可选类型注解。在某些情况下,这对于正确的代码生成是必需的(例如,当函数在 TorchScript 编译中随后使用时)。
default_value (Any) – 此函数参数应采用的默认值。注意:为了允许 None 作为默认值,应将 inspect.Signature.empty 传递给此参数,以指定该参数没有默认值。
- 返回类型:
注意
此方法与
Graph.create_node
一样适用相同的插入点和类型表达式规则。注意
本 API 保证向下兼容。
- python_code(root_module, *, verbose=False, include_stride=False, include_device=False, colored=False)[source][source]¶
将这个
Graph
转换为有效的 Python 代码。- 参数:
root_module(字符串)- 要查找合格名称目标的根模块名称。这通常是‘self’。
- 返回:
src:表示对象的 Python 源代码 globals:src 中的全局名称的字典 -> 它们引用的对象。
- 返回类型:
一个 PythonCode 对象,包含两个字段
注意
此 API 保证向后兼容。
- class torch.fx.Node(graph, name, op, target, args, kwargs, return_type=None)[source][source]¶
Node
是表示Graph
中单个操作的抽象数据类型。大部分情况下,节点代表对各种实体(如算子、方法、模块等)的调用点(一些例外包括指定函数输入和输出的节点)。每个节点都有一个由其op
属性指定的函数。Node
的语义如下:placeholder
代表函数输入。name
属性指定该值将采用的名称。target
类似地是参数的名称。args
包含以下两种情况之一:1)无内容,或 2)一个表示函数输入默认参数的单个参数。kwargs
是无关紧要的。占位符对应于图打印中的函数参数(例如x
)。get_attr
从模块层次结构中检索参数。name
类似地是分配给检索结果的名称。target
是参数在模块层次结构中的完全限定名称。args
和kwargs
是无关紧要的。call_function
将一个自由函数应用于一些值。name
类似地是分配给值的名称。target
是要应用的功能。args
和kwargs
代表函数的参数,遵循 Python 调用约定。call_module
将模块层次结构中的forward()
方法应用于给定的参数。name
如前所述。target
是要调用的模块在模块层次结构中的完全限定名称。args
和kwargs
代表调用模块时要传递的参数,不包括 self 参数。call_method
调用值上的方法。name
与此类似。target
是要应用到self
参数上的方法的字符串名称。args
和kwargs
代表调用模块时的参数,包括 self 参数output
包含被跟踪函数的输出,存储在其args[0]
属性中。这对应于图打印输出中的“return”语句。
注意
本 API 向后兼容性得到保证。
- 属性 all_input_nodeslist[torch.fx.node.Node] ¶
返回所有是该节点输入的节点。这相当于遍历
args
和kwargs
,并仅收集值为节点的值。- 返回:
列出在
args
和kwargs
中出现的Nodes
,顺序如下。
- append(x)[source][source]¶
在图中节点列表中在此节点之后插入
x
。相当于self.next.prepend(x)
- 参数:
x(节点)- 要放在此节点之后的节点。必须是同一图中的成员。
注意
本 API 向后兼容性得到保证。
- 属性 argstuple[typing.Union[tuple[typing.Union[tuple[ForwardRef('Argument'), ...], collections.abc.Sequence[ForwardRef('Argument')], collections.abc.Mapping[str, ForwardRef('Argument')], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType], ...], collections.abc.Sequence[typing.Union[tuple[ForwardRef('Argument'), ...], collections.abc.Sequence[ForwardRef('Argument')], collections.abc.Mapping[str, ForwardRef('Argument')], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType]], collections.abc.Mapping[str, typing.Union[tuple[ForwardRef('Argument'), ...], collections.abc.Sequence[ForwardRef('Argument')], collections.abc.Mapping[str, ForwardRef('Argument')], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType]]] torch.SymFloat, NoneType]], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType], ...] ¶
此函数的参数元组。参数的解释取决于节点的操作码。有关更多信息,请参阅
Node
的文档字符串。允许对此属性进行赋值。在赋值时,所有使用情况和用户的会计信息将自动更新。
- format_node(placeholder_names=None, maybe_return_typename=None)[source][source]¶
返回对
self
的描述性字符串表示。此方法可以无参数使用,作为调试工具。
此函数还用于
Graph
的__str__
方法内部。placeholder_names
和maybe_return_typename
中的字符串共同构成了此 Graph 周围 GraphModule 中自动生成的forward
函数的签名。placeholder_names
和maybe_return_typename
不应在其他情况下使用。- 参数:
placeholder_names(可选[字符串列表])- 一个列表,将存储表示生成的
forward
函数中占位符的格式化字符串。仅限内部使用。maybe_return_typename (Optional[list[str]]) – 存储生成
forward
函数输出的格式化字符串的单元素列表。仅内部使用。
- 返回:
- 如果 1) 我们使用
format_node
作为内部辅助函数
在Graph
的__str__
方法中,并且 2)self
是一个占位符节点,则返回None
。否则,返回当前节点的描述性字符串表示。
- 如果 1) 我们使用
- 返回类型:
注意
本 API 的向后兼容性得到保证。
- insert_arg(idx, arg)[source][source]¶
在指定索引处向参数列表插入一个位置参数。
- 参数:
idx (int) – 要插入到
self.args
之前的位置元素的索引。arg (Argument) – 要插入到
args
中的新参数值。
注意
此 API 保证向后兼容。
- is_impure()[source][source]¶
返回此操作是否为不纯,即其操作是否为占位符或输出,或者是否为不纯的 call_function 或 call_module。
- 返回:
判断操作是否为不纯。
- 返回类型:
警告
此 API 为实验性,且不向后兼容。
- 属性 kwargsdict[str, typing.Union[tuple[ForwardRef('Argument'), ...], collections.abc.Sequence[ForwardRef('Argument')], collections.abc.Mapping[str, ForwardRef('Argument')], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType], ...], collections.abc.Sequence[typing.Union[tuple[ForwardRef('Argument'), ...], collections.abc.Sequence[ForwardRef('Argument')], collections.abc.Mapping[str, ForwardRef('Argument')], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType]], collections.abc.Mapping[str, typing.Union[tuple[ForwardRef('Argument'), ...], collections.abc.Sequence[ForwardRef('Argument')], collections.abc.Mapping[str, ForwardRef('Argument')], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType]]] torch.SymFloat, NoneType]], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType]] ¶
该函数的关键字参数字典。参数的解释取决于节点的操作码。请参阅
Node
的文档字符串以获取更多信息。允许对此属性进行赋值。在赋值时,所有使用情况和用户的会计信息将自动更新。
- 属性 nextNode
返回节点链表中的下一个
Node
。- 返回:
节点链表中的下一个
Node
。
- normalized_arguments(root, arg_types=None, kwarg_types=None, normalize_to_only_use_kwargs=False)[source][source]¶
返回用于 Python 目标的标准化参数。这意味着 args/kwargs 将与模块/函数的签名匹配,如果 normalize_to_only_use_kwargs 为 true,则仅按位置顺序返回 kwargs。同时填充默认值。不支持位置只写参数或可变参数。
支持模块调用。
可能需要 arg_types 和 kwarg_types 来区分重载。
- 参数:
root (torch.nn.Module) – 要解析模块目标的模块。
arg_types (可选[Tuple[Any]]) – 参数类型的元组。
kwarg_types (Optional[Dict[str, Any]]) – 参数类型字典
normalize_to_only_use_kwargs (bool) – 是否规范化为仅使用 kwargs。
- 返回:
返回 NamedTuple ArgsKwargsPair,或未成功时返回 None。
- 返回类型:
Optional[ArgsKwargsPair]
警告
此 API 为实验性,且不向后兼容。
- prepend(x)[来源][来源] ¶
在图中的节点列表中在此节点之前插入 x。例如:
Before: p -> self bx -> x -> ax After: p -> x -> self bx -> ax
- 参数:
x(节点)- 要插入此节点之前的节点。必须是同一图中的成员。
注意
本 API 兼容旧版本有保证。
- 属性 prevNode ¶
返回节点链表中的前一个
Node
- 返回:
节点链表中的前一个
Node
- replace_all_uses_with(replace_with, delete_user_cb=<function Node.<lambda>>, *, propagate_meta=False)[source][source]¶
在 Graph 中将所有使用
self
的地方替换为 Nodereplace_with
。- 参数:
替换为(节点)- 要替换所有使用
self
的节点。delete_user_cb(可调用)- 当确定是否应该从 self 节点移除给定用户时被调用的回调。
propagate_meta(布尔值)- 是否将原始节点的.meta 字段上的所有属性复制到替换节点上。出于安全考虑,仅在替换节点尚未具有现有的.meta 字段时才有效。
- 返回:
已更改此更改的节点列表。
- 返回类型:
list[torch.fx.node.Node]
注意
此 API 向后兼容性得到保证。
- replace_input_with(old_input, new_input)[source][source]
遍历输入节点
self
,并将所有old_input
替换为new_input
。- 参数:
旧输入(节点)- 要替换的旧输入节点。
新输入(节点)- 替换
old_input
的新输入节点。
注意
本 API 向后兼容性得到保证。
- 属性 stack_traceOptional[str] ¶
返回在跟踪期间记录的 Python 调用栈,如果有的话。当使用 fx.Tracer 跟踪时,此属性通常由 Tracer.create_proxy 填充。为了在跟踪期间记录调用栈以进行调试,请在 Tracer 实例上设置 record_stack_traces = True。当使用 dynamo 跟踪时,此属性将默认由 OutputGraph.create_proxy 填充。
stack_trace 将在字符串末尾具有最内层的帧。
- class torch.fx.Tracer(autowrap_modules=(math,), autowrap_functions=())[source][source]¶
Tracer
是实现torch.fx.symbolic_trace
符号跟踪功能的类。对symbolic_trace(m)
的调用等同于Tracer().trace(m)
。Tracer 可以被继承以覆盖跟踪过程的多种行为。可以覆盖的不同行为在类上方法的文档字符串中描述。
注意
此 API 保证向后兼容。
- call_module(m, forward, args, kwargs)[source][source]¶
指定此
Tracer
在遇到对nn.Module
实例的调用时的行为的方法。默认情况下,行为是检查被调用的模块是否是叶子模块通过
is_leaf_module
。如果是,则在Graph
中引用m
发射一个call_module
节点。否则,正常调用Module
,通过其forward
函数中的操作进行跟踪。此方法可以被重写,例如创建嵌套的已跟踪 GraphModules,或者跨
Module
边界时你想要的任何其他行为。- 参数:
m(模块)- 正在发出调用的模块
forward(可调用)- 要调用的
Module
的 forward()方法args(元组)- 模块调用位置的参数
kwargs (字典) – 模块调用位置的 kwargs
- 返回:
模块的返回值。如果生成了
call_module
节点,则此为Proxy
值。否则,就是Module
调用返回的任何值。- 返回类型:
注意
本 API 向后兼容性得到保证。
- create_arg(a)[source][source]
一种指定在准备用作节点参数的值时的跟踪行为的方法。
默认情况下,行为包括:
遍历集合类型(例如元组、列表、字典)并对元素递归调用
create_args
。给定一个代理对象,返回底层 IR 的引用
Node
。对于非代理张量对象,针对各种情况发出 IR
对于参数,发出指向该参数的
get_attr
节点对于非参数张量,将该张量存储在特殊属性中,以引用该属性
此方法可以被重写以支持更多类型
- 参数:
一个(任何)- 要作为
Graph
发出的值Argument
。- 返回:
将
a
转换为适当的Argument
的值。- 返回类型:
参数
注意
本 API 向后兼容性得到保证。
- create_args_for_root(root_fn, is_module, concrete_args=None)[source][source]¶
创建对应于
root
模块签名的placeholder
节点。此方法会检查 root 的签名并相应地发出这些节点,同时支持*args
和**kwargs
。警告
此 API 为实验性,且不向后兼容。
- create_node(kind, target, args, kwargs, name=None, type_expr=None)[source]¶
插入一个图节点,给定目标、参数、关键字参数和名称。
此方法可以被重写以进行额外的检查、验证或修改用于节点创建的值。例如,可能希望禁止记录原地操作。
注意
此 API 向后兼容性得到保证。
- 返回类型:
- create_proxy(kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None)[source]¶
从给定的参数创建一个节点,然后返回一个包装在代理对象中的节点。
如果 kind = ‘placeholder’,则我们正在创建一个表示函数参数的节点。如果需要编码默认参数,我们使用
args
元组。对于placeholder
节点,args
为空。注意
本 API 向后兼容性得到保证。
- get_fresh_qualname(prefix)[source][source]¶
获取一个前缀的新名称并返回它。此函数确保它不会与图上的现有属性冲突。
注意
本 API 向后兼容性得到保证。
- 返回类型:
- getattr(attr, attr_val, parameter_proxy_cache)[source][source]¶
指定在调用
nn.Module
实例的 getattr 时此Tracer
的行为的方法。默认情况下,行为是返回属性的代理值。它还会将代理值存储在
parameter_proxy_cache
中,以便未来的调用将重用代理而不是创建一个新的。此方法可以被覆盖,例如,在查询参数时不返回代理。
- 参数:
查询的属性名称(字符串)
属性值(任意类型)
parameter_proxy_cache(字典[str, 任意类型])- 属性名称到代理的缓存
- 返回:
getattr 调用返回值
警告
此 API 为实验性,且不向后兼容。
- is_leaf_module(m, module_qualified_name)[source][source]¶
一个用于指定给定的
nn.Module
是否为“叶子”模块的方法。叶子模块是出现在 IR 中的原子单元,通过
call_module
调用进行引用。默认情况下,PyTorch 标准库命名空间(torch.nn)中的模块都是叶子模块。除非通过此参数指定,否则所有其他模块都会进行跟踪,并记录其构成的操作。- 参数:
m (模块) – 正在被查询的模块
模块限定名称(字符串)- 此模块根路径。例如,如果您有一个模块层次结构,其中子模块
foo
包含子模块bar
,该子模块又包含子模块baz
,则该模块将在此处以限定名称foo.bar.baz
出现。
- 返回类型:
注意
此 API 向后兼容性得到保证。
- iter(obj)[source]
- 当代理对象正在被迭代时调用,例如
当用于控制流时。通常我们不知道该做什么,因为我们不知道代理的值,但自定义跟踪器可以使用 create_node 将更多信息附加到图节点,并可以选择返回一个迭代器。
注意
此 API 的向后兼容性得到保证。
- 返回类型:
- keys(obj)[source]
- 当代理对象调用 keys()方法时触发。
当在代理上调用**时会发生这种情况。这应该返回一个迭代器 it**,在你的自定义跟踪器中应该正常工作。
注意
本 API 向后兼容性得到保证。
- 返回类型:
- path_of_module(mod)[source][source]¶
在模块层次结构中查找
mod
的限定名称的辅助方法。例如,如果root
有一个子模块名为foo
,而foo
又有子模块名为bar
,将bar
传递给此函数将返回字符串“foo.bar”。- 参数:
mod (str) – 获取限定名称的
Module
。- 返回类型:
注意
本 API 向后兼容性得到保证。
- proxy(node)[source]
注意
此 API 保证向后兼容。
- 返回类型:
- to_bool(obj)[源码] ¶
- 当代理对象被转换为布尔值时调用,例如
当用于控制流时。通常我们不知道该做什么,因为我们不知道代理的值,但自定义跟踪器可以使用 create_node 将更多信息附加到图节点,并可以选择返回一个值。
注意
本 API 向后兼容性得到保证。
- 返回类型:
- trace(root, concrete_args=None)[source][source]¶
跟踪
root
并返回相应的 FXGraph
表示。root
可以是一个nn.Module
实例或 Python 可调用对象。注意,在这次调用之后,
self.root
可能与传入此处的root
不同。例如,当将自由函数传递给trace()
时,我们将创建一个nn.Module
实例作为根实例,并添加嵌入的常量。- 参数:
root (Union[Module, Callable]) – 可以是
Module
或要追踪的函数。此参数向后兼容性得到保证。concrete_args (Optional[Dict[str, any]]) – 应视为非代理的具体参数。此参数为实验性,其向后兼容性不保证。
- 返回:
代表传入的
root
的语义的Graph
。- 返回类型:
注意
本 API 保证向后兼容。
- class torch.fx.Proxy(node, tracer=None)[source][source]¶
Proxy
对象是Node
包装器,在符号跟踪期间通过程序流动并记录它们接触到的所有操作(torch
函数调用、方法调用、运算符)到不断增长的 FX 图。如果您正在进行图转换,您可以将自己的
Proxy
方法包装在原始的Node
之上,这样您就可以使用重载的运算符向Graph
添加额外的内容。对象无法迭代。换句话说,如果在一个循环或作为
Proxy
/*args
/**kwargs
函数参数中使用Proxy
,符号追踪器将抛出错误。有两种主要的解决方案:1. 将不可追踪的逻辑提取到顶级函数中,并在其上使用
fx.wrap
。2. 如果控制流是静态的(即循环迭代次数基于某些超参数),则可以将代码保留在原始位置,并重构为类似以下内容:for i in range(self.some_hyperparameter): indexed_item = proxied_value[i]
想要更详细地了解代理内部机制,请查看 torch/fx/README.md 中的“Proxy”部分。
注意
本 API 向后兼容性得到保证。
- class torch.fx.Interpreter(module, garbage_collect_values=True, graph=None)[source][source]¶
解释器逐节点执行 FX 图。这种模式对于许多事情都很有用,包括编写代码转换和分析过程。
可以覆盖解释器类中的方法来自定义执行行为。可覆盖方法的调用层次映射:
run() +-- run_node +-- placeholder() +-- get_attr() +-- call_function() +-- call_method() +-- call_module() +-- output()
示例
假设我们想要交换所有
torch.neg
的实例与torch.sigmoid
,反之亦然(包括它们的Tensor
方法等效)。我们可以这样子类化解释器:class NegSigmSwapInterpreter(Interpreter): def call_function(self, target: Target, args: Tuple, kwargs: Dict) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(target, args, kwargs) def call_method(self, target: Target, args: Tuple, kwargs: Dict) -> Any: if target == "neg": call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(target, args, kwargs) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) input = torch.randn(3, 4) result = NegSigmSwapInterpreter(gm).run(input) torch.testing.assert_close(result, torch.neg(input).sigmoid())
- 参数:
要执行的模块(torch.nn.Module)
garbage_collect_values(布尔值)- 是否在模块执行过程中删除其最后使用后的值。这确保了执行过程中的最佳内存使用。可以通过查看
Interpreter.env
属性来禁用此功能,以检查执行过程中的所有中间值。graph(可选[Graph])- 如果传入,解释器将使用提供的模块参数执行此图,而不是 module.graph,以满足对状态的任何请求。
注意
本 API 向后兼容性得到保证。
- boxed_run(args_list)[source][source]¶
通过解释执行模块并返回结果。这使用“boxed”调用约定,其中传递一个参数列表,该列表将被解释器清除。这确保了输入张量能够及时释放。
注意
本 API 向后兼容性得到保证。
- call_function(target, args, kwargs)[source][source]¶
执行一个
call_function
节点并返回结果。- 参数:
目标(目标)- 此节点的调用目标。有关语义详情,请参阅节点。
args(元组)- 此调用的位置参数元组。
kwargs(字典)- 此调用的关键字参数字典。
- 返回类型:
- 返回
Any: 函数调用的返回值
注意
本 API 向后兼容性得到保证。
- call_method(target, args, kwargs)[source][source]¶
执行一个
call_method
节点并返回结果。- 参数:
目标(目标)- 此节点的调用目标。有关语义详情,请参阅节点。
args(元组)- 此调用的位置参数元组。
kwargs(字典)- 此调用的关键字参数字典。
- 返回类型:
- 返回
Any: 方法调用的返回值
注意
本 API 向后兼容性得到保证。
- call_module(target, args, kwargs)[source][source]¶
执行一个
call_module
节点并返回结果。- 参数:
目标(目标)- 此节点的调用目标。有关语义详情,请参阅节点。
args(元组)- 此调用的位置参数元组。
kwargs(字典)- 此调用的关键字参数字典。
- 返回类型:
- 返回
Any: 模块调用返回的值
注意
本 API 向后兼容性得到保证。
- fetch_args_kwargs_from_env(n)[source][source]¶
从当前执行环境中获取节点
n
的args
和kwargs
的具体值。- 参数:
n (节点) – 需要获取
args
和kwargs
的节点@2#。- 返回:
args
和kwargs
的具体值为n
。- 返回类型:
Tuple[Tuple, Dict]
注意
此 API 保证向后兼容。
- fetch_attr(target)[source][source]¶
从
self.module
的Module
层次结构中获取属性。- 参数:
target (str) – 要获取的属性的完全限定名称
- 返回:
属性的值。
- 返回类型:
任何
注意
本 API 向后兼容性得到保证。
- get_attr(target, args, kwargs)[source][source]¶
执行一个
get_attr
节点。将从Module
层级的self.module
中检索属性值。- 参数:
目标(目标)- 此节点的调用目标。有关语义详情,请参阅节点。
args(元组)- 此调用位置参数的元组。
kwargs(字典)- 此调用关键字参数的字典。
- 返回:
所检索到的属性值
- 返回类型:
任何
注意
此 API 保证向下兼容。
- 将节点映射到值(args, n)[source][source] ¶
递归遍历
args
并在当前执行环境中查找每个Node
的具体值。- 参数:
args(参数)- 用于查找具体值的内部数据结构。
n(节点)-
args
所属的节点。这仅用于错误报告。
- 返回类型:
Optional[Union[ tuple[Union[ForwardRef('Argument'), ...], collections.abc.Sequence[ForwardRef('Argument')], collections.abc.Mapping[ str, ForwardRef('Argument')], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType], …], Sequence[Optional[Union[ tuple[ForwardRef('Argument'), …], Sequence[Argument], Mapping[ str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]], Mapping[ str, Optional[Union[ tuple[ForwardRef('Argument'), …], Sequence[Argument], Mapping[ str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]
注意
此 API 向后兼容性得到保证。
- output(target, args, kwargs)[source][source]¶
执行一个
output
节点。这实际上只是检索由output
节点引用的值并返回它。- 参数:
目标(目标)- 此节点的调用目标。有关语义详情,请参阅节点。
args(元组)- 此调用位置参数的元组。
kwargs(字典)- 此调用关键字参数的字典。
- 返回:
输出节点引用的返回值
- 返回类型:
任何
注意
本 API 向后兼容性得到保证。
- placeholder(target, args, kwargs)[源代码][源代码] ¶
执行一个
placeholder
节点。请注意,这是有状态的:Interpreter
维护一个对传递给run
的参数的内部迭代器,此方法返回该迭代器的 next()。- 参数:
目标(Target)- 此节点的调用目标。有关语义的详细信息,请参阅 Node。
args(元组)- 此调用中位置参数的元组
kwargs(字典)- 此调用中关键字参数的字典
- 返回:
获取的参数值。
- 返回类型:
任何
注意
本 API 保证向后兼容。
- run(*args, initial_env=None, enable_io_processing=True)[source][source]¶
通过解释执行模块并返回结果。
- 参数:
*args – 要运行的模块的参数,按位置顺序排列。
initial_env(可选[Dict[Node, Any]])– 执行的可选起始环境。这是一个将 Node 映射到任何值的字典。这可以用来预先填充某些节点的结果,以便在解释器中进行部分评估。
enable_io_processing(bool)– 如果为 true,我们在使用之前首先使用图的 process_inputs 和 process_outputs 函数处理输入和输出。
- 返回:
执行模块返回的值
- 返回类型:
任何
注意
此 API 保证向下兼容。
- 运行节点(n)[source][source] ¶
在
n
运行特定节点并返回结果。根据node.op
调用占位符、get_attr、call_function、call_method、call_module 或 output- 参数:
n (节点) – 要执行的节点
- 返回:
执行
n
的结果- 返回类型:
任何
注意
本 API 保证向后兼容。
- class torch.fx.Transformer(module)[source][source]¶
Transformer
是一种特殊的解释器,它产生一个新的Module
。它公开了一个transform()
方法,该方法返回转换后的Module
。Transformer
不需要任何参数即可运行,而Interpreter
需要。Transformer
完全以符号方式工作。示例
假设我们想要交换所有
torch.neg
的实例与torch.sigmoid
以及它们的Tensor
方法等效物(包括它们的Tensor
方法等效物)。我们可以像这样子类化Transformer
:class NegSigmSwapXformer(Transformer): def call_function( self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] ) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(target, args, kwargs) def call_method( self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] ) -> Any: if target == "neg": call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(target, args, kwargs) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) transformed: torch.nn.Module = NegSigmSwapXformer(gm).transform() input = torch.randn(3, 4) torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
- 参数:
模块(GraphModule)- 要转换的
Module
注意
本 API 向后兼容性得到保证。
- get_attr(target, args, kwargs)[source][source]¶
执行一个
get_attr
节点。在Transformer
中,这被覆盖以在输出图中插入一个新的get_attr
节点。- 参数:
目标(目标)- 此节点的调用目标。有关语义详情,请参阅节点
args(元组)- 此调用位置参数的元组
kwargs(字典)- 此调用关键字参数的字典
- 返回类型:
注意
本 API 向后兼容性得到保证。
- placeholder(target, args, kwargs)[源码][源码] ¶
执行一个
placeholder
节点。在Transformer
中,这被覆盖以向输出图插入一个新的placeholder
。- 参数:
目标(Target)- 此节点的调用目标。有关语义的详细信息,请参阅节点。
args(元组)- 本调用中位置参数的元组
kwargs(字典)- 本调用中关键字参数的字典
- 返回类型:
注意
本 API 向后兼容性得到保证。
- torch.fx.replace_pattern(gm, pattern, replacement)[source][source]¶
匹配所有可能的非重叠操作符及其数据依赖集(
pattern
)在 GraphModule(gm
)的图中,然后将这些匹配的子图替换为另一个子图(replacement
)。- 参数:
gm (GraphModule) – 包装图的 GraphModule 以进行操作
pattern (Union[Callable, GraphModule]) – 要在
gm
中匹配以进行替换的子图替换(Union[Callable, GraphModule])- 要替换
pattern
的子图
- 返回:
代表原始图中与
pattern
匹配的地点的Match
对象列表。如果没有匹配项,则列表为空。Match
定义如下:class Match(NamedTuple): # Node from which the match was found anchor: Node # Maps nodes in the pattern subgraph to nodes in the larger graph nodes_map: Dict[Node, Node]
- 返回类型:
Match 列表
示例:
import torch from torch.fx import symbolic_trace, subgraph_rewriter class M(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x, w1, w2): m1 = torch.cat([w1, w2]).sum() m2 = torch.cat([w1, w2]).sum() return x + torch.max(m1) + torch.max(m2) def pattern(w1, w2): return torch.cat([w1, w2]) def replacement(w1, w2): return torch.stack([w1, w2]) traced_module = symbolic_trace(M()) subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
上述代码将首先在
traced_module
的forward
方法中匹配pattern
。模式匹配基于 use-def 关系,而不是节点名称。例如,如果您有pattern
在p = torch.cat([a, b])
中,则可以在原始forward
函数中匹配m = torch.cat([a, b])
,尽管变量名称不同(p
与m
)。return
语句在pattern
中仅根据其值进行匹配;它可能或可能不匹配到更大图中的return
语句。换句话说,模式不必扩展到更大图的末尾。当模式匹配时,它将从更大函数中移除,并由
replacement
替换。如果更大函数中有多个pattern
匹配,则每个非重叠匹配都将被替换。在匹配重叠的情况下,重叠匹配集中找到的第一个匹配将被替换。(这里的“第一个”是指节点使用-定义关系的拓扑排序中的第一个。在大多数情况下,第一个节点是直接出现在self
之后的参数,而最后一个节点是函数返回的内容。)有一个重要的事情需要注意,那就是
pattern
可调用函数的参数必须在可调用函数本身中使用,而replacement
可调用函数的参数必须匹配该模式。第一条规则是为什么在上面的代码块中,forward
函数有参数x, w1, w2
,但pattern
函数只有参数w1, w2
。pattern
没有使用x
,因此不应将其指定为参数。作为第二条规则的例子,考虑替换def pattern(x, y): return torch.neg(x) + torch.relu(y)
替换为
def replacement(x, y): return torch.relu(x)
在这种情况下,
replacement
需要和pattern
相同数量的参数(x
和y
都需要),即使参数y
在replacement
中没有被使用。调用
subgraph_rewriter.replace_pattern
之后,生成的 Python 代码看起来像这样:def forward(self, x, w1, w2): stack_1 = torch.stack([w1, w2]) sum_1 = stack_1.sum() stack_2 = torch.stack([w1, w2]) sum_2 = stack_2.sum() max_1 = torch.max(sum_1) add_1 = x + max_1 max_2 = torch.max(sum_2) add_2 = add_1 + max_2 return add_2
注意
该 API 的向后兼容性得到保证。