SWALR¶
- class torch.optim.swa_utils.SWALR(optimizer, swa_lr, anneal_epochs=10, anneal_strategy='cos', last_epoch=- 1)[source][source]¶
将每个参数组的学习率逐步调整到固定值。
此学习率调度器旨在与随机权重平均(SWA)方法(参见 torch.optim.swa_utils.AveragedModel)一起使用。
- 参数:
优化器(torch.optim.Optimizer)- 包装的优化器
swa_lrs(浮点数或列表)- 所有参数组一起或分别针对每个组的学习率值。
annealing_epochs(整数)- 退火阶段的 epoch 数量(默认:10)
annealing_strategy (str) – “余弦”或“线性”;指定退火策略:“余弦”表示余弦退火,“线性”表示线性退火(默认:“余弦”)
last_epoch (int) – 最后一个 epoch 的索引(默认:-1)
可以使用
SWALR
调度器与其他调度器一起使用,在训练后期切换到恒定学习率,如下例所示。示例
>>> loader, optimizer, model = ... >>> lr_lambda = lambda epoch: 0.9 >>> scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, >>> lr_lambda=lr_lambda) >>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, >>> anneal_strategy="linear", anneal_epochs=20, swa_lr=0.05) >>> swa_start = 160 >>> for i in range(300): >>> for input, target in loader: >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step() >>> if i > swa_start: >>> swa_scheduler.step() >>> else: >>> scheduler.step()