PyTorch 2.0 NNModule 支持 ¶
作者:Will Constable
torch.compile 对 torch.nn.Module 对象有特殊处理,与对任意 Python 类的处理方式不同,目的是通过做出关于结构的假设来生成更快的代码。
本文档描述了由于这种专业化而产生的某些权衡或边缘情况。
NN 模块钩子支持 ¶
以前,torch.compile 不支持 nn.Modules 上的钩子,如果注册了钩子,它们在编译程序中将被简单地忽略。实际上,许多用户根本不使用 nn.Module 钩子,或者只用于调试工作流程,但将 nn.Module 钩子与 torch.compile 组合是有效的用例。
通过 nn.Module.__call__实现编排的钩子包括_forward_pre_hooks、forward_hooks、_backward_pre_hooks 和_backward_hooks,以下称为“调用钩子”。这些钩子部分由 torch.compile 支持,以下为限制说明。
另一类钩子包括_state_dict_hooks 及其预和 load 变体,目前仍不支持由 torch.compile。
nn.Module.__call__ 钩子使用和限制
默认情况下,torch.compile 会追踪 nn.Module.__call__ 的内容,这意味着它会遇到并运行 forward/pre-forward 钩子。如果您在调用 torch.compile 之前安装钩子,然后没有删除或修改钩子,您的用例应该默认得到支持。
Backward/Pre-backward 钩子通常也得到支持,但有类似的注意事项:目前 dynamo 在访问 backward_hooks 字典时发生 graph-breaks,这可能在一些工作后可行。graph-breaks 也会影响 backward 钩子的触发时间,因为 graph-segments 作为 autograd 函数运行,它们会在同一时间产生所有梯度。假设 dynamo 能够在存在 backward-hooks 的情况下不发生 graph-break,我们仍然预计一系列模块的 backward 钩子将在整个编译图 backward 运行完毕后一起触发。
在‘允许的模块’上使用 torch.compile 将 torch.conv 等常见模块以及难以追踪的模块视为特殊模块,允许它们在 dynamo 图中被隐式调用,而不是被 dynamo 追踪。对于此类模块,当前钩子会触发图断开,使得受影响的模块在 dynamo 之外运行。根据模型的不同,这可能会引入显著的性能下降,并且需要额外的工作来改进这一支持。
默认情况下,torch._dynamo.config.skip_nnmodule_hook_guards 设置为 True,意味着不会在每个 nn.Module 钩子字典上安装守卫,通过减少守卫执行时间来提高运行时性能,但代价是编译后无法察觉任何钩子字典是否被更改。
如果您希望在编译后能够删除或修改钩子,并让 torch.compile 做出适当的反应(通过重新编译),则需要将 skip_nnmodule_hook_guards 设置为 False,并预期因增加守卫而导致的运行时性能下降。
TODO:确认反向/前向反向钩子是否正常工作,并相应地记录文档。
状态字典钩子
在 torch.compile 中尚未支持状态字典钩子。
TODO:如果钩子破坏图,则警告一次。如果存在钩子,则将警告一次指向此文档。