• 文档 >
  • torch.func >
  • torch.func API 参考 >
  • 修复批量归一化
快捷键

修复批量归一化 ¶

发生了什么? ¶

批标准化需要将运行均值和运行方差就地更新为与输入相同的大小。Functorch 不支持对接受批处理张量的常规张量进行就地更新(即不允许使用 regular.add_(batched) )。因此,当对单个模块的输入批次进行 vmapping 时,我们会遇到这个错误

如何修复

其中一种最好的支持方式是切换为 GroupNorm。选项 1 和 2 支持这一点

所有这些选项都假设您不需要运行统计信息。如果您使用的是模块,这意味着假设您不会在评估模式下使用批标准化。如果您有涉及在评估模式下使用 vmap 运行批标准化的用例,请提交问题

选项 1:更改批归一化(BatchNorm)

如果您想更改成组归一化(GroupNorm),在您有批归一化的任何地方,将其替换为:

BatchNorm2d(C, G, track_running_stats=False)

这里 C 与原始批归一化的 C 相同。 G 是将 C 分成多少组的数量。因此, C % G == 0 ,作为后备,您可以设置 C == G ,这意味着每个通道将被单独处理。

如果您必须使用批归一化并且您自己构建了模块,您可以更改模块以不使用运行统计信息。换句话说,在任何有批归一化模块的地方,将 track_running_stats 标志设置为 False

BatchNorm2d(64, track_running_stats=False)

选项 2:torchvision 参数 ¶

一些 torchvision 模型,如 resnet 和 regnet,可以接受一个 norm_layer 参数。这些参数通常默认为 BatchNorm2d,如果它们已经被默认设置。

而你可以将其设置为 GroupNorm。

import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c))

这里,再次, c % g == 0 因此,作为一个后备方案,设置 g = c

如果你依赖于 BatchNorm,请确保使用不使用运行统计信息的版本

import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))

选项 3:functorch 的补丁功能

functorch 增加了一些功能,允许快速就地修补模块以不使用运行统计信息。更改归一化层较为脆弱,所以我们没有提供该功能。如果你有一个希望 BatchNorm 不使用运行统计信息的网络,你可以运行 replace_all_batch_norm_modules_ 来就地更新模块以不使用运行统计信息

from torch.func import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(net)

选项 4:评估模式

在评估模式下运行时,running_mean 和 running_var 不会被更新。因此,vmap 可以支持此模式

model.eval()
vmap(model)(x)
model.train()

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

PyTorch 开发者文档全面访问

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源