torch.cuda.amp.grad_scaler 源代码
从 typing_extensions 导入 deprecated,导入 torch # 我们需要保留这个未使用的导入以保持向后兼容性,从 torch.amp.grad_scaler 导入 OptState # noqa: F401 __all__ = ["GradScaler"]
[文档]class GradScaler(torch.amp.GradScaler):
r"""
查看 :class:`torch.amp.GradScaler`.
``torch.cuda.amp.GradScaler(args...)`` 已弃用。请使用 ``torch.amp.GradScaler("cuda", args...)`` 代替。
"""
@deprecated(
"torch.cuda.amp.GradScaler(args...) 已弃用。"
"请使用 `torch.amp.GradScaler('cuda', args...)` 代替。"
category=FutureWarning,
)
def __init__(
self,
init_scale: 浮点数 = 2.0**16,
growth_factor: 浮点数 = 2.0,
backoff_factor: 浮点数 = 0.5,
growth_interval: 整数 = 2000,
enabled: 布尔 = True,
) -> 无,
super().__init__(
"cuda",
init_scale=init_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
enabled=enabled,
)