• 文档 >
  • torch.onnx >
  • 基于 TorchDynamo 的 ONNX 导出器 >
  • 理解基于 TorchDynamo 的 ONNX 导出器内存使用情况
快捷键

理解基于 TorchDynamo 的 ONNX 导出器内存使用情况 ¶

之前的基于 TorchScript 的 ONNX 导出器会执行模型一次以跟踪其执行,如果模型的内存需求超过了可用的 GPU 内存,这可能会导致它耗尽内存。这个问题已经通过新的基于 TorchDynamo 的 ONNX 导出器得到了解决。

基于 TorchDynamo 的 ONNX 导出器利用 torch.export.export()函数利用 FakeTensorMode 来避免在导出过程中执行实际的张量计算。与基于 TorchScript 的 ONNX 导出器相比,这种方法可以显著降低内存使用。

下面是一个演示基于 TorchScript 和基于 TorchDynamo 的 ONNX 导出器之间内存使用差异的示例。在这个示例中,我们使用 MONAI 中的 HighResNet 模型。在继续之前,请从 PyPI 安装它:

pip install monai

PyTorch 提供了一个用于捕获和可视化内存使用跟踪的工具。我们将使用这个工具来记录两个导出器在导出过程中的内存使用情况,并比较结果。您可以在理解 CUDA 内存使用中找到更多关于这个工具的详细信息。

基于 TorchScript 的导出器

以下代码可以运行以生成快照文件,该文件记录导出过程中分配的 CUDA 内存的状态。

import torch

from monai.networks.nets import (
    HighResNet,
)

torch.cuda.memory._record_memory_history()

model = HighResNet(
    spatial_dims=3, in_channels=1, out_channels=3, norm_type="batch"
).eval()

model = model.to("cuda")
data = torch.randn(30, 1, 48, 48, 48, dtype=torch.float32).to("cuda")

with torch.no_grad():
    onnx_program = torch.onnx.export(
        model,
        data,
        "torchscript_exporter_highresnet.onnx",
        dynamo=False,
    )

snapshot_name = "torchscript_exporter_example.pickle"
print(f"generate {snapshot_name}")

torch.cuda.memory._dump_snapshot(snapshot_name)
print("Export is done.")

打开 pytorch.org/memory_viz,将生成的快照文件拖放到可视化器中。内存使用情况如下:

_images/torch_script_exporter_memory_usage.png

从这张图中,我们可以看到内存使用峰值超过 2.8GB。

基于 TorchDynamo 的导出器

以下代码可以运行以生成快照文件,该文件记录导出过程中分配的 CUDA 内存的状态。

import torch

from monai.networks.nets import (
    HighResNet,
)

torch.cuda.memory._record_memory_history()

model = HighResNet(
    spatial_dims=3, in_channels=1, out_channels=3, norm_type="batch"
).eval()

model = model.to("cuda")
data = torch.randn(30, 1, 48, 48, 48, dtype=torch.float32).to("cuda")

with torch.no_grad():
    onnx_program = torch.onnx.export(
                        model,
                        data,
                        "test_faketensor.onnx",
                        dynamo=True,
                    )

snapshot_name = f"torchdynamo_exporter_example.pickle"
print(f"generate {snapshot_name}")

torch.cuda.memory._dump_snapshot(snapshot_name)
print(f"Export is done.")

打开 pytorch.org/memory_viz,将生成的 pickled 快照文件拖放到可视化器中。内存使用情况如下:

_images/torch_dynamo_exporter_memory_usage.png

从这张图中,我们可以看到内存使用峰值仅为约 45MB。与基于 TorchScript 的导出器的内存使用峰值相比,减少了 98% 的内存使用。


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

PyTorch 开发者文档全面访问

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源