现代神经网络的训练通常依赖于使用低精度数据类型。在 A100 GPU 上,峰值 float16 矩阵乘法和卷积性能比峰值 float32 性能快 16 倍。由于 float16 和 bfloat16 数据类型仅为 float32 的一半大小,它们可以将带宽受限内核的性能翻倍,并减少训练网络所需的内存,从而允许使用更大的模型、更大的批次或更大的输入。使用 torch.amp(代表“自动混合精度”)模块可以轻松获得低精度数据类型的速度和内存使用优势,同时保持收敛行为。
加快速度和减少内存使用总是有利的——深度学习从业者可以测试更多的模型架构和超参数,并且可以训练更大、更强大的模型。没有使用混合精度,像 Narayanan 等人以及 Brown 等人描述的非常大的模型(即使使用专家手工优化的情况下,也需要成千上万的 GPU 数月时间来训练)是不可行的。
我们之前已经讨论过混合精度技术(在这里、这里和这里),这篇博客文章是对这些技术的总结,也是对混合精度的新手的介绍。
混合精度训练实践
混合精度训练技术——在 float32 数据类型的同时使用低精度 float16 或 bfloat16 数据类型——具有广泛的应用性和有效性。请参见图 1,了解使用混合精度成功训练的模型样本,以及图 2 和图 3 中 torch.amp 的使用示例速度提升。
图 1:使用 float16 成功训练的深度学习工作负载示例(来源)。
图 2:在 NVIDIA 8xV100 上使用 torch.amp 进行混合精度训练的性能与 8xV100 GPU 上使用 float32 训练的性能对比。柱状图表示 torch.amp 相对于 float32 的速度提升因子(数值越高越好)。(来源)。
图 3:在 NVIDIA 8xA100 上使用 torch.amp 进行混合精度训练的性能与 8xV100 GPU 的性能对比。柱状图表示 A100 相对于 V100 的速度提升因子(数值越高越好)。(来源)。
查看 NVIDIA 深度学习示例仓库以获取更多混合精度工作负载示例。
类似性能图表可以在 3D 医学图像分析、注视估计、视频合成、条件 GAN 和卷积 LSTM 中看到。黄等人表明,在 V100 GPU 上,混合精度训练比 float32 快 1.5 倍到 5.5 倍,在 A100 GPU 上快 1.3 倍到 2.5 倍,在各种网络上表现尤为明显。在非常大的网络上,混合精度需求更为明显。Narayanan 等人报告称,在 1024 个 A100 GPU 上(批大小为 1536)训练 GPT-3 175B 需要 34 天,但估计使用 float32 需要超过一年!
使用 torch.amp 入门混合精度
torch.amp 是 PyTorch 1.6 中引入的,它使得使用 float16 或 bfloat16 数据类型进行混合精度训练变得简单。请参阅这篇博客文章、教程和文档以获取更多详细信息。图 4 展示了将 AMP 与梯度缩放应用于网络的示例。
import torch
# Creates once at the beginning of training
scaler = torch.cuda.amp.GradScaler()
for data, label in data_iter:
optimizer.zero_grad()
# Casts operations to mixed precision
with torch.amp.autocast(device_type=“cuda”, dtype=torch.float16):
loss = model(data)
# Scales the loss, and calls backward()
# to create scaled gradients
scaler.scale(loss).backward()
# Unscales gradients and calls
# or skips optimizer.step()
scaler.step(optimizer)
# Updates the scale for next iteration
scaler.update()
图 4:AMP 配方
选择正确的方法
使用 float16 或 bfloat16 进行混合精度训练,可以有效地加速许多深度学习模型的收敛,但某些模型可能需要更仔细的数值精度管理。以下是一些选项:
- 使用全 float32 精度。在 PyTorch 中,默认情况下,浮点张量和模块以 float32 精度创建,但这是一种历史遗留问题,并不能代表大多数现代深度学习网络的训练。网络通常不需要这么高的数值精度。
- 启用 TensorFloat32(TF32)模式。在 Ampere 及以后的 CUDA 设备上,矩阵乘法和卷积可以使用 TensorFloat32(TF32)模式进行更快但略微不精确的计算。有关详细信息,请参阅 NVIDIA TF32 Tensor Cores 博客文章《通过 NVIDIA TF32 Tensor Cores 加速 AI 训练》。默认情况下,PyTorch 为卷积启用 TF32 模式,但不是矩阵乘法。除非网络需要全 float32 精度,否则我们建议也为此设置启用此设置(有关如何操作的文档请在此处查看)。这可以显著加快计算速度,通常不会导致数值精度的显著损失。
- 使用 torch.amp 与 bfloat16 或 float16。这两种低精度浮点数据类型通常速度相当快,但某些网络可能只对其中一种收敛。如果一个网络需要更高的精度,它可能需要使用 float16,而如果一个网络需要更大的动态范围,它可能需要使用 bfloat16,其动态范围与 float32 相同。如果观察到溢出,例如,我们建议尝试使用 bfloat16。
甚至还有比这里介绍的高级选项,比如仅对模型的部分使用 torch.amp 的自动转换,或直接管理混合精度。这些主题大多超出了本博客文章的范围,但请参阅下面的“最佳实践”部分。
最佳实践
我们强烈建议在训练网络时尽可能使用 torch.amp 的混合精度或 TF32 模式(在 Ampere 及以后的 CUDA 设备上)。然而,如果上述方法中的一种不起作用,我们建议以下方法:
- 高性能计算(HPC)应用、回归任务和生成网络可能只需要完整的 float32 IEEE 精度才能达到预期的收敛。
- 尝试选择性应用 torch.amp。特别是我们建议首先在执行 torch.linalg 模块操作或进行预处理和后处理的部分禁用它。这些操作通常特别敏感。请注意,TF32 模式是全局开关,不能在网络的部分区域选择性使用。首先启用 TF32 以检查网络操作是否对模式敏感,否则禁用它。
- 如果在使用 torch.amp 时遇到类型不匹配,我们不建议一开始就插入手动类型转换。这个错误表明网络可能存在问题,通常值得首先进行调查。
- 通过实验找出您的网络是否对格式范围和/或精度敏感。例如,在 float16 中对 bfloat16 预训练模型进行微调可能会因为从 bfloat16 训练中可能的大范围而导致 float16 中的范围问题,因此如果模型是在 bfloat16 中训练的,用户应坚持使用 bfloat16 微调。
- 混合精度训练的性能提升可能取决于多个因素(例如计算密集型与内存密集型问题),用户应使用调优指南来消除训练脚本中的其他瓶颈。尽管在理论上具有相似的性能优势,但 BF16 和 FP16 在实际应用中可能会有不同的速度。建议尝试所提到的格式,并使用速度最快且保持所需数值行为的一种。
更多详细信息,请参阅 AMP 教程、使用 Tensor Cores 训练神经网络,以及 PyTorch Dev 讨论区上的“浮点精度更深入细节”帖子。
结论
混合精度训练是训练现代硬件上的深度学习模型的重要工具,随着新硬件上低精度操作与 float32 之间的性能差距继续扩大,它将变得更加重要,如图 5 所示。
图 5:Volta 和 Ampere GPU 上 float16(FP16)与 float32 矩阵乘法的相对峰值吞吐量。在 Ampere 上还显示了 TensorFloat32(TF32)模式和 bfloat16 矩阵乘法的相对峰值吞吐量。随着新硬件的发布,预计低精度数据类型如 float16 和 bfloat16 与 float32 矩阵乘法的相对峰值吞吐量将增长。
PyTorch 的 torch.amp 模块使得开始使用混合精度变得简单,我们强烈建议使用它来加速训练并减少内存使用。torch.amp 支持 float16 和 bfloat16 混合精度。
仍然有一些网络在混合精度下训练比较困难,对于这些网络,我们建议尝试在 Ampere 和后续 CUDA 硬件上使用 TF32 加速矩阵乘法。网络很少对精度如此敏感,以至于每个操作都需要完整的 float32 精度。
如果您对 torch.amp 或 PyTorch 中的混合精度支持有任何疑问或建议,请通过在 PyTorch 论坛的混合精度分类下发帖或在 PyTorch GitHub 页面上提交问题来告诉我们。