• 文档 >
  • 序列化语义
快捷键

序列化语义 ¶

本笔记描述了如何在 Python 中保存和加载 PyTorch 张量和模块状态,以及如何序列化 Python 模块以便在 C++中加载。

保存和加载张量

torch.save()torch.load() 让您轻松保存和加载张量:

>>> t = torch.tensor([1., 2.])
>>> torch.save(t, 'tensor.pt')
>>> torch.load('tensor.pt')
tensor([1., 2.])

按照惯例,PyTorch 文件通常使用‘.pt’或‘.pth’扩展名。

torch.save()torch.load() 默认使用 Python 的 pickle,因此您也可以将多个张量作为 Python 对象(如元组、列表和字典)的一部分保存:

>>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
>>> torch.save(d, 'tensor_dict.pt')
>>> torch.load('tensor_dict.pt')
{'a': tensor([1., 2.]), 'b': tensor([3., 4.])}

如果数据结构是 pickle-able 的,包含 PyTorch 张量的自定义数据结构也可以保存。

保存和加载张量会保留视图

保存张量会保留它们的视图关系:

>>> numbers = torch.arange(1, 10)
>>> evens = numbers[1::2]
>>> torch.save([numbers, evens], 'tensors.pt')
>>> loaded_numbers, loaded_evens = torch.load('tensors.pt')
>>> loaded_evens *= 2
>>> loaded_numbers
tensor([ 1,  4,  3,  8,  5, 12,  7, 16,  9])

在幕后,这些张量共享相同的“存储”。有关视图和存储的更多信息,请参阅张量视图。

当 PyTorch 保存张量时,它会分别保存它们的存储对象和张量元数据。这是一个可能在未来发生变化的实现细节,但它通常可以节省空间,并让 PyTorch 轻松重建加载的张量之间的视图关系。例如,在上面的代码片段中,只有一个存储被写入‘tensors.pt’文件。

然而,在某些情况下,保存当前的存储对象可能是多余的,并创建过大而不切实际的文件。在下面的代码片段中,写入文件的存储比保存的张量大得多:

>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small, 'small.pt')
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
999

将小张量中仅有的五个值保存到‘small.pt’中,而是将与其共享存储的大张量中的 999 个值保存和加载。

当保存元素数量少于其存储对象的张量时,可以通过首先克隆张量来减少保存的文件大小。克隆张量会产生一个新的张量,它包含一个新的存储对象,仅包含张量中的值:

>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small.clone(), 'small.pt')  # saves a clone of small
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
5

然而,由于克隆的张量彼此独立,它们没有原始张量所具有的任何视图关系。如果保存小于其存储对象大小的张量时文件大小和视图关系都很重要,那么在保存之前必须小心构建新的张量,以最大限度地减少其存储对象的大小,同时仍然具有所需的视图关系。

保存和加载 torch.nn.Modules

参见:教程:保存和加载模块

在 PyTorch 中,模块的状态通常使用“状态字典”进行序列化。模块的状态字典包含其所有参数和持久性缓冲区:

>>> bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> list(bn.named_parameters())
[('weight', Parameter containing: tensor([1., 1., 1.], requires_grad=True)),
 ('bias', Parameter containing: tensor([0., 0., 0.], requires_grad=True))]

>>> list(bn.named_buffers())
[('running_mean', tensor([0., 0., 0.])),
 ('running_var', tensor([1., 1., 1.])),
 ('num_batches_tracked', tensor(0))]

>>> bn.state_dict()
OrderedDict([('weight', tensor([1., 1., 1.])),
             ('bias', tensor([0., 0., 0.])),
             ('running_mean', tensor([0., 0., 0.])),
             ('running_var', tensor([1., 1., 1.])),
             ('num_batches_tracked', tensor(0))])

为了兼容性原因,建议不要直接保存模块,而是只保存其状态字典。Python 模块甚至有一个函数, load_state_dict() ,可以从状态字典中恢复其状态:

>>> torch.save(bn.state_dict(), 'bn.pt')
>>> bn_state_dict = torch.load('bn.pt')
>>> new_bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> new_bn.load_state_dict(bn_state_dict)
<All keys matched successfully>

注意,首先使用 torch.load() 从文件中加载状态字典,然后使用 load_state_dict() 恢复状态。

即使是自定义模块和包含其他模块的模块也有状态字典,并且可以使用此模式:

# A module with two linear layers
>>> class MyModule(torch.nn.Module):
      def __init__(self):
        super().__init__()
        self.l0 = torch.nn.Linear(4, 2)
        self.l1 = torch.nn.Linear(2, 1)

      def forward(self, input):
        out0 = self.l0(input)
        out0_relu = torch.nn.functional.relu(out0)
        return self.l1(out0_relu)

>>> m = MyModule()
>>> m.state_dict()
OrderedDict([('l0.weight', tensor([[ 0.1400, 0.4563, -0.0271, -0.4406],
                                   [-0.3289, 0.2827, 0.4588, 0.2031]])),
             ('l0.bias', tensor([ 0.0300, -0.1316])),
             ('l1.weight', tensor([[0.6533, 0.3413]])),
             ('l1.bias', tensor([-0.1112]))])

>>> torch.save(m.state_dict(), 'mymodule.pt')
>>> m_state_dict = torch.load('mymodule.pt')
>>> new_m = MyModule()
>>> new_m.load_state_dict(m_state_dict)
<All keys matched successfully>

序列化文件格式为 torch.save

自 PyTorch 1.6.0 版本起, torch.save 默认返回一个未压缩的 ZIP64 存档,除非用户设置 _use_new_zipfile_serialization=False

在此存档中,文件按如下顺序排列

checkpoint.pth
├── data.pkl
├── byteorder  # added in PyTorch 2.1.0
├── data/
│   ├── 0
│   ├── 1
│   ├── 2
│   └── …
└── version
如下所示条目:
  • data.pkl 是传递给 torch.save 的对象经过序列化后的结果,排除了它包含的 torch.Storage 对象

  • byteorder 在保存时包含一个带有 sys.byteorder 的字符串(“小”或“大”)

  • data/ 包含对象中的所有存储,每个存储都是一个单独的文件

  • 保存时包含一个版本号,该版本号可以在加载时使用

当保存时,PyTorch 将确保每个文件的本地文件头填充到 64 字节的倍数偏移量,确保每个文件的偏移量是 64 字节对齐的。

注意

某些设备上的张量(如 XLA)被序列化为 pickle 的 numpy 数组。因此,它们的存储不会序列化。在这些情况下, data/ 可能不存在于检查点中。

torch.loadweights_only=True

从版本 2.6 开始,如果没有传递 pickle_module 参数, torch.load 将使用 weights_only=True

torch.load() 文档所述, weights_only=Truetorch.load 中使用的反序列化器限制为仅执行用于 state_dicts 的普通 torch.Tensors 以及一些其他原始类型的函数/构建类,此外,与 pickle 模块提供的默认 Unpickler 不同, weights_only Unpickler 不允许在反序列化过程中动态导入任何内容。

如上所述,当使用 torch.save 时,保存模块的 state_dict 是一个最佳实践。如果加载包含 nn.Module 的老旧检查点,我们建议 weights_only=False 。当加载包含张量子类的检查点时,可能会出现需要允许列表的函数/类,下面将提供更多详细信息。

如果 weights_only Unpickler 遇到在 pickle 文件中默认不允许列表的函数或类,您应该看到一个可操作的错误,如下所示。

_pickle.UnpicklingError: Weights only load failed. This file can still be loaded,
to do so you have two options, do those steps only if you trust the source of the checkpoint.
    1. Re-running `torch.load` with `weights_only` set to `False` will likely succeed,
        but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
    2. Alternatively, to load with `weights_only=True` please check the recommended
       steps in the following error message.
       WeightsUnpickler error: Unsupported global: GLOBAL {__module__}.{__name__} was not an allowed global by
       default. Please use `torch.serialization.add_safe_globals([{__name__}])` or the
       `torch.serialization.safe_globals([{__name__}])` context manager to allowlist this global
       if you trust this class/function.

请按照错误信息中的步骤操作,并且只允许信任的函数或类。

要获取检查点中尚未允许列表的所有全局(函数/类),可以使用 torch.serialization.get_unsafe_globals_in_checkpoint() ,它将返回形式为 {__module__}.{__name__} 的字符串列表。如果您信任这些函数/类,可以按照错误信息导入它们并允许列表,通过 torch.serialization.add_safe_globals() 或上下文管理器 torch.serialization.safe_globals

要访问用户允许列表的函数/类列表,可以使用 torch.serialization.get_safe_globals() ,要清除当前列表请查看 torch.serialization.clear_safe_globals()

故障排除 weights_only

获取不安全的全局变量

注意, torch.serialization.get_unsafe_globals_in_checkpoint() 会静态分析检查点,一些类型可能在反序列化过程中动态构建,因此不会被 torch.serialization.get_unsafe_globals_in_checkpoint() 报告。例如, dtypes 在 numpy 中就是这样。在 numpy < 1.25 允许列出 torch.serialization.get_unsafe_globals_in_checkpoint() 报告的所有函数/类之后,你可能会看到如下错误

WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
but got <class 'numpy.dtype[float32]'>

这可以通过 {add_}safe_globals([type(np.dtype(np.float32))]) 允许列出

numpy >=1.25 你会看到

WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
but got <class 'numpy.dtypes.Float32DType'>

这可以通过 {add_}safe_globals([np.dtypes.Float32DType]) . 允许列表。

环境变量 ¶

有两个环境变量会影响 torch.load 的行为。如果没有访问 torch.load 调用点的权限,这些变量可能很有用。

  • TORCH_FORCE_WEIGHTS_ONLY_LOAD=1 将覆盖所有 torch.load 调用点以使用 weights_only=True

  • TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 将使 torch.load 调用站点仅在 weights_only=False 未作为参数传递时使用 weights_only

序列化 torch.nn.Modules 并在 C++ 中加载它们

参见:教程:在 C++ 中加载 TorchScript 模型

ScriptModules 可以序列化为 TorchScript 程序,并使用 torch.jit.load() 加载。这种序列化编码了所有模块的方法、子模块、参数和属性,并允许序列化的程序在 C++ 中(即不使用 Python)加载。

torch.jit.save()torch.save() 之间的区别可能并不立即明显。 torch.save() 用于保存 Python 对象,通过 pickle 实现。这在原型设计、研究和训练中特别有用。另一方面, torch.jit.save() 将 ScriptModules 序列化为可以在 Python 或 C++ 中加载的格式。当保存和加载 C++ 模块,或者使用 C++ 运行在 Python 中训练的模块时,这非常有用,这是部署 PyTorch 模型时的常见做法。

在 Python 中脚本化、序列化和加载模块:

>>> scripted_module = torch.jit.script(MyModule())
>>> torch.jit.save(scripted_module, 'mymodule.pt')
>>> torch.jit.load('mymodule.pt')
RecursiveScriptModule( original_name=MyModule
                      (l0): RecursiveScriptModule(original_name=Linear)
                      (l1): RecursiveScriptModule(original_name=Linear) )

跟踪的模块也可以使用 torch.jit.save() 保存,但有一个前提,即只有跟踪的代码路径被序列化。以下示例演示了这一点:

# A module with control flow
>>> class ControlFlowModule(torch.nn.Module):
      def __init__(self):
        super().__init__()
        self.l0 = torch.nn.Linear(4, 2)
        self.l1 = torch.nn.Linear(2, 1)

      def forward(self, input):
        if input.dim() > 1:
            return torch.tensor(0)

        out0 = self.l0(input)
        out0_relu = torch.nn.functional.relu(out0)
        return self.l1(out0_relu)

>>> traced_module = torch.jit.trace(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(traced_module, 'controlflowmodule_traced.pt')
>>> loaded = torch.jit.load('controlflowmodule_traced.pt')
>>> loaded(torch.randn(2, 4)))
tensor([[-0.1571], [-0.3793]], grad_fn=<AddBackward0>)

>>> scripted_module = torch.jit.script(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(scripted_module, 'controlflowmodule_scripted.pt')
>>> loaded = torch.jit.load('controlflowmodule_scripted.pt')
>> loaded(torch.randn(2, 4))
tensor(0)

如上模块中有一个不会被跟踪输入触发的 if 语句,因此它不属于跟踪模块,也没有与其一起序列化。然而,脚本化的模块包含这个 if 语句,并且与其一起序列化。有关脚本化和跟踪的更多信息,请参阅 TorchScript 文档。

最后,在 C++中加载模块:

>>> torch::jit::script::Module module;
>>> module = torch::jit::load('controlflowmodule_scripted.pt');

请参阅 PyTorch C++ API 文档,了解如何在 C++中使用 PyTorch 模块的详细信息。

在 PyTorch 版本之间保存和加载 ScriptModules

PyTorch 团队建议使用相同版本的 PyTorch 保存和加载模块。较旧版本的 PyTorch 可能不支持较新的模块,而较新版本可能已删除或修改了旧行为。这些更改已在 PyTorch 的发布说明中明确描述,依赖于已更改的功能的模块可能需要更新才能继续正常工作。在以下有限情况下,PyTorch 将保留序列化 ScriptModules 的历史行为,因此它们不需要更新。

torch.div 执行整数除法

在 PyTorch 1.5 及之前版本中,当给定两个整数输入时, torch.div() 会执行向下取整除法:

# PyTorch 1.5 (and earlier)
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1)

然而,在 PyTorch 1.7 版本中, torch.div() 将始终对其输入执行真正的除法,就像 Python 3 中的除法一样:

# PyTorch 1.7
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1.6667)

torch.div() 的行为在序列化的 ScriptModules 中得以保留。也就是说,使用 PyTorch 1.6 之前版本序列化的 ScriptModules,即使在较新版本的 PyTorch 中加载,当给定两个整数输入时, torch.div() 仍会执行向下取整除法。然而,在 PyTorch 1.6 及以后版本序列化的 ScriptModules 使用 torch.div() ,在早期版本的 PyTorch 中无法加载,因为这些早期版本不理解新的行为。

torch.full 总是推断为 float 类型

在 PyTorch 1.5 及之前版本中, torch.full() 总是返回一个 float 矩阵,无论给定的填充值是什么:

# PyTorch 1.5 and earlier
>>> torch.full((3,), 1)  # Note the integer fill value...
tensor([1., 1., 1.])     # ...but float tensor!

然而,在 PyTorch 1.7 中, torch.full() 将从填充值推断返回矩阵的数据类型:

# PyTorch 1.7
>>> torch.full((3,), 1)
tensor([1, 1, 1])

>>> torch.full((3,), True)
tensor([True, True, True])

>>> torch.full((3,), 1.)
tensor([1., 1., 1.])

>>> torch.full((3,), 1 + 1j)
tensor([1.+1.j, 1.+1.j, 1.+1.j])

torch.full() 的行为在序列化的 ScriptModules 中得到保留。也就是说,使用 PyTorch 1.6 之前版本序列化的 ScriptModules 仍然默认返回 float 矩阵,即使给定的填充值是 bool 或整数。然而,使用 torch.full() 并在 PyTorch 1.6 及以后版本序列化的 ScriptModules 在早期版本中无法加载,因为这些早期版本不理解新的行为。

实用函数 ¶

以下实用函数与序列化相关:

torch.serialization.register_package(priority, tagger, deserializer)[source][source]

注册用于标记和反序列化存储对象的调用函数,并关联优先级。标记在保存时将设备与存储对象关联,而在加载时将存储对象移动到适当的设备。 taggerdeserializer 将按照它们提供的 priority 的顺序执行,直到标记器/反序列化器返回一个非 None 的值。

要覆盖全局注册表中设备的反序列化行为,可以注册一个优先级高于现有标记器的标记器。

此功能还可以用于为新设备注册标记器和反序列化器。

参数:
  • 优先级(int)- 表示标记器和反序列化器关联的优先级,其中值越小表示优先级越高。

  • 标记器(Callable[[Union[Storage, TypedStorage, UntypedStorage]], Optional[str]]) – 接收存储对象并返回其标记设备字符串或 None 的可调用对象。

  • 解析器(Callable[[Union[Storage, TypedStorage, UntypedStorage], str], Optional[Union[Storage, TypedStorage, UntypedStorage]]]) – 接收存储对象和设备字符串的函数,返回适当设备上的存储对象或 None。

返回值:

示例

>>> def ipu_tag(obj):
>>>     if obj.device.type == 'ipu':
>>>         return 'ipu'
>>> def ipu_deserialize(obj, location):
>>>     if location.startswith('ipu'):
>>>         ipu = getattr(torch, "ipu", None)
>>>         assert ipu is not None, "IPU device module is not loaded"
>>>         assert torch.ipu.is_available(), "ipu is not available"
>>>         return obj.ipu(location)
>>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
torch.serialization.get_crc32_options()[source][source]

获取是否为每个记录计算并写入 crc32。

默认为 True

返回类型:

布尔型

torch.serialization.set_crc32_options(compute_crc32)[source][source]

设置是否计算并写入每条记录的 crc32。

注意

将此设置为 False 可能会导致解压缩 torch.save 输出失败或警告,因为 CRC32 已损坏。然而 torch.load 仍然能够加载文件。

参数:

compute_crc32 (bool) – 设置 crc32 计算标志

torch.serialization.get_default_load_endianness()[source][source]

获取加载文件的回退字节序

如果保存的检查点中没有字节序标记,则使用此字节序作为回退。默认为“本地”字节序。

返回值:

Optional[LoadEndianness]

返回类型:

默认加载字节序

torch.serialization.set_default_load_endianness(endianness)[source][source]

设置加载文件时的默认字节顺序

如果保存的检查点中没有字节顺序标记,则使用此字节顺序作为后备。默认情况下,它是“本地”字节顺序。

参数:

端序 - 新的回退字节序

torch.serialization.get_default_mmap_options()[source][source]

获取 torch.load() 的默认 mmap 选项 mmap=True

默认为 mmap.MAP_PRIVATE

返回值:

整型

返回类型:

默认的 mmap 选项

torch.serialization.set_default_mmap_options(flags)[来源][来源] ¶

上下文管理器或函数,用于设置 torch.load() 的默认 mmap 选项为 mmap=True 到 flags。

目前仅支持 mmap.MAP_PRIVATEmmap.MAP_SHARED 。如需添加其他选项,请在此处提交问题。

注意

此功能目前不支持 Windows 系统。

参数:

标志(整数)- mmap.MAP_PRIVATEmmap.MAP_SHARED

torch.serialization.add_safe_globals(safe_globals)[source][source]

标记给定的全局变量为 weights_only 安全加载。例如,添加到此列表中的函数可以在反序列化期间调用,类可以实例化并设置状态。

列表中的每个项目可以是函数/类,也可以是形式为(函数/类,字符串)的元组,其中字符串是函数/类的完整路径。

在序列化格式中,每个函数都通过其完整路径 {__module__}.{__qualname__} 进行标识。当调用此 API 时,您可以提供此完整路径,它应与检查点中的路径匹配,否则将使用默认的 {fn.__module__}.{fn.__qualname__}

参数:

safe_globals (List[Union[Callable, Tuple[Callable, str]]]) – 标记为安全的全局变量列表

示例

>>> import tempfile
>>> class MyTensor(torch.Tensor):
...     pass
>>> t = MyTensor(torch.randn(2, 3))
>>> with tempfile.NamedTemporaryFile() as f:
...     torch.save(t, f.name)
# Running `torch.load(f.name, weights_only=True)` will fail with
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
...     torch.serialization.add_safe_globals([MyTensor])
...     torch.load(f.name, weights_only=True)
# MyTensor([[-0.5024, -1.8152, -0.5455],
#          [-0.8234,  2.0500, -0.3657]])
torch.serialization.clear_safe_globals()[source][source]

清除安全全局变量列表。

torch.serialization.get_safe_globals()[source][source]

返回用户添加的安全全局变量列表。

返回类型:

list[Union[Callable, tuple[Callable, str]]]

torch.serialization.get_unsafe_globals_in_checkpoint(f)[source][source]

返回一个字符串列表,表示在 torch.save 对象中不安全的函数/类的列表 weights_only

对于给定的函数或类 f ,相应的字符串形式为 {f.__module__}.{f.__name__}

此函数将返回检查点中任何标记为安全的 weights_only (通过 add_safe_globals()safe_globals 上下文或默认情况下通过 torch 允许列表)之外的 GLOBALs。

注意

此函数将静态反汇编检查点中的 pickle 文件。这意味着在反序列化过程中动态推送到栈上的任何类将不会包含在输出中。

参数:

f (Union[str, PathLike[str], IO[bytes]]) – 文件对象或包含通过 torch.save 保存的检查点对象的字符串。

返回值:

检查点中 pickle GLOBALs 的字符串列表,这些 GLOBALs 未允许列表 weights_only

返回类型:

list[str]

class torch.serialization.safe_globals(safe_globals)[source][source]

上下文管理器,将某些全局变量添加为 weights_only 加载时的安全变量。

参数:

safe_globals (列表[Union[Callable, tuple[Callable, str]]]) – 仅加载权重时的全局变量列表。

示例

>>> import tempfile
>>> class MyTensor(torch.Tensor):
...     pass
>>> t = MyTensor(torch.randn(2, 3))
>>> with tempfile.NamedTemporaryFile() as f:
...     torch.save(t, f.name)
# Running `torch.load(f.name, weights_only=True)` will fail with
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
...     with torch.serialization.safe_globals([MyTensor]):
...         torch.load(f.name, weights_only=True)
# MyTensor([[-0.5024, -1.8152, -0.5455],
#          [-0.8234,  2.0500, -0.3657]])
>>> assert torch.serialization.get_safe_globals() == []
class torch.serialization.skip_data(materialize_fake_tensors=False)[source][source]

跳过为 torch.save / torch.load 调用写入/读取存储字节的上下文管理器。

对于保存路径,存储仍然会被保存,但它们通常写入的空间将是空白空间。存储字节可以在单独的步骤中填充。

对于加载路径,张量将按照检查点进行加载,但它们的存储不会填充数据。

警告

skip_data 上下文管理器是一个早期原型,可能会发生变化。

参数:

materialize_fake_tensors (bool) – 是否在保存时实例化 FakeTensors。在加载路径中不执行任何操作。

示例

>>> import tempfile
>>> t = torch.randn(2, 3)
>>> with tempfile.NamedTemporaryFile() as f:
...     with torch.serialization.skip_data():
...         torch.save(t, f.name)
...     torch.load(f.name, weights_only=True)
tensor([[0., 0., 0.],
        [0., 0., 0.]])

配置 ¶

torch.utils.serialization.config 提供一个全局配置,可以控制 torch.savetorch.load 的行为。

torch.utils.serialization.config.save 包含控制 torch.save 行为的选项。

  • 是否计算并写入 zip 文件的校验和(默认: True )。见 set_crc32_options()

  • use_pinned_memory_for_d2h : 当传递给 torch.save 时,对于位于加速器上的存储,是否将存储移动到 torch.save 内的固定内存或可分页内存(默认: False (即可分页))。

  • storage_alignment : 检查点期间存储的对齐方式,以字节为单位。(默认 64

torch.utils.serialization.config.load 包含控制 torch.load 行为的选项。

  • mmap : 请参阅 mmap 参数的文档。此配置将设置 torch.load() 的行为,如果它尚未显式传递给 torch.load 调用(默认值: False )。

  • endianness : 请参阅 set_default_load_endianness() 。(默认值: torch.serialization.LoadEndianness.NATIVE )

  • mmap_flags : 请参阅 set_default_mmap_options 。(默认值: MAP_PRIVATE )

  • calculate_storage_offsets : 如果此配置设置为 True ,则在使用 torch.load(mmap=True) 时将计算存储的偏移量,而不是通过随机读取读取。这可以最小化随机读取,当文件通过网络加载时可能很有帮助。(默认值: False


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源