快捷键

torch.func API 参考 ¶

函数转换 ¶

vmap

vmap 是向量化映射; vmap(func) 返回一个对输入的某个维度应用 func 的新函数。

grad

grad 运算符有助于计算 func 关于由 argnums 指定的输入的梯度。

grad_and_value

返回一个函数,用于计算梯度与原函数(或正向函数)的计算元组。

vjp

代表向量-雅可比乘积,返回一个包含 func 应用到 primals 的结果以及一个函数的元组,该函数接受 cotangents 作为输入,并计算 func 关于 primals 的反向模式雅可比乘以 cotangents 的结果。

jvp

代表雅可比向量积,返回一个元组,包含 func(*primals) 的输出以及 "在 primals 处评估的 func 的雅可比矩阵" 乘以 tangents

linearize

返回 funcprimals 处的值和线性近似值 primals

jacrev

使用反向模式自动微分计算 func 关于索引 argnum 的 arg(s) 的雅可比矩阵。

jacfwd

使用前向模式自动微分计算 func 关于索引 argnum 的 arg(s) 的雅可比矩阵。

hessian

通过正向-反向策略计算 func 关于索引 argnum 的 arg(s) 的 Hessian。

functionalize

functionalize 是一个转换,可以用来从函数中移除(中间)突变和别名,同时保留函数的语义。

torch.nn.Modules 的实用工具。

通常,你可以转换调用 torch.nn.Module 的函数。例如,以下是一个计算函数的雅可比矩阵的例子,该函数接受三个值并返回三个值:

model = torch.nn.Linear(3, 3)

def f(x):
    return model(x)

x = torch.randn(3)
jacobian = jacrev(f)(x)
assert jacobian.shape == (3, 3)

然而,如果您想对模型的参数进行计算雅可比矩阵,那么需要有一种方法来构造一个函数,其中参数是该函数的输入。这就是 functional_call() 的作用:它接受一个 nn.Module,变换后的 parameters ,以及模块前向传播的输入。它返回使用替换参数运行模块前向传播的值。

下面是如何计算参数的雅可比矩阵

model = torch.nn.Linear(3, 3)

def f(params, x):
    return torch.func.functional_call(model, params, x)

x = torch.randn(3)
jacobian = jacrev(f)(dict(model.named_parameters()), x)

functional_call

通过替换模块参数和缓冲区,在模块上执行功能调用。

stack_module_state

使用 vmap() 准备一个 torch.nn.Modules 列表,用于集成。

replace_all_batch_norm_modules_

通过将 running_meanrunning_var 设置为 None,并将任何 nn.BatchNorm 模块的 track_running_stats 设置为 False 来实现就地更新 root

如果您正在寻找修复批归一化模块的信息,请遵循此处指南

调试工具 ¶

debug_unwrap

解包 functorch 张量(例如


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

PyTorch 开发者文档全面访问

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源