• 教程 >
  • PyTorch 菜谱 >
  • (beta) 使用 torch.compile 编译优化器
快捷键

(beta) 使用 torch.compile 编译优化器 ¶

创建于:2025 年 4 月 1 日 | 最后更新:2025 年 4 月 1 日 | 最后验证:2024 年 11 月 5 日

作者:迈克尔·拉佐斯

优化器是训练任何深度学习模型的关键算法。由于它负责更新每个模型参数,因此它往往成为大型模型训练性能的瓶颈。在本例中,我们将对优化器应用 torch.compile 以观察 GPU 性能提升。

备注

本教程需要 PyTorch 2.2.0 或更高版本。

模型设置 ¶

在本例中,我们将使用一个简单的线性层序列。由于我们只对优化器进行基准测试,因此模型的选择并不重要,因为优化器性能是参数数量的函数。

根据您使用的机器不同,您得到的结果可能会有所差异。

import torch

model = torch.nn.Sequential(
    *[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]
)
input = torch.rand(1024, device="cuda")
output = model(input)
output.sum().backward()

设置并运行优化器基准测试 ¶

在本例中,我们将使用 Adam 优化器并创建一个辅助函数来包装 torch.compile()

备注

torch.compile 仅支持计算能力 >= 7.0 的 cuda 设备。

# exit cleanly if we are on a device that doesn't support torch.compile
if torch.cuda.get_device_capability() < (7, 0):
    print("Exiting because torch.compile is not supported on this device.")
    import sys
    sys.exit(0)


opt = torch.optim.Adam(model.parameters(), lr=0.01)


@torch.compile(fullgraph=False)
def fn():
    opt.step()


# Let's define a helpful benchmarking function:
import torch.utils.benchmark as benchmark


def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e6


# Warmup runs to compile the function
for _ in range(5):
    fn()

eager_runtime = benchmark_torch_function_in_microseconds(opt.step)
compiled_runtime = benchmark_torch_function_in_microseconds(fn)

assert eager_runtime > compiled_runtime

print(f"eager runtime: {eager_runtime}us")
print(f"compiled runtime: {compiled_runtime}us")

样本结果:

  • 期望运行时间:747.2437149845064 微秒

  • 编译运行时间:392.07384741178 微秒

参见 ¶

  • 深入技术概述请参阅

使用 PT2 编译优化器


评分这个教程

© 版权所有 2024,PyTorch。

使用 Sphinx 构建,主题由 Read the Docs 提供。
//暂时添加调查链接

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源