• 文档 >
  • torch.func >
  • torch.func API 参考 >
  • torch.func.linearize
快捷键

torch.func.linearize

torch.func.linearize(func, *primals)[source]

返回 funcprimals 处的值和线性近似值 primals

参数:
  • func (Callable) – 一个接受一个或多个参数的 Python 函数。

  • primals (Tensors) – 位置参数 func 必须都是张量。这些是在函数进行线性近似时的值。

返回值:

返回一个包含 func 应用于 primals 的结果以及一个计算 funcprimals 处 jvp 的函数的 (output, jvp_fn) 元组。

返回类型:

Any 类型的元组

当需要在 primals 处多次计算 jvp 时,linearize 非常有用。然而,为了实现这一点,linearize 会保存中间计算,并且比直接应用 jvp 有更高的内存需求。因此,如果所有 tangents 都已知,计算 vmap(jvp)可能更有效率。

注意

linearize 会评估 func 两次。请提交一个实现单次评估的实现的 issue。

示例::
>>> import torch
>>> from torch.func import linearize
>>> def fn(x):
...     return x.sin()
...
>>> output, jvp_fn = linearize(fn, torch.zeros(3, 3))
>>> jvp_fn(torch.ones(3, 3))
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
>>>

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源