快捷键

torch.distributed.checkpoint.state_dict 的源代码

# mypy: 允许未类型化定义
导入 contextlib
导入 functools
导入 gc
导入 警告
来自 collections.abc 导入 生成器, 迭代器
来自 dataclasses 导入 asdict, 数据类, 字段
来自 itertools 导入 连接
来自 打字 导入 任何, 可调用, 角色, 不进行类型检查, 可选, 联合

导入 火炬
导入 torch.distributed  dist
导入 torch.nn  神经网络
来自 torch.distributed._shard.sharded_tensor 导入 分片张量
来自 torch.distributed._state_dict_utils 导入 (
    _广播状态字典,
    _分发状态字典,
    _展平状态字典,
    _收集状态字典,
    _将状态字典卸载到 CPU,
    _展开状态字典,
)
来自 torch.distributed.algorithms._checkpoint.checkpoint_wrapper 导入 (
    CHECKPOINT_PREFIX,
)
来自 torch.distributed.fsdp 导入 (
    全优化状态字典配置,
    全状态字典配置,
    全局分片数据并行  FSDP,
    优化状态字典配置,
    分片优化状态字典配置,
    分片状态字典配置,
    状态字典配置,
    状态字典类型,
)
来自 torch.distributed.fsdp._common_utils 导入 (
    _get_module_fsdp_state_if_fully_sharded_module,
    FSDP 包装模块,
)
来自 torch.distributed.tensor 导入 DTensor
来自 torch.nn.modules 下的 module 模块 导入 不兼容键
来自 torch.nn.parallel 导入 分布式数据并行  DDP
来自 torch.utils._pytree 导入 树图模式


全部 = [
    "FQNS_T",
    基本类型,
    值类型,
    "字典值类型",
    "列表字典值类型",
    "优化器状态类型",
    "状态字典选项",
    "获取模型状态字典",
    "获取优化器状态字典",
    "获取状态字典",
    "设置模型状态字典",
    设置优化器状态字典,
    设置状态字典,
]


_FLAT_PARAM = _flat_param
_PG = 参数组
_PARAMS = 参数
状态 = state

FQNS_T = 集合[字符串]
基本类型 = 联盟[DTensor, 分片张量, 火炬.张量, 整数, float, 字符串]
值类型 = 联盟[
    基本类型, 列表[基本类型], 元组[基本类型], 字典[字符串, 值类型]
]
字典值类型 = 字典[字符串, 值类型]
列字典值类型 = 列表[字典值类型]
优化器状态类型 = 字典[字符串, 联盟[字典值类型, 列字典值类型]]


_patched_state_dict: 集合[可调用] = 集合()


@contextlib.contextmanager
定义 _gc_context():
    is_enabled = 垃圾回收.isenabled()
    垃圾回收.禁用()
    try:
        产生
    最后:
        如果 启用:
            垃圾回收.启用()


[文档]@dataclass 类 StateDictOptions: """ 该数据类指定了 get_state_dict/set_state_dict 将如何工作。 - ``full_state_dict``: 如果设置为 True,则所有张量都将被保存 返回的状态字典将被收集。没有 ShardedTensor 和 DTensor 将返回的状态字典中。 - ``cpu_offload``: 将所有张量卸载到 CPU。为防止 CPU OOM,如果 ``full_state_dict`` 也为真时,只有 rank0 将获得 state_dict,其他所有 rank 将获得空的 state_dict。 state_dict,而其他所有 rank 将获得空的 state_dict。 - ``ignore_frozen_params``: 如果值为 True,返回的 state_dict 将不包含任何冻结参数 -- ``requires_grad`` 为 False。 返回的 state_dict 将不包含任何冻结参数 -- ``requires_grad`` 为 False。 默认值为 False。 - ``keep_submodule_prefixes``(已弃用):当``submodules``不为 None 时,此选项 表示是否从 state_dict 键中保留子模块前缀。 例如,如果子模块是``module.pretrain``并且完整的 FQN 为 参数是 `pretrain.layer1.weight` 的 param。当此选项为 True 时, 返回的状态字典中参数的键将是 `pretrain.layer1.weight`。如果选项为 False, 键将是 `layer1.weight`。 。 注意,如果 `keep_submodule_prefixes` 为 False,则可能存在冲突 因此,`submodules` 中应该只有一个子模块 - `strict`:当调用 `set_state_dict` 时 `strict` 选项 模型加载状态字典 `load_state_dict()` 时的 `strict` 选项 - ``broadcast_from_rank0``: 当选项为 True 时,rank0 应接收 完整的状态字典并将广播状态字典中的张量 逐个将 optim_state_dict 发送到其他进程。其他进程将接收 张量根据模型中的本地分片进行分片 当使用此选项时,必须将 `full_state_dict` 设置为 True。 当前此选项仅支持 DTensor,不支持传统的 ShardedTensor。 """ full_state_dict: 布尔型 = False cpu_offload: 布尔值 = False ignore_frozen_params: 布尔值 = False keep_submodule_prefixes: 布尔值 = True strict: 布尔值 = True broadcast_from_rank0: 布尔值 = False flatten_optimizer_state_dict: 布尔值 = False dsd_fqn_modifiers: 字符串 = "_fqn_modifiers"
@dataclass
状态字典信息(状态字典选项): 完整限定名参数映射: 字典[ 联盟[字符串, 火炬.张量], 联盟[FQNS_T, 火炬.张量], ] = 字段(默认工厂=字典) 共享参数映射: 字典[ 联盟[字符串, 火炬.张量], 联盟[FQNS_T, 火炬.张量], ] = 字段(默认工厂=字典) 子模块前缀: 集合[字符串] = 字段(默认工厂=集合) 处理模型: 布尔类型 = 真实 处理优化: 布尔类型 = 真实 fsdp 上下文: 可调用 = contextlib.空上下文 fsdp 模块: 列表[神经网络.模块] = 字段(默认工厂=列表) @functools.缓存 定义 获取完全限定名( 模型: 神经网络.模块, 名称: 字符串, dsd_fqn_modifiers: 字符串 = _fqn_modifiers, 跳过 ddp 前缀: 布尔类型 = True, 跳过编译器前缀: 布尔类型 = True, ) -> FQNS_T: "" 此 API 用于将参数名称转换为完全限定名称(FQNs)。对于 FSDP 没有 `use_orig_params`,FlatParameter 的名称可以被映射到 多个原始参数。因此,此函数的返回类型 是 `set[str]` 类型。 参数: 模块 (nn.Module):根模型。 名称 (str):名称。 跳过 DDP 的 `module` 前缀 (bool):是否跳过 DDP 的 `module` 前缀。 返回: 基于模型遍历的规范 FQNs。 "文档" 如果存在,则删除检查点前缀。 名称 = 名称.替换(_CHECKPOINT_PREFIX, 输入文本翻译为简体中文为:"") 如果 "." 名称: 返回 {名称} obj_names = 名称.分割(“点”) 完全限定对象名称 = 输入文本为空,请提供需要翻译的文本 当前对象 = 模型 i, 当前对象名称 列举(对象名称列表): 如果 isinstance(当前对象, DDP): 断言 当前对象名称 == "模块" 当前对象 = 当前对象.模块 如果 跳过 ddp 前缀: 完全限定对象名称.追加(当前对象名称) elif isinstance(当前对象, FSDP): 如果 i < len(对象名称) - 1 对象名称[i + 1] == 平坦参数: 前缀 = “点”.连接(完整对象名称) 平坦参数 = getattr(当前对象, _平面参数) 如果 前缀: 前缀 = f"{前缀} 返回 {f"{前缀}{完全限定名}" 完全限定名 平坦参数.完全限定名} 当前对象 = getattr(当前对象, FSDP 包装模块) 如果 当前对象名称 != FSDP 包装模块: 完全限定对象名称.追加(当前对象名称) 当前对象 = getattr(当前对象, 当前对象名称) elif isinstance(当前对象, 火炬._dynamo.评估框架.优化模块): 断言 当前对象名称 == _原模 当前对象 = 当前对象._原模 如果 跳过编译器前缀: 完全限定对象名称.追加(当前对象名称) else: 在某些模块中,_fqn_modifiers 不会显示在 state_dict 键中 # 跳过它们在 fqn 中的使用以确保成功加载统计字典。 如果 有属性(当前对象, dsd_fqn 修饰符): 如果 已删除的 fqn := getattr(当前对象, dsd_fqn 修饰符)().获取( 当前对象名称 ): 如果 有属性(当前对象, 删除的完全限定名): 当前对象 = getattr(当前对象, 删除的完全限定名) 完全限定对象名称.追加(当前对象名称) 如果 当前对象名称 == 神经网络.模块.模块._额外状态键后缀: 如果 i != len(对象名称) - 1: raise 运行时错误(预期 `_extra_state` 是最后一个对象名称) else: 当前对象 = getattr(当前对象, 当前对象名称) 返回 {“点”.连接(完全限定对象名称).替换(_CHECKPOINT_PREFIX, 输入文本翻译为简体中文为:"")} _EXTRA_STATE: 通过 定义 迭代有效模型状态(模型, dsd_fqn_modifiers="_fqn_modifiers"): 访问过的模块: 集合[神经网络.模块] = 集合() 定义 递归(模块: 神经网络.模块, 当前全限定名: 字符串) -> 生成器: 已访问模块.添加(模块) 当前全限定名 = f"{当前全限定名} 如果 当前全限定名 否则 请提供需要翻译的文本 名称, 子模块 模块.命名子项(): 如果 子模块 已访问模块: continue 如果用户在模型中有 state_dict_hooks,他们可以添加 state_dict 键的变化 在 dsd_fqn_modifiers 输入中与 state_dict_hook 的功能对齐 如果 ( 有属性(模块, dsd_fqn_modifiers) 名称 getattr(模块, dsd_fqn_modifiers)().() ): # 跳过_fqn_modifiers 这里因此移除最后添加的`.` 新_fqn = 当前_fqn[-1] else: 新_fqn = f"{当前全限定名}{名称}" yield from 递归(子模块, 新全限定名) 名称, 对象 chain( 模块.命名缓冲区(递归=错误), 模块.命名参数。(递归=错误) ): 如果 名称 模块._非持久缓冲区集合: continue 新的全限定名 = f"{当前全限定名}{名称}" 产生 新全名, 对象 如果 ( getattr(模块., 获取额外状态, 神经网络.模块.获取额外状态) != 神经网络.模块.获取额外状态 ): 新全名 = f"{当前全限定名}{神经网络.模块.模块._额外状态键后缀}" 产生 新全限定名, _额外状态() yield from 递归(模型, 输入文本翻译为简体中文为:"") 定义 验证选项( 模型: 神经网络.模块, 优化: 元组[火炬.优化.优化器, ...], 仅优化: 布尔, *, 子模块: 可选[集合[神经网络.模块]] = , 选项: 可选[状态字典选项] = , ) -> _状态字典信息: "" 验证用户传入的模型和选项,并生成 _状态字典信息。 "文档" 如果 子模块: warnings.warn( "仅获取子模块的模型/优化状态字典已被弃用," "将在 2.5 版本中移除。此功能可以通过手动" "过滤从 get_state_dict 返回的状态字典来实现。", 未来警告, ) 如果 仅优化 优化器: raise 运行时错误( 优化器未传入,但仅优化设置为 True。 ) 选项 = 选项 或者 状态字典选项() 完整限定名参数映射: 字典[ 联盟[字符串, 火炬.张量], 联盟[集合[字符串], 火炬.张量] ] = {} 共享参数映射: 字典[ 联盟[字符串, 火炬.张量], 联盟[集合[字符串], 火炬.张量] ] = {} 名称, 参数 遍历有效模型状态(模型): 如果 isinstance(参数, _EXTRA_STATE): continue fqns = _get_fqns(模型, 名称) 完全限定名 = fqn_param_mapping.获取(参数, ) 如果 完全限定名 : 角色(集合[字符串], fqn_param_mapping[参数]).更新(fqns) 共享参数映射[参数] = 完整限定名参数映射[参数] else: # 我们需要复制,因为_get_fqns 是 lru 缓存的 完整限定名参数映射[参数] = fqns.复制() 完全限定名 fqns: 如果 isinstance(参数, _EXTRA_STATE): 完全状态[完全限定名] = 参数 param_, fqns_ 列表(shared_params_mapping.项目()): 完全限定名 fqns_: 共享参数映射[完全限定名] = 角色(火炬.张量, 参数_) 子模块前缀: 集合[字符串] = 集合() 如果 子模块: 子模块 = 集合(子模块) 名称, 模块 模型.命名模块(): 如果 模块 子模块: continue fqns = _get_fqns(模型, 名称) 断言 len(fqns) == 1, 子模块全限定名应只有一个实例 子模块前缀.更新(f"{完全限定名} 完全限定名 fqns) 如果 选项.从 rank0 广播 选项.完整状态字典: raise ValueError( 当 broadcast_from_rank0 为 True 时,full_state_dict 必须为 True。 ) fsdp 模块 = FSDP.fsdp 模块(模型) 状态字典配置: StateDictConfig 优化状态字典配置: OptimStateDictConfig fsdp 上下文: 可调用 如果 fsdp 模块: # FSDP API 仅在至少存在一个 FSDP 实例时才工作。 如果 选项.全状态字典: 状态字典配置 = 全状态字典配置( 转移至 CPU=选项.CPU 卸载, 仅 rank0=选项.CPU 卸载 ) 优化状态字典配置 = 全优化状态字典配置( 卸载到 CPU=选项.CPU 卸载, rank0 独占=(选项.CPU 卸载 或者 选项.从 rank0 广播), ) 状态字典类型 = 状态字典类型.全状态字典 else: 状态字典配置 = 分片状态字典配置( 转移到 CPU=选项.CPU 卸载, ) 优化状态字典配置 = 分片优化状态字典配置( 转移至 CPU=选项.CPU 卸载, ) 状态字典类型 = 状态字典类型.碎片化状态字典 @contextlib.contextmanager 定义 无警告的 fsdp 状态字典类型( 模块, 状态字典类型, 状态字典配置, 优化状态字典配置, ): warnings.捕获警告(): warnings.过滤警告( "忽略", 消息=FSDP.state_dict_type, 分类=未来警告 ) FSDP.state_dict_type( 模块=模块, 状态字典类型=状态字典类型, 状态字典配置=状态字典配置, 优化状态字典配置=优化状态字典配置, ): 产生 全局状态字典上下文 = functools.偏函数( 无警告的全局状态字典类型, 模块=模型, 状态字典类型=状态字典类型, 状态字典配置=状态字典配置, 优化状态字典配置=优化状态字典配置, ) else: 全局状态字典上下文 = contextlib.空上下文 返回 状态字典信息( **asdict(选项), 完全限定名称映射=完全限定名称映射, 共享参数映射=共享参数映射, 子模块前缀=子模块前缀, fsdp 上下文=fsdp 上下文, fsdp 模块=角色(列表[神经网络.模块], fsdp 模块), 处理模型= 仅优化, 处理优化=(len(优化) > 0), ) 定义 _验证状态字典( 模型状态字典: 字典[字符串, 值类型], 优化状态字典: 优化器状态类型, 信息: _状态字典信息, ) -> : 模块 信息.fsdp 模块: fsdp 状态 = _get_module_fsdp_state_if_fully_sharded_module(模块) 断言 fsdp 状态 , "期望有一个带有 fsdp 模块的 fsdp_state。" # 验证 model_state_dict 和 optim_state_dict 是否有效。此 API # 应该为用户提供一个明确的错误消息以供调试或报告。 如果 ( 信息.处理模型 模型状态字典 信息.子模块前缀 信息.忽略冻结参数 (信息.CPU 卸载 信息.full_state_dict) 信息.严格的 信息.从 rank0 广播 ): raise 运行时错误( "选项表示需要保存或加载模型的状态字典,但模型的状态字典为空。" "或加载,但模型状态字典为空。" frank ={距离.获取排名()=} ) 如果 信息.handle_optim: 如果 ( optim_state_dict (信息.cpu_offload 信息.full_state_dict) ( 信息.broadcast_from_rank0) ): raise 运行时错误( "选项表示需要保存模型的状态字典," f"或者加载,但优化状态字典为空。"{优化状态字典}" ) key 模型状态字典.(): 如果 _FLAT_PARAM : raise 运行时错误( f"{}包含{_FLAT_PARAM}这可能发生,如果模型“ "不是根模块。" ) 定义 _state_dict_fn(对象: 联盟[神经网络.模块, 火炬.优化.优化器], api: 字符串) -> 可调用: 调用 = getattr(对象, api) 如果 呼叫 _修复状态字典: 呼叫 = functools.偏函数(getattr(对象., api), =对象) 返回 呼叫 定义 可能是完整或 CPU 状态字典( state_dict: 字典[字符串, 任何], 信息: 状态字典信息 ) -> 字典[字符串, 任何] 如果 信息.完整状态字典: 仅排名 = ( () 如果 ( 信息.CPU 卸载 或者 火炬.分布式.已初始化()) 否则 (0,) ) 返回 _获取状态字典( state_dict, cpu 卸载=信息.CPU 卸载, 仅排名=仅排名 ) elif 信息.CPU 卸载: 返回 将状态字典卸载到 CPU(state_dict) else: 返回 状态字典 @torch.不梯度() 定义 获取模型状态字典( 模型: 神经网络.模块, 信息: 状态字典信息 ) -> 字典[字符串, 值类型] 如果 信息.处理模型: 返回 {} 信息.fsdp 上下文(): 状态字典 = _状态字典函数(模型, "state_dict")() key 列表(state_dict.()): fqns = 获取完全限定名称(模型, ) 断言 len(fqns) == 1, (, fqns) 完全限定名 = 下一(迭代(fqns)) 如果 完全限定名 != : 仅支持 FSDP、DDP 和 TP,所以只有以下情况 基于封装器的 DDP 和编译器。验证假设 # 是正确的。 定义 核实(, 完全限定名) -> 布尔: 如果 len(完全限定名) >= len(): 返回 fqn_split = 完全限定名.分割(“点”) 关键词分割 = .分割(“点”) 完全限定名索引 = 0 关键词索引, 关键词名称 列举(关键分割): 如果 关键名称 == 完全限定名分割[完全限定名索引] fqn_idx += 1 如果 fqn_idx == len(完全限定名分割): 返回 关键索引 == len(关键分割) - 1 elif 关键名称 (模块, "_orig_mod"): continue else: 返回 返回 真实 如果 核实(, 完全限定名): raise 运行时错误(f"一个意外的键,{}存在。全称是{完全限定名}") state_dict[完全限定名] = state_dict.流行() 如果 信息.子模块前缀: 新状态字典: 字典[字符串, 值类型] = {} # TODO: 使其更快。 完全限定名 state_dict.(): 前缀 信息.子模块前缀: 如果 完全限定名.以...开头(前缀): continue 如果 信息.保留子模块前缀: 新的状态字典[完全限定名] = state_dict[完全限定名] else: 新的全限定名 = 完全限定名[len(前缀) ] 新状态字典[新全称] = state_dict[完全限定名] 状态字典 = 新状态字典 如果 信息.忽略冻结参数: , 参数 模型.命名参数。(): 如果 参数.需要梯度: continue fqns = _获取完全限定名(模型, ) 完全限定名 fqns: state_dict.流行(完全限定名) , p 列表(state_dict.项目()): 如果 火炬.is_tensor(p) p.是否是元数据: state_dict.流行() 返回 可能完整或 cpu 状态字典(state_dict, 信息) @torch.不梯度() 定义 加载模型状态字典( 模型: 神经网络.模块, state_dict: 字典[字符串, 值类型], 信息: _状态字典信息, ) -> _不兼容键: 如果 信息.处理模型 或者 ( 状态字典 信息.rank0 广播): 返回 不兼容键({}, {}) 本地状态字典 = {} , value 遍历有效模型状态(模型, 信息.dsd_fqn_modifiers): fqns = 获取完全限定名(模型, , 信息.dsd_fqn_modifiers) 带前缀的完全限定名 = 获取完全限定名( 模型, , 信息.dsd_fqn_modifiers, 跳过 ddp 前缀=错误, 跳过编译器前缀=错误, ) 完全限定名, 带前缀的完全限定名 zip(fqns, 带前缀的完全限定名集合): 如果 ( 信息.从 rank0 广播 或者 距离.获取排名() == 0 ) 完全限定名 != 带前缀的全限定名称: 加载值 = state_dict.流行(完全限定名, ) 如果 加载值 : 如果 信息.严格的: raise 运行时错误(f"缺少键:"{完全限定名}.") else: state_dict[带前缀的完全限定名] = 加载值 本地状态字典[带前缀的全限定名称] = value 分配 = 如果 信息.rank0 广播 或者 信息.完整状态字典: 设备 = 集合() , value 本地状态字典.项目(): 如果 火炬.is_tensor() .暗淡() > 0: 设备.添加(.设备) 在 lora 状态字典中,可能会有多个设备,其中包含 meta 设备。 将广播/分配中的其他设备取出来,并将 assign 设置为 True。 如果 火炬.设备(元数据) 设备: 设备.删除(火炬.设备(元数据)) 分配 = 真实 如果 len(设备) == 0: 设备.添加(距离.分布式_c10d._get_pg_default_device()) elif len(设备) > 1: raise ValueError("发现多个设备") 如果 信息.从 rank0 广播: _广播状态字典( state_dict, 本地状态字典, 设备=设备.流行(), 严格的=信息.严格的, CPU 卸载=信息.CPU 卸载, ) elif 信息.全状态字典: _分发状态字典(state_dict, 本地状态字典, 设备=设备.流行()) 完全限定名, 本地状态 本地状态字典.项目(): state_dict[完全限定名] = 本地状态 信息.fsdp 上下文(): 返回 角色( _不兼容键, _状态字典函数(模型, "load_state_dict")( state_dict=state_dict, 严格的=信息.严格的, 分配=分配 ), ) 定义 初始化优化状态(优化: 火炬.优化.优化器) -> : "" 初始化优化状态,通过调用带有零梯度的 step()函数。 "文档" 如果 优化.状态: 优化器状态已初始化。 返回 # 一些优化器是无状态的,例如 SGD。这些优化器将 # 不在上述条件下返回。因此,如果存在梯度,我们也应该 如果梯度不存在,则应返回。如果梯度不存在,则不应干扰 SGD,因为梯度和 lr 都为零。 不应干扰 SGD,因为梯度和学习率 lr 都为零。 参数组 优化.参数组: 参数 参数组[_参数] 如果 参数.梯度 : 返回 参数组 优化.参数组: 参数 参数组[_参数] 如果 参数.需要梯度: 参数.梯度 = 火炬.与...相同形状的零(参数) 一些优化器会由于学习率(lr)而更新参数,即使梯度(grads)为零。 在调用 `step()` 时将学习率(lr)设为零。 学习率(lrs) = 输入文本为空,请提供需要翻译的文本 参数组 优化.参数组: 如果 "学习率(lr)" 参数组: lrs.追加(参数组["lr"]\) 参数组["lr"] = ( 火炬.张量(0.0) 如果 isinstance(参数组["lr"], 火炬.张量) 否则 0.0 ) 优化.步长(闭包=) 是否恢复“lr”并不重要,因为我们稍后会恢复检查点。 稍后恢复检查点。 参数组 优化.参数组: 如果 "lr" 参数组: 参数组["lr"] = lrs.流行(0) 优化.零梯度(设置为 None=True) 定义 _flatten_optim_state_dict(state_dict: 优化器状态类型) -> 字典[字符串, 值类型] "" 此 API 将优化器状态字典扁平化,以支持优化器重分片,例如,流水线并行(MPMD)。 没有此 API,原始优化器状态字典看起来如下: 无此 API,原始优化器状态字典将呈现如下: { "state": { "layer1.weight": { "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor }, "layer2.weight": { "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor" }, }, "param_group": [ { "lr": 0.0, "betas": (0.9, 0.95), ... "params": ["layer1.weight", "layer2.weight"] } ] } 使用此 API,优化器状态字典看起来如下: { "state.layer1.weight.step": 10, "state.layer2.weight.step": 10, "state.layer1.weight.exp_avg": SomeTensor "state.layer2.weight.exp_avg": SomeTensor "state.layer1.weight.exp_avg_sq": SomeTensor "state.layer2.weight.exp_avg_sq": SomeTensor "param_group.layer1.weight.lr" : 0.1, "param_group.layer2.weight.lr" : 0.1, "param_group.layer1.weight.betas" : (0.9, 0.95), "param_group.layer2.weight.betas" : (0.9, 0.95), } 注意,如果任何值是容器,例如示例中的 beta,则此 API 不会将其展平。 此 API 不会将容器扁平化。 "文档" 定义 _如果类型不受支持则抛出异常(v): 如果 isinstance(v, (火炬.张量, 整数, float)): raise 不支持的操作异常( 仅支持“Flattening optimizer state_dict” "张量、整型、浮点状态现在。" f"类型是"{类型(v)} ) 返回: 字典[字符串, 值类型] = {} 完全限定名, 状态 角色(字典值类型, state_dict[_状态]).项目(): k, v 角色(字典值类型, 状态).项目(): 如果类型不受支持则抛出异常(v) 返回[f"{状态}.{完全限定名}.{k}"] = v 参数组 角色(列字典值类型, state_dict[_PG)] fqns = 参数组.流行(_PARAMS) 完全限定名 角色(列表[字符串], fqns): k, v 参数组.项目(): 返回[f"{_PG}.{完全限定名}.{k}"] = v 返回 返回 定义 _unflatten_optim_state_dict( 优化: 火炬.优化.优化器, state_dict: 字典[字符串, 值类型], 信息: _状态字典信息, ) -> 优化器状态类型: "" 该 API 将由_flatten_optim_state_dict()生成的 state_dict 展开。 请参阅_flatten_optim_state_dict()的文档字符串以获取更多详细信息。 "文档" 状态: 字典值类型 = {} pg_state: 列表字典值类型 = 输入文本为空,请提供需要翻译的文本 返回_osd: 优化器状态类型 = {_状态: 状态, _页面组: pg 状态} 参数组 优化.参数组: pg 状态.追加({_参数: []}) 参数 参数组[_参数] 完全限定名 信息.完整限定名参数映射[参数] # 如果参数是共享的,则只会使用一个全限定名(FQN)。 # 因此我们需要验证这个 FQN 是否实际上在 # 状态字典中使用。 如果 完全限定名 信息.共享参数映射: 输入参数 = k 参数组.(): 如果 k == _参数: continue 展平键 = f"{_PG}.{完全限定名}.{k}" 如果 展平键 state_dict: 输入参数 = 真实 断开 else: 输入参数 = 真实 如果 输入参数: continue params = 程序状态[-1] [_参数] 断言 isinstance(参数, 列表) # 输入 参数.追加(完全限定名) 如果 参数.需要梯度: continue 状态[完全限定名] = {} 状态名称 优化.状态[参数].(): 角色(字典值类型, 状态[完全限定名])[状态名称] = state_dict[ f"{_状态}.{完全限定名}.{状态名称}" ] first_param_fqn = 角色(列表[字符串], pg 状态[-1] [参数])[0] k 参数组.(): 如果 k == 参数: continue value = state_dict[f"{_PG}.{first_param_fqn}.{k}"] 如果 k pg_state[-1] pg_state[-1] [k] = value elif pg_state[-1] [k] != : raise 运行时错误( "同一参数组中的所有参数应具有 " f"相同的已保存 param_group 值。但{first_param_fqn}.{k} " f"是{}当其他(些)是{pg 状态[-1] [k]} ) 返回 返回 osd @torch.不梯度() 定义 获取优化状态字典( 模型: 神经网络.模块, 优化器: 元组[火炬.优化.优化器, ...], 信息: _状态字典信息, ) -> 优化器状态类型: 如果 信息.处理优化: 返回 {} 优化状态字典: 优化器状态类型 = {_状态: {}, _PG: []} 优化 优化器: 初始化优化状态(优化) osd(对象存储设备) = 状态字典函数(优化, "state_dict")() 如果 信息.fsdp 模块: 信息.fsdp 上下文(): osd = FSDP.优化状态字典(模型, 优化, osd) 我们需要特别处理 FlatParameter FSDP。 FlatParameter FSDP 会将 FQNs 进行转换。 没有简单的方法可以系统地完成这种转换。 我们只能使用字符串替换,而不进行正确性检查。 如果 操作系统守护进程: continue k 列表(操作系统守护进程[状态].()): 如果 "_原修改" k: osd[状态] [k.替换("_orig_mod.", 输入文本翻译为简体中文为:"")] = osd[_STATE].流行(k) g osd[_PG] params = [k.替换("_orig_mod.", 输入文本翻译为简体中文为:"") k g[_PARAMS]] g[_PARAMS] = params else: params = 列表(chain.from_iterable(g[参数] g 优化.参数组)) 全限定名 PID 映射 = 字典(zip(参数, 范围(len(参数)))) FQN PID 映射 = {} , 参数 模型.命名参数。(): fqns = 获取完全限定名称(模型, ) 断言 len(fqns) == 1 完全限定名 = 下一(迭代(fqns)) 如果 参数 参数 PID 映射: continue 进程 ID = 参数 PID 映射[参数] 完全限定名 PID 映射[完全限定名] = 进程 ID 完全限定名 PID 映射[进程 ID] = 完全限定名 key 列表(osd[_STATE].()): 完全限定名 = fqn_pid_mapping[] osd[状态] [完全限定名] = osd[状态].流行() osd[_PG] 群组[_参数] = [完全限定名称 PID 映射[进程 ID] 进程 ID 群组[参数]] 如果 osd: continue 角色(字典值类型, 优化状态字典[状态]).更新(osd[_状态]\) 角色(列字典值类型, 优化状态字典[_PG]).扩展(osd[_PG]\) 如果 信息.展平优化器状态字典: optim_state_dict = 角色( 优化器状态类型, _flatten_optim_state_dict(优化状态字典) ) 返回 可能完整或 cpu 状态字典(优化状态字典, 信息) 定义 _split_optim_state_dict( 模型: 神经网络.模块, 优化: 火炬.优化.优化器, 优化状态字典: 优化器状态类型, 信息: _状态字典信息, ) -> 优化器状态类型: "" 从 `optim_state_dict` 中提取对应的优化状态字典。 将结果优化状态字典返回给 `optim`。 参数: model (nn.Module):根模型。 optim (torch.optim.Optimizer):优化器。 optim_state_dict (Dict[str, ValueType]): 优化状态字典的子集 包含 ``optim`` 的优化状态字典。 info (_StateDictInfo): 状态字典信息。 返回: ``optim`` 的优化状态字典。 "文档" 状态: 字典值类型 = {} 状态: 列字典值类型 = 输入文本为空,请提供需要翻译的文本 返回 osd: 优化器状态类型 = {状态: 状态, _PG: pg 状态} pg 映射: 字典[整数, 整数] = {} 如果 所有( isinstance(k, 整数) k 角色(字典值类型, 优化状态字典[_状态]).() ): 返回 optim_state_dict 参数组 优化.参数组: pg_state.追加({_PARAMS: []}) 参数 参数组[参数] 完全限定名 信息.fqn_param_mapping[参数] 如果 完全限定名 信息.共享参数映射: 参数列表 = 加载的参数组 角色( 列字典值类型, 优化状态字典[_PG] ): 如果 完全限定名 角色(列表[字符串], 加载的参数组[_PARAMS)] 输入参数 = 真实 断开 else: 输入参数 = 真实 如果 输入参数: continue params = pg 状态[-1] [_参数] 断言 isinstance(参数, 列表) 参数.追加(完全限定名) 如果 参数.需要梯度: 状态[完全限定名] = 角色(字典值类型, 优化状态字典[_状态])[完全限定名] 加载的参数组 角色( 列表字典值类型, 优化状态字典[_PG] ): 如果 完全限定名 角色(列表[字符串], 加载的参数组[_PARAMS)] 索引映射[id(已加载的参数组)] = len(返回 OSD[_PG]\) - 1 如果 len(参数组[参数]\) == 0: # 参数组无参数。 返回 = 输入文本为空,请提供需要翻译的文本 已加载的参数组 角色(列字典值类型, 优化状态字典[_PG)] 如果 len(角色(列表[字符串], 加载的参数组[_PARAMS])) == 0: 返回.追加(加载的参数组) 如果 len(返回) != 1: raise ValueError( "存在参数组为零的情况。" "在这种情况下,DSD 仅支持恰好一个参数组。" "该参数组为零。" "但是加载的状态字典有零个或多个参数组。" "参数为零。" ) 如果 len(优化状态字典[_PG]\) != len(优化.参数组): raise ValueError( "当存在一个参数组且参数为零时," "不支持多个优化器。" ) 索引映射[id(已加载的参数组)] = len(返回 OSD[_PG]\) - 1 参数组 角色(列字典值类型, 优化状态字典[_PG)] pg_idx = pg_mapping.获取(id(参数组), -1) 如果 pg_idx == -1: continue , value 参数组.项目(): 如果 key == _参数: continue # TODO: 检查是否存在相同值。 pg_state[索引] [] = value 返回 返回 osd @torch.不梯度() 定义 加载优化状态字典( 模型: 神经网络.模块, 优化器: 元组[火炬.优化.优化器, ...], state_dict: 优化器状态类型, 信息: _状态字典信息, ) -> : 如果 信息.处理优化: 返回 优化 优化器: _初始化优化状态(优化) 如果 state_dict: 如果 _状态 state_dict: optim_state_dict = _split_optim_state_dict( 模型, 优化, state_dict, 信息 ) else: optim_state_dict = _unflatten_optim_state_dict( 优化, 角色(字典[字符串, 值类型], state_dict), 信息 ) else: optim_state_dict = {} 如果 信息.fsdp 模块: # 我们需要特别处理 FlatParameter FSDP # FlatParameter FSDP 会将 FQNs 转换 original_fqn, _ 模型.命名参数。(): fqns = _get_fqns(模型, 原始完全限定名称) 带编译器的完全限定名称 = _get_fqns( 模型, 原始完全限定名称, 跳过编译器前缀= ) 如果 fqns == 编译器中的完全限定名: continue 断言 len(fqns) == 1 完全限定名 = fqns.流行() 编译器中的限定名 = 编译器中的完全限定名.流行() g 优化状态字典[_PG] val = 角色(字典[字符串, 任何], g) params = [ .替换(完全限定名, 编译器中的全限定名称) key val[_PARAMS] ] val[_PARAMS] = params osd 状态 = 角色(字典值类型, 优化状态字典[_状态]\) k 列表(osd 状态.()): 如果 完全限定名 k: osd 状态[k.替换(完全限定名, 带编译器的 FQN)] = osd 状态.流行(k) 信息.fsdp 上下文(): 优化状态字典 = FSDP.加载优化器状态字典( 模型, 优化, 优化器状态字典 ) elif 信息.完整状态字典: 信息.全状态字典 = 本地状态字典 = _获取优化状态字典(模型, (优化,), 信息) 信息.全状态字典 = 真实 设备 = 定义 设备(t): 如果 t.暗淡() > 0: 非局部 设备 如果 设备 : 设备 = t.设备 elif 设备 != t.设备: raise ValueError(设备不匹配) 返回 t _ = 仅树映射(火炬.张量, 设备, 本地状态字典) 断言 设备 展平 osd, osd 映射 = _展平状态字典(优化状态字典) 展平本地 OSD, 本地 OSD 映射 = _展平状态字典(本地状态字典) 如果 信息.从 rank0 广播: _broadcast_state_dict(flatten_osd, flatten_local_osd, 设备=设备) else: _distribute_state_dict(平展 OSD, 展平本地 OSD, 设备=设备) 所列的修改旨在解决“optim”可能存在的问题 与 optim_state_dict 相比,参数不同。这是通过 将局部区域内的微分参数进行整合,可能会导致优化 最终带有额外参数。 优化键 展平 osd.(): 如果 优化键 展平本地 OSD: 断言 优化键 OSD 映射 展平本地 OSD[优化密钥] = 平滑 OSD[优化密钥] 本地 OSD 映射[优化键] = OSD 映射[优化键] 优化状态字典 = _unflatten_state_dict( flatten_local_osd, local_osd_mapping ) pg 优化状态字典[_PG] 如果 _参数 pg: 角色(字典[字符串, 值类型], pg)_参数] = 输入文本为空,请提供需要翻译的文本 # 注意,我们在此处无需将 FQN 转换回参数 ID # 在 optim.param_groups[idx][_参数] 中的顺序与在 # optim_state_dict[_PG][idx][_PARAMS] _状态字典函数(优化, "load_state_dict")(state_dict=优化状态字典)
[文档]def get_model_state_dict( model: nn.Module, *, submodules: Optional[set[nn.Module]] = None, options: 可选[StateDictOptions] = None, ) -> dict[str, ValueType]: """ 返回 ``model`` 的模型状态字典。 查看文档 ``get_state_dict`` 了解详细用法。 参数: model (nn.Module): 模型的 nn.Module。 submodules (已弃用): Optional[set[nn.Module]]: 仅返回模型参数 属于子模块的。 选项(StateDictOptions):控制如何的选项 模型状态字典和优化器状态字典应返回。查看 `StateDictOptions` 的详细信息。 返回值: ``model`` 的 state_dict。 返回类型: typing.Dict[str, ValueType] """ with _gc_context(): info = _verify_options( 模型, (), optim_only=False, submodules=submodules, options=options, ) model_state_dict = _get_model_state_dict(model, info) _verify_state_dict(model_state_dict, {}, info) return model_state_dict
[文档]def get_optimizer_state_dict( 模型: nn.Module, optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], *, submodules: Optional[set[nn.Module]] = None, options: Optional[StateDictOptions] = None, ) -> OptimizerStateType: """ 返回优化器的合并状态字典。 请参阅 `get_state_dict` 了解详细用法。 Args: model (nn.Module): 模型的 nn.Module。 优化器(Union[None, 优化器, Iterable[优化器]]): 用于优化 `model` 的优化器。 子模块(已弃用):Optional[set[nn.Module]]:仅返回模型参数 属于子模块的。 options (StateDictOptions):控制如何返回模型状态字典和优化器状态字典的选项。请参阅`StateDictOptions`获取详细信息。 应返回的状态字典和优化器状态字典的选项。请参阅`StateDictOptions`获取详细信息。 `StateDictOptions`获取详细信息。 返回值: ``optimizers`` 的 state_dict 返回类型: OptimizerStateType """ 在 _gc_context() 环境下: optimizers = ( (optimizers,) 如果 isinstance(optimizers, torch.optim.Optimizer) 否则 tuple(optimizers) ) info = _verify_options( 模型, 优化器, optim_only=True, submodules=submodules, options=options, ) optim_state_dict = _get_optim_state_dict(model, optimizers, info) _verify_state_dict({}, optim_state_dict, info) return optim_state_dict
[文档]定义 get_state_dict( 模型: 神经网络.模块, 优化器: 联盟[火炬.优化.优化器, 迭代器[火炬.优化.优化器]], *, 子模块: 可选[集合[神经网络.模块]] = , 选项: 可选[状态字典选项] = , ) -> 元组[字典[字符串, 值类型], 优化器状态类型] "" 返回模型的状态字典和优化器的状态字典。 `get_state_dict` 可以处理任何由 PyTorch 并行化的模块 FSDP/fully_shard,DDP/replicate,tensor_parallel/parallelize_module,以及任何 这些并行度的组合。`get_state_dict`的主要功能是: 包括:1.) 返回可以与不同数量训练师和/或不同并行度重新分配的模型和优化器状态字典。 2.) 隐藏特定于并行度的状态字典 API。用户无需调用 3.) 隐藏特定于并行度的状态字典 API。用户无需调用 这些 API。 3.) 对结果状态字典进行健全性检查。 结果状态字典的键是规范的全限定名(FQN)。规范的全限定名是指基于参数的 全限定名(FQN)。一个规范的全限定名是指基于参数的 在 nn.Module 层次结构中的位置。更具体地说,是一个规范的全限定名(FQN)。 参数是 ``module.named_parameters()`` 返回的 FQN 当模块未通过任何方式分发时,`module.named_buffers()` 并行性。由于优化器内部使用参数 ID 来表示 一个参数,将会有从参数 ID 到参数值的转换 调用此 API 时使用规范的全限定名称。 `get_state_dict` 也可以处理未并行化的模块。 此类情况下,`get_state_dict` 只执行一个功能——将状态字典转换 将优化器参数 ID 映射到规范的全限定名。 示例: >>> # xdoctest: +SKIP >>> 导入 torch >>> 从 torch.distributed.fsdp 导入 FullyShardedDataParallel 作为 FSDP >>> 从 torch.nn.parallel 导入 DistributedDataParallel 作为 DDP >>> 从 torch.distributed.checkpoint.state_dict 导入 get_state_dict >>> fsdp_model = FSDP(copy.deepcopy(model)) >>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3) >>> ddp_model = DDP(copy.deepcopy(model)) >>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3) >>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim) >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict( ... fsdp_model, fsdp_optim ... ) >>> # 如果我们直接调用 ddp_model.state_dict() 和 fsdp_model.state_dict(), >>> 断言将失败。 >>> 断言 ddp_state_dict 等于 fsdp_state_dict >>> 断言 ddp_optim_state 等于 fsdp_optim_state_dict 参数: model (nn.Module): 模型的 nn.Module 优化器(Union[None, Optimizer, Iterable[Optimizer]]): 用于优化 ``模型`` 的优化器。 子模块(已弃用):Optional[set[nn.Module]]: 仅返回属于子模块的模型参数 options (StateDictOptions):控制模型状态字典和优化器状态字典返回方式的选项。 模型状态字典和优化器状态字典应如何返回。请参阅`StateDictOptions`获取详细信息。 包含模型状态字典和优化器状态字典的`Tuple`。 返回: 包含模型状态字典和优化器状态字典的元组。 typing.Tuple[typing.Dict[str, ValueType], OptimizerStateType] "文档" _gc_context(): 优化器 = ( (优化器,) 如果 isinstance(优化器, 火炬.优化.优化器) 否则 元组(优化器) ) 信息 = _验证选项( 模型, 优化器, 仅优化=错误, 子模块=子模块, 选项=选项, ) 模型状态字典 = 获取模型状态字典(模型, 信息) 优化器状态字典 = _获取优化器状态字典(模型, 优化器, 信息) 验证状态字典(模型状态字典, 优化状态字典, 信息) 返回 模型状态字典, 优化状态字典
定义 展平模型状态字典( 模型: 神经网络.模块, state_dict: 联盟[字典[神经网络.模块, 字典[字符串, 值类型]], 字典[字符串, 值类型]], ) -> 字典[字符串, 值类型] 如果 state_dict: 返回 {} 如果 isinstance(下一(迭代(state_dict.())), 神经网络.模块): warnings.warn( "将 model_state_dict 传递为 ``Dict[nn.Module, Dict[str, Any]]`` 已弃用,将在 2.5 版本中删除。如果您需要此功能,请预处理 model_state_dict 以实现相同的功能。" "此功能,请预处理 model_state_dict 以实现相同的功能。" "功能,请预处理 model_state_dict 以实现相同的功能。" "功能,请预处理 model_state_dict 以实现相同的功能。", 未来警告, ) 遍历状态字典 = 角色(字典[神经网络.模块, 字典[字符串, 值类型]], state_dict) 新状态字典: 字典[字符串, 值类型] = {} 子模块, 子状态字典 遍历状态字典.项目(): 名称, m 模型.命名模块(): 如果 m != 子模块: continue fqns = 获取完全限定名(模型, 名称) 断言 len(fqns) == 1, 子模块的 FQNs 应只有一个元素 前缀 = f"{下一(迭代(fqns))} 新状态字典.更新( {前缀 + 子 fqn: value 子 fqn, value 子状态字典.项目()} ) 返回 新状态字典 else: 返回 角色(字典[字符串, 值类型], state_dict)
[文档]def 设置模型状态字典( model: nn.Module, model_state_dict: dict[str, ValueType], *, options: Optional[StateDictOptions] = None, ) -> _IncompatibleKeys: """加载模型的状态字典。 ``get_model_state_dict`` 的对应操作,用于将状态字典设置到模型中。详见 ``set_state_dict`` 的详细用法。 模型的状态字典设置方法。请参阅 ``set_state_dict`` 了解详细使用方法。 Args: 模型 (nn.Module):模型的 nn.Module。 model_state_dict: (Dict[str, ValueType]): 加载的模型状态字典。如果 ``model_state_dict`` 的键为 是 nn.Module,键是 `model` 的子模块,值应该是 子模块的状态字典。当加载状态字典时 子模块的前缀将被附加到 state_dict 中。 选项(StateDictOptions):控制如何的选项 模型状态字典和优化器状态字典应加载。查看 `StateDictOptions` 的详细信息。 返回: `NamedTuple` 包含 `missing_keys` 和 `unexpected_keys` 字段: * **missing_keys** 是一个包含缺失键的字符串列表 * **unexpected_keys** 是一个包含意外键的字符串列表 type model_state_dict: typing.Dict[str, ValueType] """ model_state_dict: dict[str, ValueType] = _unflatten_model_state_dict( 模型, 模型状态字典 ) with _gc_context(): info = _verify_options(model, (), optim_only=False, options=options) _verify_state_dict(model_state_dict, {}, info) return _load_model_state_dict(model, model_state_dict, info)
[文档]def set_optimizer_state_dict( model: nn.Module, optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], optim_state_dict: OptimizerStateType, *, options: 可选[StateDictOptions] = None, ) -> None: """加载优化器的状态字典。 ``get_optimizer_state_dict`` 的对应操作,用于将状态字典设置到 优化器。请参阅 `set_state_dict` 了解详细用法。 警告:`set_optimizer_state_dict` 只能在调用 `backward()` 或在调用 `step()` 优化器之后调用。否则,优化器状态将无法正确初始化。 优化器调用 `step()` 之后,否则优化器状态将无法正确初始化。 初始化正确。 Args: model (nn.Module): 模型对应的 nn.Module。 optimizers (Union[Optimizer, Iterable[Optimizer]]): 优化 ``model`` 所使用的优化器。 优化状态字典:优化器状态类型 加载优化器的状态字典。 选项(StateDictOptions):控制如何的选项 模型状态字典和优化器状态字典应加载。查看 `StateDictOptions`的详细信息。 返回: 无 type optim_state_dict: typing.OptimizerStateType """ with _gc_context(): 优化器 = ( (优化器,) 如果 isinstance(optimizers, torch.optim.Optimizer) 否则 tuple(optimizers) ) info = _verify_options(model, optimizers, optim_only=True, options=options) _verify_state_dict({}, optim_state_dict, info) _load_optim_state_dict(model, optimizers, optim_state_dict, info)
[文档]定义 set_state_dict( 模型: 神经网络.模块, optimizers: 联盟[火炬.优化.优化器, 迭代器[火炬.优化.优化器]], *, 模型状态字典: 字典[字符串, 值类型], 优化状态字典: 优化器状态类型, 选项: 可选[状态字典选项] = , ) -> _不兼容键: 加载模型状态字典和优化器状态字典。 ``get_state_dict``的对应操作是将状态字典设置到模型中 优化器。给定的 `model_state_dict` 和 `optim_state_dict` 不 必须由 `get_state_dict` 返回,但必须满足以下 requirements: 1) 所有 FQN 必须是按照`get_state_dict`中定义的规范 FQN, 2) 如果一个张量被分片,它必须是 ShardedTensor 或 DTensor, 3) 优化器 state_dict 不能包含参数 ID;键应该是 规范 FQN。 警告:`set_state_dict` 只能在 `backward()` 之前或 `step()` 之后调用 被调用在优化器上。否则,优化器状态将不会被初始化 正确。 参数: 模型(nn.Module):模型中的 nn.Module。 优化器(Union[Optimizer, Iterable[Optimizer]]): 用于优化“model”的优化器。 模型状态字典: (Union[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]) 加载模型的 state_dict。如果``model_state_dict``的键 是 nn.Module,键是`model`的关键子模块,值应该是 子模块的状态字典。在加载状态字典时, 子模块的前缀将被添加到状态字典中。 optim_state_dict: 优化器状态类型: 要加载的优化器状态字典。 options (StateDictOptions):控制如何加载模型状态字典和优化器状态字典的选项。 请参阅`StateDictOptions`获取详细信息。 包含`missing_keys`和`unexpected_keys`字段的`NamedTuple`: 返回: ``NamedTuple``包含`missing_keys`和`unexpected_keys`字段: * **missing_keys** 是一个包含模型状态_dict 中缺失键的字符串列表。 * **unexpected_keys** 是一个包含模型状态_dict 中意外键的字符串列表。 type model_state_dict: typing.Dict[str, ValueType] type optim_state_dict: typing.OptimizerStateType "文档" 模型状态字典: 字典[字符串, 值类型] = _unflatten_model_state_dict( 模型, model_state_dict ) _gc_context(): optimizers = ( (优化器,) 如果 isinstance(优化器, 火炬.优化.优化器) 否则 元组(优化器) ) 信息 = _验证选项( 模型, 优化器, 仅优化= 模型状态字典, 选项=选项 ) _验证状态字典(模型状态字典, 优化状态字典, 信息) _加载优化状态字典(模型, 优化器, 优化状态字典, 信息) 返回 加载模型状态字典(模型, 模型状态字典, 信息)
# TODO: 修正 state_dict 函数签名。 # TODO: 此 API 尚未完全测试。请将其设置为私有。 @no_type_check 定义 _修补模型状态字典( 模型: 神经网络.模块, *, 选项: 可选[状态字典选项] = , ) -> : 修复 `model` 的 `state_dict` 和 `load_state_dict` 属性。 修复 `model` 的 `state_dict` 和 `load_state_dict` 属性,使其成为一个部分函数,用于调用 `get_state_dict` 和 `set_state_dict`。 修复 `model` 的 `state_dict` 和 `load_state_dict` 属性,使其成为一个部分函数,用于调用 `get_state_dict` 和 `set_state_dict`。 示例: 从 torch.distributed.fsdp 导入 FullyShardedDataParallel 作为 FSDP。 从 torch.distributed.checkpoint.state_dict 导入 patch_model_state_dict model = fsdp(model) patch_model_state_dict(model) 参数: model (nn.Module): 模型对应的 nn.Module options (StateDictOptions):控制如何加载模型状态字典和优化器状态字典的选项。 请参阅`StateDictOptions`获取详细信息。 _state_dict_call 返回: "文档" _state_dict_call = functools.偏函数( 获取模型状态字典, 模型=模型, 选项=选项, ) 定义 调用状态字典(): 返回 _调用状态字典() 模型.状态字典 = 调用状态字典 加载状态字典调用 = functools.偏函数( 设置模型状态字典, 模型=模型, 选项=选项, ) 定义 加载状态字典调用(state_dict: 字典[字符串, 任何)] 加载状态字典调用(模型状态字典=state_dict) 模型.加载状态字典 = 加载状态字典调用 _修复状态字典.添加(状态字典调用) _修复状态字典.添加(加载状态字典调用) # TODO: 修正 load_state_dict 函数签名。 # TODO: 此 API 尚未完全测试。将其设为私有。 @no_type_check 定义 _修补优化器状态字典( 模型: 神经网络.模块, *, 优化器: 元组[火炬.优化.优化器, ...], 选项: 可选[状态字典选项] = , ) -> : 修复 ``optimizers`` 的 ``state_dict`` 和 ``load_state_dict`` 属性。 修复 ``optimizers`` 的 ``state_dict`` 和 ``load_state_dict`` 属性。 调用 `get_state_dict` 和 `set_state_dict` 的部分函数。 注意,如果有多个优化器,所有优化器都将被修复。 因此,用户只需要调用其中一个 state_dict() 即可获取完整结果。 示例: 从 torch.distributed.fsdp 导入 FullyShardedDataParallel 作为 FSDP。 从 torch.distributed.checkpoint.state_dict 导入 patch_model_state_dict model = fsdp(model) patch_model_state_dict(model) 参数: model (nn.Module): 模型对应的 nn.Module options (StateDictOptions):控制如何加载模型状态字典和优化器状态字典的选项。 请参阅`StateDictOptions`获取详细信息。 _state_dict_call 返回: "文档" _state_dict_call = functools.偏函数( 获取优化器状态字典, 模型=模型, 优化器=优化器, 选项=选项, ) 定义 调用状态字典(): 返回 _state_dict_call() _load_state_dict_call = functools.偏函数( set_optimizer_state_dict, 模型=模型, optimizers=优化器, 选项=选项, ) 定义 加载状态字典调用(state_dict: 字典[字符串, 任何)] _加载状态字典调用(优化状态字典=state_dict) _修复状态字典.添加(状态字典调用) _修复状态字典.添加(加载状态字典调用) 优化器 = ( (优化器,) 如果 isinstance(优化器, 火炬.优化.优化器) 否则 元组(优化器) ) 优化 优化器: 优化.状态字典 = 调用状态字典 优化.加载状态字典 = 加载状态字典调用

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

查看 PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源,获取您的疑问解答

查看资源