推理模式 ¶
- 类 torch.autograd.grad_mode.inference_mode(mode=True)[source][source] ¶
上下文管理器,用于启用或禁用推理模式。
InferenceMode 是一个类似于
no_grad
的上下文管理器,用于当你确定你的操作将与 autograd 没有交互时(例如,模型训练)。在此模式下运行的代码通过禁用视图跟踪和版本计数器提升获得更好的性能。请注意,与一些其他机制不同,进入 inference_mode 不仅禁用 grad,还禁用了前向模式的 AD。此上下文管理器是线程局部化的;它不会影响其他线程的计算。
同时也是一个装饰器。
注意
推理模式是几个可以局部启用或禁用梯度的机制之一,有关它们如何比较的更多信息,请参阅“局部禁用梯度计算”。
- 参数:
模式(布尔值或函数)- 是否启用或禁用推理模式的布尔标志或用于启用推理模式的 Python 装饰函数
- 示例::
>>> import torch >>> x = torch.ones(1, 2, 3, requires_grad=True) >>> with torch.inference_mode(): ... y = x * x >>> y.requires_grad False >>> y._version Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: Inference tensors do not track version counter. >>> @torch.inference_mode() ... def func(x): ... return x * x >>> out = func(x) >>> out.requires_grad False >>> @torch.inference_mode() ... def doubler(x): ... return x * 2 >>> out = doubler(x) >>> out.requires_grad False