• 文档 >
  • 自动微分包 - torch.autograd >
  • torch.autograd.Function.vmap
快捷键

torch.autograd.Function.vmap

static Function.vmap(info, in_dims, *args)[source][source]

定义在 torch.vmap() 下的 autograd.Function 的行为。

为了支持 torch.autograd.Function() ,您必须重写此静态方法,或将 generate_vmap_rule 设置为 True (您不能同时进行这两项操作)。

如果您选择重写此静态方法:它必须接受

  • 使用 info 对象作为第一个参数。 info.batch_size 指定了被 vmapped 的维度大小,而 info.randomness 是传递给 torch.vmap() 的随机选项。

  • 使用 in_dims 元组作为第二个参数。对于 args 中的每个参数, in_dims 都有一个对应的 Optional[int] 。如果参数不是 Tensor 或者参数没有被 vmapped,则是 None ,否则,它是一个整数,指定了 Tensor 被 vmapped 的维度。

  • *args ,与 forward() 的参数相同。

vmap 静态方法的返回值是一个包含 (output, out_dims) 的元组。类似于 in_dimsout_dims 应该与 output 具有相同的结构,并且每个输出包含一个 out_dim ,指定输出是否有 vmapped 维度以及它在其中的索引。

请参阅《使用 autograd.Function 扩展 torch.func 的更多详细信息。》


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源