• 文档 >
  • torch.fx
快捷键

torch.fx

概述 ¶

FX 是开发者用来转换 nn.Module 实例的工具包。FX 包含三个主要组件:符号追踪器、中间表示和 Python 代码生成。以下是这些组件在行动中的演示:

import torch


# Simple module for demonstration
class MyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)


module = MyModule()

from torch.fx import symbolic_trace

# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)

# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
    %x : [num_users=1] = placeholder[target=x]
    %param : [num_users=1] = get_attr[target=param]
    %add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp
"""

# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp
"""

符号追踪器执行 Python 代码的“符号执行”。它通过代码传递称为代理的假值。记录这些代理的操作。有关符号追踪的更多信息,请参阅 symbolic_trace()Tracer 文档。

中间表示是记录在符号追踪期间的操作的容器。它由表示函数输入、调用点(到函数、方法或 torch.nn.Module 实例)和返回值的节点列表组成。有关 IR 的更多信息,请参阅 Graph 的文档。IR 是应用转换的格式。

Python 代码生成是 FX 成为 Python 到 Python(或模块到模块)转换工具包的原因。对于每个图 IR,我们可以创建与图语义匹配的有效 Python 代码。此功能封装在 GraphModule 中,它是一个 torch.nn.Module 实例,包含一个 Graph 以及从图生成的 forward 方法。

综合来看,这个组件管道(符号跟踪 -> 中间表示 -> 转换 -> Python 代码生成)构成了 FX 的 Python 到 Python 转换管道。此外,这些组件也可以单独使用。例如,符号跟踪可以单独使用来捕获代码的一种形式以供分析(而非转换)目的。代码生成可以用于程序化生成模型,例如从配置文件中生成。FX 有很多用途!

几个示例转换可以在示例仓库中找到。

编写转换

什么是 FX 转换?本质上,它是一个看起来像这样的函数。

import torch
import torch.fx

def transform(m: nn.Module,
              tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
    # Step 1: Acquire a Graph representing the code in `m`

    # NOTE: torch.fx.symbolic_trace is a wrapper around a call to
    # fx.Tracer.trace and constructing a GraphModule. We'll
    # split that out in our transform to allow the caller to
    # customize tracing behavior.
    graph : torch.fx.Graph = tracer_class().trace(m)

    # Step 2: Modify this Graph or create a new one
    graph = ...

    # Step 3: Construct a Module to return
    return torch.fx.GraphModule(m, graph)

您的转换将接受一个 torch.nn.Module ,从中获取一个 Graph ,进行一些修改,然后返回一个新的 torch.nn.Module 。您应该将您的 FX 转换返回的 torch.nn.Module 视为与常规的 torch.nn.Module 相同 – 您可以将其传递给另一个 FX 转换,可以传递给 TorchScript,或者可以运行它。确保您的 FX 转换的输入和输出是 torch.nn.Module 将允许进行组合。

注意

也可以修改现有的 GraphModule 而不是创建一个新的,如下所示:

import torch
import torch.fx

def transform(m : nn.Module) -> nn.Module:
    gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)

    # Modify gm.graph
    # <...>

    # Recompile the forward() method of `gm` from its Graph
    gm.recompile()

    return gm

注意,您必须调用 GraphModule.recompile() 来使生成的 forward() 方法与修改后的 GraphModule 同步。

既然您已经传递了一个已经被追踪到 Graphtorch.nn.Module ,现在有两种主要方法可以构建一个新的 Graph

图论快速入门指南

图的语义的全面介绍可以在 Graph 文档中找到,但在这里我们将介绍基础知识。 Graph 是一种数据结构,它表示在 GraphModule 上的方法。这需要的信息包括:

  • 该方法有哪些输入?

  • 方法内部运行的操作有哪些?

  • 该方法输出的(即返回的)值是什么?

这三个概念都用 Node 实例表示。让我们用一个简短的例子来看看我们是什么意思:

import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(
            self.linear(x + self.linear.weight).relu(), dim=-1), 3)

m = MyModule()
gm = torch.fx.symbolic_trace(m)

gm.graph.print_tabular()

在这里,我们定义一个用于演示的模块 MyModule ,实例化它,进行符号跟踪,然后调用 Graph.print_tabular() 方法打印出显示此 Graph 节点的表格:

指令码

名称

目标

参数

关键字参数

占位符

x

x

()

{}

获取属性

线性权重

线性.weight

()

{}

调用函数

加 1

<内置函数加>

(x, 线性权重)

{}

调用模块

线性_1

线性

(add_1,)

{}

调用方法

relu_1

relu

(线性_1,)

{}

调用函数

求和 1

<内置方法 sum …>

(relu_1,)

{‘维度’:-1}

调用函数

topk_1

<内置方法 topk …>

(和_1, 3)

{}

输出

输出

输出

(topk_1,)

{}

我们可以使用这些信息来回答我们上面提出的问题。

  • 该方法有哪些输入?在 FX 中,方法输入通过特殊的 placeholder 节点指定。在这种情况下,我们有一个单独的 placeholder 节点,其 targetx ,这意味着我们有一个名为 x 的单个(非自身)参数。

  • 方法中包含哪些操作? get_attrcall_functioncall_modulecall_method 节点代表方法中的操作。所有这些操作的语义的全面解释可以在 Node 文档中找到。

  • 该方法返回值是什么?在 Graph 中,返回值由一个特殊的 output 节点指定。

既然我们已经了解了 FX 中代码表示的基础知识,现在我们可以探索如何编辑 Graph

图形操作 ¶

直接图形操作 ¶

建立这种新 Graph 的一种方法就是直接操作你的旧版本。为了帮助实现这一点,我们可以简单地从符号跟踪中获取的 Graph 进行修改。例如,假设我们希望用 torch.mul() 调用替换 torch.add() 调用。

import torch
import torch.fx

# Sample module
class M(torch.nn.Module):
    def forward(self, x, y):
        return torch.add(x, y)

def transform(m: torch.nn.Module,
              tracer_class : type = fx.Tracer) -> torch.nn.Module:
    graph : fx.Graph = tracer_class().trace(m)
    # FX represents its Graph as an ordered list of
    # nodes, so we can iterate through them.
    for node in graph.nodes:
        # Checks if we're calling a function (i.e:
        # torch.add)
        if node.op == 'call_function':
            # The target attribute is the function
            # that call_function calls.
            if node.target == torch.add:
                node.target = torch.mul

    graph.lint() # Does some checks to make sure the
                 # Graph is well-formed.

    return fx.GraphModule(m, graph)

我们还可以进行更复杂的 Graph 重写,例如删除或添加节点。为了帮助这些转换,FX 提供了用于转换图的实用函数,这些函数可以在 Graph 文档中找到。以下是一个使用这些 API 添加 torch.relu() 调用的示例。

# Specifies the insertion point. Any nodes added to the
# Graph within this scope will be inserted after `node`
with traced.graph.inserting_after(node):
    # Insert a new `call_function` node calling `torch.relu`
    new_node = traced.graph.call_function(
        torch.relu, args=(node,))

    # We want all places that used the value of `node` to
    # now use that value after the `relu` call we've added.
    # We use the `replace_all_uses_with` API to do this.
    node.replace_all_uses_with(new_node)

对于只包含替换的简单转换,你还可以使用子图重写器。

使用 replace_pattern() 进行子图重写

FX 还在直接图操作之上提供另一层自动化。 replace_pattern() API 实质上是一个用于编辑 Graph 的“查找/替换”工具。它允许您指定一个 patternreplacement 函数,然后它会追踪这些函数,在 pattern 图中找到操作组的实例,并将这些实例替换为 replacement 图的副本。这可以帮助极大地自动化繁琐的图操作代码,随着变换变得更加复杂,这些代码可能会变得难以控制。

图操作示例 ¶

代理/重绘

另一种操作 Graph 的方法是重用符号跟踪中使用的 Proxy 机制。例如,让我们想象我们想要编写一个将 PyTorch 函数分解成更小操作的转换。它将把每个 F.relu(x) 调用转换成 (x > 0) * x 。一种可能性是在 F.relu 之后插入比较和乘法所需的图重写,然后清理原始的 F.relu 。然而,我们可以通过使用 Proxy 对象来自动记录操作到 Graph 来自动化此过程。

使用此方法,我们将要插入的操作以常规 PyTorch 代码的形式编写,并使用 Proxy 对象作为参数调用该代码。这些 Proxy 对象将捕获对它们的操作并将它们追加到 Graph

# Note that this decomposition rule can be read as regular Python
def relu_decomposition(x):
    return (x > 0) * x

decomposition_rules = {}
decomposition_rules[F.relu] = relu_decomposition

def decompose(model: torch.nn.Module,
              tracer_class : type = fx.Tracer) -> torch.nn.Module:
    """
    Decompose `model` into smaller constituent operations.
    Currently,this only supports decomposing ReLU into its
    mathematical definition: (x > 0) * x
    """
    graph : fx.Graph = tracer_class().trace(model)
    new_graph = fx.Graph()
    env = {}
    tracer = torch.fx.proxy.GraphAppendingTracer(new_graph)
    for node in graph.nodes:
        if node.op == 'call_function' and node.target in decomposition_rules:
            # By wrapping the arguments with proxies,
            # we can dispatch to the appropriate
            # decomposition rule and implicitly add it
            # to the Graph by symbolically tracing it.
            proxy_args = [
                fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args]
            output_proxy = decomposition_rules[node.target](*proxy_args)

            # Operations on `Proxy` always yield new `Proxy`s, and the
            # return value of our decomposition rule is no exception.
            # We need to extract the underlying `Node` from the `Proxy`
            # to use it in subsequent iterations of this transform.
            new_node = output_proxy.node
            env[node.name] = new_node
        else:
            # Default case: we don't have a decomposition rule for this
            # node, so just copy the node over into the new graph.
            new_node = new_graph.node_copy(node, lambda x: env[x.name])
            env[node.name] = new_node
    return fx.GraphModule(model, new_graph)

除了避免显式图操作外,使用 Proxy 还可以让您将重写规则指定为原生 Python 代码。对于需要大量重写规则(如 vmap 或 grad)的转换,这通常可以提高规则的可读性和可维护性。请注意,在调用 Proxy 时,我们还传递了一个指向底层变量图的跟踪器。这样做是为了如果图中的操作是 n 元(例如,add 是一个二元运算符)时,调用 Proxy 不会创建多个图跟踪器实例,这可能导致意外的运行时错误。我们建议在底层运算符不能安全假设为一元时,特别使用此方法使用 Proxy

使用 Proxy 进行 Graph 操作的示例可以在此处找到。

解释器模式 ¶

在 FX 中,一种有用的代码组织模式是遍历一个模块中的所有 Node ,并执行它们。这可以用于多种用途,包括对通过图流动的值的运行时分析或通过 Proxy 进行回溯来转换代码。例如,假设我们想要运行一个 GraphModule 并记录节点在运行时看到的 torch.Tensor 形状和 dtype 属性。这可能看起来像:

import torch
import torch.fx
from torch.fx.node import Node

from typing import Dict

class ShapeProp:
    """
    Shape propagation. This class takes a `GraphModule`.
    Then, its `propagate` method executes the `GraphModule`
    node-by-node with the given arguments. As each operation
    executes, the ShapeProp class stores away the shape and
    element type for the output values of each operation on
    the `shape` and `dtype` attributes of the operation's
    `Node`.
    """
    def __init__(self, mod):
        self.mod = mod
        self.graph = mod.graph
        self.modules = dict(self.mod.named_modules())

    def propagate(self, *args):
        args_iter = iter(args)
        env : Dict[str, Node] = {}

        def load_arg(a):
            return torch.fx.graph.map_arg(a, lambda n: env[n.name])

        def fetch_attr(target : str):
            target_atoms = target.split('.')
            attr_itr = self.mod
            for i, atom in enumerate(target_atoms):
                if not hasattr(attr_itr, atom):
                    raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
                attr_itr = getattr(attr_itr, atom)
            return attr_itr

        for node in self.graph.nodes:
            if node.op == 'placeholder':
                result = next(args_iter)
            elif node.op == 'get_attr':
                result = fetch_attr(node.target)
            elif node.op == 'call_function':
                result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
            elif node.op == 'call_method':
                self_obj, *args = load_arg(node.args)
                kwargs = load_arg(node.kwargs)
                result = getattr(self_obj, node.target)(*args, **kwargs)
            elif node.op == 'call_module':
                result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))

            # This is the only code specific to shape propagation.
            # you can delete this `if` branch and this becomes
            # a generic GraphModule interpreter.
            if isinstance(result, torch.Tensor):
                node.shape = result.shape
                node.dtype = result.dtype

            env[node.name] = result

        return load_arg(self.graph.result)

如您所见,FX 的完整解释器并不复杂,但它非常有用。为了简化使用这种模式,我们提供了一个 Interpreter 类,它以这种方式封装了上述逻辑,使得解释器执行的一些方面可以通过方法重写来覆盖。

除了执行操作外,我们还可以通过将 Proxy 值通过解释器传递来生成一个新的图。同样,我们提供了一个 Transformer 类来封装这种模式。 Transformer 的行为与 Interpreter 类似,但您不是调用 run 方法从 Module 获取具体的输出值,而是调用 Transformer.transform() 方法返回一个新的 GraphModule ,该新图已应用您安装的任何转换规则作为重写方法。

解释器模式的示例

调试

简介

在编写变换的过程中,我们的代码可能并不完全正确。在这种情况下,我们可能需要进行一些调试。关键是要逆向工作:首先,检查调用生成的模块的结果以证明或反驳正确性。然后,检查和调试生成的代码。然后,调试导致生成代码的变换过程。

如果您不熟悉调试器,请参阅辅助部分“可用的调试器”。

常见在转换创作中的陷阱 ¶

  • 非确定性的 set 迭代顺序。在 Python 中, set 数据类型是无序的。使用 set 来包含对象集合,例如,可以导致意外的非确定性。一个例子是遍历一组 Node s 以将其插入到 Node 。因为 Graph 数据类型是无序的,输出程序中操作的顺序将是非确定性的,并且可以在程序调用之间发生变化。建议的替代方案是使用 set 数据类型,该数据类型自 Python 3.7(以及 cPython 3.6)起为插入顺序。可以使用 dict 相当于一个集合,通过将需要去重的值存储在 dict 的键中。

检查模块的正确性 ¶

由于大多数深度学习模块的输出由浮点 torch.Tensor 实例组成,检查两个 torch.nn.Module 的结果之间的等价性并不像进行简单的相等性检查那样简单。为了说明这一点,让我们用一个例子来说明:

import torch
import torch.fx
import torchvision.models as models

def transform(m : torch.nn.Module) -> torch.nn.Module:
    gm = torch.fx.symbolic_trace(m)

    # Imagine we're doing some transforms here
    # <...>

    gm.recompile()

    return gm

resnet18 = models.resnet18()
transformed_resnet18 = transform(resnet18)

input_image = torch.randn(5, 3, 224, 224)

assert resnet18(input_image) == transformed_resnet18(input_image)
"""
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
"""

在这里,我们尝试使用 == 等价运算符检查两个深度学习模型的值是否相等。然而,这并不明确,一方面是因为该运算符返回一个张量而不是布尔值,另一方面是因为浮点数的比较应该使用误差范围(或 epsilon)来考虑浮点运算的非交换性(更多详情请见此处)。我们可以使用 torch.allclose() ,它将给出一个近似比较,考虑到相对和绝对容差阈值:

assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))

这是我们的工具箱中第一个检查转换后的模块是否按预期与参考实现相比表现良好的工具。

生成代码的调试

因为 FX 在 GraphModule 上生成 forward() 函数,所以使用传统的调试技术,如 print 语句或 pdb ,并不那么直接。幸运的是,我们有几种可以用来调试生成代码的技术。

使用 pdb

使用 pdb 进入正在运行的程序。虽然代表 Graph 的代码不在任何源文件中,但在调用前向传递时,我们仍然可以使用 pdb 手动进入。

import torch
import torch.fx
import torchvision.models as models

def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
    graph = tracer_class().trace(inp)
    # Transformation logic here
    # <...>

    # Return new Module
    return fx.GraphModule(inp, graph)

my_module = models.resnet18()
my_module_transformed = my_pass(my_module)

input_value = torch.randn(5, 3, 224, 224)

# When this line is executed at runtime, we will be dropped into an
# interactive `pdb` prompt. We can use the `step` or `s` command to
# step into the execution of the next line
import pdb; pdb.set_trace()

my_module_transformed(input_value)

使用 to_folder 函数从 GraphModule

GraphModule.to_folder()GraphModule 中的一个方法,允许您将生成的 FX 代码导出到文件夹。尽管将前向传递复制到代码中通常就足够了,如在“打印生成的代码”中,但使用 to_folder 可能更容易检查模块和参数。

m = symbolic_trace(M())
m.to_folder("foo", "Bar")
from foo import Bar
y = Bar()

在运行上述示例之后,我们就可以查看 foo/module.py 中的代码并根据需要对其进行修改(例如添加 print 语句或使用 pdb )以调试生成的代码。

调试转换 ¶

现在我们已经确定一个转换正在生成错误的代码,现在是时候调试这个转换本身了。首先,我们将检查文档中的“符号跟踪的限制”部分。一旦我们验证跟踪是否按预期工作,目标就变成了找出我们的 GraphModule 转换中出了什么问题。在《编写转换》中可能有一个快速的答案,但如果没有,我们有几种方法可以检查我们的跟踪模块:

# Sample Module
class M(torch.nn.Module):
    def forward(self, x, y):
        return x + y

# Create an instance of `M`
m = M()

# Symbolically trace an instance of `M` (returns a GraphModule). In
# this example, we'll only be discussing how to inspect a
# GraphModule, so we aren't showing any sample transforms for the
# sake of brevity.
traced = symbolic_trace(m)

# Print the code produced by tracing the module.
print(traced)
# The generated `forward` function is:
"""
def forward(self, x, y):
    add = x + y;  x = y = None
    return add
"""

# Print the internal Graph.
print(traced.graph)
# This print-out returns:
"""
graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})
    return add
"""

# Print a tabular representation of the internal Graph.
traced.graph.print_tabular()
# This gives us:
"""
opcode         name    target                   args    kwargs
-------------  ------  -----------------------  ------  --------
placeholder    x       x                        ()      {}
placeholder    y       y                        ()      {}
call_function  add     <built-in function add>  (x, y)  {}
output         output  output                   (add,)  {}
"""

使用上面的实用函数,我们可以比较我们在应用转换前后跟踪的模块。有时,简单的视觉比较就足以追踪到错误。如果仍然不清楚出了什么问题,使用 pdb 这样的调试器可以是一个好的下一步。

根据上面的例子,考虑以下代码:

# Sample user-defined function
def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
    # Get the Graph from our traced Module
    g = tracer_class().trace(module)

    """
    Transformations on `g` go here
    """

    return fx.GraphModule(module, g)

# Transform the Graph
transformed = transform_graph(traced)

# Print the new code after our transforms. Check to see if it was
# what we expected
print(transformed)

使用上面的例子,假设 print(traced) 的调用显示我们的转换中存在错误。我们想通过调试器找出错误所在。我们启动一个 pdb 会话。我们可以通过在 transform_graph(traced) 处中断,然后按 s “进入” transform_graph(traced) 的调用来查看转换过程中的情况。

我们也可以通过编辑 print_tabular 方法来打印图中节点不同的属性。(例如,我们可能想看到节点的 input_nodesusers 。)

可用调试器 §

最常用的 Python 调试器是 pdb。您可以通过在命令行中输入 python -m pdb FILENAME.py 以“调试模式”启动您的程序,其中 FILENAME 是要调试的文件名。之后,您可以使用 pdb 调试器命令逐步执行您的程序。通常,在开始 pdb 时设置一个断点( b LINE-NUMBER ),然后调用 c 运行程序直到该点。这可以防止您必须逐行执行(使用 sn )才能到达想要检查的代码部分。或者,您可以在想要中断的行之前写入 import pdb; pdb.set_trace() 。如果添加 pdb.set_trace() ,则您的程序将在运行时自动进入调试模式。(换句话说,您可以直接在命令行中输入 python FILENAME.py 而不是 python -m pdb FILENAME.py 。)一旦您以调试模式运行文件,您就可以使用某些命令逐步执行代码并检查程序的内部状态。网上有许多关于 pdb 的优秀教程,包括 RealPython 的“使用 pdb 进行 Python 调试”。

PyCharm 或 VSCode 等集成开发环境通常内置了调试器。在您的 IDE 中,您可以选择以下方式之一:a) 通过在 IDE 中打开终端窗口(例如 VSCode 中的“视图→终端”)来使用 pdb ;b) 使用内置的调试器(通常是一个围绕 pdb 的图形包装器)。

符号跟踪的限制

FX 使用符号跟踪(也称为符号执行)系统来以可转换/可分析的形式捕获程序的语义。该系统是跟踪的,因为它执行程序(实际上是 torch.nn.Module 或函数)以记录操作。它是符号的,因为在执行过程中通过程序的数据不是真实数据,而是符号(在 FX 术语中为 Proxy )。

虽然符号跟踪对于大多数神经网络代码都有效,但它有一些局限性。

动态控制流

符号跟踪的主要局限性是它目前不支持动态控制流。也就是说,循环或 if 语句的条件可能依赖于程序输入的值。

例如,让我们分析以下程序:

def func_to_trace(x):
    if x.sum() > 0:
        return torch.relu(x)
    else:
        return torch.neg(x)

traced = torch.fx.symbolic_trace(func_to_trace)
"""
  <...>
  File "dyn.py", line 6, in func_to_trace
    if x.sum() > 0:
  File "pytorch/torch/fx/proxy.py", line 155, in __bool__
    return self.tracer.to_bool(self)
  File "pytorch/torch/fx/proxy.py", line 85, in to_bool
    raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""

if 语句的条件依赖于 x.sum() 的值,而 x.sum() 的值又依赖于 x 的值,这是一个函数输入。由于 x 可以改变(即如果你向跟踪函数传递新的输入张量),这就是动态控制流。回溯会沿着你的代码向上走,以显示这种情况发生的位置。

静态控制流 ¶

另一方面,所谓的静态控制流得到了支持。静态控制流是指值在调用之间不能改变的循环或 if 语句。通常,在 PyTorch 程序中,这种控制流出现在根据超参数做出决策的代码中。作为一个具体的例子:

import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self, do_activation : bool = False):
        super().__init__()
        self.do_activation = do_activation
        self.linear = torch.nn.Linear(512, 512)

    def forward(self, x):
        x = self.linear(x)
        # This if-statement is so-called static control flow.
        # Its condition does not depend on any input values
        if self.do_activation:
            x = torch.relu(x)
        return x

without_activation = MyModule(do_activation=False)
with_activation = MyModule(do_activation=True)

traced_without_activation = torch.fx.symbolic_trace(without_activation)
print(traced_without_activation.code)
"""
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    return linear_1
"""

traced_with_activation = torch.fx.symbolic_trace(with_activation)
print(traced_with_activation.code)
"""
import torch
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    relu_1 = torch.relu(linear_1);  linear_1 = None
    return relu_1
"""

if-语句 if self.do_activation 不依赖于任何函数输入,因此它是静态的。 do_activation 可以被视为一个超参数,而 MyModule 的不同实例的痕迹,该参数有不同的值,有不同的代码。这是一个有效的模式,也是符号跟踪所支持的。

许多动态控制流的实例在语义上是静态控制流。通过移除对输入值的依赖,例如通过将值移动到 Module 属性或绑定具体值到符号跟踪期间的参数,可以将这些实例转换为支持符号跟踪:

def f(x, flag):
    if flag: return x
    else: return x*2

fx.symbolic_trace(f) # Fails!

fx.symbolic_trace(f, concrete_args={'flag': True})

对于真正的动态控制流,包含此代码的程序部分可以被视为对方法(请参阅使用 Tracer 类自定义跟踪)或函数(请参阅 wrap() )的调用,而不是通过它们进行跟踪。

非函数 torch

FX 使用 __torch_function__ 作为拦截调用的机制(有关此内容的更多信息,请参阅技术概述)。一些函数,例如内置的 Python 函数或 math 模块中的函数,不受 __torch_function__ 的保护,但我们仍然希望将它们捕获在符号跟踪中。例如:

import torch
import torch.fx
from math import sqrt

def normalize(x):
    """
    Normalize `x` by the size of the batch dimension
    """
    return x / sqrt(len(x))

# It's valid Python code
normalize(torch.rand(3, 4))

traced = torch.fx.symbolic_trace(normalize)
"""
  <...>
  File "sqrt.py", line 9, in normalize
    return x / sqrt(len(x))
  File "pytorch/torch/fx/proxy.py", line 161, in __len__
    raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""

错误提示我们内置函数 len 不受支持。我们可以使此类函数在跟踪中以直接调用方式记录,使用 wrap() API:

torch.fx.wrap('len')
torch.fx.wrap('sqrt')

traced = torch.fx.symbolic_trace(normalize)

print(traced.code)
"""
import math
def forward(self, x):
    len_1 = len(x)
    sqrt_1 = math.sqrt(len_1);  len_1 = None
    truediv = x / sqrt_1;  x = sqrt_1 = None
    return truediv
"""

使用 Tracer 类定制跟踪 ¶

Tracer 类是实现 symbolic_trace 的底层类。可以通过子类化 Tracer 来自定义跟踪的行为,如下所示:

class MyCustomTracer(torch.fx.Tracer):
    # Inside here you can override various methods
    # to customize tracing. See the `Tracer` API
    # reference
    pass


# Let's use this custom tracer to trace through this module
class MyModule(torch.nn.Module):
    def forward(self, x):
        return torch.relu(x) + torch.ones(3, 4)

mod = MyModule()

traced_graph = MyCustomTracer().trace(mod)
# trace() returns a Graph. Let's wrap it up in a
# GraphModule to make it runnable
traced = torch.fx.GraphModule(mod, traced_graph)

叶模块

叶模块是出现在符号跟踪中的调用模块,而不是通过跟踪实现的模块。默认的叶模块集是标准 torch.nn 模块实例的集合。例如:

class MySpecialSubmodule(torch.nn.Module):
    def forward(self, x):
        return torch.neg(x)

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 4)
        self.submod = MySpecialSubmodule()

    def forward(self, x):
        return self.submod(self.linear(x))

traced = torch.fx.symbolic_trace(MyModule())
print(traced.code)
# `linear` is preserved as a call, yet `submod` is traced though.
# This is because the default set of "Leaf Modules" includes all
# standard `torch.nn` modules.
"""
import torch
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    neg_1 = torch.neg(linear_1);  linear_1 = None
    return neg_1
"""

可以通过覆盖 Tracer.is_leaf_module() 来自定义叶模块集。

杂项 ¶

  • 索引构造函数(例如 torch.zerostorch.onestorch.randtorch.randntorch.sparse_coo_tensor )目前无法追踪。

    • 确定性构造函数( zerosones )可以使用,它们产生的值将被嵌入到追踪中作为常量。这只有在这些构造函数的参数引用动态输入大小时才会成为问题。在这种情况下, ones_likezeros_like 可能是可行的替代品。

    • 非确定性构造函数( randrandn )将在追踪中嵌入单个随机值。这很可能不是预期的行为。一种解决方案是将 torch.randn 包裹在 torch.fx.wrap 函数中,并调用该函数。

    @torch.fx.wrap
    def torch_randn(x, shape):
        return torch.randn(shape)
    
    def f(x):
        return x + torch_randn(x, 5)
    fx.symbolic_trace(f)
    
    • 该行为可能在未来的版本中得到修复。

  • 类型注解

    • 支持 Python 3 风格的类型注解(例如 func(x : torch.Tensor, y : int) -> torch.Tensor ),并且符号跟踪将保留这些注解。

    • 目前不支持 Python 2 风格的注释类型注解 # type: (torch.Tensor, int) -> torch.Tensor

    • 函数内部对局部名称的注释目前不支持。

  • 关于 training 标志和子模块的注意事项。

    • 当使用 torch.nn.functional.dropout 等函数时,训练参数通常会被传递为 self.training 。在 FX 追踪期间,这很可能会被固化为一个常量值。

    import torch
    import torch.fx
    
    class DropoutRepro(torch.nn.Module):
      def forward(self, x):
        return torch.nn.functional.dropout(x, training=self.training)
    
    
    traced = torch.fx.symbolic_trace(DropoutRepro())
    print(traced.code)
    """
    def forward(self, x):
      dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False);  x = None
      return dropout
    """
    
    traced.eval()
    
    x = torch.randn(5, 3)
    torch.testing.assert_close(traced(x), x)
    """
    AssertionError: Tensor-likes are not close!
    
    Mismatched elements: 15 / 15 (100.0%)
    Greatest absolute difference: 1.6207983493804932 at index (0, 2) (up to 1e-05 allowed)
    Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed)
    """
    
    • 然而,当使用标准 nn.Dropout() 子模块时,训练标志被封装,并且由于保留了 nn.Module 对象模型,可以更改。

    class DropoutRepro2(torch.nn.Module):
      def __init__(self):
        super().__init__()
        self.drop = torch.nn.Dropout()
    
      def forward(self, x):
        return self.drop(x)
    
    traced = torch.fx.symbolic_trace(DropoutRepro2())
    print(traced.code)
    """
    def forward(self, x):
      drop = self.drop(x);  x = None
      return drop
    """
    
    traced.eval()
    
    x = torch.randn(5, 3)
    torch.testing.assert_close(traced(x), x)
    
  • 由于这种差异,请考虑将动态交互 training 标志的模块标记为叶模块。

API 参考指南

torch.fx.symbolic_trace(root, concrete_args=None)[source][source] 参考指南

符号追踪 API

给定一个 nn.Module 或函数实例 root ,此函数将返回一个通过记录在 root 中看到的操作所构建的 GraphModule

concrete_args 允许您部分专门化您的函数,无论是要移除控制流还是数据结构。

例如:

def f(a, b):
    if b == True:
        return a
    else:
        return a * 2

由于存在控制流,FX 通常无法追踪此操作。但是,我们可以使用 concrete_args 来专门化 b 的值以追踪此操作:

f = fx.symbolic_trace(f, concrete_args={"b": False})
assert f(3, False) == 6

注意,尽管您仍然可以传入不同的 b 值,但它们将被忽略。

我们还可以使用 concrete_args 来消除函数中的数据结构处理。这将使用 pytrees 来简化您的输入。为了避免过度专业化,对于不应专业化的值,请传入 fx.PH。例如:

def f(x):
    out = 0
    for v in x.values():
        out += v
    return out


f = fx.symbolic_trace(f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}})
assert f({"a": 1, "b": 2, "c": 4}) == 7
参数:
  • root (Union[torch.nn.Module, Callable]) – 要追踪和转换为图表示的模块或函数。

  • concrete_args (Optional[Dict[str, any]]) – 部分专业化的输入

返回:

root 记录的操作创建的模块

返回类型:

图模块

注意

本 API 向后兼容性得到保证。

torch.fx.wrap(fn_or_name)[source][source]

此函数可以在模块级作用域中调用,将 fn_or_name 注册为“叶子函数”。一个“叶子函数”将被保留为 FX 跟踪中的 CallFunction 节点,而不是被跟踪:

# foo/bar/baz.py
def my_custom_function(x, y):
    return x * x + y * y


torch.fx.wrap("my_custom_function")


def fn_to_be_traced(x, y):
    # When symbolic tracing, the below call to my_custom_function will be inserted into
    # the graph rather than tracing it.
    return my_custom_function(x, y)

此函数也可以等效地用作装饰器:

# foo/bar/baz.py
@torch.fx.wrap
def my_custom_function(x, y):
    return x * x + y * y

被包装的函数可以被视为“叶子函数”,类似于“叶子模块”的概念,即它们是保留在 FX 跟踪中的调用,而不是被跟踪的函数。

参数:

fn_or_name (Union[str, Callable]) – 当调用时,将函数或全局函数的名称插入到图中的函数

注意

本 API 保证向后兼容性。

class torch.fx.GraphModule(*args, **kwargs)[source][source]

GraphModule 是由 fx.Graph 生成的 nn.Module。Graphmodule 具有从 graph 属性生成的 code 以及 forwardgraph 属性。

警告

graph 被重新赋值时, codeforward 将会自动重新生成。然而,如果您编辑了 graph 的内容而没有重新赋值 graph 属性本身,您必须调用 recompile() 来更新生成的代码。

注意

此 API 保证向后兼容。

__init__(root, graph, class_name='GraphModule')[source][source]

构建一个 GraphModule。

参数:
  • root (Union[torch.nn.Module, Dict[str, Any]) – root 可以是 nn.Module 实例或映射字符串到任何属性类型的 Dict。如果 root 是 Module,则 Graph 的 Nodes 的 target 字段中通过限定名称引用的基于 Module 的对象将被从 root 的 Module 层级复制到 GraphModule 的模块层级。如果 root 是 dict,则 Node 的 target 中找到的限定名称将直接在 dict 的键中查找。映射到 Dict 的对象将被复制到 GraphModule 的模块层级中相应的位置。

  • 图(图)- graph 包含此 GraphModule 应使用的节点以进行代码生成

  • class_name(字符串)- name 表示此 GraphModule 的调试名称。如果未设置,所有错误消息都将报告为来自 GraphModule 。将其设置为 root 的原始名称或适合您转换上下文的名称可能很有帮助。

注意

此 API 向后兼容性得到保证。

add_submodule(target, m)[源代码][源代码] ¶

将给定的子模块添加到 self

如果不存在,则安装空的模块,这些模块是 target 的子路径。

参数:
  • target (str) – 新子模块的完全限定字符串名称(参见 nn.Module.get_submodule 中的示例,了解如何指定完全限定字符串。)

  • m (Module) – 子模块本身;我们想要安装到当前模块的实际对象

返回:

子模块是否可以插入。

此方法返回 True,链中的每个对象(由 target 表示)必须满足以下条件之一:a)尚不存在,或 b)引用 nn.Module (不是参数或其他属性)

返回类型:

布尔型

注意

此 API 的后向兼容性得到保证。

属性 codestr ¶

返回由 Graph 生成的 Python 代码。

delete_all_unused_submodules()[source][source]

self 中删除所有未使用的子模块。

模块被认为“被使用”,如果以下任何一个条件成立:1. 它有被使用的子模块 2. 它的前向操作通过一个 call_module 节点直接调用 3. 它有一个非模块属性,该属性从一个 get_attr 节点被使用

可以调用此方法来清理 nn.Module ,而无需手动对每个未使用的子模块调用 delete_submodule

注意

此 API 向后兼容性得到保证。

delete_submodule(target)[source][source]

self 中删除指定的子模块。

模块在 target 不是一个有效目标时不会被删除。

参数:

目标(字符串)- 新子模块的完全限定字符串名称(参见 nn.Module.get_submodule 中的示例,了解如何指定完全限定字符串。)

返回:

无论目标字符串是否引用了

子模块我们要删除。返回值 False 表示 target 不是一个有效的子模块引用。

返回类型:

布尔型

注意

此 API 保证向后兼容。

属性图 Graph ¶

返回此 Graph 的底层 GraphModule

打印可读输出(print_output=True,include_stride=False,include_device=False,colored=False)[source][source] ¶

返回当前 GraphModule 及其子 GraphModule 生成的 Python 代码

警告

此 API 为实验性,且不向后兼容。

recompile()[source][source]

从其 graph 属性重新编译此 GraphModule。编辑包含的 graph 之后,应该调用此操作,否则此 GraphModule 生成的代码将过时。

注意

本 API 向后兼容性有保证。

返回类型:

Python 代码

to_folder(folder, module_name='FxModule')[source][source]
输出模块到 folder ,使用 module_name 以便可以

使用 from <folder> import <module_name> 导入

Args:

文件夹(Union[str, os.PathLike]):输出代码的文件夹

module_name (str):用于 Module while 的顶级名称

输出代码

警告

此 API 为实验性,且不向后兼容。

class torch.fx.Graph(拥有模块=None, 跟踪器类=None, 跟踪器额外参数=None)[source][source] ¶

Graph 是 FX 中间表示法中使用的最主要的数据结构。它由一系列 Node 组成,每个 Node 代表一个调用点(或其他语法结构)。 Node 的列表,合在一起,构成一个有效的 Python 函数。

例如,以下代码

import torch
import torch.fx


class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(
            torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3
        )


m = MyModule()
gm = torch.fx.symbolic_trace(m)

将产生以下 Graph:

print(gm.graph)
graph(x):
    %linear_weight : [num_users=1] = self.linear.weight
    %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})
    %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
    %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})
    %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})
    return topk_1

关于在 Graph 中表示的操作的语义,请参阅 Node

注意

本 API 保证向下兼容。

__init__(owning_module=None, tracer_cls=None, tracer_extras=None)[source][source]

构建一个空图。

注意

此 API 保证向后兼容。

call_function(the_function, args=None, kwargs=None, type_expr=None)[source][source]

Graph 中插入 call_function Node 。一个 call_function 节点表示对 Python 可调用对象的调用,该对象由 the_function 指定。

参数:
  • the_function (Callable[..., Any]) – 要调用的函数。可以是任何 PyTorch 操作符、Python 函数或 builtinsoperator 命名空间中的成员。

  • args (可选[Tuple[Argument, ...]]) – 要传递给被调用函数的位置参数。

  • kwargs (可选[Dict[str, Argument]]) – 要传递给被调用函数的关键字参数

  • type_expr (可选[Any]) – 表示此节点输出将具有的 Python 类型的可选类型注解。

返回:

新创建并插入的 call_function 节点。

返回类型:

节点

注意

此方法的插入点类型表达式规则与 Graph.create_node() 相同。

注意

此 API 向后兼容性得到保证。

call_method(method_name, args=None, kwargs=None, type_expr=None)[source][source]

Graph 中插入 call_method Node 。一个 call_method 节点表示对 args 的第 0 个元素调用给定方法。

参数:
  • method_name (str) – 要应用到 self 参数上的方法名称。例如,如果 args[0]是一个表示 TensorNode ,那么要调用 Tensor 上的 relu() ,需要将 relu 传递给 method_name

  • args (Optional[Tuple[Argument, ...]]) – 要传递给被调用方法的定位参数。注意,这应该包括一个 self 参数。

  • kwargs (Optional[Dict[str, Argument]]) – 要传递给被调用方法的键值参数

  • type_expr (Optional[Any]) – 表示此节点输出将具有的 Python 类型的可选类型注解。

返回:

新创建并插入的 call_method 节点。

返回类型:

节点

注意

此方法的插入点和类型表达式规则与 Graph.create_node() 相同。

注意

此 API 保证向后兼容。

call_module(module_name, args=None, kwargs=None, type_expr=None)[source][source]

Graph 中插入 call_module Node 。一个 call_module 节点代表对 Module 层次结构中 Module 的 forward()函数的调用。

参数:
  • module_name (str) – 要调用的 ModuleModule 层次结构中的限定名称。例如,如果被跟踪的 Module 有一个名为 foo 的子模块,该子模块有一个名为 bar 的子模块,则应将限定名称 foo.bar 作为 module_name 传递以调用该模块。

  • args(可选[Tuple[Argument, ...]])- 要传递给调用方法的定位参数。请注意,此参数不应包括 self 参数。

  • kwargs(可选[Dict[str, Argument]])- 要传递给调用方法的键值参数

  • type_expr(可选[Any])- 表示此节点输出将具有的 Python 类型的可选类型注解。

返回:

新创建并插入的 call_module 节点。

返回类型:

节点

注意

此方法的插入点类型表达式规则与 Graph.create_node() 相同。

注意

此 API 向后兼容性得到保证。

create_node(op, target, args=None, kwargs=None, name=None, type_expr=None)[source][source]

创建一个 Node 并将其添加到当前插入点处的 Graph 。请注意,当前插入点可以通过 Graph.inserting_before()Graph.inserting_after() 设置。

参数:
  • op (str) – 此节点的操作码。可以是 'call_function'、'call_method'、'get_attr'、'call_module'、'placeholder' 或 'output' 之一。这些操作码的语义在 Graph 文档字符串中描述。

  • args (Optional[Tuple[Argument, ...]]) – 是一个包含此节点参数的元组。

  • kwargs (Optional[Dict[str, Argument]]) – 此节点的 kwargs

  • 可选的字符串名称(Optional[str]) - 为 Node . 分配的值提供可选的字符串名称。这将影响 Python 生成的代码中分配的值名称。

  • 类型表达式(Optional[Any]) - 表示此节点输出将具有的 Python 类型的可选类型注解。

返回:

新创建并插入的节点。

返回类型:

节点

注意

此 API 保证向后兼容。

eliminate_dead_code(is_impure_node=None)[source][source]

根据每个节点的用户数量以及节点是否有副作用,从图中删除所有死代码。调用之前必须对图进行拓扑排序。

参数:
  • is_impure_node (Optional[Callable[[Node], bool]]) – 一个返回

  • (节点是否不纯。如果是) –

  • (那么默认行为是) –

  • (使用)Node.is_impure. –

返回:

(该图是否因为该遍历而改变。)

返回类型:

布尔型

示例:

在消除死代码之前,一个如下的 a = x + 1 没有用户,因此可以从图中删除而不会产生影响。

def forward(self, x):
    a = x + 1
    return x + self.attr_1

在消除死代码之后,a = x + 1 已被删除,其余的前向操作仍然保留。

def forward(self, x):
    return x + self.attr_1

警告

死代码消除有一些启发式方法来避免删除有副作用的节点(见 Node.is_impure),但总体覆盖率非常差,因此你应该假设除非你知道你的 FX 图完全由函数操作组成,或者你提供了自己的自定义函数来检测有副作用的节点,否则这种方法是不可靠的。

注意

本 API 的向后兼容性得到保证。

erase_node(to_erase)[源代码][源代码]

Graph 中删除 Node 。如果该节点在 Graph 中仍有用户,则抛出异常。

参数:

to_erase (节点) – 从 Graph 中删除的 Node

注意

本 API 向后兼容性有保证。

find_nodes(*, op, target=None, sort=True)[source][source]

允许快速查询节点

参数:
  • op (str) – 操作名称

  • target (Optional[Target]) – 节点的目标。对于 call_function,目标为必需。对于其他操作,目标为可选。

  • 是否按在图上出现的顺序返回节点。

返回:

具有请求操作和目标的节点可迭代。

警告

此 API 为实验性,且不向后兼容。

get_attr(qualified_name, type_expr=None)[source][source]

在图中插入一个 get_attr 节点。A get_attr Node 表示从 Module 层次中获取属性。

参数:
  • 完整名称(str)- 要检索的属性的完整名称。例如,如果跟踪的模块有一个名为 foo 的子模块,该子模块有一个名为 bar 的子模块,该子模块有一个名为 baz 的属性,则应将合格名称 foo.bar.baz 传递为 qualified_name

  • type_expr(可选[Any])- 表示此节点输出将具有的 Python 类型的可选类型注解。

返回:

新创建并插入的 get_attr 节点。

返回类型:

节点

注意

此方法与 Graph.create_node 的插入点及类型表达式规则相同。

注意

此 API 向后兼容性得到保证。

graph_copy(g, val_map, return_output_node=False)[source][source]

将给定图中的所有节点复制到 self

参数:
  • g(图)- 从中复制节点的源图。

  • val_map(Dict[Node, Node])- 一个字典,将填充从 gself 的节点映射。注意,如果 val_map 已经包含值,则可以传入以覆盖某些值的复制。

返回:

如果 g 有一个 output 节点,则 self 中的值现在等同于 g 的输出值。否则为 None

返回类型:

Optional[Union[ tuple[Union[ForwardRef('Argument'), ...], collections.abc.Sequence[ForwardRef('Argument')], collections.abc.Mapping[ str, ForwardRef('Argument')], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType], …], Sequence[Optional[Union[ tuple[ForwardRef('Argument'), …], Sequence[Argument], Mapping[ str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]], Mapping[ str, Optional[Union[ tuple[ForwardRef('Argument'), …], Sequence[Argument], Mapping[ str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]

注意

此 API 向后兼容性得到保证。

inserting_after(n=None)[source][source]
设置 create_node 和伴随方法将插入到图中的点。

当在“with”语句中使用时,这将临时设置插入点,并在退出“with”语句时恢复:

with g.inserting_after(n):
    ...  # inserting after node n
...  # insert point restored to what it was previously
g.inserting_after(n)  #  set the insert point permanently

参数:

n (可选[节点]): 要插入其前的节点。如果为 None,则将在整个图的开始之后插入。

的位置。

返回:

一个将恢复插入点在 __exit__ 的资源管理器。

注意

本 API 向后兼容性得到保证。

inserting_before(n=None)[source][source]
设置 create_node 和伴随方法将插入到图中的点。

当在“with”语句中使用时,这将临时设置插入点,并在退出“with”语句时恢复:

with g.inserting_before(n):
    ...  # inserting before node n
...  # insert point restored to what it was previously
g.inserting_before(n)  #  set the insert point permanently

参数:

n (可选[节点]): 要插入其前的节点。如果为 None,则将在整个图的开始处插入。

的位置。

返回:

一个将恢复 __exit__ .插入点的资源管理器。

注意

本 API 向后兼容性得到保证。

lint()[源码][源码] ¶

对此图进行各种检查以确保其格式正确。特别是:- 检查节点拥有正确的所有权(由此图拥有)- 检查节点按拓扑顺序出现- 如果此图有一个拥有 GraphModule,检查该 GraphModule 中存在目标

注意

此 API 的向后兼容性得到保证。

node_copy(node, arg_transform=>)[源码][源码] ¶

将一个节点从一张图复制到另一张图。 arg_transform 需要将节点所在图的参数转换为自身图的参数。示例:

# Copying all the nodes in `g` into `new_graph`
g: torch.fx.Graph = ...
new_graph = torch.fx.graph()
value_remap = {}
for node in g.nodes:
    value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n])
参数:
  • 节点(Node)- 要复制的节点 self

  • arg_transform (Callable[[Node], Argument]) - 一个函数,用于将节点 argskwargs 中的参数转换为 self 中的等效参数。在简单情况下,这应该从将原始图中的节点映射到 self 的表中检索值。

返回类型:

节点

注意

本 API 向后兼容性得到保证。

属性节点_node_list _

获取构成此图的节点列表。

注意,此 Node 列表表示形式是双向链表。迭代期间的变异(例如删除节点、添加节点)是安全的。

返回:

双向链表的节点。注意,可以通过 reversed 来切换此列表的迭代顺序。

on_generate_code(make_transformer)[source][source]

在生成 Python 代码时注册转换器函数

参数:
make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]):

一个返回要注册的代码转换器的函数。此函数由 on_generate_code 调用以获取代码转换器。

此函数还接收当前已注册的代码转换器(如果没有注册则为 None),以防不希望覆盖它。这对于将代码转换器链接在一起很有用。

返回:

一个上下文管理器,当在 with 语句中使用时,会自动恢复之前注册的代码转换器。

示例:

gm: fx.GraphModule = ...


# This is a code transformer we want to register. This code
# transformer prepends a pdb import and trace statement at the very
# beginning of the generated torch.fx code to allow for manual
# debugging with the PDB library.
def insert_pdb(body):
    return ["import pdb; pdb.set_trace()\n", *body]


# Registers `insert_pdb`, and overwrites the current registered
# code transformer (given by `_` to the lambda):
gm.graph.on_generate_code(lambda _: insert_pdb)

# Or alternatively, registers a code transformer which first
# runs `body` through existing registered transformer, then
# through `insert_pdb`:
gm.graph.on_generate_code(
    lambda current_trans: (
        lambda body: insert_pdb(current_trans(body) if current_trans else body)
    )
)

gm.recompile()
gm(*inputs)  # drops into pdb

此功能也可以用作上下文管理器,具有自动恢复先前注册的代码转换器的优势:

# ... continue from previous example

with gm.graph.on_generate_code(lambda _: insert_pdb):
    # do more stuff with `gm`...
    gm.recompile()
    gm(*inputs)  # drops into pdb

# now previous code transformer is restored (but `gm`'s code with pdb
# remains - that means you can run `gm` with pdb here too, until you
# run next `recompile()`).

警告

此 API 为实验性,且不向后兼容。

output(result, type_expr=None)[source][source]

Graph 中插入 output Node 。一个 output 节点代表 Python 代码中的一个 return 语句。 result 是应该返回的值。

参数:
  • result (参数) – 要返回的值。

  • type_expr (Optional[Any]) – 表示此节点输出将具有的 Python 类型的可选类型注解。

注意

此方法与 Graph.create_node 适用于相同的插入点和类型表达式规则。

注意

本 API 保证向后兼容性。

output_node()[source][source]

警告

此 API 为实验性,且不向后兼容。

返回类型:

Node

placeholder(名称, 类型表达式=None, 默认值)[源代码][源代码] ¶

在图中插入一个 placeholder 节点。 placeholder 代表函数输入。

参数:
  • 名称 (str) – 输入值的名称。这对应于该 Graph 代表的函数的位置参数名称。

  • type_expr (Optional[Any]) – 表示此节点输出将具有的 Python 类型的可选类型注解。在某些情况下,这对于正确的代码生成是必需的(例如,当函数在 TorchScript 编译中随后使用时)。

  • default_value (Any) – 此函数参数应采用的默认值。注意:为了允许 None 作为默认值,应将 inspect.Signature.empty 传递给此参数,以指定该参数没有默认值。

返回类型:

节点

注意

此方法与 Graph.create_node 一样适用相同的插入点和类型表达式规则。

注意

本 API 保证向下兼容。

print_tabular()[source][source]

以表格形式打印图的中间表示。请注意,此 API 需要安装 tabulate 模块。

注意

本 API 保证向下兼容。

process_inputs(*args)[source][source]

处理参数,以便将它们传递给 FX 图。

警告

此 API 为实验性,且不向后兼容。

process_outputs(out)[source][source]

警告

此 API 为实验性,且不向后兼容。

python_code(root_module, *, verbose=False, include_stride=False, include_device=False, colored=False)[source][source]

将这个 Graph 转换为有效的 Python 代码。

参数:

root_module(字符串)- 要查找合格名称目标的根模块名称。这通常是‘self’。

返回:

src:表示对象的 Python 源代码 globals:src 中的全局名称的字典 -> 它们引用的对象。

返回类型:

一个 PythonCode 对象,包含两个字段

注意

此 API 保证向后兼容。

set_codegen(codegen)[source][source]

警告

此 API 为实验性,且不向后兼容。

class torch.fx.Node(graph, name, op, target, args, kwargs, return_type=None)[source][source]

Node 是表示 Graph 中单个操作的抽象数据类型。大部分情况下,节点代表对各种实体(如算子、方法、模块等)的调用点(一些例外包括指定函数输入和输出的节点)。每个节点都有一个由其 op 属性指定的函数。 Node 的语义如下:

  • placeholder 代表函数输入。 name 属性指定该值将采用的名称。 target 类似地是参数的名称。 args 包含以下两种情况之一:1)无内容,或 2)一个表示函数输入默认参数的单个参数。 kwargs 是无关紧要的。占位符对应于图打印中的函数参数(例如 x )。

  • get_attr 从模块层次结构中检索参数。 name 类似地是分配给检索结果的名称。 target 是参数在模块层次结构中的完全限定名称。 argskwargs 是无关紧要的。

  • call_function 将一个自由函数应用于一些值。 name 类似地是分配给值的名称。 target 是要应用的功能。 argskwargs 代表函数的参数,遵循 Python 调用约定。

  • call_module 将模块层次结构中的 forward() 方法应用于给定的参数。 name 如前所述。 target 是要调用的模块在模块层次结构中的完全限定名称。 argskwargs 代表调用模块时要传递的参数,不包括 self 参数。

  • call_method 调用值上的方法。 name 与此类似。 target 是要应用到 self 参数上的方法的字符串名称。 argskwargs 代表调用模块时的参数,包括 self 参数

  • output 包含被跟踪函数的输出,存储在其 args[0] 属性中。这对应于图打印输出中的“return”语句。

注意

本 API 向后兼容性得到保证。

属性 all_input_nodeslist[torch.fx.node.Node] ¶

返回所有是该节点输入的节点。这相当于遍历 argskwargs ,并仅收集值为节点的值。

返回:

列出在 argskwargs 中出现的 Nodes ,顺序如下。

append(x)[source][source]

在图中节点列表中在此节点之后插入 x 。相当于 self.next.prepend(x)

参数:

x(节点)- 要放在此节点之后的节点。必须是同一图中的成员。

注意

本 API 向后兼容性得到保证。

属性 argstuple[typing.Union[tuple[typing.Union[tuple[ForwardRef('Argument'), ...], collections.abc.Sequence[ForwardRef('Argument')], collections.abc.Mapping[str, ForwardRef('Argument')], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType], ...], collections.abc.Sequence[typing.Union[tuple[ForwardRef('Argument'), ...], collections.abc.Sequence[ForwardRef('Argument')], collections.abc.Mapping[str, ForwardRef('Argument')], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType]], collections.abc.Mapping[str, typing.Union[tuple[ForwardRef('Argument'), ...], collections.abc.Sequence[ForwardRef('Argument')], collections.abc.Mapping[str, ForwardRef('Argument')], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType]]] torch.SymFloat, NoneType]], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType], ...] ¶

此函数的参数元组。参数的解释取决于节点的操作码。有关更多信息,请参阅 Node 的文档字符串。

允许对此属性进行赋值。在赋值时,所有使用情况和用户的会计信息将自动更新。

format_node(placeholder_names=None, maybe_return_typename=None)[source][source]

返回对 self 的描述性字符串表示。

此方法可以无参数使用,作为调试工具。

此函数还用于 Graph__str__ 方法内部。 placeholder_namesmaybe_return_typename 中的字符串共同构成了此 Graph 周围 GraphModule 中自动生成的 forward 函数的签名。 placeholder_namesmaybe_return_typename 不应在其他情况下使用。

参数:
  • placeholder_names(可选[字符串列表])- 一个列表,将存储表示生成的 forward 函数中占位符的格式化字符串。仅限内部使用。

  • maybe_return_typename (Optional[list[str]]) – 存储生成 forward 函数输出的格式化字符串的单元素列表。仅内部使用。

返回:

如果 1) 我们使用 format_node 作为内部辅助函数

Graph__str__ 方法中,并且 2) self 是一个占位符节点,则返回 None 。否则,返回当前节点的描述性字符串表示。

返回类型:

str

注意

本 API 的向后兼容性得到保证。

insert_arg(idx, arg)[source][source]

在指定索引处向参数列表插入一个位置参数。

参数:
  • idx (int) – 要插入到 self.args 之前的位置元素的索引。

  • arg (Argument) – 要插入到 args 中的新参数值。

注意

此 API 保证向后兼容。

is_impure()[source][source]

返回此操作是否为不纯,即其操作是否为占位符或输出,或者是否为不纯的 call_function 或 call_module。

返回:

判断操作是否为不纯。

返回类型:

布尔型

警告

此 API 为实验性,且不向后兼容。

属性 kwargsdict[str, typing.Union[tuple[ForwardRef('Argument'), ...], collections.abc.Sequence[ForwardRef('Argument')], collections.abc.Mapping[str, ForwardRef('Argument')], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType], ...], collections.abc.Sequence[typing.Union[tuple[ForwardRef('Argument'), ...], collections.abc.Sequence[ForwardRef('Argument')], collections.abc.Mapping[str, ForwardRef('Argument')], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType]], collections.abc.Mapping[str, typing.Union[tuple[ForwardRef('Argument'), ...], collections.abc.Sequence[ForwardRef('Argument')], collections.abc.Mapping[str, ForwardRef('Argument')], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType]]] torch.SymFloat, NoneType]], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType]] ¶

该函数的关键字参数字典。参数的解释取决于节点的操作码。请参阅 Node 的文档字符串以获取更多信息。

允许对此属性进行赋值。在赋值时,所有使用情况和用户的会计信息将自动更新。

属性 nextNode

返回节点链表中的下一个 Node

返回:

节点链表中的下一个 Node

normalized_arguments(root, arg_types=None, kwarg_types=None, normalize_to_only_use_kwargs=False)[source][source]

返回用于 Python 目标的标准化参数。这意味着 args/kwargs 将与模块/函数的签名匹配,如果 normalize_to_only_use_kwargs 为 true,则仅按位置顺序返回 kwargs。同时填充默认值。不支持位置只写参数或可变参数。

支持模块调用。

可能需要 arg_types 和 kwarg_types 来区分重载。

参数:
  • root (torch.nn.Module) – 要解析模块目标的模块。

  • arg_types (可选[Tuple[Any]]) – 参数类型的元组。

  • kwarg_types (Optional[Dict[str, Any]]) – 参数类型字典

  • normalize_to_only_use_kwargs (bool) – 是否规范化为仅使用 kwargs。

返回:

返回 NamedTuple ArgsKwargsPair,或未成功时返回 None。

返回类型:

Optional[ArgsKwargsPair]

警告

此 API 为实验性,且不向后兼容。

prepend(x)[来源][来源] ¶

在图中的节点列表中在此节点之前插入 x。例如:

Before: p -> self
        bx -> x -> ax
After:  p -> x -> self
        bx -> ax
参数:

x(节点)- 要插入此节点之前的节点。必须是同一图中的成员。

注意

本 API 兼容旧版本有保证。

属性 prevNode ¶

返回节点链表中的前一个 Node

返回:

节点链表中的前一个 Node

replace_all_uses_with(replace_with, delete_user_cb=<function Node.<lambda>>, *, propagate_meta=False)[source][source]

在 Graph 中将所有使用 self 的地方替换为 Node replace_with

参数:
  • 替换为(节点)- 要替换所有使用 self 的节点。

  • delete_user_cb(可调用)- 当确定是否应该从 self 节点移除给定用户时被调用的回调。

  • propagate_meta(布尔值)- 是否将原始节点的.meta 字段上的所有属性复制到替换节点上。出于安全考虑,仅在替换节点尚未具有现有的.meta 字段时才有效。

返回:

已更改此更改的节点列表。

返回类型:

list[torch.fx.node.Node]

注意

此 API 向后兼容性得到保证。

replace_input_with(old_input, new_input)[source][source]

遍历输入节点 self ,并将所有 old_input 替换为 new_input

参数:
  • 旧输入(节点)- 要替换的旧输入节点。

  • 新输入(节点)- 替换 old_input 的新输入节点。

注意

本 API 向后兼容性得到保证。

属性 stack_traceOptional[str] ¶

返回在跟踪期间记录的 Python 调用栈,如果有的话。当使用 fx.Tracer 跟踪时,此属性通常由 Tracer.create_proxy 填充。为了在跟踪期间记录调用栈以进行调试,请在 Tracer 实例上设置 record_stack_traces = True。当使用 dynamo 跟踪时,此属性将默认由 OutputGraph.create_proxy 填充。

stack_trace 将在字符串末尾具有最内层的帧。

update_arg(idx, arg)[source][source]

更新现有位置参数以包含新值 arg 。调用后, self.args[idx] == arg

参数:
  • idx (int) – 更新元素的索引 self.args

  • arg (Argument) – 要写入 args 的新参数值

注意

本 API 向后兼容性得到保证。

update_kwarg(key, arg)[source][source]

更新现有的关键字参数以包含新值 arg 。调用后, self.kwargs[key] == arg

参数:
  • key (str) – 要更新的元素的 self.kwargs 中的键

  • arg (Argument) – 要写入 kwargs 的新参数值

注意

本 API 保证向后兼容。

class torch.fx.Tracer(autowrap_modules=(math,), autowrap_functions=())[source][source]

Tracer 是实现 torch.fx.symbolic_trace 符号跟踪功能的类。对 symbolic_trace(m) 的调用等同于 Tracer().trace(m)

Tracer 可以被继承以覆盖跟踪过程的多种行为。可以覆盖的不同行为在类上方法的文档字符串中描述。

注意

此 API 保证向后兼容。

call_module(m, forward, args, kwargs)[source][source]

指定此 Tracer 在遇到对 nn.Module 实例的调用时的行为的方法。

默认情况下,行为是检查被调用的模块是否是叶子模块通过 is_leaf_module 。如果是,则在 Graph 中引用 m 发射一个 call_module 节点。否则,正常调用 Module ,通过其 forward 函数中的操作进行跟踪。

此方法可以被重写,例如创建嵌套的已跟踪 GraphModules,或者跨 Module 边界时你想要的任何其他行为。

参数:
  • m(模块)- 正在发出调用的模块

  • forward(可调用)- 要调用的 Module 的 forward()方法

  • args(元组)- 模块调用位置的参数

  • kwargs (字典) – 模块调用位置的 kwargs

返回:

模块的返回值。如果生成了 call_module 节点,则此为 Proxy 值。否则,就是 Module 调用返回的任何值。

返回类型:

任何

注意

本 API 向后兼容性得到保证。

create_arg(a)[source][source]

一种指定在准备用作节点参数的值时的跟踪行为的方法。

默认情况下,行为包括:

  1. 遍历集合类型(例如元组、列表、字典)并对元素递归调用 create_args

  2. 给定一个代理对象,返回底层 IR 的引用 Node

  3. 对于非代理张量对象,针对各种情况发出 IR

    • 对于参数,发出指向该参数的 get_attr 节点

    • 对于非参数张量,将该张量存储在特殊属性中,以引用该属性

此方法可以被重写以支持更多类型

参数:

一个(任何)- 要作为 Graph 发出的值 Argument

返回:

a 转换为适当的 Argument 的值。

返回类型:

参数

注意

本 API 向后兼容性得到保证。

create_args_for_root(root_fn, is_module, concrete_args=None)[source][source]

创建对应于 root 模块签名的 placeholder 节点。此方法会检查 root 的签名并相应地发出这些节点,同时支持 *args**kwargs

警告

此 API 为实验性,且不向后兼容。

create_node(kind, target, args, kwargs, name=None, type_expr=None)[source]

插入一个图节点,给定目标、参数、关键字参数和名称。

此方法可以被重写以进行额外的检查、验证或修改用于节点创建的值。例如,可能希望禁止记录原地操作。

注意

此 API 向后兼容性得到保证。

返回类型:

Node

create_proxy(kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None)[source]

从给定的参数创建一个节点,然后返回一个包装在代理对象中的节点。

如果 kind = ‘placeholder’,则我们正在创建一个表示函数参数的节点。如果需要编码默认参数,我们使用 args 元组。对于 placeholder 节点, args 为空。

注意

本 API 向后兼容性得到保证。

get_fresh_qualname(prefix)[source][source]

获取一个前缀的新名称并返回它。此函数确保它不会与图上的现有属性冲突。

注意

本 API 向后兼容性得到保证。

返回类型:

str

getattr(attr, attr_val, parameter_proxy_cache)[source][source]

指定在调用 nn.Module 实例的 getattr 时此 Tracer 的行为的方法。

默认情况下,行为是返回属性的代理值。它还会将代理值存储在 parameter_proxy_cache 中,以便未来的调用将重用代理而不是创建一个新的。

此方法可以被覆盖,例如,在查询参数时不返回代理。

参数:
  • 查询的属性名称(字符串)

  • 属性值(任意类型)

  • parameter_proxy_cache(字典[str, 任意类型])- 属性名称到代理的缓存

返回:

getattr 调用返回值

警告

此 API 为实验性,且不向后兼容。

is_leaf_module(m, module_qualified_name)[source][source]

一个用于指定给定的 nn.Module 是否为“叶子”模块的方法。

叶子模块是出现在 IR 中的原子单元,通过 call_module 调用进行引用。默认情况下,PyTorch 标准库命名空间(torch.nn)中的模块都是叶子模块。除非通过此参数指定,否则所有其他模块都会进行跟踪,并记录其构成的操作。

参数:
  • m (模块) – 正在被查询的模块

  • 模块限定名称(字符串)- 此模块根路径。例如,如果您有一个模块层次结构,其中子模块 foo 包含子模块 bar ,该子模块又包含子模块 baz ,则该模块将在此处以限定名称 foo.bar.baz 出现。

返回类型:

布尔型

注意

此 API 向后兼容性得到保证。

iter(obj)[source]
当代理对象正在被迭代时调用,例如

当用于控制流时。通常我们不知道该做什么,因为我们不知道代理的值,但自定义跟踪器可以使用 create_node 将更多信息附加到图节点,并可以选择返回一个迭代器。

注意

此 API 的向后兼容性得到保证。

返回类型:

迭代器

keys(obj)[source]
当代理对象调用 keys()方法时触发。

当在代理上调用**时会发生这种情况。这应该返回一个迭代器 it**,在你的自定义跟踪器中应该正常工作。

注意

本 API 向后兼容性得到保证。

返回类型:

任何

path_of_module(mod)[source][source]

在模块层次结构中查找 mod 的限定名称的辅助方法。例如,如果 root 有一个子模块名为 foo ,而 foo 又有子模块名为 bar ,将 bar 传递给此函数将返回字符串“foo.bar”。

参数:

mod (str) – 获取限定名称的 Module

返回类型:

str

注意

本 API 向后兼容性得到保证。

proxy(node)[source]

注意

此 API 保证向后兼容。

返回类型:

代理

to_bool(obj)[源码] ¶
当代理对象被转换为布尔值时调用,例如

当用于控制流时。通常我们不知道该做什么,因为我们不知道代理的值,但自定义跟踪器可以使用 create_node 将更多信息附加到图节点,并可以选择返回一个值。

注意

本 API 向后兼容性得到保证。

返回类型:

布尔型

trace(root, concrete_args=None)[source][source]

跟踪 root 并返回相应的 FX Graph 表示。 root 可以是一个 nn.Module 实例或 Python 可调用对象。

注意,在这次调用之后, self.root 可能与传入此处的 root 不同。例如,当将自由函数传递给 trace() 时,我们将创建一个 nn.Module 实例作为根实例,并添加嵌入的常量。

参数:
  • root (Union[Module, Callable]) – 可以是 Module 或要追踪的函数。此参数向后兼容性得到保证。

  • concrete_args (Optional[Dict[str, any]]) – 应视为非代理的具体参数。此参数为实验性,其向后兼容性不保证。

返回:

代表传入的 root 的语义的 Graph

返回类型:

注意

本 API 保证向后兼容。

class torch.fx.Proxy(node, tracer=None)[source][source]

Proxy 对象是 Node 包装器,在符号跟踪期间通过程序流动并记录它们接触到的所有操作( torch 函数调用、方法调用、运算符)到不断增长的 FX 图。

如果您正在进行图转换,您可以将自己的 Proxy 方法包装在原始的 Node 之上,这样您就可以使用重载的运算符向 Graph 添加额外的内容。

对象无法迭代。换句话说,如果在一个循环或作为 Proxy / *args / **kwargs 函数参数中使用 Proxy ,符号追踪器将抛出错误。

有两种主要的解决方案:1. 将不可追踪的逻辑提取到顶级函数中,并在其上使用 fx.wrap 。2. 如果控制流是静态的(即循环迭代次数基于某些超参数),则可以将代码保留在原始位置,并重构为类似以下内容:

for i in range(self.some_hyperparameter):
    indexed_item = proxied_value[i]

想要更详细地了解代理内部机制,请查看 torch/fx/README.md 中的“Proxy”部分。

注意

本 API 向后兼容性得到保证。

class torch.fx.Interpreter(module, garbage_collect_values=True, graph=None)[source][source]

解释器逐节点执行 FX 图。这种模式对于许多事情都很有用,包括编写代码转换和分析过程。

可以覆盖解释器类中的方法来自定义执行行为。可覆盖方法的调用层次映射:

run()
    +-- run_node
        +-- placeholder()
        +-- get_attr()
        +-- call_function()
        +-- call_method()
        +-- call_module()
        +-- output()

示例

假设我们想要交换所有 torch.neg 的实例与 torch.sigmoid ,反之亦然(包括它们的 Tensor 方法等效)。我们可以这样子类化解释器:

class NegSigmSwapInterpreter(Interpreter):
    def call_function(self, target: Target, args: Tuple, kwargs: Dict) -> Any:
        if target == torch.sigmoid:
            return torch.neg(*args, **kwargs)
        return super().call_function(target, args, kwargs)

    def call_method(self, target: Target, args: Tuple, kwargs: Dict) -> Any:
        if target == "neg":
            call_self, *args_tail = args
            return call_self.sigmoid(*args_tail, **kwargs)
        return super().call_method(target, args, kwargs)


def fn(x):
    return torch.sigmoid(x).neg()


gm = torch.fx.symbolic_trace(fn)
input = torch.randn(3, 4)
result = NegSigmSwapInterpreter(gm).run(input)
torch.testing.assert_close(result, torch.neg(input).sigmoid())
参数:
  • 要执行的模块(torch.nn.Module)

  • garbage_collect_values(布尔值)- 是否在模块执行过程中删除其最后使用后的值。这确保了执行过程中的最佳内存使用。可以通过查看 Interpreter.env 属性来禁用此功能,以检查执行过程中的所有中间值。

  • graph(可选[Graph])- 如果传入,解释器将使用提供的模块参数执行此图,而不是 module.graph,以满足对状态的任何请求。

注意

本 API 向后兼容性得到保证。

boxed_run(args_list)[source][source]

通过解释执行模块并返回结果。这使用“boxed”调用约定,其中传递一个参数列表,该列表将被解释器清除。这确保了输入张量能够及时释放。

注意

本 API 向后兼容性得到保证。

call_function(target, args, kwargs)[source][source]

执行一个 call_function 节点并返回结果。

参数:
  • 目标(目标)- 此节点的调用目标。有关语义详情,请参阅节点。

  • args(元组)- 此调用的位置参数元组。

  • kwargs(字典)- 此调用的关键字参数字典。

返回类型:

任何

返回

Any: 函数调用的返回值

注意

本 API 向后兼容性得到保证。

call_method(target, args, kwargs)[source][source]

执行一个 call_method 节点并返回结果。

参数:
  • 目标(目标)- 此节点的调用目标。有关语义详情,请参阅节点。

  • args(元组)- 此调用的位置参数元组。

  • kwargs(字典)- 此调用的关键字参数字典。

返回类型:

任何

返回

Any: 方法调用的返回值

注意

本 API 向后兼容性得到保证。

call_module(target, args, kwargs)[source][source]

执行一个 call_module 节点并返回结果。

参数:
  • 目标(目标)- 此节点的调用目标。有关语义详情,请参阅节点。

  • args(元组)- 此调用的位置参数元组。

  • kwargs(字典)- 此调用的关键字参数字典。

返回类型:

任何

返回

Any: 模块调用返回的值

注意

本 API 向后兼容性得到保证。

fetch_args_kwargs_from_env(n)[source][source]

从当前执行环境中获取节点 nargskwargs 的具体值。

参数:

n (节点) – 需要获取 argskwargs 的节点@2#。

返回:

argskwargs 的具体值为 n

返回类型:

Tuple[Tuple, Dict]

注意

此 API 保证向后兼容。

fetch_attr(target)[source][source]

self.moduleModule 层次结构中获取属性。

参数:

target (str) – 要获取的属性的完全限定名称

返回:

属性的值。

返回类型:

任何

注意

本 API 向后兼容性得到保证。

get_attr(target, args, kwargs)[source][source]

执行一个 get_attr 节点。将从 Module 层级的 self.module 中检索属性值。

参数:
  • 目标(目标)- 此节点的调用目标。有关语义详情,请参阅节点。

  • args(元组)- 此调用位置参数的元组。

  • kwargs(字典)- 此调用关键字参数的字典。

返回:

所检索到的属性值

返回类型:

任何

注意

此 API 保证向下兼容。

将节点映射到值(args, n)[source][source] ¶

递归遍历 args 并在当前执行环境中查找每个 Node 的具体值。

参数:
  • args(参数)- 用于查找具体值的内部数据结构。

  • n(节点)- args 所属的节点。这仅用于错误报告。

返回类型:

Optional[Union[ tuple[Union[ForwardRef('Argument'), ...], collections.abc.Sequence[ForwardRef('Argument')], collections.abc.Mapping[ str, ForwardRef('Argument')], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType], …], Sequence[Optional[Union[ tuple[ForwardRef('Argument'), …], Sequence[Argument], Mapping[ str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]], Mapping[ str, Optional[Union[ tuple[ForwardRef('Argument'), …], Sequence[Argument], Mapping[ str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]

注意

此 API 向后兼容性得到保证。

output(target, args, kwargs)[source][source]

执行一个 output 节点。这实际上只是检索由 output 节点引用的值并返回它。

参数:
  • 目标(目标)- 此节点的调用目标。有关语义详情,请参阅节点。

  • args(元组)- 此调用位置参数的元组。

  • kwargs(字典)- 此调用关键字参数的字典。

返回:

输出节点引用的返回值

返回类型:

任何

注意

本 API 向后兼容性得到保证。

placeholder(target, args, kwargs)[源代码][源代码] ¶

执行一个 placeholder 节点。请注意,这是有状态的: Interpreter 维护一个对传递给 run 的参数的内部迭代器,此方法返回该迭代器的 next()。

参数:
  • 目标(Target)- 此节点的调用目标。有关语义的详细信息,请参阅 Node。

  • args(元组)- 此调用中位置参数的元组

  • kwargs(字典)- 此调用中关键字参数的字典

返回:

获取的参数值。

返回类型:

任何

注意

本 API 保证向后兼容。

run(*args, initial_env=None, enable_io_processing=True)[source][source]

通过解释执行模块并返回结果。

参数:
  • *args – 要运行的模块的参数,按位置顺序排列。

  • initial_env(可选[Dict[Node, Any]])– 执行的可选起始环境。这是一个将 Node 映射到任何值的字典。这可以用来预先填充某些节点的结果,以便在解释器中进行部分评估。

  • enable_io_processing(bool)– 如果为 true,我们在使用之前首先使用图的 process_inputs 和 process_outputs 函数处理输入和输出。

返回:

执行模块返回的值

返回类型:

任何

注意

此 API 保证向下兼容。

运行节点(n)[source][source] ¶

n 运行特定节点并返回结果。根据 node.op 调用占位符、get_attr、call_function、call_method、call_module 或 output

参数:

n (节点) – 要执行的节点

返回:

执行 n 的结果

返回类型:

任何

注意

本 API 保证向后兼容。

class torch.fx.Transformer(module)[source][source]

Transformer 是一种特殊的解释器,它产生一个新的 Module 。它公开了一个 transform() 方法,该方法返回转换后的 ModuleTransformer 不需要任何参数即可运行,而 Interpreter 需要。 Transformer 完全以符号方式工作。

示例

假设我们想要交换所有 torch.neg 的实例与 torch.sigmoid 以及它们的 Tensor 方法等效物(包括它们的 Tensor 方法等效物)。我们可以像这样子类化 Transformer

class NegSigmSwapXformer(Transformer):
    def call_function(
        self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
    ) -> Any:
        if target == torch.sigmoid:
            return torch.neg(*args, **kwargs)
        return super().call_function(target, args, kwargs)

    def call_method(
        self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
    ) -> Any:
        if target == "neg":
            call_self, *args_tail = args
            return call_self.sigmoid(*args_tail, **kwargs)
        return super().call_method(target, args, kwargs)


def fn(x):
    return torch.sigmoid(x).neg()


gm = torch.fx.symbolic_trace(fn)

transformed: torch.nn.Module = NegSigmSwapXformer(gm).transform()
input = torch.randn(3, 4)
torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
参数:

模块(GraphModule)- 要转换的 Module

注意

本 API 向后兼容性得到保证。

call_function(target, args, kwargs)[source][source]

注意

本 API 向后兼容性得到保证。

返回类型:

任何

call_module(target, args, kwargs)[source][source]

注意

本 API 向后兼容性得到保证。

返回类型:

任何

get_attr(target, args, kwargs)[source][source]

执行一个 get_attr 节点。在 Transformer 中,这被覆盖以在输出图中插入一个新的 get_attr 节点。

参数:
  • 目标(目标)- 此节点的调用目标。有关语义详情,请参阅节点

  • args(元组)- 此调用位置参数的元组

  • kwargs(字典)- 此调用关键字参数的字典

返回类型:

代理

注意

本 API 向后兼容性得到保证。

placeholder(target, args, kwargs)[源码][源码] ¶

执行一个 placeholder 节点。在 Transformer 中,这被覆盖以向输出图插入一个新的 placeholder

参数:
  • 目标(Target)- 此节点的调用目标。有关语义的详细信息,请参阅节点。

  • args(元组)- 本调用中位置参数的元组

  • kwargs(字典)- 本调用中关键字参数的字典

返回类型:

代理

注意

本 API 向后兼容性得到保证。

transform()[source][source]

Transform self.module and return the transformed GraphModule.

注意

此 API 向后兼容性得到保证。

返回类型:

GraphModule

torch.fx.replace_pattern(gm, pattern, replacement)[source][source]

匹配所有可能的非重叠操作符及其数据依赖集( pattern )在 GraphModule( gm )的图中,然后将这些匹配的子图替换为另一个子图( replacement )。

参数:
  • gm (GraphModule) – 包装图的 GraphModule 以进行操作

  • pattern (Union[Callable, GraphModule]) – 要在 gm 中匹配以进行替换的子图

  • 替换(Union[Callable, GraphModule])- 要替换 pattern 的子图

返回:

代表原始图中与 pattern 匹配的地点的 Match 对象列表。如果没有匹配项,则列表为空。 Match 定义如下:

class Match(NamedTuple):
    # Node from which the match was found
    anchor: Node
    # Maps nodes in the pattern subgraph to nodes in the larger graph
    nodes_map: Dict[Node, Node]

返回类型:

Match 列表

示例:

import torch
from torch.fx import symbolic_trace, subgraph_rewriter


class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x, w1, w2):
        m1 = torch.cat([w1, w2]).sum()
        m2 = torch.cat([w1, w2]).sum()
        return x + torch.max(m1) + torch.max(m2)


def pattern(w1, w2):
    return torch.cat([w1, w2])


def replacement(w1, w2):
    return torch.stack([w1, w2])


traced_module = symbolic_trace(M())

subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)

上述代码将首先在 traced_moduleforward 方法中匹配 pattern 。模式匹配基于 use-def 关系,而不是节点名称。例如,如果您有 patternp = torch.cat([a, b]) 中,则可以在原始 forward 函数中匹配 m = torch.cat([a, b]) ,尽管变量名称不同( pm )。

return 语句在 pattern 中仅根据其值进行匹配;它可能或可能不匹配到更大图中的 return 语句。换句话说,模式不必扩展到更大图的末尾。

当模式匹配时,它将从更大函数中移除,并由 replacement 替换。如果更大函数中有多个 pattern 匹配,则每个非重叠匹配都将被替换。在匹配重叠的情况下,重叠匹配集中找到的第一个匹配将被替换。(这里的“第一个”是指节点使用-定义关系的拓扑排序中的第一个。在大多数情况下,第一个节点是直接出现在 self 之后的参数,而最后一个节点是函数返回的内容。)

有一个重要的事情需要注意,那就是 pattern 可调用函数的参数必须在可调用函数本身中使用,而 replacement 可调用函数的参数必须匹配该模式。第一条规则是为什么在上面的代码块中, forward 函数有参数 x, w1, w2 ,但 pattern 函数只有参数 w1, w2pattern 没有使用 x ,因此不应将其指定为参数。作为第二条规则的例子,考虑替换

def pattern(x, y):
    return torch.neg(x) + torch.relu(y)

替换为

def replacement(x, y):
    return torch.relu(x)

在这种情况下, replacement 需要和 pattern 相同数量的参数( xy 都需要),即使参数 yreplacement 中没有被使用。

调用 subgraph_rewriter.replace_pattern 之后,生成的 Python 代码看起来像这样:

def forward(self, x, w1, w2):
    stack_1 = torch.stack([w1, w2])
    sum_1 = stack_1.sum()
    stack_2 = torch.stack([w1, w2])
    sum_2 = stack_2.sum()
    max_1 = torch.max(sum_1)
    add_1 = x + max_1
    max_2 = torch.max(sum_2)
    add_2 = add_1 + max_2
    return add_2

注意

该 API 的向后兼容性得到保证。


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

PyTorch 开发者文档全面访问

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源