快捷键

torch.jit.trace_module

torch.jit.trace_module(mod, inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>, example_inputs_is_kwarg=False, _store_inputs=True)[source][source]

跟踪一个模块并返回一个可执行的 ScriptModule ,该可执行文件将使用即时编译进行优化。

当将模块传递给 torch.jit.trace 时,仅运行并跟踪 forward 方法。使用 trace_module ,您可以指定一个包含方法名称到示例输入的字典以进行跟踪(见下面的 inputs 参数)。

查看更多关于跟踪的信息 torch.jit.trace

参数:
  • mod (torch.nn.Module) – 一个包含指定在 inputs 中的方法名称的 torch.nn.Module 。给定的方法将被编译为一个单一的 ScriptModule 的一部分。

  • inputs (dict) – 一个字典,包含按 mod 中的方法名称索引的样本输入。在跟踪过程中,这些输入将被传递到与输入键对应的方法。 { 'forward' : example_forward_input, 'method2': example_method2_input}

关键字参数:
  • check_trace ( bool , 可选) – 检查相同的输入在经过跟踪的代码运行后是否产生相同的输出。默认: True 。例如,如果你的网络中包含非确定性操作,或者你确信网络即使在检查器失败的情况下也是正确的,你可能想禁用此功能。

  • check_inputs (列表中的字典,可选) – 一系列字典形式的输入参数列表,用于将跟踪结果与预期结果进行比较。每个元组相当于在 inputs 中指定的一组输入参数。为了获得最佳结果,请传入一组代表性的检查输入,这些输入应代表网络预期看到的形状和类型的空间。如果没有指定,则使用原始的 inputs 进行检查。

  • check_tolerance (浮点数,可选) – 在检查程序中使用的浮点数比较容差。这可以在结果因已知原因(如操作融合)而数值上偏离时放松检查器的严格性。

  • example_inputs_is_kwarg ( bool ,可选) – 此参数指示示例输入是否为关键字参数的打包。默认: False

返回值:

一个包含单个方法的对象,该方法包含追踪的代码。当 functorch.nn.Module 时,返回的 ScriptModule 将具有与 func 相同的子模块和参数集。

示例(追踪具有多个方法的模块):

import torch
import torch.nn as nn


class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x)

    def weighted_kernel_sum(self, weight):
        return weight * self.conv.weight


n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)

# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)

# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)

# Trace specific methods on a module (specified in `inputs`), constructs
# a `ScriptModule` with `forward` and `weighted_kernel_sum` methods
inputs = {"forward": example_forward_input, "weighted_kernel_sum": example_weight}
module = torch.jit.trace_module(n, inputs)

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源