torch.hub ¬
PyTorch Hub 是一个预训练模型库,旨在促进研究的可重复性。
发布模型
PyTorch Hub 支持通过添加简单的 hubconf.py
文件将预训练模型(模型定义和预训练权重)发布到 GitHub 仓库;
hubconf.py
可以有多个入口点。每个入口点被定义为 Python 函数(例如:您想要发布的预训练模型)。
def entrypoint_name(*args, **kwargs):
# args & kwargs are optional, for models which take positional/keyword arguments.
...
如何实现入口点?¶
下面是一个代码片段,指定了在 pytorch/vision/hubconf.py
中扩展实现时的 resnet18
模型的入口点。在大多数情况下,在 hubconf.py
中导入正确的函数就足够了。这里我们只是想使用扩展版本作为示例来展示其工作原理。您可以在 pytorch/vision 仓库中查看完整的脚本。
dependencies = ['torch']
from torchvision.models.resnet import resnet18 as _resnet18
# resnet18 is the name of entrypoint
def resnet18(pretrained=False, **kwargs):
""" # This docstring shows up in hub.help()
Resnet18 model
pretrained (bool): kwargs, load pretrained weights into the model
"""
# Call the model, load pretrained weights
model = _resnet18(pretrained=pretrained, **kwargs)
return model
dependencies
变量是加载模型所需的包名列表。注意这可能与训练模型所需的依赖项略有不同。args
和kwargs
将传递给实际的调用函数。函数的文档字符串充当帮助信息。它解释了模型做什么以及允许的位置/关键字参数。强烈建议在此处添加一些示例。
入口函数可以返回模型(nn.module),或者辅助工具以使用户工作流程更顺畅,例如分词器。
以下划线开头的方法被视为辅助函数,不会显示在
torch.hub.list()
中。预训练权重可以存储在 GitHub 仓库中,或者通过
torch.hub.load_state_dict_from_url()
加载。如果小于 2GB,建议将其附加到项目版本中并使用版本中的 URL。在上面的示例中,torchvision.models.resnet.resnet18
处理pretrained
,或者您也可以在入口点定义中放置以下逻辑。
if pretrained:
# For checkpoint saved in local GitHub repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth
dirname = os.path.dirname(__file__)
checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
state_dict = torch.load(checkpoint)
model.load_state_dict(state_dict)
# For checkpoint saved elsewhere
checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))
重要通知 ¶
发布的模型至少应该在分支/标签中。不能是随机提交。
从 Hub 加载模型
Pytorch Hub 提供了方便的 API 来通过 torch.hub.list()
探索 hub 中所有可用的模型,通过 torch.hub.help()
显示文档字符串和示例,并通过 torch.hub.load()
加载预训练模型。
- torch.hub.list(github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True)[source][source]¶
列出由
github
指定的 repo 中所有可调用的入口点。- 参数:
github (str) – 格式为“repo_owner/repo_name[:ref]”的字符串,其中 ref 为可选的标签或分支。如果未指定
ref
,则默认分支为main
(如果存在),否则为master
。示例:‘pytorch/vision:0.10’force_reload (bool, optional) – 是否丢弃现有缓存并强制重新下载。默认为
False
。skip_validation (bool, optional) – 如果
False
,torchhub 将检查由github
参数指定的分支或提交是否正确属于 repo 所有者。这将向 GitHub API 发起请求;您可以通过设置GITHUB_TOKEN
环境变量来指定非默认的 GitHub 令牌。默认为False
。trust_repo (bool, str or None) –
"check"
,True
,False
或None
。此参数自 v1.12 版本引入,有助于确保用户仅运行他们信任的仓库中的代码。如果
False
,则会提示用户是否信任该仓库。如果
True
,则该仓库将被添加到信任列表中,并无需明确确认即可加载。如果
"check"
,则该仓库将与缓存中信任仓库列表进行核对。如果不在该列表中,行为将回退到trust_repo=False
选项。如果
None
:这将引发警告,提示用户将trust_repo
设置为False
、True
或"check"
。这仅用于向后兼容,将在 v2.0 版本中删除。
默认为
None
,最终将在 v2.0 版本中更改为"check"
。verbose(布尔值,可选)- 如果
False
,则静音关于命中本地缓存的提示。请注意,关于首次下载的提示无法静音。默认为True
。
- 返回:
可用的可调用入口点
- 返回类型:
示例
>>> entrypoints = torch.hub.list("pytorch/vision", force_reload=True)
- torch.hub.help(github, model, force_reload=False, skip_validation=False, trust_repo=None)[source][source]¶
显示入口点的文档字符串
model
。- 参数:
github (str) – 格式为 的字符串,其中 ref 为可选的引用(标签或分支)。如果未指定
ref
,则默认分支为main
(如果存在),否则为master
。示例:‘pytorch/vision:0.10’model (str) – 在 repo 的
hubconf.py
中定义的入口点名称的字符串force_reload (bool, 可选) – 是否丢弃现有缓存并强制重新下载。默认为
False
。skip_validation (bool, 可选) – 如果
False
,torchhub 将检查通过github
参数指定的 ref 是否正确属于仓库所有者。这将向 GitHub API 发起请求;您可以通过设置GITHUB_TOKEN
环境变量来指定非默认的 GitHub 令牌。默认为False
。trust_repo (bool, str 或 None) –
"check"
,True
,False
或None
。此参数自 v1.12 版本引入,有助于确保用户仅运行来自他们信任的仓库的代码。如果
False
,将提示用户是否信任该仓库。如果
True
,将把该仓库添加到信任列表中,并加载而不需要明确确认。如果
"check"
,将检查该仓库是否在缓存中信任的仓库列表中。如果不在该列表中,行为将回退到trust_repo=False
选项。如果
None
:这将引发警告,邀请用户将trust_repo
设置为False
、True
或"check"
。这仅用于向后兼容,将在 v2.0 版本中删除。
默认为
None
,最终将在 v2.0 版本中更改为"check"
。
示例
>>> print(torch.hub.help("pytorch/vision", "resnet18", force_reload=True))
- torch.hub.load(repo_or_dir, model, *args, source='github', trust_repo=None, force_reload=False, verbose=True, skip_validation=False, **kwargs)[source][source]¶
从 github 仓库或本地目录加载模型。
注意:加载模型是典型用例,但此方法也可用于加载其他对象,如分词器、损失函数等。
如果
source
是 ‘github’,则repo_or_dir
应该是形如repo_owner/repo_name[:ref]
的格式,可带有可选的引用(标签或分支)。如果
source
是 ‘local’,则repo_or_dir
应该是本地目录的路径。- 参数:
repo_or_dir (str) – 如果
source
是 ‘github’,则应对应于格式为repo_owner/repo_name[:ref]
的 github 仓库,可带有可选的引用(标签或分支),例如 ‘pytorch/vision:0.10’。如果未指定ref
,则默认分支为main
(如果存在),否则为master
。如果source
是 ‘local’,则应为本地目录的路径。model (str) – 仓库/目录中定义的可调用(入口点)的名称。
*args(可选)- 对应于可调用对象
model
的参数。source(str,可选)- 'github'或'local'。指定如何解释
repo_or_dir
。默认为'github'。trust_repo(bool,str 或 None)
"check"
,True
,False
或None
。此参数自 v1.12 版本引入,有助于确保用户仅运行他们信任的仓库中的代码。如果
False
,将提示用户是否信任该仓库。如果
True
,将把该仓库添加到信任列表中,并加载而不需要明确确认。如果
"check"
,将检查该仓库是否在缓存中信任的仓库列表中。如果不在该列表中,行为将回退到trust_repo=False
选项。如果
None
:这将引发警告,邀请用户将trust_repo
设置为False
、True
或"check"
。这仅用于向后兼容,将在 v2.0 版本中删除。
默认为
None
,最终将在 v2.0 版本中更改为"check"
。force_reload(布尔值,可选)- 是否无条件强制重新下载 github 仓库。如果
source = 'local'
,则没有影响。默认为False
。verbose(布尔值,可选)- 如果
False
,则静音关于命中本地缓存的提示。请注意,关于首次下载的提示无法静音。如果source = 'local'
,则没有影响。默认为True
。skip_validation(布尔值,可选)- 如果
False
,torchhub 将检查通过github
参数指定的分支或提交是否正确属于仓库所有者。这将向 GitHub API 发出请求;您可以通过设置GITHUB_TOKEN
环境变量来指定非默认的 GitHub 令牌。默认为False
。**kwargs(可选)- 对应于可调用对象
model
的 kwargs。
- 返回:
当使用给定的
*args
和**kwargs
调用model
可调用对象时,其输出。
示例
>>> # from a github repo >>> repo = "pytorch/vision" >>> model = torch.hub.load( ... repo, "resnet50", weights="ResNet50_Weights.IMAGENET1K_V1" ... ) >>> # from a local directory >>> path = "/some/local/path/pytorch/vision" >>> model = torch.hub.load(path, "resnet50", weights="ResNet50_Weights.DEFAULT")
- torch.hub.download_url_to_file(url, dst, hash_prefix=None, progress=True)[source][source]¶
将给定 URL 的对象下载到本地路径。
- 参数:
对象下载的 URL(字符串)- 要下载的对象的 URL
dst(字符串)- 对象保存的完整路径,例如
/tmp/temporary_file
hash_prefix(字符串,可选)- 如果不为 None,则下载的 SHA256 文件应以
hash_prefix
开头。默认:Noneprogress(布尔值,可选)- 是否在 stderr 上显示进度条。默认:True
示例
>>> torch.hub.download_url_to_file( ... "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth", ... "/tmp/temporary_file", ... )
- torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None, weights_only=False)[source][source]¶
从给定的 URL 加载 Torch 序列化对象。
如果下载的文件是 zip 文件,它将被自动解压缩。
如果对象已存在于 model_dir 中,它将被反序列化并返回。默认值
model_dir
是<hub_dir>/checkpoints
,其中hub_dir
是get_dir()
返回的目录。- 参数:
对象下载的 URL(字符串)
保存对象的目录(字符串,可选)
map_location(可选)- 一个函数或字典,指定如何重映射存储位置(参见 torch.load)
progress(布尔值,可选)- 是否在 stderr 上显示进度条。默认:True
check_hash (bool, 可选) – 如果为 True,URL 的文件名部分应遵循命名约定
filename-<sha256>.ext
,其中<sha256>
是文件内容的 SHA256 哈希的前八个或更多数字。该哈希用于确保名称唯一并验证文件内容。默认:Falsefile_name (str, 可选) – 下载文件的名称。如果没有设置,将使用
url
的文件名。weights_only (bool, 可选) – 如果为 True,则只加载权重,不加载复杂的 pickled 对象。建议用于不可信的来源。见
load()
了解更多详情。
- 返回类型:
dict[str, Any]
示例
>>> state_dict = torch.hub.load_state_dict_from_url( ... "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth" ... )
运行加载的模型:
注意在 torch.hub.load()
中使用 *args
和 **kwargs
来实例化一个模型。加载模型后,您如何了解模型可以做什么?建议的工作流程是
查看模型的所有可用方法。
检查运行时
model.foo
需要哪些参数
为了帮助用户在不频繁查阅文档的情况下探索,我们强烈建议仓库所有者使函数帮助信息清晰简洁。包含一个最小的工作示例也很有帮助。
我的下载模型保存在哪里?
位置的使用顺序为
调用
hub.set_dir(<PATH_TO_HUB_DIR>)
如果设置了环境变量
TORCH_HOME
如果设置了环境变量
XDG_CACHE_HOME
~/.cache/torch/hub
- torch.hub.get_dir()[source][source]
获取用于存储下载的模型和权重的 Torch Hub 缓存目录。
如果未调用
set_dir()
,则默认路径为$TORCH_HOME/hub
,其中环境变量$TORCH_HOME
的默认值是$XDG_CACHE_HOME/torch
。$XDG_CACHE_HOME
遵循 Linux 文件系统布局的 X 设计组规范,如果未设置环境变量,则默认值为~/.cache
。- 返回类型:
- torch.hub.set_dir(d)[source][source]
可选地设置用于保存下载的模型和权重的 Torch Hub 目录。
- 参数:
d (str) – 保存下载的模型和权重的本地文件夹路径。
缓存逻辑 ¶
默认情况下,我们不会在加载后清理文件。如果默认目录中已存在,Hub 会使用缓存。 get_dir()
.
用户可以通过调用 hub.load(..., force_reload=True)
强制重新加载。这将删除现有的 GitHub 文件夹和下载的权重,重新初始化新的下载。当同一分支发布更新时,用户可以保持与最新版本的同步。
已知限制: ¶
火炬库通过将包导入就像它已安装一样来工作。在 Python 中导入会引入一些副作用。例如,您可以在 Python 缓存中看到新的条目 sys.modules
和 sys.path_importer_cache
,这是正常的 Python 行为。这也意味着,如果您从不同的仓库导入不同的模型,可能会出现导入错误,如果这些仓库有相同的子包名称(通常是 model
子包)。对于这类导入错误的一个解决方案是,从 sys.modules
字典中删除有问题的子包;更多详细信息可以在 GitHub 问题中找到。
这里值得提一下的一个已知限制:用户不能在同一个 Python 进程中加载同一个仓库的两个不同分支。这就像在 Python 中安装两个同名包一样,这并不好。如果真的尝试这样做,缓存可能会加入进来给你带来惊喜。当然,在单独的进程中加载它们是完全没问题的。