torch.jit.trace¶
- torch.jit.trace(func, example_inputs=None, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>, example_kwarg_inputs=None, _store_inputs=True)[source][source]¶
跟踪函数并返回一个可执行的或
ScriptFunction
,该代码将在即时编译中进行优化。追踪非常适合仅操作于
Tensor
\s 和列表、字典以及元组的Tensor
\s 的代码。使用 torch.jit.trace 和 torch.jit.trace_module,可以将现有的模块或 Python 函数转换为 TorchScript
ScriptFunction
或ScriptModule
。您必须提供示例输入,然后我们运行该函数,记录所有张量上执行的操作。独立函数生成的结果为 ScriptFunction。
nn.Module.forward 或 nn.Module 生成的结果为 ScriptModule。
此模块还包含原始模块的所有参数。
警告
仅跟踪正确记录不依赖于数据的函数和模块(例如,在张量中没有条件语句)以及没有任何未跟踪的外部依赖(例如,执行输入/输出或访问全局变量)。跟踪仅记录在给定函数在给定张量上运行时进行的操作。因此,返回的 ScriptModule 将在任何输入上始终运行相同的跟踪图。当您的模块需要根据输入和/或模块状态运行不同的操作集时,这有一些重要的含义。例如,
跟踪不会记录任何控制流,如 if 语句或循环。当这种控制流在您的模块中是常数时,这是可以的,并且通常会内联控制流决策。但有时控制流实际上是模型本身的一部分。例如,循环神经网络是对输入序列长度(可能是动态的)的循环。
在返回的
ScriptModule
中,具有不同行为的training
和eval
模式的操作将始终表现为在跟踪期间的模式,无论 ScriptModule 处于何种模式。
在这种情况下,跟踪可能不合适,而
scripting
是更好的选择。如果您跟踪此类模型,可能会在模型后续调用时静默地得到错误的结果。跟踪器会在执行可能导致错误跟踪的操作时尝试发出警告。- 参数:
func (callable or torch.nn.Module) – 一个 Python 函数或 torch.nn.Module,它将使用 example_inputs 运行。func 的参数和返回值必须是张量或(可能嵌套的)包含张量的元组。当传递 torch.jit.trace 模块时,仅运行并跟踪
forward
方法(有关详细信息,请参阅torch.jit.trace
)。- 关键字参数:
example_inputs (tuple or torch.Tensor or None, optional) – 一个示例输入元组,将在跟踪函数时传递给该函数。默认:
None
。此参数或example_kwarg_inputs
应指定。假设跟踪的操作支持这些类型和形状,则结果跟踪可以与不同类型和形状的输入运行。example_inputs 也可以是一个单独的张量,在这种情况下,它将自动包装在元组中。当值为 None 时,应指定example_kwarg_inputs
。check_trace (
bool
,可选) – 检查通过跟踪代码运行的相同输入是否产生相同的输出。默认:True
。如果您,例如,您的网络包含非确定性操作或您确信网络即使在检查器失败的情况下也是正确的,则可能希望禁用此功能。check_inputs (元组列表,可选) – 应用于检查跟踪与预期结果的元组列表。每个元组相当于在
example_inputs
中指定的输入参数集。为了获得最佳结果,请传递一组代表性检查输入,这些输入代表您期望网络看到的形状和类型输入空间。如果未指定,则使用原始的example_inputs
进行检查。check_tolerance (浮点数,可选) – 在检查程序中使用的浮点数比较容差。这可以在结果因已知原因(例如操作融合)而数值上偏离时放松检查器的严格性。
严格模式(
bool
,可选)- 以严格模式运行跟踪器或否(默认:True
)。只有在你想要跟踪器记录你的可变容器类型(目前为list
/dict
)并且你确定你在问题中使用的容器是一个constant
结构并且不会用作控制流(if,for)条件时,才关闭此选项。example_kwarg_inputs(dict,可选)- 此参数是一组示例输入的关键字参数的打包,在跟踪函数时将传递给该参数。默认:
None
。此参数或example_inputs
必须指定一个。字典将通过跟踪函数的参数名称进行解包。如果字典的键与跟踪函数的参数名称不匹配,将引发运行时异常。
- 返回值:
如果 func 是 nn.Module 或 nn.Module 的
forward
,则跟踪返回一个包含单个forward
方法的ScriptModule
对象,该方法包含跟踪的代码。返回的 ScriptModule 将具有与原始nn.Module
相同的子模块和参数集。如果func
是一个独立函数,则trace
返回 ScriptFunction。
示例(跟踪一个函数):
import torch def foo(x, y): return 2 * x + y # Run `foo` with the provided inputs and record the tensor operations traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3))) # `traced_foo` can now be run with the TorchScript interpreter or saved # and loaded in a Python-free environment
沉浸式翻译(追踪现有模块示例):
import torch import torch.nn as nn class Net(nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d(1, 1, 3) def forward(self, x): return self.conv(x) n = Net() example_weight = torch.rand(1, 1, 3, 3) example_forward_input = torch.rand(1, 1, 3, 3) # Trace a specific method and construct `ScriptModule` with # a single `forward` method module = torch.jit.trace(n.forward, example_forward_input) # Trace a module (implicitly traces `forward`) and construct a # `ScriptModule` with a single `forward` method module = torch.jit.trace(n, example_forward_input)