• 文档 >
  • torch >
  • torch.cond
快捷键

torch.cond

torch.cond(pred, true_fn, false_fn, operands=())[source]

条件性地应用 true_fn 或 false_fn。

警告

torch.cond 是 PyTorch 中的一个原型功能。它对输入和输出类型的支持有限,目前不支持训练。请期待 PyTorch 未来版本中更稳定的实现。有关功能分类的更多信息,请参阅:https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype

cond 是一个结构化控制流操作符。也就是说,它类似于 Python 的 if 语句,但对 true_fn、false_fn 和操作数有约束,使其能够被 torch.compile 和 torch.export 捕获。

假设满足 cond 的参数约束,cond 等价于以下内容:

def cond(pred, true_branch, false_branch, operands):
    if pred:
        return true_branch(*operands)
    else:
        return false_branch(*operands)
参数:
  • pred (Union[bool, torch.Tensor]) – 表示要应用哪个分支函数的布尔表达式或只有一个元素的张量。

  • true_fn (Callable) – 被追踪作用域内的可调用函数(a -> b)。

  • false_fn (Callable) – 被追踪作用域内的可调用函数(a -> b)。真分支和假分支必须具有一致的输入和输出,即输入必须相同,输出必须是相同类型和形状。

  • operands (Tuple of possibly nested dict/list/tuple of torch.Tensor) – 真假函数的输入元组。如果 true_fn/false_fn 不需要输入,则可以为空。默认为()。

返回类型:

任何

示例:

def true_fn(x: torch.Tensor):
    return x.cos()
def false_fn(x: torch.Tensor):
    return x.sin()
return cond(x.shape[0] > 4, true_fn, false_fn, (x,))
限制条件:
  • 条件语句(又称 pred)必须满足以下约束之一:

    • 它是一个只有一个元素的 torch.Tensor,且数据类型为 torch.bool

    • 它是一个布尔表达式,例如 x.shape[0] > 10 或 x.dim() > 1 且 x.shape[1] > 10

  • 分支函数(又称 true_fn/false_fn)必须满足以下所有约束条件:

    • 函数签名必须与操作数匹配。

    • 函数必须返回具有相同元数据(例如形状、数据类型等)的张量。

    • 函数不能在输入或全局变量上执行就地修改。(注意:在分支中允许就地张量操作,如 add_用于中间结果。)


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源