• 文档 >
  • torch.func >
  • torch.func API 参考 >
  • torch.func.vjp
快捷键

torch.func.vjp

torch.func.vjp(func, *primals, has_aux=False)[source]

代表向量-雅可比乘积,返回一个包含 func 应用到 primals 的结果以及一个函数的元组,该函数接受 cotangents 作为输入,并计算 func 关于 primals 的反向模式雅可比乘以 cotangents 的结果。

参数:
  • func (Callable) – 一个 Python 函数,可以接受一个或多个参数。必须返回一个或多个张量。

  • primals(张量)- func 的定位参数,必须全部是张量。返回的函数也将计算对这些参数的导数

  • has_aux(布尔值)- 标志表示 func 返回一个 (output, aux) 元组,其中第一个元素是要进行微分的函数的输出,第二个元素是其他不会进行微分的辅助对象。默认:False。

返回值:

返回一个包含 func 应用于 primals 的结果以及一个计算 func 相对于所有 primals 的 vjp 的函数的 (output, vjp_fn) 元组。如果 has_aux is True ,则返回一个 (output, vjp_fn, aux) 元组。返回的 vjp_fn 函数将返回每个 VJP 的元组。

当在简单情况下使用时, vjp() 的行为与 grad() 相同。

>>> x = torch.randn([5])
>>> f = lambda x: x.sin().sum()
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> grad = vjpfunc(torch.tensor(1.))[0]
>>> assert torch.allclose(grad, torch.func.grad(f)(x))

然而, vjp() 可以通过为每个输出传递余切来支持具有多个输出的函数。

>>> x = torch.randn([5])
>>> f = lambda x: (x.sin(), x.cos())
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> vjps = vjpfunc((torch.ones([5]), torch.ones([5])))
>>> assert torch.allclose(vjps[0], x.cos() + -x.sin())

vjp() 甚至可以支持输出为 Python 结构体。

>>> x = torch.randn([5])
>>> f = lambda x: {'first': x.sin(), 'second': x.cos()}
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])}
>>> vjps = vjpfunc(cotangents)
>>> assert torch.allclose(vjps[0], x.cos() + -x.sin())

该函数返回的 vjp() 将计算相对于每个 primals 的偏导数

>>> x, y = torch.randn([5, 4]), torch.randn([4, 5])
>>> (_, vjpfunc) = torch.func.vjp(torch.matmul, x, y)
>>> cotangents = torch.randn([5, 5])
>>> vjps = vjpfunc(cotangents)
>>> assert len(vjps) == 2
>>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1)))
>>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))

primalsf 的位置参数。所有 kwargs 都使用它们的默认值

>>> x = torch.randn([5])
>>> def f(x, scale=4.):
>>>   return x * scale
>>>
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> vjps = vjpfunc(torch.ones_like(x))
>>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.))

注意

使用 PyTorch 的 torch.no_gradvjp 结合。情况 1:在函数中使用 torch.no_grad

>>> def f(x):
>>>     with torch.no_grad():
>>>         c = x ** 2
>>>     return x - c

在这种情况下, vjp(f)(x) 将尊重内部 torch.no_grad

情况 2:在 torch.no_grad 上下文管理器中使用 vjp

>>> with torch.no_grad():
>>>     vjp(f)(x)

在这种情况下, vjp 将尊重内部 torch.no_grad ,但不尊重外部的一个。这是因为 vjp 是一个“函数转换”:其结果不应依赖于 f 外部的上下文管理器的结果。


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源