• 文档 >
  • TorchScript
快捷键

TorchScript

TorchScript 是一种从 PyTorch 代码创建可序列化和可优化的模型的方法。任何 TorchScript 程序都可以从 Python 进程中保存,并在没有 Python 依赖的进程中加载。

我们提供工具,可以将模型从纯 Python 程序逐步转换为可以独立于 Python 运行的 TorchScript 程序,例如在独立的 C++ 程序中运行。这使得使用 Python 中的熟悉工具在 PyTorch 中训练模型成为可能,然后将模型通过 TorchScript 导出至可能因性能和线程多方面原因而不适合 Python 程序的生产环境中。

想要了解 TorchScript 的入门级介绍,请参阅《TorchScript 简介》教程。

想要查看将 PyTorch 模型转换为 TorchScript 并在 C++ 中运行的端到端示例,请参阅《在 C++ 中加载 PyTorch 模型》教程。

创建 TorchScript 代码

script

脚本化函数

trace

跟踪函数并返回一个可执行的或 ScriptFunction ,该代码将在即时编译中进行优化。

script_if_tracing

在跟踪期间首次调用时编译 fn

trace_module

跟踪一个模块并返回一个可执行的 ScriptModule ,该可执行文件将使用即时编译进行优化。

fork

创建一个异步任务,执行 func 并引用此执行的结果值。

wait

强制完成 torch.jit.Future[T]异步任务,返回任务的结果。

ScriptModule

C++ torch::jit::Module 的包装器,具有方法、属性和参数。

ScriptFunction

函数上等同于 ScriptModule ,但表示一个单一函数,没有属性或参数。

freeze

冻结 ScriptModule、内联子模块和属性为常量。

optimize_for_inference

执行一系列优化过程以优化模型用于推理的目的。

enable_onednn_fusion

根据参数启用情况启用或禁用 onednn JIT 融合。

onednn_fusion_enabled

返回是否启用了 ONEDNN JIT 融合。

set_fusion_strategy

设置融合过程中可能发生的特殊化类型和数量。

strict_fusion

如果推理过程中不是所有节点都进行了融合,或者在训练中符号性地进行了微分,则给出错误。

save

为此模块保存离线版本,以便在单独的进程中使用。

load

加载之前使用 torch.jit.save 保存的 ScriptModuleScriptFunction

ignore

此装饰器指示编译器忽略函数或方法,并将其保留为 Python 函数。

unused

此装饰器指示编译器忽略函数或方法,并用抛出异常来替换。

interface

用于注释不同类型的类或模块。

isinstance

在 TorchScript 中提供容器类型细化。

Attribute

此方法是一个透传函数,用于返回值,通常用于向 TorchScript 编译器指示左侧表达式是一个具有特定类型的类实例属性。

annotate

用于在 TorchScript 编译器中指定_the_value 的类型。

混合追踪和脚本化

在许多情况下,将模型转换为 TorchScript,追踪或脚本化都是一个更简单的方法。追踪和脚本化可以组合起来以满足模型某一部分的特定需求。

脚本函数可以调用追踪函数。这在需要使用控制流围绕一个简单的前馈模型时特别有用。例如,序列到序列模型的 beam search 通常是用脚本编写的,但可以调用使用追踪生成的编码器模块。

示例(在脚本中调用追踪函数):

import torch

def foo(x, y):
    return 2 * x + y

traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

@torch.jit.script
def bar(x):
    return traced_foo(x, x)

追踪函数可以调用脚本函数。当模型的大部分只是前馈网络,但模型的一小部分需要一些控制流时,这很有用。由追踪函数调用的脚本函数中的控制流被正确保留。

示例(在跟踪函数中调用脚本函数):

import torch

@torch.jit.script
def foo(x, y):
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r


def bar(x, y, z):
    return foo(x, y) + z

traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))

此组合也适用于 nn.Module s,其中可以使用跟踪生成子模块,并可以从脚本模块的方法中调用。

示例(使用跟踪模块):

import torch
import torchvision

class MyScriptModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
                                        .resize_(1, 3, 1, 1))
        self.resnet = torch.jit.trace(torchvision.models.resnet18(),
                                      torch.rand(1, 3, 224, 224))

    def forward(self, input):
        return self.resnet(input - self.means)

my_script_module = torch.jit.script(MyScriptModule())

TorchScript 语言 §

TorchScript 是 Python 的静态类型子集,因此许多 Python 特性可以直接应用于 TorchScript。请参阅完整的 TorchScript 语言参考以获取详细信息。

内置函数和模块

TorchScript 支持使用大多数 PyTorch 函数和许多 Python 内置函数。请参阅 TorchScript 内置函数以获取支持的函数的完整参考。

PyTorch 函数和模块

TorchScript 支持 PyTorch 提供的子集张量和神经网络函数。Tensor 上的大多数方法以及 torch 命名空间中的函数, torch.nn.functional 中的所有函数以及 torch.nn 中的大多数模块都支持在 TorchScript 中使用。

请参阅“TorchScript 不支持的 PyTorch 构造”以获取不支持的 PyTorch 函数和模块的列表。

Python 函数和模块

许多 Python 的内置函数在 TorchScript 中得到支持。 math 模块也得到支持(有关详细信息,请参阅 math 模块),但不支持其他 Python 模块(内置或第三方)。

Python 语言参考比较 §

查看支持的 Python 特性的完整列表,请参阅 Python 语言参考覆盖范围。

调试 §

调试时禁用 JIT §

PYTORCH_JIT

设置环境变量 PYTORCH_JIT=0 将禁用所有脚本和跟踪注解。如果您的 TorchScript 模型中存在难以调试的错误,可以使用此标志强制所有内容使用原生 Python 运行。由于使用此标志禁用了 TorchScript(脚本和跟踪),您可以使用 pdb 之类的工具来调试模型代码。例如:

@torch.jit.script
def scripted_fn(x : torch.Tensor):
    for i in range(12):
        x = x + x
    return x

def fn(x):
    x = torch.neg(x)
    import pdb; pdb.set_trace()
    return scripted_fn(x)

traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),))
traced_fn(torch.rand(3, 4))

使用 pdb 调试此脚本时,除了调用 @torch.jit.script 函数时外都正常。我们可以全局禁用 JIT,这样我们就可以将 @torch.jit.script 函数作为普通 Python 函数调用,而不是编译它。如果上述脚本被 disable_jit_example.py 调用,我们可以这样调用它:

$ PYTORCH_JIT=0 python disable_jit_example.py

我们将能够像普通 Python 函数一样进入 @torch.jit.script 函数。要禁用特定函数的 TorchScript 编译器,请参阅 @torch.jit.ignore

检查代码

TorchScript 为所有 ScriptModule 实例提供了一个代码美化打印器。这个美化打印器将脚本方法的代码解释为有效的 Python 语法。例如:

@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv

print(foo.code)

一个具有单个 forward 方法的 ScriptModule 将有一个名为 code 的属性,您可以使用它来检查 ScriptModule 的代码。如果 ScriptModule 有多个方法,您需要访问方法本身的 .code 而不是模块。我们可以通过访问 .foo.code 来检查名为 foo 的方法在 ScriptModule 上的代码。上面的示例生成了以下输出:

def foo(len: int) -> Tensor:
    rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
    rv0 = rv
    for i in range(len):
        if torch.lt(i, 10):
            rv1 = torch.sub(rv0, 1., 1)
        else:
            rv1 = torch.add(rv0, 1., 1)
        rv0 = rv1
    return rv0

这是 TorchScript 对 forward 方法的代码编译。您可以使用它来确保 TorchScript(跟踪或脚本)已正确捕获您的模型代码。

解读图表

TorchScript 还在代码美化器以下有表示,以 IR 图表的形式。

TorchScript 使用静态单赋值(SSA)中间表示(IR)来表示计算。这种格式的指令包括 ATen(PyTorch 的 C++ 后端)运算符和其他原始运算符,包括循环和条件语句的控制流运算符。例如:

@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv

print(foo.graph)

graph 遵循“检查代码”部分中描述的 forward 方法查找的相同规则。

上面的示例脚本生成了以下图表:

graph(%len.1 : int):
  %24 : int = prim::Constant[value=1]()
  %17 : bool = prim::Constant[value=1]() # test.py:10:5
  %12 : bool? = prim::Constant()
  %10 : Device? = prim::Constant()
  %6 : int? = prim::Constant()
  %1 : int = prim::Constant[value=3]() # test.py:9:22
  %2 : int = prim::Constant[value=4]() # test.py:9:25
  %20 : int = prim::Constant[value=10]() # test.py:11:16
  %23 : float = prim::Constant[value=1]() # test.py:12:23
  %4 : int[] = prim::ListConstruct(%1, %2)
  %rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10
  %rv : Tensor = prim::Loop(%len.1, %17, %rv.1) # test.py:10:5
    block0(%i.1 : int, %rv.14 : Tensor):
      %21 : bool = aten::lt(%i.1, %20) # test.py:11:12
      %rv.13 : Tensor = prim::If(%21) # test.py:11:9
        block0():
          %rv.3 : Tensor = aten::sub(%rv.14, %23, %24) # test.py:12:18
          -> (%rv.3)
        block1():
          %rv.6 : Tensor = aten::add(%rv.14, %23, %24) # test.py:14:18
          -> (%rv.6)
      -> (%17, %rv.13)
  return (%rv)

%rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10 指令为例。

  • %rv.1 : Tensor 表示我们将输出分配给一个名为 rv.1 的唯一值,该值的数据类型为 Tensor ,其具体形状未知。

  • aten::zeros 是操作符(相当于 torch.zeros ),输入列表 (%4, %6, %6, %10, %12) 指定了作用域中哪些值应作为输入。内置函数如 aten::zeros 的方案可以在“内置函数”部分找到。

  • # test.py:9:10 是生成此指令的原源文件中的位置。在这种情况下,它是一个名为 test.py 的文件,在第 9 行,第 10 个字符处。

注意,运算符也可以有相关的 blocks ,即 prim::Loopprim::If 运算符。在图形打印输出中,这些运算符被格式化为反映它们的等效源代码形式,以方便调试。

可以像下面这样检查图形,以确认由 ScriptModule 描述的计算是正确的,无论是自动的还是手动的,如下所述。

跟踪器 ¶

追踪边缘情况

存在一些边缘情况,在这些情况下,给定 Python 函数/模块的追踪结果可能无法代表底层代码。这些情况可能包括:

  • 依赖于输入的控制流追踪(例如张量的形状)

  • 张量视图的就地操作追踪(例如赋值左侧的索引)

注意这些情况在未来实际上可能是可追溯的。

自动追踪检查

自动捕捉追踪中的许多错误的一种方法是在 torch.jit.trace() API 上使用 check_inputscheck_inputs 接受一个输入元组的列表,这些输入将用于重新追踪计算并验证结果。例如:

def loop_in_traced_fn(x):
    result = x[0]
    for i in range(x.size(0)):
        result = result * x[i]
    return result

inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]

traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs)

给出以下诊断信息:

ERROR: Graphs differed across invocations!
Graph diff:

            graph(%x : Tensor) {
            %1 : int = prim::Constant[value=0]()
            %2 : int = prim::Constant[value=0]()
            %result.1 : Tensor = aten::select(%x, %1, %2)
            %4 : int = prim::Constant[value=0]()
            %5 : int = prim::Constant[value=0]()
            %6 : Tensor = aten::select(%x, %4, %5)
            %result.2 : Tensor = aten::mul(%result.1, %6)
            %8 : int = prim::Constant[value=0]()
            %9 : int = prim::Constant[value=1]()
            %10 : Tensor = aten::select(%x, %8, %9)
        -   %result : Tensor = aten::mul(%result.2, %10)
        +   %result.3 : Tensor = aten::mul(%result.2, %10)
        ?          ++
            %12 : int = prim::Constant[value=0]()
            %13 : int = prim::Constant[value=2]()
            %14 : Tensor = aten::select(%x, %12, %13)
        +   %result : Tensor = aten::mul(%result.3, %14)
        +   %16 : int = prim::Constant[value=0]()
        +   %17 : int = prim::Constant[value=3]()
        +   %18 : Tensor = aten::select(%x, %16, %17)
        -   %15 : Tensor = aten::mul(%result, %14)
        ?     ^                                 ^
        +   %19 : Tensor = aten::mul(%result, %18)
        ?     ^                                 ^
        -   return (%15);
        ?             ^
        +   return (%19);
        ?             ^
            }

这条消息向我们表明,我们在首次追踪和用 check_inputs 追踪时,计算结果存在差异。实际上, loop_in_traced_fn 体内的循环依赖于输入 x 的形状,因此当我们尝试具有不同形状的另一个 x 时,追踪结果会有所不同。

在这种情况下,可以使用 torch.jit.script() 来捕获这种数据相关的控制流:

def fn(x):
    result = x[0]
    for i in range(x.size(0)):
        result = result * x[i]
    return result

inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]

scripted_fn = torch.jit.script(fn)
print(scripted_fn.graph)
#print(str(scripted_fn.graph).strip())

for input_tuple in [inputs] + check_inputs:
    torch.testing.assert_close(fn(*input_tuple), scripted_fn(*input_tuple))

产生结果如下:

graph(%x : Tensor) {
    %5 : bool = prim::Constant[value=1]()
    %1 : int = prim::Constant[value=0]()
    %result.1 : Tensor = aten::select(%x, %1, %1)
    %4 : int = aten::size(%x, %1)
    %result : Tensor = prim::Loop(%4, %5, %result.1)
    block0(%i : int, %7 : Tensor) {
        %10 : Tensor = aten::select(%x, %1, %i)
        %result.2 : Tensor = aten::mul(%7, %10)
        -> (%5, %result.2)
    }
    return (%result);
}

跟踪器警告 ¶

追踪器会对追踪计算中的几个问题模式产生警告。例如,考虑一个包含对 Tensor 切片(视图)就地赋值的函数的追踪:

def fill_row_zero(x):
    x[0] = torch.rand(*x.shape[1:2])
    return x

traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)

产生多个警告和一个简单地返回输入的图:

fill_row_zero.py:4: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
    x[0] = torch.rand(*x.shape[1:2])
fill_row_zero.py:6: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 1] (0.09115803241729736 vs. 0.6782537698745728) and 3 other locations (33.00%)
    traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
graph(%0 : Float(3, 4)) {
    return (%0);
}

我们可以通过修改代码,不使用就地更新,而是使用 torch.cat 就地构建结果张量来修复这个问题:

def fill_row_zero(x):
    x = torch.cat((torch.rand(1, *x.shape[1:2]), x[1:2]), dim=0)
    return x

traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)

常见问题解答 ¶

Q: 我想在 GPU 上训练模型并在 CPU 上进行推理。有哪些最佳实践?

首先将您的模型从 GPU 转换为 CPU,然后保存,如下所示:

cpu_model = gpu_model.cpu()
sample_input_cpu = sample_input_gpu.cpu()
traced_cpu = torch.jit.trace(cpu_model, sample_input_cpu)
torch.jit.save(traced_cpu, "cpu.pt")

traced_gpu = torch.jit.trace(gpu_model, sample_input_gpu)
torch.jit.save(traced_gpu, "gpu.pt")

# ... later, when using the model:

if use_gpu:
  model = torch.jit.load("gpu.pt")
else:
  model = torch.jit.load("cpu.pt")

model(input)

这是因为追踪器可能会看到在特定设备上的张量创建,所以将已加载的模型进行转换可能会有意外效果。在保存模型之前转换模型可以确保追踪器有正确的设备信息。

Q: 我如何在 ScriptModule 上存储属性?

假设我们有一个这样的模型:

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.x = 2

    def forward(self):
        return self.x

m = torch.jit.script(Model())

如果实例化了 Model ,将会导致编译错误,因为编译器不知道 x 。有 4 种方式让编译器知道 ScriptModule 上的属性:

1. nn.Parameter - 用 nn.Parameter 包裹的值将像在 nn.Module 上那样工作

2. register_buffer - 用 register_buffer 包裹的值将像在 nn.Module 上那样工作。这相当于类型为 Tensor 的属性(参见 4)。

3. 常量 - 将类成员标注为 Final (或将其添加到类定义级别的列表 __constants__ 中)将标记包含的名称为常量。常量将直接保存在模型的代码中。有关详细信息,请参阅内置常量。

4. 属性 - 可以添加支持类型的值作为可变属性。大多数类型可以推断,但某些类型可能需要指定,有关模块属性详细信息,请参阅模块属性。

Q: 我想追踪模块的方法,但我一直得到这个错误:

RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient

此错误通常意味着您正在追踪的方法使用了模块的参数,您传递的是模块的方法而不是模块实例(例如 my_module_instance.forwardmy_module_instance )。

  • 使用模块的方法调用 trace 可以将模块参数(可能需要梯度)捕获为常量。

  • 另一方面,使用模块实例(例如 my_module )调用 trace 将创建一个新的模块,并正确地将参数复制到新模块中,以便在需要时累积梯度。

要追踪模块上的特定方法,请参阅 torch.jit.trace_module

已知问题 §

如果你使用 TorchScript,某些子模块的输入可能会被错误地推断为,即使它们已经被正确注释。标准的解决方案是继承并重新声明输入类型。

附录

迁移到 PyTorch 1.2 递归脚本 API

本节详细介绍了 PyTorch 1.2 中 TorchScript 的变化。如果你是 TorchScript 的新手,可以跳过本节。PyTorch 1.2 对 TorchScript API 有两个主要的变化。

1. torch.jit.script 将尝试递归编译遇到的函数、方法和类。一旦调用 torch.jit.script ,编译将变为“默认”,而不是“可选”。

2. torch.jit.script(nn_module_instance) 现在是创建 ScriptModule 的首选方式,而不是从 torch.jit.ScriptModule 继承。这些更改结合在一起,为将您的 nn.Module 转换为 ScriptModule 提供了一个更简单、更易于使用的 API,以便在非 Python 环境中进行优化和执行。

3. 新的使用方法如下:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

my_model = Model()
my_scripted_model = torch.jit.script(my_model)
  • 4. 该模块的 forward 默认编译。从 forward 调用的方法按其在 forward 中使用的顺序懒加载编译。

  • 编译一个未被从 forward 调用的非 forward 方法,请添加 @torch.jit.export

  • 停止编译器编译方法,请添加 @torch.jit.ignore@torch.jit.unused@ignore 保留原样。

  • 方法作为一个对 Python 的调用, @unused 用异常替换它。 @ignored 无法导出; @unused 可以。

  • 大多数属性类型可以推断,因此 torch.jit.Attribute 不是必需的。对于空容器类型,使用 PEP 526 风格的类注解来标注它们的类型。

  • 常量可以用 Final 类注解标记,而不是将成员名称添加到 __constants__ 中。

  • Python 3 类型提示可以用作 torch.jit.annotate 的替代。

由于这些更改,以下项目被认为是过时的,不应出现在新代码中:
  • @torch.jit.script_method 装饰器

  • 继承自 torch.jit.ScriptModule 的类

  • torch.jit.Attribute 包装类

  • __constants__ 数组

  • torch.jit.annotate 函数

模块

警告

在 PyTorch 1.2 中, @torch.jit.ignore 注释的行为发生了变化。在 PyTorch 1.2 之前,使用 @ignore 装饰器可以使函数或方法从导出的代码中可调用。要恢复此功能,请使用 @torch.jit.unused() 。现在 @torch.jit.ignore 等同于 @torch.jit.ignore(drop=False) 。有关详细信息,请参阅 @torch.jit.ignore@torch.jit.unused

当传递给 torch.jit.script 函数时, torch.nn.Module 的数据被复制到 ScriptModule 中,并且 TorchScript 编译器编译该模块。默认情况下编译模块的 forward 。从 forward 调用的方法按其在 forward 中使用的顺序懒加载编译,以及任何 @torch.jit.export 方法。

torch.jit.export(fn)[source][source]

此装饰器表示在 nn.Module 上的方法用作 ScriptModule 的入口点,并且应该被编译。

forward 默认被认为是入口点,因此不需要此装饰器。从 forward 调用的函数和方法按照编译器看到的进行编译,因此也不需要此装饰器。

示例(在方法上使用 @torch.jit.export ):

import torch
import torch.nn as nn

class MyModule(nn.Module):
    def implicitly_compiled_method(self, x):
        return x + 99

    # `forward` is implicitly decorated with `@torch.jit.export`,
    # so adding it here would have no effect
    def forward(self, x):
        return x + 10

    @torch.jit.export
    def another_forward(self, x):
        # When the compiler sees this call, it will compile
        # `implicitly_compiled_method`
        return self.implicitly_compiled_method(x)

    def unused_method(self, x):
        return x - 20

# `m` will contain compiled methods:
#     `forward`
#     `another_forward`
#     `implicitly_compiled_method`
# `unused_method` will not be compiled since it was not called from
# any compiled methods and wasn't decorated with `@torch.jit.export`
m = torch.jit.script(MyModule())

函数 ¶

函数变化不大,如果需要,可以装饰为 @torch.jit.ignoretorch.jit.unused

# Same behavior as pre-PyTorch 1.2
@torch.jit.script
def some_fn():
    return 2

# Marks a function as ignored, if nothing
# ever calls it then this has no effect
@torch.jit.ignore
def some_fn2():
    return 2

# As with ignore, if nothing calls it then it has no effect.
# If it is called in script it is replaced with an exception.
@torch.jit.unused
def some_fn3():
  import pdb; pdb.set_trace()
  return 4

# Doesn't do anything, this function is already
# the main entry point
@torch.jit.export
def some_fn4():
    return 2

TorchScript 类

警告

TorchScript 类支持目前处于实验阶段。目前它最适合简单的记录类型(例如带有方法的 NamedTuple )。

用户定义的 TorchScript 类中的所有内容默认导出,如果需要,函数可以装饰为 @torch.jit.ignore

属性

TorchScript 编译器需要知道模块属性的类型。大多数类型可以从成员的值中推断出来。空列表和字典无法推断其类型,必须使用 PEP 526 风格的类注解来指定其类型。如果无法推断类型且未显式注解,则不会将其添加为结果 ScriptModule 的属性

旧 API:

from typing import Dict
import torch

class MyModule(torch.jit.ScriptModule):
    def __init__(self):
        super().__init__()
        self.my_dict = torch.jit.Attribute({}, Dict[str, int])
        self.my_int = torch.jit.Attribute(20, int)

m = MyModule()

新 API:

from typing import Dict

class MyModule(torch.nn.Module):
    my_dict: Dict[str, int]

    def __init__(self):
        super().__init__()
        # This type cannot be inferred and must be specified
        self.my_dict = {}

        # The attribute type here is inferred to be `int`
        self.my_int = 20

    def forward(self):
        pass

m = torch.jit.script(MyModule())

常量 ¶

类型构造函数 Final 可以用来标记成员为常量。如果成员没有被标记为常量,它们将被复制到结果 ScriptModule 作为属性。使用 Final 如果已知值是固定的,可以提供优化机会,并增加额外的类型安全性。

旧 API:

class MyModule(torch.jit.ScriptModule):
    __constants__ = ['my_constant']

    def __init__(self):
        super().__init__()
        self.my_constant = 2

    def forward(self):
        pass
m = MyModule()

新 API:

from typing import Final

class MyModule(torch.nn.Module):

    my_constant: Final[int]

    def __init__(self):
        super().__init__()
        self.my_constant = 2

    def forward(self):
        pass

m = torch.jit.script(MyModule())

变量 §

容器默认具有类型 Tensor 且为非可选(更多信息请参阅默认类型)。之前, torch.jit.annotate 用于告知 TorchScript 编译器类型应该是什么。现在支持 Python 3 风格的类型提示。

import torch
from typing import Dict, Optional

@torch.jit.script
def make_dict(flag: bool):
    x: Dict[str, int] = {}
    x['hi'] = 2
    b: Optional[int] = None
    if flag:
        b = 2
    return x, b

融合后端

可用于优化 TorchScript 执行的融合后端有几个。CPU 上的默认融合器是 NNC,它可以执行 CPU 和 GPU 的融合。GPU 上的默认融合器是 NVFuser,它支持更广泛的算子,并已证明生成的内核具有更高的吞吐量。有关使用和调试的详细信息,请参阅 NVFuser 文档。

参考


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

PyTorch 开发者文档全面访问

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源