• 教程 >
  • PyTorch 菜谱 >
  • torch.export 流程演示,常见挑战及解决方案
快捷键

torch.export 流程演示,常见挑战及解决方案 ¶

作者:Ankith Gunapal,Jordi Ramon,Marcos Carranza

在《torch.export 教程简介》中,我们学习了如何使用 torch.export。本教程在上一教程的基础上进行了扩展,探讨了使用代码导出流行模型的过程,并针对可能出现的常见挑战进行了说明。

在本教程中,您将学习如何导出以下用例的模型:

  • 视频分类器(MViT)

  • 自动语音识别(OpenAI Whisper-Tiny)

  • 图像描述(BLIP)

  • 可提示图像分割(SAM2)

四个模型均被选中以展示 torch.export 的独特功能,以及实现过程中的一些实际考虑和问题。

前提条件 _

  • PyTorch 2.4 或更高版本

  • torch.export 和 PyTorch 即时推理的基本理解。

关键要求: torch.export 无图断句

torch.compile 通过使用 JIT 将 PyTorch 代码编译成优化的内核来加速 PyTorch 代码。它使用 TorchDynamo 优化给定的模型,创建优化的图,然后使用 API 中指定的后端将其降低到硬件。当 TorchDynamo 遇到不支持的 Python 特性时,会中断计算图,让默认的 Python 解释器处理不支持的代码,然后继续捕获图。这种计算图的中断称为图断。

torch.exporttorch.compile 之间的一个关键区别是 torch.export 不支持图断点,这意味着您要导出的整个模型或模型的一部分必须是一个单独的图。这是因为处理图断点需要用默认的 Python 评估来解释不支持的操作,这与 torch.export 的设计不兼容。您可以通过此链接了解各种 PyTorch 框架之间的差异详情。

您可以使用以下命令在程序中识别图断点:

TORCH_LOGS="graph_breaks" python <file_name>.py

您需要修改程序以消除图断点。一旦解决,您就可以准备导出模型了。PyTorch 对流行的 HuggingFace 和 TIMM 模型进行 torch.compile 的夜间基准测试。这些模型中的大多数都没有图断点。

本配方中的模型没有图断点,但使用 torch.export 时会失败。

视频分类

MViT 是一种基于多尺度视觉变换器的模型类别。该模型使用 Kinetics-400 数据集进行了视频分类训练。使用相关数据集的该模型可以用于游戏中的动作识别。

下面的代码通过 batch_size=2 跟踪导出 MViT,然后检查导出的程序是否可以通过 batch_size=4 运行。

import numpy as np
import torch
from torchvision.models.video import MViT_V1_B_Weights, mvit_v1_b
import traceback as tb

model = mvit_v1_b(weights=MViT_V1_B_Weights.DEFAULT)

# Create a batch of 2 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(2,16, 224, 224, 3)
# Transpose to get [1, 3, num_clips, height, width].
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))

# Export the model.
exported_program = torch.export.export(
    model,
    (input_frames,),
)

# Create a batch of 4 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(4,16, 224, 224, 3)
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
try:
    exported_program.module()(input_frames)
except Exception:
    tb.print_exc()

错误:静态批量大小

    raise RuntimeError(
RuntimeError: Expected input at *args[0].shape[0] to be equal to 2, but got 4

默认情况下,导出流程将跟踪程序,假设所有输入形状都是静态的,因此如果您使用与跟踪时不同的输入形状运行程序,您将遇到错误。

解决方案

为了解决这个错误,我们指定输入的第一个维度( batch_size )为动态,指定 batch_size 的预期范围。在下面的修正示例中,我们指定预期的 batch_size 可以从 1 到 16。需要注意的是, min=2 不是一个错误,它在《0/1 特殊化问题》中有解释。关于 torch.export 的动态形状的详细描述可以在导出教程中找到。下面的代码演示了如何导出具有动态批处理大小的 mViT:

import numpy as np
import torch
from torchvision.models.video import MViT_V1_B_Weights, mvit_v1_b
import traceback as tb


model = mvit_v1_b(weights=MViT_V1_B_Weights.DEFAULT)

# Create a batch of 2 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(2,16, 224, 224, 3)

# Transpose to get [1, 3, num_clips, height, width].
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))

# Export the model.
batch_dim = torch.export.Dim("batch", min=2, max=16)
exported_program = torch.export.export(
    model,
    (input_frames,),
    # Specify the first dimension of the input x as dynamic
    dynamic_shapes={"x": {0: batch_dim}},
)

# Create a batch of 4 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(4,16, 224, 224, 3)
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
try:
    exported_program.module()(input_frames)
except Exception:
    tb.print_exc()

自动语音识别

自动语音识别(ASR)是使用机器学习将语音转录成文本的技术。Whisper 是来自 OpenAI 的基于 Transformer 的编码器-解码器模型,该模型在 680k 小时的标记数据上进行了训练,用于 ASR 和语音翻译。下面的代码尝试导出用于 ASR 的 whisper-tiny 模型。

import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset

# load model
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")

# dummy inputs for exporting the model
input_features = torch.randn(1,80, 3000)
attention_mask = torch.ones(1, 3000)
decoder_input_ids = torch.tensor([[1, 1, 1 , 1]]) * model.config.decoder_start_token_id

model.eval()

exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(input_features, attention_mask, decoder_input_ids,))

错误:使用 TorchDynamo 的严格跟踪

torch._dynamo.exc.InternalTorchDynamoError: AttributeError: 'DynamicCache' object has no attribute 'key_cache'

默认情况下, torch.export 使用 TorchDynamo 进行代码追踪,TorchDynamo 是一个字节码分析引擎,它符号性地分析您的代码并构建一个图。这种分析提供了更强大的安全性保证,但并非所有 Python 代码都受支持。当我们使用默认的严格模式导出 whisper-tiny 模型时,通常会在 Dynamo 中返回一个错误,因为存在不受支持的功能。要了解为什么在 Dynamo 中出错,您可以参考这个 GitHub 问题。

解决方案

为了解决上述错误, torch.export 支持使用 non_strict 模式进行程序追踪,该模式使用 Python 解释器进行追踪,这与 PyTorch 的即时执行类似。唯一的区别是所有 Tensor 对象将被替换为 ProxyTensors ,这将记录它们的所有操作到图中。通过使用 strict=False ,我们能够导出程序。

import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset

# load model
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")

# dummy inputs for exporting the model
input_features = torch.randn(1,80, 3000)
attention_mask = torch.ones(1, 3000)
decoder_input_ids = torch.tensor([[1, 1, 1 , 1]]) * model.config.decoder_start_token_id

model.eval()

exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(input_features, attention_mask, decoder_input_ids,), strict=False)

图像描述

图像描述是将图像内容用文字表达的任务。在游戏领域,图像描述可以通过动态生成场景中各种游戏对象的文本描述来增强游戏体验,从而为玩家提供更多细节。BLIP 是由 SalesForce Research 发布的一个流行的图像描述模型。下面的代码尝试使用 batch_size=1 导出 BLIP。

import torch
from models.blip import blip_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_size = 384
image = torch.randn(1, 3,384,384).to(device)
caption_input = ""

model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'
model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base')
model.eval()
model = model.to(device)

exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(image,caption_input,), strict=False)

错误:无法修改具有冻结存储的张量

在导出模型时可能会失败,因为模型实现可能包含某些尚未由 torch.export 支持的 Python 操作。其中一些失败可能有解决方案。BLIP 是一个例子,原始模型出错,可以通过在代码中做小的修改来解决。 torch.export 列出了 ExportDB 中支持的和不支持的常见操作,并展示了如何修改您的代码以使其导出兼容。

File "/BLIP/models/blip.py", line 112, in forward
    text.input_ids[:,0] = self.tokenizer.bos_token_id
  File "/anaconda3/envs/export/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py", line 545, in __torch_dispatch__
    outs_unwrapped = func._op_dk(
RuntimeError: cannot mutate tensors with frozen storage

解决方案

在导出失败的地方克隆张量。

text.input_ids = text.input_ids.clone() # clone the tensor
text.input_ids[:,0] = self.tokenizer.bos_token_id

备注

在 PyTorch 2.7 的 nightly 版本中,这个限制已经被放宽。这应该在 PyTorch 2.7 中直接工作。

可提示的图像分割

图像分割是一种计算机视觉技术,它根据像素的特征将数字图像分割成不同的像素组或区域。Segment Anything Model (SAM)引入了可提示图像分割,它根据指示所需对象的提示预测对象掩码。SAM 2 是第一个用于跨图像和视频分割对象的统一模型。SAM2ImagePredictor 类提供了一个简单的接口来提示模型。该模型可以接受点提示和框提示,以及来自前一次预测迭代的掩码作为输入。由于 SAM2 提供了强大的零样本性能,因此它可以用于场景中游戏对象的跟踪。

SAM2ImagePredictor 的 predict 方法中的张量操作发生在_predict 方法中。因此,我们尝试这样导出。

ep = torch.export.export(
    self._predict,
    args=(unnorm_coords, labels, unnorm_box, mask_input, multimask_output),
    kwargs={"return_logits": return_logits},
    strict=False,
)

错误:模型类型不是 torch.nn.Module

torch.export 期望模块类型为 torch.nn.Module 。然而,我们尝试导出的模块是一个类方法。因此出现错误。

Traceback (most recent call last):
  File "/sam2/image_predict.py", line 20, in <module>
    masks, scores, _ = predictor.predict(
  File "/sam2/sam2/sam2_image_predictor.py", line 312, in predict
    ep = torch.export.export(
  File "python3.10/site-packages/torch/export/__init__.py", line 359, in export
    raise ValueError(
ValueError: Expected `mod` to be an instance of `torch.nn.Module`, got <class 'method'>.

解决方案

我们编写了一个辅助类,该类继承自 torch.nn.Module 并在类的 forward 方法中调用 _predict method 。完整的代码可以在以下链接找到。

class ExportHelper(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(_, *args, **kwargs):
        return self._predict(*args, **kwargs)

 model_to_export = ExportHelper()
 ep = torch.export.export(
      model_to_export,
      args=(unnorm_coords, labels, unnorm_box, mask_input,  multimask_output),
      kwargs={"return_logits": return_logits},
      strict=False,
      )

结论 ¶

在本教程中,我们学习了如何通过正确的配置和简单的代码修改,使用 torch.export 来为流行的用例导出模型。一旦能够导出模型,您可以使用 AOTInductor 在服务器上和 ExecuTorch 在边缘设备上降低 ExportedProgram 。要了解更多关于 AOTInductor (AOTI)的信息,请参阅 AOTI 教程。要了解更多关于 ExecuTorch 的信息,请参阅 ExecuTorch 教程。


评分这个教程

© 版权所有 2024,PyTorch。

使用 Sphinx 构建,主题由 Read the Docs 提供。
//暂时添加调查链接

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源