torch.func¶
torch.func,之前被称为“functorch”,是 PyTorch 的类似 JAX 的可组合函数转换。
注意
该库目前处于测试版。这意味着(除非另有说明)功能通常可以正常工作,并且(我们 PyTorch 团队)致力于推进这个库。然而,API 可能会根据用户反馈而更改,并且我们没有对 PyTorch 操作的全覆盖。
如果您对 API 或希望涵盖的使用案例有建议,请打开 GitHub 问题或联系。我们很乐意了解您如何使用这个库。
什么是可组合函数转换?
“函数转换”是一种高阶函数,它接受一个数值函数并返回一个新的函数,该函数计算不同的量。
torch.func
具有自动微分转换(grad(f)
返回一个计算f
梯度的函数),向量化/批处理转换(vmap(f)
返回一个计算输入批次f
的函数),以及其他转换。这些函数转换可以任意组合。例如,组合
vmap(grad(f))
计算一个名为 per-sample-gradients 的量,而 PyTorch 目前无法高效地计算这个量。
为什么需要可组合的函数转换?
目前在 PyTorch 中存在一些难以实现的用例:
计算每个样本的梯度(或其他每个样本的量)
在单台机器上运行模型集合
在 MAML 的内循环中高效地批处理任务
高效地计算雅可比矩阵和海森矩阵
高效地计算批处理的雅可比矩阵和海森矩阵
组合 vmap()
、 grad()
和 vjp()
变换可以让我们在不为每个变换设计单独子系统的情况下表达上述内容。这种可组合函数变换的想法来源于 JAX 框架。