快捷键

UX 限制 ¶

torch.func 与 JAX 类似,在可转换的内容上存在限制。一般来说,JAX 的限制是转换仅适用于纯函数:即输出完全由输入决定且不涉及副作用(如突变)的函数。

我们也有类似的保证:我们的转换与纯函数配合良好。然而,我们确实支持某些就地操作。一方面,编写与函数转换兼容的代码可能需要改变您编写 PyTorch 代码的方式,另一方面,您可能会发现我们的转换让您能够表达之前在 PyTorch 中难以表达的内容。

通用限制

所有 torch.func 转换都存在一个限制,即函数不应向全局变量赋值。相反,函数的所有输出都必须从函数中返回。这种限制源于 torch.func 的实现方式:每个转换都会将 Tensor 输入包装在特殊的 torch.func Tensor 子类中,以方便转换。

因此,不要这样做:

import torch
from torch.func import grad

# Don't do this
intermediate = None

def f(x):
  global intermediate
  intermediate = x.sin()
  z = intermediate.sin()
  return z

x = torch.randn([])
grad_x = grad(f)(x)

请将 f 重写为返回 intermediate

def f(x):
  intermediate = x.sin()
  z = intermediate.sin()
  return z, intermediate

grad_x, intermediate = grad(f, has_aux=True)(x)

torch.autograd API

如果您尝试在由 vmap() 或 torch.func 的 AD 转换( vjp()jvp()jacrev()jacfwd() )转换的函数中使用 torch.autograd API(如 torch.autograd.gradtorch.autograd.backward ),转换可能无法覆盖它。如果无法完成转换,您将收到错误消息。

这是在 PyTorch 的 AD 支持实现中的基本设计限制,也是我们设计 torch.func 库的原因。请改用 torch.autograd API 的 torch.func 等效函数:- torch.autograd.gradTensor.backward -> torch.func.vjptorch.func.grad - torch.autograd.functional.jvp -> torch.func.jvp - torch.autograd.functional.jacobian -> torch.func.jacrevtorch.func.jacfwd - torch.autograd.functional.hessian -> torch.func.hessian

vmap 限制

注意

vmap() 是我们最严格的转换。与 grad 相关的转换( grad()vjp()jvp() )没有这些限制。 jacfwd() (以及 hessian() ,它通过 jacfwd() 实现)是 vmap()jvp() 的组合,因此也具有这些限制。

vmap(func) 是一个返回函数的转换,该函数将 func 映射到每个输入 Tensor 的一些新维度上。vmap 的心理模型就像运行一个 for-loop:对于纯函数(即在无副作用的情况下), vmap(f)(x) 等同于:

torch.stack([f(x_i) for x_i in x.unbind(0)])

变异:Python 数据结构的任意变异

在存在副作用的情况下, vmap() 就不再像是在运行一个 for-loop。例如,以下函数:

def f(x, list):
  list.pop()
  print("hello!")
  return x.sum(0)

x = torch.randn(3, 1)
lst = [0, 1, 2, 3]

result = vmap(f, in_dims=(0, None))(x, lst)

只会打印一次“hello!”并从 lst 中弹出只有一个元素。

vmap() 只执行 f 一次,因此所有副作用只发生一次。

这是由 vmap 的实现方式导致的。torch.func 有一个特殊的、内部的 BatchedTensor 类。 vmap(f)(*inputs) 接受所有 Tensor 输入,将它们转换为 BatchedTensors,并调用 f(*batched_tensor_inputs) 。BatchedTensor 重写了 PyTorch API,为每个 PyTorch 操作产生批处理(即向量化)行为。

变更:原地 PyTorch 操作

你可能是因为收到了关于 vmap 不兼容的 in-place 操作的错误而在这里。 vmap() 如果遇到不支持的 PyTorch in-place 操作,将引发错误,否则将成功。不支持的操作是那些会导致将具有更多元素的 Tensor 写入具有较少元素的 Tensor 的操作。以下是一个这种情况的例子:

def f(x, y):
  x.add_(y)
  return x

x = torch.randn(1)
y = torch.randn(3, 1)  # When vmapped over, looks like it has shape [1]

# Raises an error because `x` has fewer elements than `y`.
vmap(f, in_dims=(None, 0))(x, y)

x 是一个只有一个元素的 Tensor, y 是一个有三个元素的 Tensor。 x + y 有三个元素(由于广播),但尝试将三个元素写回到只有一个元素的 x 中,将引发错误,因为试图将三个元素写入只有一个元素的 Tensor。

如果写入的 Tensor 在 vmap() 下批处理(即正在被 vmapped),则没有问题。

def f(x, y):
  x.add_(y)
  return x

x = torch.randn(3, 1)
y = torch.randn(3, 1)
expected = x + y

# Does not raise an error because x is being vmapped over.
vmap(f, in_dims=(0, 0))(x, y)
assert torch.allclose(x, expected)

对于这个问题的一个常见修复是将对工厂函数的调用替换为它们的“new_*”等效函数。例如:

  • torch.zeros() 替换为 Tensor.new_zeros()

  • torch.empty() 替换为 Tensor.new_empty()

要了解这为什么有帮助,请考虑以下内容。

def diag_embed(vec):
  assert vec.dim() == 1
  result = torch.zeros(vec.shape[0], vec.shape[0])
  result.diagonal().copy_(vec)
  return result

vecs = torch.tensor([[0., 1, 2], [3., 4, 5]])

# RuntimeError: vmap: inplace arithmetic(self, *extra_args) is not possible ...
vmap(diag_embed)(vecs)

vmap() 内部, result 是一个形状为[3, 3]的张量。然而,尽管 vec 看起来像它有形状[3],但 vec 实际上有[2, 3]的底层形状。由于元素过多,无法将 vec 复制到形状为[3]的 result.diagonal() 中。

def diag_embed(vec):
  assert vec.dim() == 1
  result = vec.new_zeros(vec.shape[0], vec.shape[0])
  result.diagonal().copy_(vec)
  return result

vecs = torch.tensor([[0., 1, 2], [3., 4, 5]])
vmap(diag_embed)(vecs)

torch.zeros() 替换为 Tensor.new_zeros() ,使得 result 具有形状为[2, 3, 3]的底层 Tensor,因此现在可以将具有形状[2, 3]的 vec 复制到 result.diagonal() 中。

变异:out= PyTorch 操作

vmap() 不支持 PyTorch 操作中的 out= 关键字参数。如果它在您的代码中遇到该参数,将优雅地出错。

这不是根本性的限制;从理论上讲,我们本可以支持这一点,但我们现在选择不这样做。

数据依赖的 Python 控制流

目前我们还不支持对数据依赖控制流的 vmap 操作。数据依赖控制流是指 if 语句、while 循环或 for 循环的条件是一个被 vmap 操作的 Tensor。例如,以下代码将引发错误信息:

def relu(x):
  if x > 0:
    return x
  return 0

x = torch.randn(3)
vmap(relu)(x)

然而,任何不依赖于被 vmap 操作的 Tensor 值的控制流都将正常工作:

def custom_dot(x):
  if x.dim() == 1:
    return torch.dot(x, x)
  return (x * x).sum()

x = torch.randn(3)
vmap(custom_dot)(x)

JAX 支持使用特殊的控制流运算符(例如 jax.lax.condjax.lax.while_loop )转换数据依赖控制流。我们正在研究为 PyTorch 添加这些运算符的等效功能。

数据依赖操作 (.item())

我们不支持(并且将不会支持)在调用 .item() 的 Tensor 上的 vmap。例如,以下代码将引发错误信息:

def f(x):
  return x.item()

x = torch.randn(3)
vmap(f)(x)

请尝试重写您的代码,以避免使用 .item() 调用。

您可能还会遇到关于使用 .item() 的错误信息,但您可能并没有使用它。在这些情况下,PyTorch 内部可能调用了 .item() ,请向 GitHub 上提交问题,我们将修复 PyTorch 内部问题。

动态形状操作(非零及其相关操作)

vmap(f) 要求将 f 应用于输入中的每个“示例”时返回具有相同形状的张量。不支持 torch.nonzerotorch.is_nonzero 等操作,将导致错误。

为了理解原因,请考虑以下示例:

xs = torch.tensor([[0, 1, 2], [0, 0, 3]])
vmap(torch.nonzero)(xs)

torch.nonzero(xs[0]) 返回形状为 2 的张量;但 torch.nonzero(xs[1]) 返回形状为 1 的张量。我们无法构造一个单独的输出张量;输出需要是一个稀疏张量(但 PyTorch 尚未有稀疏张量的概念)。

随机性 ¶

用户调用随机操作时的意图可能不明确。具体来说,一些用户可能希望随机行为在批次之间保持一致,而另一些用户可能希望在不同批次之间有所不同。为了解决这个问题, vmap 接受一个随机性标志。

该标志只能传递给 vmap,可以取三个值:“错误”、“不同”或“相同”,默认为“错误”。在“错误”模式下,任何对随机函数的调用都会产生一个错误,提示用户根据他们的用例使用其他两个标志之一。

在“不同”随机性下,批次中的元素会产生不同的随机值。例如,

def add_noise(x):
  y = torch.randn(())  # y will be different across the batch
  return x + y

x = torch.ones(3)
result = vmap(add_noise, randomness="different")(x)  # we get 3 different values

在“相同”的随机性下,一批中的元素会产生相同的随机值。例如,

def add_noise(x):
  y = torch.randn(())  # y will be the same across the batch
  return x + y

x = torch.ones(3)
result = vmap(add_noise, randomness="same")(x)  # we get the same value, repeated 3 times

警告

我们的系统只能确定 PyTorch 操作符的随机性行为,无法控制其他库(如 numpy)的行为。这与 JAX 解决方案的限制类似。

注意

使用任何一种支持的随机性进行多次 vmap 调用都不会产生相同的结果。就像标准 PyTorch 一样,用户可以通过在 vmap 外部使用 torch.manual_seed() 或使用生成器来获得随机性的可重现性。

注意

最后,我们的随机性与 JAX 不同,因为我们没有使用无状态 PRNG,部分原因是因为 PyTorch 不支持无状态 PRNG。相反,我们引入了标志系统,以允许最常见的随机性形式。如果您的用例不适用于这些随机性形式,请提交问题。


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

PyTorch 开发者文档全面访问

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源