由瓦西利斯·弗里尼奥蒂斯

TorchVision 推出了一个向后兼容的 API,用于构建支持多权重的模型。新的 API 允许在相同的模型变体上加载不同的预训练权重,跟踪重要的元数据,如分类标签,并包括使用模型所需的预处理转换。在本篇博客文章中,我们计划回顾原型 API,展示其功能,并突出与现有 API 的关键差异。

我们希望能在最终确定 API 之前得到您的意见。为了收集您的反馈,我们创建了一个 GitHub 问题,您可以在那里发表您的想法、问题和评论。

当前 API 的局限性

TorchVision 目前提供了预训练模型,可以作为迁移学习的起点,或直接用于计算机视觉应用。实例化预训练模型并进行预测的典型方式是:

import torch

from PIL import Image
from torchvision import models as M
from torchvision.transforms import transforms as T


img = Image.open("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Step 1: Initialize model
model = M.resnet50(pretrained=True)
model.eval()

# Step 2: Define and initialize the inference transforms
preprocess = T.Compose([
    T.Resize([256, ]),
    T.CenterCrop(224),
    T.PILToTensor(),
    T.ConvertImageDtype(torch.float),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
prediction = model(batch).squeeze(0).softmax(0)

# Step 4: Use the model and print the predicted category
class_id = prediction.argmax().item()
score = prediction[class_id].item()
with open("imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]
    category_name = categories[class_id]
print(f"{category_name}: {100 * score}%")

上述方法存在一些局限性:

  1. 无法支持多个预训练权重:由于 pretrained 变量是布尔值,我们只能提供一组权重。当我们显著提高现有模型的准确性并希望将这些改进提供给社区时,这会带来严重的限制。这也阻止了我们为同一模型变体在不同数据集上提供预训练权重。
  2. 缺少推理/预处理转换:用户被迫在使用模型之前定义必要的转换。推理转换通常与训练过程和用于估计权重的数据集相关。这些转换中的任何微小差异(如插值值、调整/裁剪大小等)都可能导致准确性大幅下降或模型无法使用。
  3. 缺乏元数据:与权重相关的关键信息对用户不可用。例如,需要查阅外部来源和文档才能找到类别标签、训练配方、准确性指标等信息。

新的 API 解决了上述限制,并减少了执行标准任务所需的样板代码量。

原型 API 概述

让我们看看如何使用新的 API 实现上述完全相同的结果:

from PIL import Image
from torchvision.prototype import models as PM


img = Image.open("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Step 1: Initialize model
weights = PM.ResNet50_Weights.IMAGENET1K_V1
model = PM.resnet50(weights=weights)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
prediction = model(batch).squeeze(0).softmax(0)

# Step 4: Use the model and print the predicted category
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score}*%*")

如我们所见,新的 API 消除了上述限制。让我们详细探索新功能。

多权重支持

新 API 的核心功能之一是能够为同一模型变体定义多个不同的权重。每个模型构建方法(例如 resnet50 )都有一个关联的枚举类(例如 ResNet50_Weights ),该枚举类有与可用的预训练权重数量一样多的条目。此外,每个枚举类都有一个 DEFAULT 别名,该别名指向特定模型的最佳可用权重。这使得希望始终使用最佳可用权重的用户能够这样做,而无需修改他们的代码。

下面是一个使用不同权重的模型初始化示例:

from torchvision.prototype.models import resnet50, ResNet50_Weights

# Legacy weights with accuracy 76.130%
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)

# New weights with accuracy 80.858%
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

# Best available weights (currently alias for IMAGENET1K_V2)
model = resnet50(weights=ResNet50_Weights.DEFAULT)

# No weights - random initialization
model = resnet50(weights=None)

关联的元数据和预处理转换

每个模型的权重都与元数据相关联。我们存储的信息类型取决于模型的任务(分类、检测、分割等)。典型信息包括训练食谱的链接、插值模式、有关类别和验证指标等信息。这些值可以通过 meta 属性以编程方式访问:

from torchvision.prototype.models import ResNet50_Weights

# Accessing a single record
size = ResNet50_Weights.IMAGENET1K_V2.meta["size"]

# Iterating the items of the meta-data dictionary
for k, v in ResNet50_Weights.IMAGENET1K_V2.meta.items():
    print(k, v)

此外,每个权重条目都与必要的预处理转换相关联。所有当前预处理转换都是 JIT-scriptable,可以通过 transforms 属性访问。在使用数据之前,需要初始化/构建这些转换。这种延迟初始化方案是为了确保解决方案的内存效率。转换的输入可以是使用 torchvision.io 读取的 PIL.ImageTensor

from torchvision.prototype.models import ResNet50_Weights

# Initializing preprocessing at standard 224x224 resolution
preprocess = ResNet50_Weights.IMAGENET1K_V2.transforms()

# Initializing preprocessing at 400x400 resolution
preprocess = ResNet50_Weights.IMAGENET1K_V2.transforms(crop_size=400, resize_size=400)

# Once initialized the callable can accept the image data:
# img_preprocessed = preprocess(img)

将权重与其元数据和预处理相关联将提高透明度,改善可重复性,并使记录一组权重是如何产生的变得更加容易。

通过名称获取权重

将权重与其属性(元数据、预处理函数等)直接关联的能力是我们实现使用枚举而不是字符串的原因。尽管如此,对于仅提供权重名称的情况,我们提供了一种将权重名称与其枚举关联的方法:

from torchvision.prototype.models import get_weight

# Weights can be retrieved by name:
assert get_weight("ResNet50_Weights.IMAGENET1K_V1") == ResNet50_Weights.IMAGENET1K_V1
assert get_weight("ResNet50_Weights.IMAGENET1K_V2") == ResNet50_Weights.IMAGENET1K_V2

# Including using the DEFAULT alias:
assert get_weight("ResNet50_Weights.DEFAULT") == ResNet50_Weights.IMAGENET1K_V2

弃用

在新的 API 中,之前用于将权重加载到完整模型或其骨干网络的布尔参数 pretrainedpretrained_backbone 已被弃用。当前实现完全向后兼容,因为它无缝地将旧参数映射到新参数。使用旧参数调用新构建器会发出以下弃用警告:

>>> model = torchvision.prototype.models.resnet50(pretrained=True)
 UserWarning: The parameter 'pretrained' is deprecated, please use 'weights' instead.
UserWarning:
Arguments other than a weight enum or `None` for 'weights' are deprecated.
The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`.
You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.

此外,构建器方法需要使用关键字参数。使用位置参数已被弃用,使用它们会发出以下警告:

>>> model = torchvision.prototype.models.resnet50(None)
UserWarning:
Using 'weights' as positional parameter(s) is deprecated.
Please use keyword parameter(s) instead.

测试新的 API

迁移到新的 API 非常简单。以下两个 API 之间的所有方法调用都是等效的:

# Using pretrained weights:
torchvision.prototype.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
torchvision.models.resnet50(pretrained=True)
torchvision.models.resnet50(True)

# Using no weights:
torchvision.prototype.models.resnet50(weights=None)
torchvision.models.resnet50(pretrained=False)
torchvision.models.resnet50(False)

注意,原型功能仅在 TorchVision 的夜间版本中可用,因此要使用它,您需要按照以下方式安装:

conda install torchvision -c pytorch-nightly

对于安装夜间的替代方法,请查看 PyTorch 下载页面。您还可以从最新的主分支中从源代码安装 TorchVision;有关更多信息,请查看我们的仓库。

使用新 API 访问最先进的模型权重

如果您仍然对尝试新 API 持怀疑态度,这里还有一个理由。我们最近更新了训练配方,并从许多模型中实现了 SOTA 精度。改进的权重可以通过新 API 轻松访问。以下是模型改进的快速概述:

Model 旧 Acc@1 新 Acc@1
EfficientNet B1 78.642 79.838
MobileNetV3 大型 74.042 75.274
量化 ResNet50 75.92 80.282
量化 ResNeXt101 32x8d 78.986 82.574
RegNet X 400mf 72.834 74.864
RegNet X 800mf 75.212 77.522
RegNet X 1.6gf 77.04 79.668
RegNet X 3.2gf 78.364 81.198
RegNet X 8gf 79.344 81.682
RegNet X 16gf 80.058 82.72
RegNet X 32gf 80.622 83.018
RegNet Y 400mf 74.046 75.806
RegNet Y 800mf 76.42 78.838
RegNet Y 1.6gf 77.95 80.882
RegNet Y 3.2gf 78.948 81.984
RegNet Y 8gf 80.032 82.828
RegNet Y 16gf 80.424 82.89
RegNet Y 32gf 80.878 83.366
ResNet50 76.13 80.858
ResNet101 77.374 81.886
ResNet152 78.312 82.284
ResNeXt50 32x4d 77.618 81.198
ResNeXt101 32x8d 79.312 82.834
Wide ResNet50 2 78.468 81.602
Wide ResNet101 2 78.848 82.51

请抽出几分钟时间对新 API 提供您的反馈,这对将其从原型阶段毕业并包含在下一个版本中至关重要。您可以在专门的 Github Issue 上进行操作。我们期待阅读您的评论!