torch.optim.Optimizer.state_dict()
- Optimizer.state_dict()[source][source]()
返回优化器的状态作为
dict
。它包含两个条目:
-
state
:一个字典,包含当前的优化状态。其内容
优化器类之间有所不同,但有一些共同特征。例如,状态是按参数保存的,而参数本身不保存。state
是一个字典,将参数 ID 映射到与每个参数对应的状态的字典。
-
- 包含所有参数组的列表
参数组是一个字典。每个参数组包含特定于优化器的元数据,例如学习率和权重衰减,以及组中参数的 ID 列表。如果参数组使用named_parameters()
初始化,则名称内容也将保存在状态字典中。
注意:参数 ID 看起来像索引,但它们只是将状态与 param_group 关联的 ID。在从 state_dict 加载时,优化器将 param_group
params
(整数 ID)和优化器param_groups
(实际的nn.Parameter
s)进行 zip,以便匹配状态,无需额外验证。返回的状态字典可能看起来像这样:
{ 'state': { 0: {'momentum_buffer': tensor(...), ...}, 1: {'momentum_buffer': tensor(...), ...}, 2: {'momentum_buffer': tensor(...), ...}, 3: {'momentum_buffer': tensor(...), ...} }, 'param_groups': [ { 'lr': 0.01, 'weight_decay': 0, ... 'params': [0] 'param_names' ['param0'] (optional) }, { 'lr': 0.001, 'weight_decay': 0.5, ... 'params': [1, 2, 3] 'param_names': ['param1', 'layer.weight', 'layer.bias'] (optional) } ] }
- 返回类型:
dict[str, Any]