备注
点击此处下载完整示例代码
torch.export
AOTInductor 教程(Python 运行时版)(Beta)¶
创建于:2025 年 4 月 1 日 | 最后更新:2025 年 4 月 1 日 | 最后验证:2024 年 11 月 5 日
作者:Ankith Gunapal, Bin Bao, Angela Yi
警告
torch._inductor.aoti_compile_and_package
和 torch._inductor.aoti_load_package
处于 Beta 状态,可能存在向后兼容性破坏的变更。本教程提供了一个使用 Python 运行时部署模型的 API 示例。
之前已经展示了如何使用 AOTInductor 通过创建一个可以在非 Python 环境中运行的工件来对 PyTorch 导出的模型进行提前编译。在本教程中,您将学习如何使用 AOTInductor 进行 Python 运行时的端到端示例。
目录
前提条件 ¶
PyTorch 2.6 或更高版本
理解
torch.export
和 AOTInductor 基础完成 AOTInductor:为 Torch.Export 导出模型的前端编译教程
你将学习的内容 ¶
如何在 Python 运行时使用 AOTInductor
如何使用
torch._inductor.aoti_compile_and_package()
和torch.export.export()
生成编译后的工件如何使用
torch._export.aot_load()
在 Python 运行时加载和运行工件何时在 Python 运行时使用 AOTInductor
模型编译 ¶
我们将以 TorchVision 预训练的 ResNet18
模型为例。
第一步是将模型导出为图表示形式,使用 torch.export.export()
。要了解更多关于使用此功能的信息,您可以查看文档或教程。
一旦我们导出了 PyTorch 模型并获得了 ExportedProgram
,我们就可以将 AOTInductor 应用于编译程序到指定的设备,并将生成的内容保存到“pt2”文件中。
备注
此 API 支持与 torch.compile()
相同的可用选项,例如 mode
和 max_autotune
(对于想要启用 CUDA 图并利用基于 Triton 的矩阵乘法和卷积的用户)
import os
import torch
import torch._inductor
from torchvision.models import ResNet18_Weights, resnet18
model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.eval()
with torch.inference_mode():
inductor_configs = {}
if torch.cuda.is_available():
device = "cuda"
inductor_configs["max_autotune"] = True
else:
device = "cpu"
model = model.to(device=device)
example_inputs = (torch.randn(2, 3, 224, 224, device=device),)
exported_program = torch.export.export(
model,
example_inputs,
)
path = torch._inductor.aoti_compile_and_package(
exported_program,
package_path=os.path.join(os.getcwd(), "resnet18.pt2"),
inductor_configs=inductor_configs
)
aoti_compile_and_package()
的结果是一个名为“resnet18.pt2”的工件,可以在 Python 和 C++中加载和执行。
该工件本身包含由 AOTInductor 生成的代码,例如生成的 C++运行器文件、从 C++文件编译的共享库以及 CUDA 二进制文件,即 cubin 文件(如果针对 CUDA 进行优化)。
结构上,该工件是一个结构化的 .zip
文件,其规格如下:
我们可以使用以下命令来检查工件内容:
$ unzip -l resnet18.pt2
Archive: resnet18.pt2
Length Date Time Name
--------- ---------- ----- ----
1 01-08-2025 16:40 version
3 01-08-2025 16:40 archive_format
10088 01-08-2025 16:40 data/aotinductor/model/cagzt6akdaczvxwtbvqe34otfe5jlorktbqlojbzqjqvbfsjlge4.cubin
17160 01-08-2025 16:40 data/aotinductor/model/c6oytfjmt5w4c7onvtm6fray7clirxt7q5xjbwx3hdydclmwoujz.cubin
16616 01-08-2025 16:40 data/aotinductor/model/c7ydp7nocyz323hij4tmlf2kcedmwlyg6r57gaqzcsy3huneamu6.cubin
17776 01-08-2025 16:40 data/aotinductor/model/cyqdf46ordevqhiddvpdpp3uzwatfbzdpl3auj2nx23uxvplnne2.cubin
10856 01-08-2025 16:40 data/aotinductor/model/cpzfebfgrusqslui7fxsuoo4tvwulmrxirc5tmrpa4mvrbdno7kn.cubin
14608 01-08-2025 16:40 data/aotinductor/model/c5ukeoz5wmaszd7vczdz2qhtt6n7tdbl3b6wuy4rb2se24fjwfoy.cubin
11376 01-08-2025 16:40 data/aotinductor/model/csu3nstcp56tsjfycygaqsewpu64l5s6zavvz7537cm4s4cv2k3r.cubin
10984 01-08-2025 16:40 data/aotinductor/model/cp76lez4glmgq7gedf2u25zvvv6rksv5lav4q22dibd2zicbgwj3.cubin
14736 01-08-2025 16:40 data/aotinductor/model/c2bb5p6tnwz4elgujqelsrp3unvkgsyiv7xqxmpvuxcm4jfl7pc2.cubin
11376 01-08-2025 16:40 data/aotinductor/model/c6eopmb2b4ngodwsayae4r5q6ni3jlfogfbdk3ypg56tgpzhubfy.cubin
11624 01-08-2025 16:40 data/aotinductor/model/chmwe6lvoekzfowdbiizitm3haiiuad5kdm6sd2m6mv6dkn2zk32.cubin
15632 01-08-2025 16:40 data/aotinductor/model/c3jop5g344hj3ztsu4qm6ibxyaaerlhkzh2e6emak23rxfje6jam.cubin
25472 01-08-2025 16:40 data/aotinductor/model/chaiixybeiuuitm2nmqnxzijzwgnn2n7uuss4qmsupgblfh3h5hk.cubin
139389 01-08-2025 16:40 data/aotinductor/model/cvk6qzuybruhwxtfblzxiov3rlrziv5fkqc4mdhbmantfu3lmd6t.cpp
27 01-08-2025 16:40 data/aotinductor/model/cvk6qzuybruhwxtfblzxiov3rlrziv5fkqc4mdhbmantfu3lmd6t_metadata.json
47195424 01-08-2025 16:40 data/aotinductor/model/cvk6qzuybruhwxtfblzxiov3rlrziv5fkqc4mdhbmantfu3lmd6t.so
--------- -------
47523148 18 files
Python 中的模型推理 ¶
要在 Python 中加载和运行该工件,我们可以使用 torch._inductor.aoti_load_package()
。
import os
import torch
import torch._inductor
model_path = os.path.join(os.getcwd(), "resnet18.pt2")
compiled_model = torch._inductor.aoti_load_package(model_path)
example_inputs = (torch.randn(2, 3, 224, 224, device=device),)
with torch.inference_mode():
output = compiled_model(example_inputs)
何时使用与 Python 运行时一起的 AOTInductor ¶
使用 AOTInductor 与 Python 运行时的主要原因主要有两个:
torch._inductor.aoti_compile_and_package
生成一个单一的序列化工件。这对于模型版本控制和跟踪模型性能随时间变化非常有用。由于
torch.compile()
是一个即时编译器,第一次编译会有一定的预热成本。您的部署需要考虑第一次推理所花费的编译时间。使用 AOTInductor,编译是在使用torch.export.export
和torch._inductor.aoti_compile_and_package
之前完成的。在部署时,在加载模型后运行推理不会产生任何额外的成本。
下文部分展示了使用 AOTInductor 在第一次推理中实现的加速效果。
我们定义了一个效用函数 timed
来衡量推理所需的时间。
import time
def timed(fn):
# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for accurate
# measurement on CUDA enabled devices.
if torch.cuda.is_available():
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
else:
start = time.time()
result = fn()
if torch.cuda.is_available():
end.record()
torch.cuda.synchronize()
else:
end = time.time()
# Measure time taken to execute the function in miliseconds
if torch.cuda.is_available():
duration = start.elapsed_time(end)
else:
duration = (end - start) * 1000
return result, duration
让我们来测量使用 AOTInductor 进行首次推理所需的时间
torch._dynamo.reset()
model = torch._inductor.aoti_load_package(model_path)
example_inputs = (torch.randn(1, 3, 224, 224, device=device),)
with torch.inference_mode():
_, time_taken = timed(lambda: model(example_inputs))
print(f"Time taken for first inference for AOTInductor is {time_taken:.2f} ms")
让我们来测量使用 torch.compile
进行首次推理所需的时间
torch._dynamo.reset()
model = resnet18(weights=ResNet18_Weights.DEFAULT).to(device)
model.eval()
model = torch.compile(model)
example_inputs = torch.randn(1, 3, 224, 224, device=device)
with torch.inference_mode():
_, time_taken = timed(lambda: model(example_inputs))
print(f"Time taken for first inference for torch.compile is {time_taken:.2f} ms")
我们看到,与 torch.compile
相比,使用 AOTInductor 在首次推理时间上有了显著加速
结论 §
在本教程中,我们学习了如何通过编译和加载预训练的 ResNet18
模型来有效地使用 AOTInductor 进行 Python 运行时。这个过程展示了在 Python 环境中生成编译工件并运行的实用应用。我们还探讨了在使用 AOTInductor 进行模型部署时的优势,尤其是在首次推理时间上的加速。
脚本总运行时间:(0 分钟 0.000 秒)