• 文档 >
  • torch.func
快捷键

torch.func

torch.func,之前被称为“functorch”,是 PyTorch 的类似 JAX 的可组合函数转换。

注意

该库目前处于测试版。这意味着(除非另有说明)功能通常可以正常工作,并且(我们 PyTorch 团队)致力于推进这个库。然而,API 可能会根据用户反馈而更改,并且我们没有对 PyTorch 操作的全覆盖。

如果您对 API 或希望涵盖的使用案例有建议,请打开 GitHub 问题或联系。我们很乐意了解您如何使用这个库。

什么是可组合函数转换?

  • “函数转换”是一种高阶函数,它接受一个数值函数并返回一个新的函数,该函数计算不同的量。

  • torch.func 具有自动微分转换( grad(f) 返回一个计算 f 梯度的函数),向量化/批处理转换( vmap(f) 返回一个计算输入批次 f 的函数),以及其他转换。

  • 这些函数转换可以任意组合。例如,组合 vmap(grad(f)) 计算一个名为 per-sample-gradients 的量,而 PyTorch 目前无法高效地计算这个量。

为什么需要可组合的函数转换?

目前在 PyTorch 中存在一些难以实现的用例:

  • 计算每个样本的梯度(或其他每个样本的量)

  • 在单台机器上运行模型集合

  • 在 MAML 的内循环中高效地批处理任务

  • 高效地计算雅可比矩阵和海森矩阵

  • 高效地计算批处理的雅可比矩阵和海森矩阵

组合 vmap()grad()vjp() 变换可以让我们在不为每个变换设计单独子系统的情况下表达上述内容。这种可组合函数变换的想法来源于 JAX 框架。

阅读更多 ¶


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

PyTorch 开发者文档全面访问

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源