快捷键

控制流 - 条件 ¶

torch.cond 是一个结构化控制流操作符。它可以用来指定类似于 if-else 的控制流,在逻辑上可以看作是以下实现。

def cond(
    pred: Union[bool, torch.Tensor],
    true_fn: Callable,
    false_fn: Callable,
    operands: Tuple[torch.Tensor]
):
    if pred:
        return true_fn(*operands)
    else:
        return false_fn(*operands)

它的独特之处在于能够表达数据依赖的控制流:它降低为条件运算符(torch.ops.higher_order.cond),这保留了谓词、真函数和假函数。这为根据输入值或张量运算的中间输出值或形状更改模型架构的模型编写和部署提供了极大的灵活性。

警告

torch.cond 是 PyTorch 中的一个原型功能。它对输入和输出类型的支持有限,目前不支持训练。请期待 PyTorch 未来版本中更稳定的实现。更多关于功能分类的信息请参阅:https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype

示例 ¶

下面是一个使用 cond 根据输入形状进行分支的示例:

import torch

def true_fn(x: torch.Tensor):
    return x.cos() + x.sin()

def false_fn(x: torch.Tensor):
    return x.sin()

class DynamicShapeCondPredicate(torch.nn.Module):
    """
    A basic usage of cond based on dynamic shape predicate.
    """

    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        def true_fn(x: torch.Tensor):
            return x.cos()

        def false_fn(x: torch.Tensor):
            return x.sin()

        return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,))

dyn_shape_mod = DynamicShapeCondPredicate()

我们可以积极运行模型,并期望结果根据输入形状而变化:

inp = torch.randn(3)
inp2 = torch.randn(5)
assert torch.equal(dyn_shape_mod(inp), false_fn(inp))
assert torch.equal(dyn_shape_mod(inp2), true_fn(inp2))

我们可以将模型导出以进行进一步转换和部署:

inp = torch.randn(4, 3)
dim_batch = torch.export.Dim("batch", min=2)
ep = torch.export.export(DynamicShapeCondPredicate(), (inp,), {}, dynamic_shapes={"x": {0: dim_batch}})
print(ep)

如下所示,我们得到了一个导出的程序:

class GraphModule(torch.nn.Module):
    def forward(self, arg0_1: f32[s0, 3]):
        sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 0)
        gt: Sym(s0 > 4) = sym_size > 4;  sym_size = None
        true_graph_0 = self.true_graph_0
        false_graph_0 = self.false_graph_0
        conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]);  gt = true_graph_0 = false_graph_0 = arg0_1 = None
        return (conditional,)

    class <lambda>(torch.nn.Module):
        def forward(self, arg0_1: f32[s0, 3]):
            cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
            sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
            add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin);  cos = sin = None
            return add

    class <lambda>(torch.nn.Module):
        def forward(self, arg0_1: f32[s0, 3]):
            sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
            return sin

注意,torch.cond 已降低为 torch.ops.higher_order.cond,其谓词成为输入形状的符号表达式,分支函数成为顶层图模块的两个子图属性。

下面是一个展示如何表达数据相关控制流的另一个示例:

class DataDependentCondPredicate(torch.nn.Module):
    """
    A basic usage of cond based on data dependent predicate.
    """
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.cond(x.sum() > 4.0, true_fn, false_fn, (x,))

导出后的程序:我们得到的导出程序:

class GraphModule(torch.nn.Module):
    def forward(self, arg0_1: f32[s0, 3]):
        sum_1: f32[] = torch.ops.aten.sum.default(arg0_1)
        gt: b8[] = torch.ops.aten.gt.Scalar(sum_1, 4.0);  sum_1 = None

        true_graph_0 = self.true_graph_0
        false_graph_0 = self.false_graph_0
        conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]);  gt = true_graph_0 = false_graph_0 = arg0_1 = None
        return (conditional,)

    class <lambda>(torch.nn.Module):
        def forward(self, arg0_1: f32[s0, 3]):
            cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
            sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
            add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin);  cos = sin = None
            return add

    class <lambda>(torch.nn.Module):
        def forward(self, arg0_1: f32[s0, 3]):
            sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
            return sin

torch.ops.higher_order.cond 的不变量

对于 torch.ops.higher_order.cond 有几个有用的不变量:

  • 对于谓词:
    • 谓词的动态性得到保留(例如,上述示例中的 gt)

    • 如果用户程序中的谓词是常量(例如,Python 布尔常量),则操作符的谓词将是一个常量。

  • 对于分支:
    • 输入和输出签名将是一个扁平化的元组。

    • 它们是 torch.fx.GraphModule。

    • 原函数中的闭包变为显式输入。没有闭包。

    • 不允许对输入或全局变量进行修改。

  • 对于操作数:
    • 它也将是一个扁平的元组。

  • 在用户程序中 torch.cond 的嵌套成为嵌套图模块。

API 参考¶

torch._higher_order_ops.cond.cond(pred, true_fn, false_fn, operands=())[source]

条件性地应用 true_fn 或 false_fn。

警告

torch.cond 是 PyTorch 中的一个原型功能。它对输入和输出类型的支持有限,目前不支持训练。请期待 PyTorch 未来版本中更稳定的实现。有关功能分类的更多信息,请参阅:https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype

cond 是一个结构化控制流操作符。也就是说,它类似于 Python 的 if 语句,但对 true_fn、false_fn 和操作数有约束,使其能够被 torch.compile 和 torch.export 捕获。

假设满足 cond 的参数约束,cond 等价于以下内容:

def cond(pred, true_branch, false_branch, operands):
    if pred:
        return true_branch(*operands)
    else:
        return false_branch(*operands)
参数:
  • pred (Union[bool, torch.Tensor]) – 表示要应用哪个分支函数的布尔表达式或只有一个元素的张量。

  • true_fn (Callable) – 被追踪作用域内的可调用函数(a -> b)。

  • false_fn (Callable) – 被追踪作用域内的可调用函数(a -> b)。真分支和假分支必须具有一致的输入和输出,即输入必须相同,输出必须是相同类型和形状。

  • operands (Tuple of possibly nested dict/list/tuple of torch.Tensor) – 真假函数的输入元组。如果 true_fn/false_fn 不需要输入,则可以为空。默认为()。

返回类型:

任何

示例:

def true_fn(x: torch.Tensor):
    return x.cos()
def false_fn(x: torch.Tensor):
    return x.sin()
return cond(x.shape[0] > 4, true_fn, false_fn, (x,))
限制条件:
  • 条件语句(又称 pred)必须满足以下约束之一:

    • 它是一个只有一个元素的 torch.Tensor,且数据类型为 torch.bool

    • 它是一个布尔表达式,例如 x.shape[0] > 10 或 x.dim() > 1 且 x.shape[1] > 10

  • 分支函数(又称 true_fn/false_fn)必须满足以下所有约束条件:

    • 函数签名必须与操作数匹配。

    • 函数必须返回具有相同元数据(例如形状、数据类型等)的张量。

    • 函数不能在输入或全局变量上执行就地修改。(注意:在分支中允许就地张量操作,如 add_用于中间结果。)


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

PyTorch 开发者文档全面访问

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源