快捷键

火炬脚本语言参考

火炬脚本是一种静态类型的 Python 子集,可以直接编写(使用 @torch.jit.script 装饰器)或通过跟踪从 Python 代码自动生成。在跟踪时,代码会自动转换为这个 Python 子集,通过仅记录张量上的实际操作符并简单地执行和丢弃其他周围的 Python 代码。

直接使用 @torch.jit.script 装饰器编写火炬脚本时,程序员必须只使用火炬脚本支持的 Python 子集。本节将文档化火炬脚本中支持的内容,就像它是独立语言的语言参考一样。本参考中没有提到的 Python 功能不是火炬脚本的一部分。有关可用的 PyTorch 张量方法、模块和函数的完整参考,请参阅内置函数。

作为 Python 的一个子集,任何有效的 TorchScript 函数也都是有效的 Python 函数。这使得可以禁用 TorchScript 并使用标准的 Python 工具(如 pdb )来调试函数。反之则不然:有许多有效的 Python 程序并不是有效的 TorchScript 程序。相反,TorchScript 专注于 Python 中用于表示 PyTorch 神经网络模型所需的功能。

类型

与完整的 Python 语言相比,TorchScript 最大的不同之处在于它只支持一小部分用于表达神经网络模型所需的类型。具体来说,TorchScript 支持:

类型

描述

Tensor

任何 dtype、维度或后端的 PyTorch 张量

Tuple[T0, T1, ..., TN]

包含子类型 T0T1 等(例如 Tuple[Tensor, Tensor]

bool

一个布尔值

int

一个标量整数

float

一个标量浮点数

str

一个字符串

List[T]

所有成员都是类型 T 的列表

Optional[T]

一个值,要么是 None,要么是类型 T

Dict[K, V]

键类型为 K ,值类型为 V 的字典。仅允许 strintfloat 作为键类型。

T

TorchScript 类

E

TorchScript 枚举

NamedTuple[T0, T1, ...]

collections.namedtuple 元组类型

Union[T0, T1, ...]

T0T1 等子类型之一

与 Python 不同,TorchScript 函数中的每个变量都必须具有单个静态类型。这使得优化 TorchScript 函数变得更加容易。

示例(类型不匹配)

import torch

@torch.jit.script
def an_error(x):
    if x:
        r = torch.rand(1)
    else:
        r = 4
    return r
Traceback (most recent call last):
  ...
RuntimeError: ...

Type mismatch: r is set to type Tensor in the true branch and type int in the false branch:
@torch.jit.script
def an_error(x):
    if x:
    ~~~~~
        r = torch.rand(1)
        ~~~~~~~~~~~~~~~~~
    else:
    ~~~~~
        r = 4
        ~~~~~ <--- HERE
    return r
and was used here:
    else:
        r = 4
    return r
           ~ <--- HERE...

不支持的类型构造

TorchScript 不支持 typing 模块的所有功能和类型。其中一些是更基础的东西,未来不太可能添加,而另一些如果用户需求足够强烈,可能会被优先考虑添加。

来自 typing 模块的这些类型和功能在 TorchScript 中不可用。

项目

描述

typing.Any

typing.Any 目前处于开发中,但尚未发布

typing.NoReturn

未实现

typing.Sequence

未实现

typing.Callable

未实现

typing.Literal

未实现

typing.ClassVar

未实现

typing.Final

支持模块属性类属性注解,但不支持函数

typing.AnyStr

TorchScript 不支持 bytes ,因此此类型未使用

typing.overload

typing.overload 目前处于开发中,但尚未发布

类型别名

未实现

名义子类型与结构子类型

名义类型正在开发中,但结构类型尚未

新类型

很可能不会被实现

泛型

很可能不会被实现

本文档未明确列出的 typing 模块中的任何其他功能均不受支持。

默认类型 ¶

默认情况下,所有 TorchScript 函数的参数都被假定为 Tensor。要指定 TorchScript 函数的参数是其他类型,可以使用上面列出的 types 模块中的 MyPy 风格类型注解。

import torch

@torch.jit.script
def foo(x, tup):
    # type: (int, Tuple[Tensor, Tensor]) -> Tensor
    t0, t1 = tup
    return t0 + t1 + x

print(foo(3, (torch.rand(3), torch.rand(3))))

注意

还可以使用来自 typing 模块的 Python 3 类型提示来注解类型。

import torch
from typing import Tuple

@torch.jit.script
def foo(x: int, tup: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
    t0, t1 = tup
    return t0 + t1 + x

print(foo(3, (torch.rand(3), torch.rand(3))))

空列表被认为是 List[Tensor] ,空字典被认为是 Dict[str, Tensor] 。要实例化其他类型的空列表或字典,请使用 Python 3 类型提示。

Python 3 的类型注解示例:

import torch
import torch.nn as nn
from typing import Dict, List, Tuple

class EmptyDataStructures(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]:
        # This annotates the list to be a `List[Tuple[int, float]]`
        my_list: List[Tuple[int, float]] = []
        for i in range(10):
            my_list.append((i, x.item()))

        my_dict: Dict[str, int] = {}
        return my_list, my_dict

x = torch.jit.script(EmptyDataStructures())

可选类型细化

当在 if 语句的条件或通过 assert 检查时,TorchScript 会细化类型为 Optional[T] 的变量的类型。编译器可以推理出与 andornot 结合的多个 None 检查。对于没有明确书写的 if 语句的 else 块,也会进行类型细化。

None 检查必须在 if 语句的条件内;将 None 检查赋值给变量并在 if 语句的条件中使用,不会细化检查中变量的类型。只有局部变量会被细化,例如 self.x 这样的属性则不会,并且必须将其赋值给局部变量才能进行细化。

示例(在参数和局部变量上精炼类型):

import torch
import torch.nn as nn
from typing import Optional

class M(nn.Module):
    z: Optional[int]

    def __init__(self, z):
        super().__init__()
        # If `z` is None, its type cannot be inferred, so it must
        # be specified (above)
        self.z = z

    def forward(self, x, y, z):
        # type: (Optional[int], Optional[int], Optional[int]) -> int
        if x is None:
            x = 1
            x = x + 1

        # Refinement for an attribute by assigning it to a local
        z = self.z
        if y is not None and z is not None:
            x = y + z

        # Refinement via an `assert`
        assert z is not None
        x += z
        return x

module = torch.jit.script(M(2))
module = torch.jit.script(M(None))

TorchScript 类

警告

TorchScript 类支持目前处于实验阶段。目前它最适合简单的记录类型(例如带有方法的 NamedTuple )。

如果 Python 类使用 @torch.jit.script 进行注解,则可以在 TorchScript 中使用,类似于声明 TorchScript 函数的方式:

@torch.jit.script
class Foo:
  def __init__(self, x, y):
    self.x = x

  def aug_add_x(self, inc):
    self.x += inc

此子集受限制:

  • 所有函数必须是有效的 TorchScript 函数(包括 __init__() )。

  • 类必须是新式类,因为我们使用 __new__() 和 pybind11 来构造它们。

  • TorchScript 类是静态类型的。成员只能通过在 __init__() 方法中给 self 赋值来声明。

    例如,在 __init__() 方法外部对 self 进行赋值:

    @torch.jit.script
    class Foo:
      def assign_x(self):
        self.x = torch.rand(2, 3)
    

    将导致:

    RuntimeError:
    Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?:
    def assign_x(self):
      self.x = torch.rand(2, 3)
      ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    
  • 类体中不允许有除方法定义之外的表达式。

  • 不支持继承或其他任何多态策略,除非从 object 继承以指定新式类。

定义一个类之后,它可以在 TorchScript 和 Python 中像任何其他 TorchScript 类型一样互相使用:

# Declare a TorchScript class
@torch.jit.script
class Pair:
  def __init__(self, first, second):
    self.first = first
    self.second = second

@torch.jit.script
def sum_pair(p):
  # type: (Pair) -> Tensor
  return p.first + p.second

p = Pair(torch.rand(2, 3), torch.rand(2, 3))
print(sum_pair(p))

TorchScript 枚举

Python 枚举可以在 TorchScript 中使用,无需任何额外的注释或代码:

from enum import Enum


class Color(Enum):
    RED = 1
    GREEN = 2

@torch.jit.script
def enum_fn(x: Color, y: Color) -> bool:
    if x == Color.RED:
        return True

    return x == y

定义枚举之后,它可以在 TorchScript 和 Python 中像任何其他 TorchScript 类型一样互相使用。枚举值的类型必须是 intfloatstr 。所有值必须具有相同的类型;枚举值不支持异构类型。

命名元组

collections.namedtuple 产生的类型可以在 TorchScript 中使用。

import torch
import collections

Point = collections.namedtuple('Point', ['x', 'y'])

@torch.jit.script
def total(point):
    # type: (Point) -> Tensor
    return point.x + point.y

p = Point(x=torch.rand(3), y=torch.rand(3))
print(total(p))

迭代器

一些函数(例如, zipenumerate )只能操作可迭代类型。TorchScript 中的可迭代类型包括 Tensor 、列表、元组、字典、字符串、 torch.nn.ModuleListtorch.nn.ModuleDict

表达式 ¶

支持以下 Python 表达式。

文字 ¶

True
False
None
'string literals'
"string literals"
3  # interpreted as int
3.4  # interpreted as a float

列表构造 ¶

空列表假定具有类型 List[Tensor] 。其他列表字面量的类型由成员的类型推导而来。有关详细信息,请参阅默认类型。

[3, 4]
[]
[torch.rand(3), torch.rand(4)]

元组构造

(3, 4)
(3,)

字典构造

空字典假定具有类型 Dict[str, Tensor] 。其他字典字面量的类型由成员的类型推导而来。有关详细信息,请参阅默认类型。

{'hello': 3}
{}
{'a': torch.rand(3), 'b': torch.rand(4)}

变量 ¶

请参阅变量解析以了解变量如何解析。

my_variable_name

算术运算符 ¶

a + b
a - b
a * b
a / b
a ^ b
a @ b

比较运算符 ¶

a == b
a != b
a < b
a > b
a <= b
a >= b

逻辑运算符

a and b
a or b
not b

下标和切片

t[0]
t[-1]
t[0:2]
t[1:]
t[:1]
t[:]
t[0, 1]
t[0, 1:2]
t[0, :1]
t[-1, 1:, 0]
t[1:, -1, 0]
t[i:j, i]

函数调用

内置函数调用

torch.rand(3, dtype=torch.int)

调用其他脚本函数:

import torch

@torch.jit.script
def foo(x):
    return x + 1

@torch.jit.script
def bar(x):
    return foo(x)

方法调用 ¶

调用内置类型(如 tensor)的方法: x.mm(y)

在模块中,方法必须在调用之前编译。TorchScript 编译器在编译其他方法时会递归地编译它看到的方法。默认情况下,编译从 forward 方法开始。任何被 forward 调用的方法都将被编译,以及这些方法调用的方法,依此类推。要从一个不是 forward 的方法开始编译,请使用 @torch.jit.export 装饰器( forward 隐式标记为 @torch.jit.export )。

直接调用子模块(例如 self.resnet(input) )等同于调用其 forward 方法(例如 self.resnet.forward(input) )。

import torch
import torch.nn as nn
import torchvision

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        means = torch.tensor([103.939, 116.779, 123.68])
        self.means = torch.nn.Parameter(means.resize_(1, 3, 1, 1))
        resnet = torchvision.models.resnet18()
        self.resnet = torch.jit.trace(resnet, torch.rand(1, 3, 224, 224))

    def helper(self, input):
        return self.resnet(input - self.means)

    def forward(self, input):
        return self.helper(input)

    # Since nothing in the model calls `top_level_method`, the compiler
    # must be explicitly told to compile this method
    @torch.jit.export
    def top_level_method(self, input):
        return self.other_helper(input)

    def other_helper(self, input):
        return input + 10

# `my_script_module` will have the compiled methods `forward`, `helper`,
# `top_level_method`, and `other_helper`
my_script_module = torch.jit.script(MyModule())

三元表达式 ¶

x if x > y else y

类型转换 ¶

float(ten)
int(3.5)
bool(ten)
str(2)``

访问模块参数 ¶

self.my_parameter
self.my_submodule.my_parameter

陈述句 §

TorchScript 支持以下类型的语句:

简单赋值 §

a = b
a += b # short-hand for a = a + b, does not operate in-place on a
a -= b

模式匹配赋值 §

a, b = tuple_or_list
a, b, *c = a_tuple

多重赋值

a = b, c = tup

如果语句

if a < 4:
    r = -a
elif a < 3:
    r = a + a
else:
    r = 3 * a

除了布尔值外,浮点数、整数和张量也可以用于条件判断,并且会隐式转换为布尔值。

当循环

a = 0
while a < 4:
    print(a)
    a += 1

带 range 的 for 循环

x = 0
for i in range(10):
    x *= i

遍历元组的 for 循环

这些循环展开,为元组的每个成员生成一个主体。主体的类型检查必须对每个成员都正确。

tup = (3, torch.rand(4))
for x in tup:
    print(x)

对常量 nn.ModuleList 进行循环

要在编译方法中使用 nn.ModuleList ,必须将该属性标记为常量,并将该属性名称添加到 __constants__ 列表中,对于 nn.ModuleList 的循环,将在编译时展开循环体,每个常量模块列表的成员都会展开。

class SubModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(2))

    def forward(self, input):
        return self.weight + input

class MyModule(torch.nn.Module):
    __constants__ = ['mods']

    def __init__(self):
        super().__init__()
        self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])

    def forward(self, v):
        for module in self.mods:
            v = module(v)
        return v


m = torch.jit.script(MyModule())

Break 和 Continue

for i in range(5):
    if i == 1:
        continue
    if i == 3:
        break
    print(i)

返回

return a, b

变量解析 ¶

TorchScript 支持 Python 变量解析(即作用域)规则的一个子集。局部变量在 Python 中的行为相同,但有一个限制,即变量在函数的所有路径上必须具有相同的类型。如果一个变量在 if 语句的不同分支上具有不同的类型,则在 if 语句结束后使用它将是一个错误。

同样,如果一个变量仅在函数的一些路径上定义,则不允许使用它。

示例:

@torch.jit.script
def foo(x):
    if x < 0:
        y = 4
    print(y)
Traceback (most recent call last):
  ...
RuntimeError: ...

y is not defined in the false branch...
@torch.jit.script...
def foo(x):
    if x < 0:
    ~~~~~~~~~
        y = 4
        ~~~~~ <--- HERE
    print(y)
and was used here:
    if x < 0:
        y = 4
    print(y)
          ~ <--- HERE...

非局部变量在函数定义时编译时解析为 Python 值。然后使用《使用 Python 值》中描述的规则将这些值转换为 TorchScript 值。

Python 值的使用 ¶

为了使编写 TorchScript 更加方便,我们允许脚本代码在周围作用域中引用 Python 值。例如,每当有对 torch 的引用时,当函数声明时,TorchScript 编译器实际上会将它解析为 torch Python 模块。这些 Python 值不是 TorchScript 的第一类组成部分。相反,它们在编译时被转换为 TorchScript 支持的原始类型。这取决于在编译时引用的 Python 值的动态类型。本节描述了在 TorchScript 中访问 Python 值时使用的规则。

函数 ¶

TorchScript 可以调用 Python 函数。当逐步将模型转换为 TorchScript 时,这个功能非常有用。可以将模型逐个函数地移动到 TorchScript 中,同时保留对 Python 函数的调用。这样,您可以逐步检查模型在转换过程中的正确性。

torch.jit.is_scripting()[source][source]

当处于编译状态时返回 True,否则返回 False。这对于与 @unused 装饰器一起使用特别有用,可以在您的模型中留下尚未与 TorchScript 兼容的代码。.. testcode:

import torch

@torch.jit.unused
def unsupported_linear_op(x):
    return x

def linear(x):
    if torch.jit.is_scripting():
        return torch.linear(x)
    else:
        return unsupported_linear_op(x)
返回类型:

布尔型

torch.jit.is_tracing()[source][source]

返回一个布尔值。

返回在跟踪(如果函数在带有 torch.jit.trace 的代码跟踪期间被调用)时为 True ,否则为 False

Python 模块属性查找 §

TorchScript 可以查找模块的属性。内置函数如 torch.add 就是通过这种方式访问的。这允许 TorchScript 调用在其他模块中定义的函数。

Python 定义的常量 §

TorchScript 还提供了一种使用在 Python 中定义的常量的方法。这些常量可以用来将超参数硬编码到函数中,或者定义通用常量。有两种方式指定 Python 值应该被视为常量。

  1. 查询模块属性时,假定其值为常量:

import math
import torch

@torch.jit.script
def fn():
    return math.pi
  1. 可以通过注解来标记 ScriptModule 的属性为常量: Final[T]

import torch
import torch.nn as nn

class Foo(nn.Module):
    # `Final` from the `typing_extensions` module can also be used
    a : torch.jit.Final[int]

    def __init__(self):
        super().__init__()
        self.a = 1 + 4

    def forward(self, input):
        return self.a + input

f = torch.jit.script(Foo())

支持的常量 Python 类型有

  • int

  • float

  • bool

  • torch.device

  • torch.layout

  • torch.dtype

  • 包含支持类型的元组

  • torch.nn.ModuleList 可用于 TorchScript 循环中

模块属性

torch.nn.Parameter 包装器和 register_buffer 可用于将张量分配给模块。编译后的模块中,如果可以推断出其他值的数据类型,则这些值将被添加到编译后的模块中。TorchScript 中所有可用的数据类型都可以用作模块属性。张量属性与缓冲区在语义上是相同的。空列表、字典和 None 值的类型无法推断,必须通过 PEP 526 风格的类注解来指定。如果无法推断类型且未显式注解,则不会将其添加为结果的 ScriptModule 属性。

示例:

from typing import List, Dict

class Foo(nn.Module):
    # `words` is initialized as an empty list, so its type must be specified
    words: List[str]

    # The type could potentially be inferred if `a_dict` (below) was not
    # empty, but this annotation ensures `some_dict` will be made into the
    # proper type
    some_dict: Dict[str, int]

    def __init__(self, a_dict):
        super().__init__()
        self.words = []
        self.some_dict = a_dict

        # `int`s can be inferred
        self.my_int = 10

    def forward(self, input):
        # type: (str) -> int
        self.words.append(input)
        return self.some_dict[input] + self.my_int

f = torch.jit.script(Foo({'hi': 2}))

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

PyTorch 开发者文档全面访问

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源