• 文档 >
  • torch.optim
快捷键

torch.optim

torch.optim 是一个实现各种优化算法的包。

最常用的方法已经得到支持,并且接口足够通用,因此未来也可以轻松集成更复杂的方法。

如何使用优化器

要使用 torch.optim ,您需要构建一个优化器对象,该对象将保存当前状态并根据计算出的梯度更新参数。

构建它

构建一个 Optimizer ,您需要提供一个包含参数的可迭代对象(所有参数应为 Parameter )或命名参数((str, Parameter )的元组)以进行优化。然后,您可以指定优化器特定的选项,如学习率、权重衰减等。

示例:

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001)

命名参数示例:

optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([('layer0', var1), ('layer1', var2)], lr=0.0001)

每个参数选项

Optimizer 也支持指定每个参数的选项。为此,而不是传递一个 Variable 的可迭代对象,传递一个 dict 的可迭代对象。每个都将定义一个单独的参数组,并且应该包含一个 params 键,其中包含属于该组的参数列表。其他键应与优化器接受的键匹配,并将用作此组的优化选项。

例如,当想要指定每层的学习率时,这非常有用:

optim.SGD([
                {'params': model.base.parameters(), 'lr': 1e-2},
                {'params': model.classifier.parameters()}
            ], lr=1e-3, momentum=0.9)

optim.SGD([
                {'params': model.base.named_parameters(), 'lr': 1e-2},
                {'params': model.classifier.named_parameters()}
            ], lr=1e-3, momentum=0.9)

这意味着 model.base 的参数将使用 1e-2 的学习率,而 model.classifier 的参数将坚持使用 1e-3 的默认学习率。最后,所有参数将使用 0.9 的动量。

注意

您仍然可以通过关键字参数传递选项。它们将被用作默认值,在未覆盖它们的组中。当您只想更改单个选项,同时保持所有其他参数组之间的一致性时,这很有用。

还请考虑以下与参数不同惩罚相关的示例。请记住, parameters() 返回一个包含所有可学习参数的可迭代对象,包括偏差和其他可能需要不同惩罚的参数。为了解决这个问题,可以为每个参数组指定单独的惩罚权重:

bias_params = [p for name, p in self.named_parameters() if 'bias' in name]
others = [p for name, p in self.named_parameters() if 'bias' not in name]

optim.SGD([
                {'params': others},
                {'params': bias_params, 'weight_decay': 0}
            ], weight_decay=1e-2, lr=1e-2)

以这种方式,偏差项与非偏差项被隔离,并专门为偏差项设置 weight_decay0 ,以避免对该组进行任何惩罚。

进行优化步骤

所有优化器都实现了一个 step() 方法,用于更新参数。它可以有两种使用方式:

optimizer.step()

这是一个大多数优化器都支持的简化版本。在计算梯度后,例如使用 backward() ,可以调用此函数。

示例:

for input, target in dataset:
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()

optimizer.step(closure)

一些优化算法,如共轭梯度法和 LBFGS,需要多次重新评估函数,因此您必须传递一个闭包,以便它们重新计算您的模型。该闭包应清除梯度、计算损失并返回它。

示例:

for input, target in dataset:
    def closure():
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        return loss
    optimizer.step(closure)

基类 ¶

类 torch.optim.Optimizer(params, defaults)[source][source] ¶

所有优化器的基类。

警告

参数需要指定为具有确定性顺序且在运行之间一致的集合。不满足这些属性的示例对象包括集合和遍历字典值的迭代器。

参数:
  • params (可迭代对象) – 一个可迭代的 torch.Tensor s 或 dict s。指定了哪些张量应该被优化。

  • defaults (字典[str, Any]) – (字典):包含优化选项默认值的字典(当参数组未指定时使用)。

Optimizer.add_param_group

Optimizer s 的 param_groups 中添加一个参数组。

Optimizer.load_state_dict

加载优化器状态。

Optimizer.register_load_state_dict_pre_hook

注册一个在调用 load_state_dict() 之前被调用的 load_state_dict 预钩子。它应该具有以下签名::.

Optimizer.register_load_state_dict_post_hook

注册一个在调用 load_state_dict() 之后被调用的 load_state_dict 后钩子。它应该具有以下签名::.

Optimizer.state_dict

返回优化器的状态作为 dict

Optimizer.register_state_dict_pre_hook

注册一个在调用 state_dict() 之前被调用的状态字典预钩子。

Optimizer.register_state_dict_post_hook

注册一个状态字典后钩子,该钩子将在调用 state_dict() 之后被调用。

Optimizer.step

执行单个优化步骤以更新参数。

Optimizer.register_step_pre_hook

注册一个优化步骤前钩子,该钩子将在优化步骤之前被调用。

Optimizer.register_step_post_hook

注册一个优化步骤后钩子,该钩子将在优化步骤之后被调用。

Optimizer.zero_grad

重置所有优化 torch.Tensor 的梯度。

算法 ¶

Adadelta

实现 Adadelta 算法。

Adafactor

实现 Adafactor 算法。

Adagrad

实现 Adagrad 算法。

Adam

实现 Adam 算法。

AdamW

实现 AdamW 算法,其中权重衰减不会累积在动量中,也不会累积在方差中。

SparseAdam

SparseAdam 实现了一个适用于稀疏梯度的 Adam 算法的掩码版本。

Adamax

实现 Adamax 算法(基于无穷范数的 Adam 算法变体)。

ASGD

实现平均随机梯度下降法。

LBFGS

实现 L-BFGS 算法。

NAdam

实现 NAdam 算法。

RAdam

实现了 RAdam 算法。

RMSprop

实现了 RMSprop 算法。

Rprop

实现了鲁棒反向传播算法。

SGD

实现了随机梯度下降(可选带动量)。

我们的大多数算法都有针对性能、可读性和/或通用性优化的各种实现,因此如果没有用户指定特定的实现,我们将默认使用当前设备上通常最快的实现。

我们有三种主要的实现类别:for 循环、foreach(多张量)和融合。最直接的实施方式是使用 for 循环遍历参数,进行大量计算。for 循环通常比我们的 foreach 实现慢,因为 foreach 将参数组合成多张量,一次性运行大量计算,从而节省了许多顺序内核调用。我们的一些优化器甚至有更快的融合实现,它们将大量计算融合成一个内核。我们可以将 foreach 实现视为水平融合,将融合实现视为在此基础上垂直融合。

通常情况下,三种实现的性能排序为融合 > foreach > for-loop。因此,在适用的情况下,我们默认使用 foreach 而不是 for-loop。适用意味着 foreach 实现可用,用户未指定任何特定实现的 kwargs(例如,融合、foreach、可微分),并且所有张量都是原生的。请注意,虽然融合应该比 foreach 更快,但这些实现较新,我们希望在全局范围内切换之前给予它们更多的测试时间。下表总结了每个实现的稳定性状态,欢迎您尝试使用它们!

下表显示了每个算法的可用和默认实现:

算法

默认

有 foreach 吗?

有融合吗?

Adadelta

foreach

是的

Adafactor

for 循环

Adagrad

foreach

是的

是的(仅 CPU)

Adam

foreach

是的

是的

AdamW

foreach

是的

是的

SparseAdam

for 循环

Adamax

foreach

是的

ASGD

foreach

是的

LBFGS

for 循环

NAdam

foreach

是的

RAdam

foreach

是的

RMSprop

foreach

是的

Rprop

foreach

是的

SGD

foreach

是的

是的

以下表格显示了融合实现的稳定性状态:

算法

CPU

CUDA

MPS

Adadelta

不支持

不支持

不支持

Adafactor

不支持

不支持

不支持

Adagrad

测试版

不支持

不支持

Adam

测试版

稳定版

测试版

AdamW

测试版

稳定

测试版

SparseAdam

不支持

不支持

不支持

Adamax

不支持

不支持

不支持

ASGD

不支持

不支持

不支持

LBFGS

不支持

不支持

不支持

NAdam

不支持

不支持

不支持

RAdam

不支持

不支持

不支持

RMSprop

不支持

不支持

不支持

Rprop

不支持

不支持

不支持

SGD

测试版

测试版

测试版

如何调整学习率 ¶

torch.optim.lr_scheduler.LRScheduler 提供了多种方法根据训练轮数调整学习率。 torch.optim.lr_scheduler.ReduceLROnPlateau 允许根据某些验证度量动态降低学习率。

应在学习率调度器更新后应用学习率调度;例如,你应该这样编写你的代码:

示例:

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = ExponentialLR(optimizer, gamma=0.9)

for epoch in range(20):
    for input, target in dataset:
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
    scheduler.step()

大多数学习率调度器可以连续调用(也称为调度器链式)。结果是每个调度器都依次应用于前一个调度器获得的学习率。

示例:

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler1 = ExponentialLR(optimizer, gamma=0.9)
scheduler2 = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)

for epoch in range(20):
    for input, target in dataset:
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
    scheduler1.step()
    scheduler2.step()

在文档的许多地方,我们将使用以下模板来引用调度器算法。

>>> scheduler = ...
>>> for epoch in range(100):
>>>     train(...)
>>>     validate(...)
>>>     scheduler.step()

警告

在 PyTorch 1.1.0 之前,学习率调度器应该在优化器更新之前调用;1.1.0 的更改以 BC 兼容性破坏的方式改变了这种行为。如果您在优化器更新(调用 optimizer.step() )之前使用学习率调度器(调用 scheduler.step() ),这将跳过学习率调度表的第一项。如果您在升级到 PyTorch 1.1.0 后无法重现结果,请检查您是否在错误的时间调用了 scheduler.step()

lr_scheduler.LRScheduler

在优化过程中调整学习率。

lr_scheduler.LambdaLR

设置初始学习率。

lr_scheduler.MultiplicativeLR

将每个参数组的学习率乘以指定函数中给出的因子。

lr_scheduler.StepLR

每经过 step_size 个 epoch,将每个参数组的学习率衰减为 gamma。

lr_scheduler.MultiStepLR

当 epoch 数达到里程碑之一时,将每个参数组的学习率衰减为 gamma。

lr_scheduler.ConstantLR

将每个参数组的学习率乘以一个小的常数因子。

lr_scheduler.LinearLR

通过线性改变小的乘性因子来衰减每个参数组的学习率。

lr_scheduler.ExponentialLR

每个 epoch 衰减每个参数组的学习率,衰减因子为伽马。

lr_scheduler.PolynomialLR

使用给定总迭代次数中的多项式函数衰减每个参数组的学习率。

lr_scheduler.CosineAnnealingLR

使用余弦退火计划设置每个参数组的学习率。

lr_scheduler.ChainedScheduler

连接一系列学习率调度器。

lr_scheduler.SequentialLR

包含在优化过程中按顺序调用的调度器列表。

lr_scheduler.ReduceLROnPlateau

当指标停止改进时降低学习率。

lr_scheduler.CyclicLR

根据周期性学习率策略(CLR)设置每个参数组的 学习率。

lr_scheduler.OneCycleLR

根据 1 周期学习率策略设置每个参数组的学习率。

lr_scheduler.CosineAnnealingWarmRestarts

使用余弦退火计划设置每个参数组的学习率。

如何使用命名参数来加载优化器状态字典。

函数 load_state_dict() 存储从加载的状态字典中可选的 param_names 内容(如果存在)。然而,加载优化器状态的过程不受影响,因为参数的顺序很重要,以保持兼容性(以防顺序不同)。为了利用从加载的状态字典中加载的参数名称,需要根据期望的行为实现自定义的 register_load_state_dict_pre_hook

这在模型架构发生变化,但权重和优化器状态需要保持不变的情况下很有用。以下示例演示了如何实现这种自定义。

示例:

class OneLayerModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(3, 4)

    def forward(self, x):
        return self.fc(x)

model = OneLayerModel()
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
# training..
torch.save(optimizer.state_dict(), PATH)

假设 model 实现了一个专家(MoE),我们想要复制它并为两个专家恢复训练,这两个专家的初始化方式与 fc 层相同。对于以下 model2 ,我们创建了两个与 fc 相同的层,并通过从 model 加载模型权重和优化器状态到 fc1fc2 中(并相应调整它们)来恢复训练:

class TwoLayerModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(3, 4)
        self.fc2 = nn.Linear(3, 4)

    def forward(self, x):
        return (self.fc1(x) + self.fc2(x)) / 2

model2 = TwoLayerModel()
# adapt and load model weights..
optimizer2 = optim.SGD(model2.named_parameters(), lr=0.01, momentum=0.9)

要加载 optimizer2 的状态字典,使其与先前优化器状态的状态字典相同,以便 fc1fc2 都将使用 fc 优化器状态的副本进行初始化(以从 fc 恢复每个层的训练),我们可以使用以下钩子:

def adapt_state_dict_ids(optimizer, state_dict):
    adapted_state_dict = deepcopy(optimizer.state_dict())
    # Copy setup parameters (lr, weight_decay, etc.), in case they differ in the loaded state dict.
    for k, v in state_dict['param_groups'][0].items():
        if k not in ['params', 'param_names']:
            adapted_state_dict['param_groups'][0][k] = v

    lookup_dict = {
        'fc1.weight': 'fc.weight',
        'fc1.bias': 'fc.bias',
        'fc2.weight': 'fc.weight',
        'fc2.bias': 'fc.bias'
    }
    clone_deepcopy = lambda d: {k: (v.clone() if isinstance(v, torch.Tensor) else deepcopy(v)) for k, v in d.items()}
    for param_id, param_name in zip(
            optimizer.state_dict()['param_groups'][0]['params'],
            optimizer.state_dict()['param_groups'][0]['param_names']):
        name_in_loaded = lookup_dict[param_name]
        index_in_loaded_list = state_dict['param_groups'][0]['param_names'].index(name_in_loaded)
        id_in_loaded = state_dict['param_groups'][0]['params'][index_in_loaded_list]
        # Copy the state of the corresponding parameter
        if id_in_loaded in state_dict['state']:
            adapted_state_dict['state'][param_id] = clone_deepcopy(state_dict['state'][id_in_loaded])

    return adapted_state_dict

optimizer2.register_load_state_dict_pre_hook(adapt_state_dict_ids)
optimizer2.load_state_dict(torch.load(PATH)) # The previous optimizer saved state_dict

这确保了在模型加载过程中将使用适应了正确状态的 model2 层的 state_dict。请注意,此代码专门为此示例设计(例如,假设单个参数组),其他情况可能需要不同的适应。

以下示例展示了在模型结构发生变化时如何处理加载的 state dict 中缺失的参数。 Model_bypass 添加了一个新的 bypass 层,该层在原始的 Model1 中不存在。为了继续训练,使用自定义的 adapt_state_dict_missing_param 钩子来适应优化器的 state_dict ,确保现有参数正确映射,而缺失的参数(如旁路层)保持不变(如本例中初始化的那样)。这种方法即使在模型发生变化的情况下也能实现优化器状态的平滑加载和恢复。新的旁路层将从零开始训练:

class Model1(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(5, 5)

    def forward(self, x):
        return self.fc(x) + x


model = Model1()
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
# training..
torch.save(optimizer.state_dict(), PATH)

class Model_bypass(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(5, 5)
        self.bypass = nn.Linear(5, 5, bias=False)
        torch.nn.init.eye_(self.bypass.weight)

    def forward(self, x):
        return self.fc(x) + self.bypass(x)

model2 = Model_bypass()
optimizer2 = optim.SGD(model2.named_parameters(), lr=0.01, momentum=0.9)

def adapt_state_dict_missing_param(optimizer, state_dict):
    adapted_state_dict = deepcopy(optimizer.state_dict())
    # Copy setup parameters (lr, weight_decay, etc.), in case they differ in the loaded state dict.
    for k, v in state_dict['param_groups'][0].items():
        if k not in ['params', 'param_names']:
            adapted_state_dict['param_groups'][0][k] = v

    lookup_dict = {
        'fc.weight': 'fc.weight',
        'fc.bias': 'fc.bias',
        'bypass.weight': None,
    }

    clone_deepcopy = lambda d: {k: (v.clone() if isinstance(v, torch.Tensor) else deepcopy(v)) for k, v in d.items()}
    for param_id, param_name in zip(
            optimizer.state_dict()['param_groups'][0]['params'],
            optimizer.state_dict()['param_groups'][0]['param_names']):
        name_in_loaded = lookup_dict[param_name]
        if name_in_loaded in state_dict['param_groups'][0]['param_names']:
            index_in_loaded_list = state_dict['param_groups'][0]['param_names'].index(name_in_loaded)
            id_in_loaded = state_dict['param_groups'][0]['params'][index_in_loaded_list]
            # Copy the state of the corresponding parameter
            if id_in_loaded in state_dict['state']:
                adapted_state_dict['state'][param_id] = clone_deepcopy(state_dict['state'][id_in_loaded])

    return adapted_state_dict

optimizer2.register_load_state_dict_pre_hook(adapt_state_dict_ids)
optimizer2.load_state_dict(torch.load(PATH)) # The previous optimizer saved state_dict

作为第三个示例,而不是根据参数的顺序(默认方法)加载状态,此钩子可以用来根据参数的名称加载:

def names_matching(optimizer, state_dict):
    assert len(state_dict['param_groups']) == len(optimizer.state_dict()['param_groups'])
    adapted_state_dict = deepcopy(optimizer.state_dict())
    for g_ind in range(len(state_dict['param_groups'])):
        assert len(state_dict['param_groups'][g_ind]['params']) == len(
            optimizer.state_dict()['param_groups'][g_ind]['params'])

        for k, v in state_dict['param_groups'][g_ind].items():
            if k not in ['params', 'param_names']:
                adapted_state_dict['param_groups'][g_ind][k] = v

        for param_id, param_name in zip(
                optimizer.state_dict()['param_groups'][g_ind]['params'],
                optimizer.state_dict()['param_groups'][g_ind]['param_names']):
            index_in_loaded_list = state_dict['param_groups'][g_ind]['param_names'].index(param_name)
            id_in_loaded = state_dict['param_groups'][g_ind]['params'][index_in_loaded_list]
            # Copy the state of the corresponding parameter
            if id_in_loaded in state_dict['state']:
                adapted_state_dict['state'][param_id] = deepcopy(state_dict['state'][id_in_loaded])

    return adapted_state_dict

权重平均(SWA 和 EMA)

实现了随机权重平均(SWA)和指数移动平均(EMA),①实现了 SWA 学习率调度器,②是一个在训练结束时更新 SWA/EMA 批归一化统计信息的实用函数。

随机权重平均(SWA)在《平均权重可导致更宽的优化范围和更好的泛化》一文中被提出。

指数移动平均(EMA)是一种广为人知的通过减少所需的权重更新次数来减少训练时间的技巧。它是 Polyak 平均的一种变体,但使用指数权重而不是迭代中的等权重。

构建平均模型

AveragedModel 类用于计算 SWA 或 EMA 模型的权重。

您可以通过以下命令创建 SWA 平均模型:

>>> averaged_model = AveragedModel(model)

EMA 模型通过指定 multi_avg_fn 参数来构建:

>>> decay = 0.999
>>> averaged_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(decay))

衰减参数介于 0 和 1 之间,用于控制平均参数的衰减速度。如果不提供 torch.optim.swa_utils.get_ema_multi_avg_fn() ,则默认为 0.999。衰减值应接近 1.0,因为较小的值可能导致优化收敛问题。

返回一个函数,该函数应用以下 EMA 方程来调整权重:

Wt+1EMA=αWtEMA+(1α)WtmodelW^\textrm{EMA}_{t+1} = \alpha W^\textrm{EMA}_{t} + (1 - \alpha) W^\textrm{model}_t

其中 alpha 是 EMA 衰减系数。

在这里,模型 model 可以是任意 torch.nn.Module 对象。 averaged_model 将跟踪 model 参数的运行平均值。要更新这些平均值,应在 optimizer.step()之后使用 update_parameters() 函数:

>>> averaged_model.update_parameters(model)

对于 SWA 和 EMA,这个调用通常在 optimizer step() 之后进行。在 SWA 的情况下,通常在训练开始的前几个步骤中跳过这一步。

自定义平均策略

默认情况下, torch.optim.swa_utils.AveragedModel 计算您提供的参数的运行等权平均,但您也可以使用自定义平均函数通过 avg_fnmulti_avg_fn 参数:

  • avg_fn 允许定义一个函数,该函数作用于每个参数元组(平均参数,模型参数),并应返回新的平均参数。

  • multi_avg_fn 允许定义更高效的运算,作用于参数列表元组(平均参数列表,模型参数列表),同时进行,例如使用 torch._foreach* 函数。此函数必须就地更新平均参数。

在以下示例中 ema_model 使用 avg_fn 参数计算指数移动平均:

>>> ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged:\
>>>         0.9 * averaged_model_parameter + 0.1 * model_parameter
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg)

在以下示例中 ema_model 使用更高效的 multi_avg_fn 参数计算指数移动平均:

>>> ema_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(0.9))

SWA 学习率调度方案

通常,在 SWA 中,学习率被设置为高常数值。 SWALR 是一个将学习率逐渐调整到固定值并保持恒定的学习率调度器。例如,以下代码创建了一个调度器,该调度器在每个参数组中将学习率从初始值线性调整到 0.05,需要 5 个 epoch:

>>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, \
>>>         anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05)

您也可以通过设置 anneal_strategy="cos" 来使用余弦退火而不是线性退火。

注意批量归一化(Batch Normalization)。

update_bn() 是一个实用函数,允许在给定数据加载器 loader 的末尾计算 SWA 模型的批量归一化统计信息:

>>> torch.optim.swa_utils.update_bn(loader, swa_model)

update_bn()swa_model 应用于数据加载器中的每个元素,并计算模型中每个批归一化层的激活统计信息。

警告

update_bn() 假设数据加载器 loader 中的每个批次要么是张量,要么是张量的列表,其中第一个元素是网络 swa_model 应该应用到的张量。如果您的数据加载器结构不同,您可以通过在每个数据集元素上执行 swa_model 的前向传递来更新 swa_model 的批量归一化统计信息。

将所有内容整合:SWA ¶

在下面的示例中, swa_model 是累积权重平均值的 SWA 模型。我们对该模型进行总共 300 个周期的训练,并在第 160 个周期切换到 SWA 学习率计划,并开始收集 SWA 参数平均值。

>>> loader, optimizer, model, loss_fn = ...
>>> swa_model = torch.optim.swa_utils.AveragedModel(model)
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)
>>> swa_start = 160
>>> swa_scheduler = SWALR(optimizer, swa_lr=0.05)
>>>
>>> for epoch in range(300):
>>>       for input, target in loader:
>>>           optimizer.zero_grad()
>>>           loss_fn(model(input), target).backward()
>>>           optimizer.step()
>>>       if epoch > swa_start:
>>>           swa_model.update_parameters(model)
>>>           swa_scheduler.step()
>>>       else:
>>>           scheduler.step()
>>>
>>> # Update bn statistics for the swa_model at the end
>>> torch.optim.swa_utils.update_bn(loader, swa_model)
>>> # Use swa_model to make predictions on test data
>>> preds = swa_model(test_input)

将所有内容整合:EMA ¶

在下面的示例中, ema_model 是累积以 0.999 的衰减率指数衰减的权重平均值的 EMA 模型。我们对该模型进行总共 300 个周期的训练,并立即开始收集 EMA 平均值。

>>> loader, optimizer, model, loss_fn = ...
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, \
>>>             multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999))
>>>
>>> for epoch in range(300):
>>>       for input, target in loader:
>>>           optimizer.zero_grad()
>>>           loss_fn(model(input), target).backward()
>>>           optimizer.step()
>>>           ema_model.update_parameters(model)
>>>
>>> # Update bn statistics for the ema_model at the end
>>> torch.optim.swa_utils.update_bn(loader, ema_model)
>>> # Use ema_model to make predictions on test data
>>> preds = ema_model(test_input)

swa_utils.AveragedModel

实现了随机权重平均(SWA)和指数移动平均(EMA)的平均模型。

swa_utils.SWALR

将每个参数组的学习率调整到固定值。

torch.optim.swa_utils.get_ema_multi_avg_fn(decay=0.999)[源代码][源代码]

获取应用于多个参数的指数移动平均(EMA)函数。

torch.optim.swa_utils.update_bn(loader, model, device=None)[source][source]

更新模型中的 BatchNorm 运行均值、运行方差缓冲区。

它在 loader 中的数据上执行一次遍历,以估计模型中 BatchNorm 层的激活统计信息。

参数:
  • loader (torch.utils.data.DataLoader) – 用于计算激活统计数据的数据集加载器。每个数据批次应为一个张量,或者是一个列表/元组,其中第一个元素是一个包含数据的张量。

  • model (torch.nn.Module) – 我们要更新 BatchNorm 统计信息的模型。

  • device (torch.device, 可选) – 如果设置,数据将在传递给 model 之前被转移到 device

示例

>>> loader, model = ...
>>> torch.optim.swa_utils.update_bn(loader, model)

注意

The update_bn 工具假定 loader 中的每个数据批次都是一个张量,或者是一个列表或元组中的张量;在后一种情况下,假定应对数据批次对应的列表或元组的第一个元素调用 model.forward()


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

PyTorch 开发者文档全面访问

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源