从 functorch 迁移到 torch.func ¶
torch.func,之前被称为“functorch”,是 PyTorch 的类似 JAX 的可组合函数转换。
functorch 最初作为 PyTorch/functorch 仓库中的一个树外库启动。我们的目标始终是将 functorch 直接集成到 PyTorch 中,并提供作为一个核心 PyTorch 库。
作为集成的最后一步,我们决定从顶级包( functorch
)迁移到 PyTorch 的一部分,以反映函数转换直接集成到 PyTorch 核心的方式。从 PyTorch 2.0 开始,我们将弃用 import functorch
,并要求用户迁移到最新的 API,我们将继续维护。 import functorch
将保留以维护几个版本的向后兼容性。
函数转换
以下 API 是以下 functorch API 的直接替换。它们完全向后兼容。
functorch API |
PyTorch API(自 PyTorch 2.0 版本起) |
---|---|
functorch.vmap |
|
functorch.grad |
|
functorch.vjp |
|
functorch.jvp |
|
functorch.jacrev |
|
functorch.jacfwd |
|
functorch.hessian |
|
functorch.functionalize |
此外,如果您正在使用 torch.autograd.functional API,请尝试使用 torch.func
等价函数。 torch.func
函数转换在许多情况下更易于组合且性能更优。
torch.autograd.functional API |
torch.func API(自 PyTorch 2.0 起) |
---|---|
|
|
|
|
NN 模块工具
我们已将 API 更改为在 NN 模块上应用函数转换,以更好地适应 PyTorch 设计理念。新的 API 有所不同,请仔细阅读本节。
functorch.make_functional
torch.func.functional_call()
是 functorch.make_functional 和 functorch.make_functional_with_buffers 的替代品。但是,它并不是直接替代。
如果您急于求成,可以使用此 gist 中的辅助函数来模拟 functorch.make_functional 和 functorch.make_functional_with_buffers 的行为。我们建议直接使用 torch.func.functional_call()
,因为它是一个更明确和灵活的 API。
具体来说,functorch.make_functional 返回一个功能模块和参数。该功能模块接受参数和模型输入作为参数。 torch.func.functional_call()
允许使用新的参数、缓冲区和输入调用现有模块的前向传递。
下面是一个使用 functorch 与 torch.func
计算模型参数梯度的示例:
# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
fmodel, params = functorch.make_functional(model)
def compute_loss(params, inputs, targets):
prediction = fmodel(params, inputs)
return torch.nn.functional.mse_loss(prediction, targets)
grads = functorch.grad(compute_loss)(params, inputs, targets)
# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
params = dict(model.named_parameters())
def compute_loss(params, inputs, targets):
prediction = torch.func.functional_call(model, params, (inputs,))
return torch.nn.functional.mse_loss(prediction, targets)
grads = torch.func.grad(compute_loss)(params, inputs, targets)
下面是一个计算模型参数雅可比的示例:
# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
fmodel, params = functorch.make_functional(model)
jacobians = functorch.jacrev(fmodel)(params, inputs)
# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
from torch.func import jacrev, functional_call
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
params = dict(model.named_parameters())
# jacrev computes jacobians of argnums=0 by default.
# We set it to 1 to compute jacobians of params
jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,))
注意,为了内存消耗,你应该只携带参数的单个副本。 model.named_parameters()
不会复制参数。如果在你的模型训练中你原地更新模型的参数,那么你的 nn.Module
模型就拥有参数的单个副本,一切正常。
然而,如果您想在字典中携带参数并在原地更新它们,那么参数有两个副本:字典中的一个和 model
中的一个。在这种情况下,您应该将 model
转换为通过 model.to('meta')
的元设备来不占用内存。
functorch.combine_state_for_ensemble
请使用 torch.func.stack_module_state()
代替 functorch.combine_state_for_ensemble torch.func.stack_module_state()
返回两个字典,一个是堆叠的参数,另一个是堆叠的缓冲区,然后可以使用 torch.vmap()
和 torch.func.functional_call()
进行集成。
例如,以下是一个如何在非常简单的模型上集成的示例:
import torch
num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)
# ---------------
# using functorch
# ---------------
import functorch
fmodel, params, buffers = functorch.combine_state_for_ensemble(models)
output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)
# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import copy
# Construct a version of the model with no memory by putting the Tensors on
# the meta device.
base_model = copy.deepcopy(models[0])
base_model.to('meta')
params, buffers = torch.func.stack_module_state(models)
# It is possible to vmap directly over torch.func.functional_call,
# but wrapping it in a function makes it clearer what is going on.
def call_single_model(params, buffers, data):
return torch.func.functional_call(base_model, (params, buffers), (data,))
output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)
functorch.compile
我们不再支持 functorch.compile(也称为 AOTAutograd)作为 PyTorch 编译的前端;我们已经将 AOTAutograd 集成到 PyTorch 的编译故事中。如果您是用户,请使用 torch.compile()
代替。