torch.jit.freeze¶
- torch.jit.freeze(mod, preserved_attrs=None, optimize_numerics=True)[source][source]¶
冻结 ScriptModule,内嵌子模块和属性为常量。
冻结一个
ScriptModule
将会克隆它,并尝试将克隆模块的子模块、参数和属性作为常量内联到 TorchScript IR 图。默认情况下,将保留 forward 以及保存在 preserved_attrs 中指定的属性和方法。此外,任何在保留方法中修改的属性也将被保留。冻结目前仅接受处于 eval 模式的 ScriptModules。
冻结会应用通用的优化,这将加快您的模型运行速度,无论在哪种机器上。为了进一步优化,使用特定服务器的设置,请在冻结后运行 optimize_for_inference。
- 参数:
mod (
ScriptModule
) – 要冻结的模块preserved_attrs (Optional[List[str]]) – 额外保留的属性列表(可选),除了前向方法外。在保留方法中修改的属性也将被保留。
optimize_numerics (bool) – 如果
True
,将运行一组不严格保留数值的优化过程。优化详情请参阅 torch.jit.run_frozen_optimizations。
- 返回值:
Frozen
ScriptModule
。
示例(冻结一个简单的模块带有参数):
def forward(self, input): output = self.weight.mm(input) output = self.linear(output) return output scripted_module = torch.jit.script(MyModule(2, 3).eval()) frozen_module = torch.jit.freeze(scripted_module) # parameters have been removed and inlined into the Graph as constants assert len(list(frozen_module.named_parameters())) == 0 # See the compiled graph as Python code print(frozen_module.code)
示例(冻结模块并保留属性)
def forward(self, input): self.modified_tensor += 1 return input + self.modified_tensor scripted_module = torch.jit.script(MyModule2().eval()) frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"]) # we've manually preserved `version`, so it still exists on the frozen module and can be modified assert frozen_module.version == 1 frozen_module.version = 2 # `modified_tensor` is detected as being mutated in the forward, so freezing preserves # it to retain model semantics assert frozen_module(torch.tensor(1)) == torch.tensor(12) # now that we've run it once, the next result will be incremented by one assert frozen_module(torch.tensor(1)) == torch.tensor(13)
注意
支持冻结子模块属性:frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=[“submodule.version”])
注意
如果您不确定为什么某个属性没有被内联为常量,您可以在 frozen_module.forward.graph 上运行 dump_alias_db 来查看是否冻结检测到该属性正在被修改。
注意
由于冻结将权重设置为常量并移除了模块层次结构,因此 to 和其他 nn.Module 方法来操作设备或数据类型不再起作用。作为解决方案,您可以在 torch.jit.load 中指定 map_location 来重新映射设备,但是模型中可能已经包含了特定于设备的逻辑。