备注
点击此处下载完整示例代码
(beta) 运行编译后的优化器与学习率调度器 ¶
创建于:2025 年 4 月 1 日 | 最后更新:2025 年 4 月 1 日 | 最后验证:2024 年 11 月 5 日
作者:迈克尔·拉佐斯
优化器是训练任何深度学习模型的关键算法。在本例中,我们将展示如何将使用 torch.compile
编译的优化器与 LR 调度器配对,以加速训练收敛。
备注
本教程需要 PyTorch 2.3.0 或更高版本。
模型设置 ¶
在本例中,我们将使用一个简单的线性层序列。
import torch
# Create simple model
model = torch.nn.Sequential(
*[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]
)
input = torch.rand(1024, device="cuda")
# run forward pass
output = model(input)
# run backward to populate the grads for our optimizer below
output.sum().backward()
设置并运行编译后的优化器与学习率调度器 ¶
在本节中,我们将使用 Adam 优化器与线性学习率调度器,并创建一个辅助函数来包装对每个的 step()
调用和 torch.compile()
。
备注
仅支持具有 7.0 或更高计算能力的 CUDA 设备。
# exit cleanly if we are on a device that doesn't support ``torch.compile``
if torch.cuda.get_device_capability() < (7, 0):
print("Exiting because torch.compile is not supported on this device.")
import sys
sys.exit(0)
# !!! IMPORTANT !!! Wrap the lr in a Tensor if we are pairing the
# the optimizer with an LR Scheduler.
# Without this, torch.compile will recompile as the value of the LR
# changes.
opt = torch.optim.Adam(model.parameters(), lr=torch.tensor(0.01))
sched = torch.optim.lr_scheduler.LinearLR(opt, total_iters=5)
@torch.compile(fullgraph=False)
def fn():
opt.step()
sched.step()
# Warmup runs to compile the function
for _ in range(5):
fn()
print(opt.param_groups[0]["lr"])
扩展:非张量 LR 会发生什么?
对于好奇者,我们将展示在不将 LR 包装在张量中时,如何窥视 torch.compile
会发生什么。
# No longer wrap the LR in a tensor here
opt = torch.optim.Adam(model.parameters(), lr=0.01)
sched = torch.optim.lr_scheduler.LinearLR(opt, total_iters=5)
@torch.compile(fullgraph=False)
def fn():
opt.step()
sched.step()
# Setup logging to view recompiles
torch._logging.set_logs(recompiles=True)
# Warmup runs to compile the function
# We will now recompile on each iteration
# as the value of the lr is mutated.
for _ in range(5):
fn()
通过这个例子,我们可以看到由于 param_groups[0]
中的 lr
守卫失败,我们多次重新编译优化器。
结论 ¶
在本教程中,我们展示了如何将编译的优化器与 LR 调度器配对以加速训练收敛。我们使用了一个由简单线性层组成的模型,并配以 Adam 优化器和 LinearLR 调度器,以展示迭代过程中学习率的改变。
参见:
编译优化器教程 - 编译优化器的简介。
使用 PT2 编译优化器 - 关于编译优化器的更深入技术细节。
脚本总运行时间:(0 分钟 0.000 秒)