torch.autograd.Function.vmap
- static Function.vmap(info, in_dims, *args)[source][source]
定义在
torch.vmap()
下的 autograd.Function 的行为。为了支持
torch.autograd.Function()
,您必须重写此静态方法,或将generate_vmap_rule
设置为True
(您不能同时进行这两项操作)。如果您选择重写此静态方法:它必须接受
使用
info
对象作为第一个参数。info.batch_size
指定了被 vmapped 的维度大小,而info.randomness
是传递给torch.vmap()
的随机选项。使用
in_dims
元组作为第二个参数。对于args
中的每个参数,in_dims
都有一个对应的Optional[int]
。如果参数不是 Tensor 或者参数没有被 vmapped,则是None
,否则,它是一个整数,指定了 Tensor 被 vmapped 的维度。*args
,与forward()
的参数相同。
vmap 静态方法的返回值是一个包含
(output, out_dims)
的元组。类似于in_dims
,out_dims
应该与output
具有相同的结构,并且每个输出包含一个out_dim
,指定输出是否有 vmapped 维度以及它在其中的索引。请参阅《使用 autograd.Function 扩展 torch.func 的更多详细信息。》