• 文档 >
  • torch.func >
  • torch.func API 参考 >
  • torch.func.jvp
快捷键

torch.func.jvp

torch.func.jvp(func, primals, tangents, *, strict=False, has_aux=False)[source]

代表雅可比-向量积,返回一个包含 func(*primals)的输出和“ funcprimals 处的雅可比”乘以 tangents 的元组。这也就是前向自动微分模式。

参数:
  • func(函数)- 一个 Python 函数,接受一个或多个参数,其中必须有一个是 Tensor,并返回一个或多个 Tensor

  • primals(张量)- func 的定位参数,必须全部是张量。返回的函数也将计算对这些参数的导数

  • tangents(张量)- 计算雅可比-向量积的“向量”。必须与 func 的输入具有相同的结构和大小

  • has_aux(布尔值)- 标志表示 func 返回一个 (output, aux) 元组,其中第一个元素是要进行微分的函数的输出,第二个元素是其他不会进行微分的辅助对象。默认:False。

返回值:

返回一个包含 funcprimals 评估结果的输出和雅可比-向量积的 (output, jvp_out) 元组。如果 has_aux is True ,则返回一个 (output, jvp_out, aux) 元组。

注意

您可能会看到这个 API 因为“X 运算符未实现前向模式 AD”而出错。如果是这样,请提交一个错误报告,我们将优先处理。

jvp 在您希望计算函数 R^1 -> R^N 的梯度时很有用

>>> from torch.func import jvp
>>> x = torch.randn([])
>>> f = lambda x: x * torch.tensor([1., 2., 3])
>>> value, grad = jvp(f, (x,), (torch.tensor(1.),))
>>> assert torch.allclose(value, f(x))
>>> assert torch.allclose(grad, torch.tensor([1., 2, 3]))

jvp() 可以通过传递每个输入的切线来支持具有多个输入的函数

>>> from torch.func import jvp
>>> x = torch.randn(5)
>>> y = torch.randn(5)
>>> f = lambda x, y: (x * y)
>>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))
>>> assert torch.allclose(output, x + y)

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源