由 PyTorch 团队

可复现性是许多研究领域,包括基于机器学习技术的领域的基本要求。然而,许多机器学习出版物要么不可复现,要么难以复现。随着研究出版物数量的持续增长,包括现在 arXiv 上托管的上万篇论文以及会议提交数量达到历史最高水平,研究可复现性比以往任何时候都更重要。虽然许多这些出版物都附带了代码以及训练好的模型,这很有帮助,但仍需要用户自己解决许多步骤。

我们很高兴地宣布 PyTorch Hub 的可用性,这是一个简单的 API 和工作流程,为提高机器学习研究可复现性提供了基本构建块。PyTorch Hub 包含一个专门为促进研究可复现性和启用新研究而设计的预训练模型存储库。它还内置了对 Colab 的支持,与 Papers With Code 集成,并包含广泛的模型,包括分类和分割、生成、Transformer 等。

[所有者] 发布模型

PyTorch Hub 支持通过添加简单的 hubconf.py 文件将预训练模型(模型定义和预训练权重)发布到 GitHub 仓库。这提供了要支持哪些模型以及运行模型所需的依赖项的枚举。示例可以在 torchvision、huggingface-bert 和 gan-model-zoo 仓库中找到。

让我们看看最简单的情况: torchvisionhubconf.py :

# Optional list of dependencies required by the package
dependencies = ['torch']

from torchvision.models.alexnet import alexnet
from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161
from torchvision.models.inception import inception_v3
from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152,\
resnext50_32x4d, resnext101_32x8d
from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1
from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
from torchvision.models.segmentation import fcn_resnet101, deeplabv3_resnet101
from torchvision.models.googlenet import googlenet
from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
from torchvision.models.mobilenet import mobilenet_v2

torchvision 中,模型具有以下属性:

  • 每个模型文件都可以独立运行和执行
  • 他们不需要除了 PyTorch(编码为 hubconf.py 作为 dependencies['torch'] )之外的任何包
  • 因为模型创建时即可无缝工作,所以他们不需要单独的入口点

最小化包依赖性可以减少用户加载您的模型进行即时实验的摩擦

一个更复杂的例子是 HuggingFace 的 BERT 模型。以下是它们的 hubconf.py

dependencies = ['torch', 'tqdm', 'boto3', 'requests', 'regex']

from hubconfs.bert_hubconf import (
    bertTokenizer,
    bertModel,
    bertForNextSentencePrediction,
    bertForPreTraining,
    bertForMaskedLM,
    bertForSequenceClassification,
    bertForMultipleChoice,
    bertForQuestionAnswering,
    bertForTokenClassification
)

每个模型都需要创建一个入口点。以下是一个指定 bertForMaskedLM 模型入口点的代码片段,该入口点返回预训练模型权重。

def bertForMaskedLM(*args, **kwargs):
    """
    BertForMaskedLM includes the BertModel Transformer followed by the
    pre-trained masked language modeling head.
    Example:
      ...
    """
    model = BertForMaskedLM.from_pretrained(*args, **kwargs)
    return model

这些入口点可以作为复杂模型工厂的包装器。它们可以提供干净一致的帮助文档字符串,具有支持下载预训练权重(例如通过 pretrained=True )的逻辑,或者具有可视化等特定于 hub 的功能。

在设置 hubconf.py 之后,您可以根据此处提供的模板提交一个 pull request。我们的目标是整理高质量、易于复制的、最大限度地有益的研究模型。因此,我们可能会与您合作完善您的 pull request,在某些情况下拒绝一些低质量的模型发布。一旦我们接受您的 pull request,您的模型将很快出现在 Pytorch hub 网页上,供所有用户探索。

[用户] 工作流程

作为用户,PyTorch Hub 允许您遵循几个简单的步骤来完成一些事情,例如:1)探索可用的模型;2)加载模型;以及 3)了解任何给定模型可用的方法。让我们通过一些示例来了解每个步骤。

探索可用的入口点。

用户可以使用 torch.hub.list() API 列出存储库中所有可用的入口点。

>>> torch.hub.list('pytorch/vision')
>>>
['alexnet',
'deeplabv3_resnet101',
'densenet121',
...
'vgg16',
'vgg16_bn',
'vgg19',
 'vgg19_bn']

注意,PyTorch Hub 还允许辅助入口点(除了预训练模型之外),例如 bertTokenizer 用于 BERT 模型的预处理,以使用户的工作流程更加顺畅。

加载模型

现在我们知道了 Hub 中可用的模型,用户可以使用 torch.hub.load() API 加载模型入口点。这只需要一条命令,无需安装 wheel。此外, torch.hub.help() API 还可以提供有关如何实例化模型的有用信息。

print(torch.hub.help('pytorch/vision', 'deeplabv3_resnet101'))
model = torch.hub.load('pytorch/vision', 'deeplabv3_resnet101', pretrained=True)

同样,仓库所有者通常会希望持续添加错误修复或性能改进。PyTorch Hub 让用户通过调用以下命令轻松获取最新更新:

model = torch.hub.load(..., force_reload=True)

我们相信这将有助于减轻仓库所有者重复发布包的负担,并使他们能够更多地专注于研究。这也确保了作为用户,您将获得最新的可用模型。

相反,稳定性对用户来说很重要。因此,一些模型所有者从特定的分支或标签为用户提供服务,而不是 master 分支,以确保代码的稳定性。例如, pytorch_GAN_zoohub 分支为用户提供服务:

model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN', pretrained=True, useGPU=False)

注意,传递给 hub.load()*args**kwargs 用于实例化一个模型。在上面的例子中, pretrained=TrueuseGPU=False 被赋予模型的入口点。

探索已加载的模型

一旦您从 PyTorch Hub 加载了模型,您可以使用以下工作流程来查找支持的方法以及更好地了解运行它所需的参数。

查看模型所有可用方法的命令。让我们看看 bertForMaskedLM 的可用方法。

>>> dir(model)
>>>
['forward'
...
'to'
'state_dict',
]

它提供了查看加载的模型运行所需的参数的视图。

>>> help(model.forward)
>>>
Help on method forward in module pytorch_pretrained_bert.modeling:
forward(input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None)
...

仔细查看 BERT 和 DeepLabV3 页面,您可以在那里看到加载后如何使用这些模型。

其他探索方式

PyTorch Hub 中可用的模型也支持 Colab,并且可以直接在 Papers With Code 上链接,您只需一键即可开始使用。以下是一个良好的入门示例(如下所示)。

其他资源:

  • PyTorch Hub API 文档可在此处找到。
  • 请在此处提交模型以在 PyTorch Hub 上发布。
  • 前往 https://maskerprc.github.io/hub 了解更多可用的模型。
  • 在 paperswithcode.com 寻找即将推出的更多模型。

非常感谢 HuggingFace、PapersWithCode 团队、fast.ai 和 Nvidia 以及 Morgane Riviere(FAIR Paris)以及许多其他人为帮助启动这项工作做出的贡献!!

喝彩!

PyTorch 团队

常见问题解答:

Q: 如果我们想贡献一个已经在 Hub 中但可能我的模型精度更高的模型,我应该仍然贡献吗?

A: 当然!!Hub 的下一步是实施点赞/踩系统,以展示最好的模型。

Q: 谁托管 PyTorch Hub 中的模型权重?

A: 作为贡献者,您负责托管模型权重。您可以在您喜欢的云存储中托管您的模型,或者如果它符合限制,可以在 GitHub 上托管。如果您无力托管权重,请通过在 Hub 仓库中打开一个问题与我们联系。

问:如果我的模型是在私有数据上训练的,我是否还应贡献这个模型?

答:不!PyTorch Hub 以开放研究为中心,这包括使用开放数据集来训练这些模型。如果提交的拉取请求是针对专有模型的,我们将礼貌地要求您重新提交一个在公开和可用内容上训练的模型。

问:我下载的模型保存在哪里?

答:我们遵循 XDG 基础目录规范,并遵守关于缓存文件和目录的通用标准。

位置的使用顺序为:

  • 调用 hub.set_dir(<PATH_TO_HUB_DIR>)
  • 如果设置了环境变量 TORCH_HOME
  • 如果设置了环境变量 XDG_CACHE_HOME
  • ~/.cache/torch/hub