torch.func.linearize¶
- torch.func.linearize(func, *primals)[source]¶
返回
func
在primals
处的值和线性近似值primals
。- 参数:
func (Callable) – 一个接受一个或多个参数的 Python 函数。
primals (Tensors) – 位置参数
func
必须都是张量。这些是在函数进行线性近似时的值。
- 返回值:
返回一个包含
func
应用于primals
的结果以及一个计算func
在primals
处 jvp 的函数的(output, jvp_fn)
元组。- 返回类型:
Any 类型的元组
当需要在
primals
处多次计算 jvp 时,linearize 非常有用。然而,为了实现这一点,linearize 会保存中间计算,并且比直接应用 jvp 有更高的内存需求。因此,如果所有tangents
都已知,计算 vmap(jvp)可能更有效率。注意
linearize 会评估
func
两次。请提交一个实现单次评估的实现的 issue。- 示例::
>>> import torch >>> from torch.func import linearize >>> def fn(x): ... return x.sin() ... >>> output, jvp_fn = linearize(fn, torch.zeros(3, 3)) >>> jvp_fn(torch.ones(3, 3)) tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]) >>>