TorchScript 语言参考手册
本参考手册描述了 TorchScript 语言的语法和核心语义。TorchScript 是 Python 语言的一个静态类型子集。本文件解释了 TorchScript 中支持的 Python 特性以及该语言与常规 Python 的差异。本参考手册未提及的 Python 特性不属于 TorchScript。TorchScript 专注于 Python 中用于表示 PyTorch 神经网络模型的特性。
术语 §
本文档使用了以下术语:
模式 |
笔记 |
|---|---|
|
表示给定的符号被定义为。 |
|
代表语法中作为一部分的实关键词和分隔符。 |
|
表示 A 或 B。 |
|
表示分组 |
|
表示可选 |
|
表示至少重复一次的术语 A 的正则表达式 |
|
表示零次或多次重复的术语 A 的正则表达式 |
类型系统 ¶
TorchScript 是 Python 的静态类型子集。与完整的 Python 语言相比,TorchScript 的最大区别是它只支持用于表达神经网络模型所需的一小部分类型。
TorchScript 类型 ¶
TorchScript 类型系统由以下定义的 TSType 和 TSModuleType 组成。
TSAllType ::= TSType | TSModuleType
TSType ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType
表示大多数可组合且可用于 TorchScript 类型注解的 TorchScript 类型。 TSType 指的是以下任何一种:
元类型,例如,
Any基本类型,例如,
int,float,和str结构化类型,例如,
Optional[int]或List[MyClass]常量类型(Python 类),例如:
MyClass(用户定义的),torch.tensor(内置的)
TSModuleType 代表 torch.nn.Module 及其子类。由于其类型模式部分从对象实例和部分从类定义中推断,因此它与 TSType 处理方式不同。 TSModuleType 的实例可能不遵循相同的静态类型模式。 TSModuleType 不能用作 TorchScript 类型注解,也不能与 TSType 组合以考虑类型安全性。
元类型
元类型非常抽象,它们更像类型约束而不是具体类型。目前 TorchScript 定义了一个元类型 Any ,它代表任何 TorchScript 类型。
Any 类型 ¶
The Any 类型表示任何 TorchScript 类型。 Any 指定没有类型约束,因此没有在 Any 上进行类型检查。因此,它可以绑定到任何 Python 或 TorchScript 数据类型(例如 int ,TorchScript tuple ,或任何未脚本化的任意 Python 类)。
TSMetaType ::= "Any"
其中:
Any是来自 typing 模块的 Python 类名。因此,要使用Any类型,您必须从typing导入它(例如from typing import Any)。由于
Any可以表示任何 TorchScript 类型,因此可以在Any上操作此类型值的操作符集合受到限制。
支持的 Any 类型操作符
将数据赋值给
Any类型。绑定到
Any类型的参数或返回值。x is,x is not为x类型。isinstance(x, Type),其中x为Any类型。Any类型的数据可打印。如果数据是相同类型
T的值列表且T支持比较运算符,则List[Any]类型的数据可能是可排序的。
与 Python 相比
在 TorchScript 类型系统中, Any 是约束最少的类型。从这个意义上讲,它与 Python 中的 Object 类非常相似。然而, Any 只支持 Object 支持的子集运算符和方法。
设计笔记
当我们为 PyTorch 模块编写脚本时,可能会遇到不参与脚本执行的数据。尽管如此,这些数据必须由类型模式进行描述。在脚本上下文中描述未使用的数据的静态类型不仅麻烦,还可能导致不必要的脚本失败。 Any 的引入是为了描述数据类型,在编译过程中不需要精确的静态类型。
示例 1
本示例说明了如何使用 Any 允许元组的第二个参数为任何类型。这是因为 x[1] 不参与任何需要知道其确切类型的计算。
import torch
from typing import Tuple
from typing import Any
@torch.jit.export
def inc_first_element(x: Tuple[int, Any]):
return (x[0]+1, x[1])
m = torch.jit.script(inc_first_element)
print(m((1,2.0)))
print(m((1,(100,200))))
上面的示例生成了以下输出:
(2, 2.0)
(2, (100, 200))
元组的第二个元素为 Any 类型,因此可以绑定到多种类型。例如, (1, 2.0) 将浮点类型绑定到 Any ,如 Tuple[int, Any] 所示,而在第二次调用中, (1, (100, 200)) 将元组绑定到 Any 。
示例 2
本示例说明了我们如何使用 isinstance 来动态检查标注为 Any 类型的数据类型:
import torch
from typing import Any
def f(a:Any):
print(a)
return (isinstance(a, torch.Tensor))
ones = torch.ones([2])
m = torch.jit.script(f)
print(m(ones))
上面的示例生成了以下输出:
1
1
[ CPUFloatType{2} ]
True
基本类型 -
原始的 TorchScript 类型是表示单一类型值的类型,并与其预定义的类型名称相对应。
TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None"
结构化类型
结构化类型是指没有用户定义名称的结构化定义的类型(与命名类型不同),例如 Future[int] 。结构化类型可以与任何 TSType 组合。
TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict |
TSOptional | TSUnion | TSFuture | TSRRef | TSAwait
TSTuple ::= "Tuple" "[" (TSType ",")* TSType "]"
TSNamedTuple ::= "namedtuple" "(" (TSType ",")* TSType ")"
TSList ::= "List" "[" TSType "]"
TSOptional ::= "Optional" "[" TSType "]"
TSUnion ::= "Union" "[" (TSType ",")* TSType "]"
TSFuture ::= "Future" "[" TSType "]"
TSRRef ::= "RRef" "[" TSType "]"
TSAwait ::= "Await" "[" TSType "]"
TSDict ::= "Dict" "[" KeyType "," TSType "]"
KeyType ::= "str" | "int" | "float" | "bool" | TensorType | "Any"
位置:
Tuple,List,Optional,Union,Future,Dict代表在模块typing中定义的 Python 类型类名。要使用这些类型名称,您必须从typing导入它们(例如,from typing import Tuple)。namedtuple代表 Python 类collections.namedtuple或typing.NamedTuple。Future和RRef代表 Python 类torch.futures和torch.distributed.rpc。Await代表 Python 类torch._awaits._Await
与 Python 相比
除了与 TorchScript 类型可组合外,这些 TorchScript 结构类型通常支持其 Python 对应操作符和方法的一个公共子集。
示例 1
本示例使用 typing.NamedTuple 语法定义元组:
import torch
from typing import NamedTuple
from typing import Tuple
class MyTuple(NamedTuple):
first: int
second: int
def inc(x: MyTuple) -> Tuple[int, int]:
return (x.first+1, x.second+1)
t = MyTuple(first=1, second=2)
scripted_inc = torch.jit.script(inc)
print("TorchScript:", scripted_inc(t))
上述示例生成了以下输出:
TorchScript: (2, 3)
示例 2
本示例使用 collections.namedtuple 语法来定义元组:
import torch
from typing import NamedTuple
from typing import Tuple
from collections import namedtuple
_AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('first', int), ('second', int)])
_UnannotatedNamedTuple = namedtuple('_NamedTupleAnnotated', ['first', 'second'])
def inc(x: _AnnotatedNamedTuple) -> Tuple[int, int]:
return (x.first+1, x.second+1)
m = torch.jit.script(inc)
print(inc(_UnannotatedNamedTuple(1,2)))
上述示例生成了以下输出:
(2, 3)
示例 3
本示例说明了注释结构类型时常见的错误,即没有从 typing 模块导入复合类型类:
import torch
# ERROR: Tuple not recognized because not imported from typing
@torch.jit.export
def inc(x: Tuple[int, int]):
return (x[0]+1, x[1]+1)
m = torch.jit.script(inc)
print(m((1,2)))
运行上述代码会导致以下脚本错误:
File "test-tuple.py", line 5, in <module>
def inc(x: Tuple[int, int]):
NameError: name 'Tuple' is not defined
解决方法是向代码开头添加行 from typing import Tuple 。
名义类型 ¶
名义 TorchScript 类型是 Python 类。这些类型被称为名义类型,因为它们使用自定义名称声明,并且通过类名进行比较。名义类进一步分为以下几类:
TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum
其中, TSCustomClass 和 TSEnum 必须可编译为 TorchScript 中间表示(IR)。这是由类型检查器强制执行的。
内置类 ¶
内置名义类型是 Python 类,其语义被构建到 TorchScript 系统中(例如,张量类型)。TorchScript 定义了这些内置名义类型的语义,通常只支持其 Python 类定义中方法或属性的一个子集。
TSBuiltinClass ::= TSTensor | "torch.device" | "torch.Stream" | "torch.dtype" |
"torch.nn.ModuleList" | "torch.nn.ModuleDict" | ...
TSTensor ::= "torch.Tensor" | "common.SubTensor" | "common.SubWithTorchFunction" |
"torch.nn.parameter.Parameter" | and subclasses of torch.Tensor
关于 torch.nn.ModuleList 和 torch.nn.ModuleDict 的特殊说明
虽然 torch.nn.ModuleList 和 torch.nn.ModuleDict 在 Python 中被定义为列表和字典,但在 TorchScript 中它们的行为更像元组:
在 TorchScript 中,
torch.nn.ModuleList或torch.nn.ModuleDict的实例是不可变的。代码遍历
torch.nn.ModuleList或torch.nn.ModuleDict将被完全展开,以便torch.nn.ModuleList或torch.nn.ModuleDict的元素或键可以属于torch.nn.Module的不同子类。
示例
以下示例突出了几个内置 Torchscript 类的使用( torch.* ):
import torch
@torch.jit.script
class A:
def __init__(self):
self.x = torch.rand(3)
def f(self, y: torch.device):
return self.x.to(device=y)
def g():
a = A()
return a.f(torch.device("cpu"))
script_g = torch.jit.script(g)
print(script_g.graph)
自定义类 ¶
与内置类不同,自定义类的语义由用户定义,整个类定义必须可编译为 TorchScript IR,并受 TorchScript 类型检查规则的约束。
TSClassDef ::= [ "@torch.jit.script" ]
"class" ClassName [ "(object)" ] ":"
MethodDefinition |
[ "@torch.jit.ignore" ] | [ "@torch.jit.unused" ]
MethodDefinition
Where:
类必须是新式类。Python 3 只支持新式类。在 Python 2.x 中,新式类通过从 object 继承来指定。
实例数据属性是静态类型的,并且实例属性必须在
__init__()方法内部的赋值语句中声明。不支持方法重载(即不能有多个具有相同方法名的函数)。
必须可编译为 TorchScript IR 并遵循 TorchScript 的类型检查规则(即所有方法都必须是有效的 TorchScript 函数,类属性定义也必须是有效的 TorchScript 语句)。
可以使用
torch.jit.ignore和torch.jit.unused来忽略不完全支持 torchscript 或应该被编译器忽略的方法或函数。
与 Python 相比
与其 Python 对应物相比,TorchScript 自定义类相当有限。Torchscript 自定义类:
不支持类属性。
除了接口类型或对象的子类化外,不支持子类化。
不支持方法重载。
必须在
__init__()中初始化所有实例属性;这是因为 TorchScript 通过在__init__()中推断属性类型来构建类的静态模式。必须只包含满足 TorchScript 类型检查规则且可编译为 TorchScript IR 的方法。
示例 1
如果 Python 类使用 @torch.jit.script 进行注解,则可以在 TorchScript 中使用,类似于声明 TorchScript 函数的方式:
@torch.jit.script
class MyClass:
def __init__(self, x: int):
self.x = x
def inc(self, val: int):
self.x += val
示例 2
TorchScript 自定义类类型必须通过在 __init__() 中赋值“声明”所有实例属性。如果一个实例属性在 __init__() 中没有定义但在类的其他方法中被访问,则该类无法编译为 TorchScript 类,如下例所示:
import torch
@torch.jit.script
class foo:
def __init__(self):
self.y = 1
# ERROR: self.x is not defined in __init__
def assign_x(self):
self.x = torch.rand(2, 3)
该类将无法编译并报出以下错误:
RuntimeError:
Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?:
def assign_x(self):
self.x = torch.rand(2, 3)
~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
示例 3
在本例中,一个 TorchScript 自定义类定义了一个不允许的类变量名:
import torch
@torch.jit.script
class MyClass(object):
name = "MyClass"
def __init__(self, x: int):
self.x = x
def fn(a: MyClass):
return a.name
导致以下编译时错误:
RuntimeError:
'__torch__.MyClass' object has no attribute or method 'name'. Did you forget to initialize an attribute in __init__()?:
File "test-class2.py", line 10
def fn(a: MyClass):
return a.name
~~~~~~ <--- HERE
枚举类型 ¶
与自定义类类似,枚举类型的语义由用户定义,整个类定义必须可编译为 TorchScript IR,并遵守 TorchScript 类型检查规则。
TSEnumDef ::= "class" Identifier "(enum.Enum | TSEnumType)" ":"
( MemberIdentifier "=" Value )+
( MethodDefinition )*
位置:
值必须是类型为
int、float或str的 TorchScript 文字面量,并且必须是相同的 TorchScript 类型。TSEnumType是 TorchScript 枚举类型的名称。类似于 Python 枚举,TorchScript 允许受限的Enum子类化,即,如果枚举没有定义任何成员,则允许子类化枚举。
与 Python 相比
TorchScript 仅支持
enum.Enum。它不支持其他变体,如enum.IntEnum、enum.Flag、enum.IntFlag和enum.auto。TorchScript 枚举成员的值必须是同一类型,并且只能是
int、float或str类型,而 Python 枚举成员可以是任何类型。包含方法的枚举在 TorchScript 中会被忽略。
示例 1
以下示例定义了一个类 Color 为 Enum 类型:
import torch
from enum import Enum
class Color(Enum):
RED = 1
GREEN = 2
def enum_fn(x: Color, y: Color) -> bool:
if x == Color.RED:
return True
return x == y
m = torch.jit.script(enum_fn)
print("Eager: ", enum_fn(Color.RED, Color.GREEN))
print("TorchScript: ", m(Color.RED, Color.GREEN))
示例 2
以下示例展示了受限枚举子类化的情况,其中 BaseColor 没有定义任何成员,因此可以被 Color 子类化:
import torch
from enum import Enum
class BaseColor(Enum):
def foo(self):
pass
class Color(BaseColor):
RED = 1
GREEN = 2
def enum_fn(x: Color, y: Color) -> bool:
if x == Color.RED:
return True
return x == y
m = torch.jit.script(enum_fn)
print("TorchScript: ", m(Color.RED, Color.GREEN))
print("Eager: ", enum_fn(Color.RED, Color.GREEN))
TorchScript 模块类
TSModuleType 是一种特殊类类型,它从在 TorchScript 外部创建的对象实例推断而来。 TSModuleType 由对象实例的 Python 类命名。Python 类的 __init__() 方法不被视为 TorchScript 方法,因此它不需要遵守 TorchScript 的类型检查规则。
模块实例类的类型模式直接从实例对象(在 TorchScript 作用域外创建)构建,而不是从类似于自定义类的 __init__() 推断。可能存在两个相同实例类类型的对象遵循不同的类型模式。
在这种意义上, TSModuleType 并非真正的静态类型。因此,出于类型安全考虑, TSModuleType 不能用于 TorchScript 类型注解,也不能与 TSType 组合。
模块实例类
TorchScript 模块类型表示用户定义的 PyTorch 模块实例的类型模式。在脚本化 PyTorch 模块时,模块对象始终在 TorchScript 之外创建(即作为参数传递给 forward )。Python 模块类被视为模块实例类,因此 Python 模块类的 __init__() 方法不受 TorchScript 的类型检查规则约束。
TSModuleType ::= "class" Identifier "(torch.nn.Module)" ":"
ClassBodyDefinition
Where:
forward()和其他使用@torch.jit.export装饰的方法必须可编译为 TorchScript IR,并受 TorchScript 的类型检查规则约束。
与自定义类不同,只有模块类型的前向方法和其他使用 @torch.jit.export 装饰的方法需要可编译。最值得注意的是, __init__() 不被视为 TorchScript 方法。因此,模块类型构造函数不能在 TorchScript 的作用域内调用。相反,TorchScript 模块对象总是在外部构建并传递给 torch.jit.script(ModuleObj) 。
示例 1
本例展示了模块类型的一些功能:
TestModule实例是在 TorchScript 作用域之外创建的(即在调用torch.jit.script之前)。__init__()不是一个 TorchScript 方法,因此不需要进行注解,可以包含任意 Python 代码。此外,实例类的__init__()方法不能在 TorchScript 代码中调用。因为TestModule实例是在 Python 中实例化的,在这个例子中,TestModule(2.0)和TestModule(2)创建了具有不同数据属性类型的两个实例。self.x对于TestModule(2.0)是float类型,而self.y对于TestModule(2.0)是int类型。TorchScript 会自动编译通过
@torch.jit.export或forward()方法注解的方法调用的其他方法(例如mul())。火炬脚本程序的入口点可以是模块类型的
forward(),或者被标注为torch.jit.script的函数,或者被标注为torch.jit.export的方法。
import torch
class TestModule(torch.nn.Module):
def __init__(self, v):
super().__init__()
self.x = v
def forward(self, inc: int):
return self.x + inc
m = torch.jit.script(TestModule(1))
print(f"First instance: {m(3)}")
m = torch.jit.script(TestModule(torch.ones([5])))
print(f"Second instance: {m(3)}")
上面的示例生成了以下输出:
First instance: 4
Second instance: tensor([4., 4., 4., 4., 4.])
示例 2
以下示例展示了模块类型的不正确使用。具体来说,这个示例在 TorchScript 的作用域内调用了 TestModule 的构造函数。
import torch
class TestModule(torch.nn.Module):
def __init__(self, v):
super().__init__()
self.x = v
def forward(self, x: int):
return self.x + x
class MyModel:
def __init__(self, v: int):
self.val = v
@torch.jit.export
def doSomething(self, val: int) -> int:
# error: should not invoke the constructor of module type
myModel = TestModule(self.val)
return myModel(val)
# m = torch.jit.script(MyModel(2)) # Results in below RuntimeError
# RuntimeError: Could not get name of python class object
类型注解 ¶
由于 TorchScript 是静态类型的,程序员需要在 TorchScript 代码的战略点上注解类型,以确保每个局部变量或实例数据属性都有静态类型,每个函数和方法都有静态类型的签名。
何时进行类型注解 ¶
通常情况下,只有在静态类型无法自动推断的地方才需要类型注解(例如,参数或有时方法或函数的返回类型)。局部变量和数据属性的类型通常可以从它们的赋值语句中自动推断。有时推断的类型可能过于严格,例如,通过赋值推断为 x ,而实际上使用的是 NoneType ,而 x = None 是作为 x 使用的。在这种情况下,可能需要类型注解来覆盖自动推断,例如, Optional 。请注意,即使可以自动推断类型,也可以安全地对局部变量或数据属性进行类型注解。注解的类型必须与 TorchScript 的类型检查一致。
当参数、局部变量或数据属性未进行类型注解且其类型无法自动推断时,TorchScript 会将其默认类型假设为 TensorType 、 List[TensorType] 或 Dict[str, TensorType] 。
注释函数签名
由于参数可能无法从函数体(包括函数和方法)中自动推断,因此它们需要类型注解。否则,它们会假设默认类型 TensorType 。
TorchScript 支持两种方法签名和函数签名类型注解的风格:
Python3 风格的注解直接在签名上标注类型。因此,允许单独的参数不进行注解(其类型将为默认类型
TensorType),或者允许返回类型不进行注解(其类型将自动推断)。
Python3Annotation ::= "def" Identifier [ "(" ParamAnnot* ")" ] [ReturnAnnot] ":"
FuncOrMethodBody
ParamAnnot ::= Identifier [ ":" TSType ] ","
ReturnAnnot ::= "->" TSType
注意,在使用 Python3 风格时,类型 self 将自动推断,不应进行注解。
Mypy 风格将类型注解作为注释直接标注在函数/方法声明下方。在 Mypy 风格中,由于参数名称不出现在注解中,所有参数都必须进行注解。
MyPyAnnotation ::= "# type:" "(" ParamAnnot* ")" [ ReturnAnnot ]
ParamAnnot ::= TSType ","
ReturnAnnot ::= "->" TSType
示例 1
在本例中:
a未进行标注,默认为TensorType的类型。b被标注为int类型。返回类型未进行标注,自动推断为返回值的类型
TensorType(基于返回值的类型)。
import torch
def f(a, b: int):
return a+b
m = torch.jit.script(f)
print("TorchScript:", m(torch.ones([6]), 100))
示例 2
以下示例使用 Mypy 风格的注解。请注意,即使某些参数或返回值假设默认类型,也必须对它们进行注解。
import torch
def f(a, b):
# type: (torch.Tensor, int) → torch.Tensor
return a+b
m = torch.jit.script(f)
print("TorchScript:", m(torch.ones([6]), 100))
注释变量和数据属性
通常,数据属性的类型(包括类和实例数据属性)和局部变量可以从赋值语句中自动推断。有时,如果变量或属性与不同类型的值相关联(例如,作为 None 或 TensorType ),则可能需要显式地将其注释为更宽泛的类型,如 Optional[int] 或 Any 。
本地变量 ¶
本地变量可以根据 Python3 类型模块的注解规则进行注解,即
LocalVarAnnotation ::= Identifier [":" TSType] "=" Expr
通常,本地变量的类型可以自动推断。然而,在某些情况下,您可能需要为可能关联不同具体类型的本地变量注解多类型。典型的多类型包括 Optional[T] 和 Any 。
示例
import torch
def f(a, setVal: bool):
value: Optional[torch.Tensor] = None
if setVal:
value = a
return value
ones = torch.ones([6])
m = torch.jit.script(f)
print("TorchScript:", m(ones, True), m(ones, False))
实例数据属性 ¶
对于 ModuleType 类,实例数据属性可以根据 Python3 类型注解模块的注解规则进行注解。实例数据属性可以通过 Final (可选)注解为最终属性。
"class" ClassIdentifier "(torch.nn.Module):"
InstanceAttrIdentifier ":" ["Final("] TSType [")"]
...
位置:
InstanceAttrIdentifier是一个实例属性的名称。Final表示该属性不能在__init__之外重新赋值或在子类中重写。
示例
import torch
class MyModule(torch.nn.Module):
offset_: int
def __init__(self, offset):
self.offset_ = offset
...
类型注解 API ¶
torch.jit.annotate(T, expr)¶
此 API 将类型 T 标注到表达式 expr 。这通常用于表达式的默认类型不是程序员所期望的类型时。例如,空列表(字典)的默认类型为 List[TensorType] ( Dict[TensorType, TensorType] ),但有时可能用于初始化其他类型的列表。另一个常见用例是标注 tensor.tolist() 的返回类型。请注意,它不能用于标注__init__中模块属性的类型; torch.jit.Attribute 应用于此。
示例
在此示例中, [] 通过 torch.jit.annotate 声明为整数列表(而不是假设 [] 为 List[TensorType] 的默认类型):
import torch
from typing import List
def g(l: List[int], val: int):
l.append(val)
return l
def f(val: int):
l = g(torch.jit.annotate(List[int], []), val)
return l
m = torch.jit.script(f)
print("Eager:", f(3))
print("TorchScript:", m(3))
更多信息请参阅 torch.jit.annotate() 。
类型注解附录 ¶
火炬脚本类型系统定义
TSAllType ::= TSType | TSModuleType
TSType ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType
TSMetaType ::= "Any"
TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None"
TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict | TSOptional |
TSUnion | TSFuture | TSRRef | TSAwait
TSTuple ::= "Tuple" "[" (TSType ",")* TSType "]"
TSNamedTuple ::= "namedtuple" "(" (TSType ",")* TSType ")"
TSList ::= "List" "[" TSType "]"
TSOptional ::= "Optional" "[" TSType "]"
TSUnion ::= "Union" "[" (TSType ",")* TSType "]"
TSFuture ::= "Future" "[" TSType "]"
TSRRef ::= "RRef" "[" TSType "]"
TSAwait ::= "Await" "[" TSType "]"
TSDict ::= "Dict" "[" KeyType "," TSType "]"
KeyType ::= "str" | "int" | "float" | "bool" | TensorType | "Any"
TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum
TSBuiltinClass ::= TSTensor | "torch.device" | "torch.stream"|
"torch.dtype" | "torch.nn.ModuleList" |
"torch.nn.ModuleDict" | ...
TSTensor ::= "torch.tensor" and subclasses
不支持的类型构造
火炬脚本不支持 Python3 类型模块的所有特性和类型。本文档未明确指定的类型模块功能均不支持。下表总结了在火炬脚本中不受支持或受限制支持的 typing 构造。
项目 |
描述 |
|
开发中 |
|
不支持 |
|
不支持 |
|
不支持 |
|
不支持 |
|
支持模块属性、类属性和注解,但不支持函数。 |
|
不支持 |
|
开发中 |
类型别名 |
不支持 |
命名类型 |
开发中 |
结构化类型 |
不支持 |
新类型 |
不支持 |
泛型 |
不支持 |
表达式 ¶
以下部分描述了在 TorchScript 中受支持的表达式语法。它模仿了 Python 语言参考中的表达式章节。
算术转换 §
在 TorchScript 中执行了许多隐式类型转换:
当一个具有
float或int数据类型的Tensor可以隐式转换为FloatType或IntType实例,前提是它的大小为 0,没有将require_grad设置为True,并且不会需要缩窄。StringType的实例可以隐式转换为DeviceType。上述两个要点中的隐式转换规则可以应用于
TupleType的实例,以生成包含适当类型实例的ListType。
可以使用接受原始数据类型作为参数的 float 、 int 、 bool 和 str 内置函数显式调用转换,如果用户定义的类型实现了 __bool__ 、 __str__ 等,则也可以接受。
原子
原子是表达式的最基本元素。
atom ::= identifier | literal | enclosure
enclosure ::= parenth_form | list_display | dict_display
标识符 ¶
规定 TorchScript 中合法标识符的规则与它们的 Python 对应物相同。
文字常量 ¶
literal ::= stringliteral | integer | floatnumber
文字常量的评估会产生具有特定值的适当类型的对象(对于浮点数,根据需要应用近似值)。文字常量是不可变的,对相同文字常量的多次评估可能获得相同的对象或具有相同值的独立对象。字符串文字、整数和浮点数字的表示方式与它们的 Python 对应物相同。
括号形式 ¶
parenth_form ::= '(' [expression_list] ')'
括号表达式列表返回表达式列表的结果。如果列表中至少有一个逗号,则返回 Tuple ;否则,返回表达式列表内的单个表达式。一对空括号返回一个空的 Tuple 对象( Tuple[] )。
列表和字典显示 ¶
list_comprehension ::= expression comp_for
comp_for ::= 'for' target_list 'in' or_expr
list_display ::= '[' [expression_list | list_comprehension] ']'
dict_display ::= '{' [key_datum_list | dict_comprehension] '}'
key_datum_list ::= key_datum (',' key_datum)*
key_datum ::= expression ':' expression
dict_comprehension ::= key_datum comp_for
列表和字典可以通过明确列出容器内容或通过提供一组循环指令来计算它们的方式(即推导式)来构建。推导式在语义上等同于使用 for 循环并将元素追加到正在进行的列表中。推导式隐式地创建自己的作用域,以确保目标列表的项不会泄漏到外部作用域。如果容器项明确列出,则表达式列表中的表达式按从左到右的顺序评估。如果在具有 key_datum_list 的 dict_display 中重复键,则结果字典使用使用重复键的列表中右侧数据点的值。
基础
primary ::= atom | attributeref | subscription | slicing | call
属性引用
attributeref ::= primary '.' identifier
primary 必须评估为一个支持具有名为 identifier 的属性引用的类型对象。
订阅
subscription ::= primary '[' expression_list ']'
primary 必须评估为一个支持订阅的对象。
如果主操作数是
List、Tuple或str,表达式列表必须评估为整数或切片。如果主操作数是
Dict,表达式列表必须评估为与Dict的键类型相同的对象。如果主操作数是
ModuleList,表达式列表必须是一个integer文本字面量。如果主操作数是
ModuleDict,则表达式必须是stringliteral。
切片操作
切片操作从 str 、 Tuple 、 List 或 Tensor 中选择一系列项。切片可以作为赋值或 del 语句中的表达式或目标使用。
slicing ::= primary '[' slice_list ']'
slice_list ::= slice_item (',' slice_item)* [',']
slice_item ::= expression | proper_slice
proper_slice ::= [expression] ':' [expression] [':' [expression] ]
包含多个切片项的切片只能与评估为 Tensor 类型对象的初值一起使用。
呼叫
call ::= primary '(' argument_list ')'
argument_list ::= args [',' kwargs] | kwargs
args ::= [arg (',' arg)*]
kwargs ::= [kwarg (',' kwarg)*]
kwarg ::= arg '=' expression
arg ::= identifier
primary 必须解糖或评估为可调用对象。所有参数表达式在尝试调用之前都会被评估。
幂运算符
power ::= primary ['**' u_expr]
幂运算符与内置的 pow 函数(不支持)具有相同的语义;它计算其左操作数乘以其右操作数的幂。它比左边的单目运算符绑定得更紧密,但比右边的单目运算符绑定得松;即 -2 ** -3 == -(2 ** (-3)) 。左操作数和右操作数可以是 int , float 或 Tensor 。在标量与张量/张量与标量幂运算的情况下,标量会被广播,而张量与张量幂运算则是逐元素进行的,不进行广播。
一元和算术位运算 ¶
u_expr ::= power | '-' power | '~' power
一元 - 运算符返回其参数的否定。一元 ~ 运算符返回其参数的位反转。 - 可以与 int 和 float 的 int 、 float 和 Tensor 使用。 ~ 只能与 int 的 int 和 Tensor 使用。
二元算术运算 ¶
m_expr ::= u_expr | m_expr '*' u_expr | m_expr '@' m_expr | m_expr '//' u_expr | m_expr '/' u_expr | m_expr '%' u_expr
a_expr ::= m_expr | a_expr '+' m_expr | a_expr '-' m_expr
二元算术运算符可以作用于 Tensor 、 int 和 float 。对于张量-张量运算,两个参数必须具有相同的形状。对于标量-张量或张量-标量运算,标量通常广播到张量的尺寸。除法运算只能接受标量作为其右侧参数,不支持广播。 @ 运算符用于矩阵乘法,并且只作用于 Tensor 参数。乘法运算符( * )可以用列表和整数一起使用,以得到重复一定次数的原始列表的结果。
调整操作 ¶
shift_expr ::= a_expr | shift_expr ( '<<' | '>>' ) a_expr
这些运算符接受两个 int 参数,两个 Tensor 参数,或者一个 Tensor 参数和一个 int 或 float 参数。在所有情况下,右移 n 定义为向下取整除以 pow(2, n) ,而左移 n 定义为乘以 pow(2, n) 。当两个参数都是 Tensors 时,它们必须具有相同的形状。当一个参数是标量,另一个是 Tensor 时,标量将逻辑上广播以匹配 Tensor 的大小。
二进制位运算 ¶
and_expr ::= shift_expr | and_expr '&' shift_expr
xor_expr ::= and_expr | xor_expr '^' and_expr
or_expr ::= xor_expr | or_expr '|' xor_expr
& 运算符计算其参数的按位与, ^ 按位异或, | 按位或。两个操作数必须是 int 或 Tensor ,或者左操作数必须是 Tensor ,而右操作数必须是 int 。当两个操作数都是 Tensor 时,它们必须具有相同的形状。当右操作数是 int ,而左操作数是 Tensor 时,右操作数将逻辑上广播以匹配 Tensor 的形状。
比较运算符
comparison ::= or_expr (comp_operator or_expr)*
comp_operator ::= '<' | '>' | '==' | '>=' | '<=' | '!=' | 'is' ['not'] | ['not'] 'in'
比较运算符返回一个布尔值( True 或 False ),如果其中一个操作数是 Tensor ,则返回一个布尔 Tensor 。只要比较运算符不返回包含多个元素的布尔 Tensors ,就可以任意地链式调用比较运算符。 a op1 b op2 c ... 等价于 a op1 b and b op2 c and ... 。
值比较
运算符 < 、 > 、 == 、 >= 、 <= 和 != 比较两个对象的价值。两个对象通常需要是同一类型,除非在对象之间存在隐式类型转换。如果用户定义的类型上定义了丰富的比较方法(例如, __lt__ ),则可以比较用户定义的类型。内置类型比较与 Python 的工作方式类似:
数字是按数学方式比较的。
字符串是按字典顺序比较的。
lists、tuples和dicts只能与其他相同类型的lists、tuples和dicts比较,并且使用对应元素的比较运算符进行比较。
成员资格测试操作 §
操作符 in 和 not in 用于测试成员资格。 x in s 如果 x 是 s 的成员,则评估为 True ,否则为 False 。 x not in s 等价于 not x in s 。此操作符支持 lists 、 dicts 和 tuples ,并且如果用户定义的类型实现了 __contains__ 方法,也可以使用。
身份比较
对于除 int 、 double 、 bool 和 torch.device 之外的所有类型,操作符 is 和 is not 测试对象的身份; x is y 是 True 当且仅当 x 和 y 是相同的对象。对于其他所有类型, is 等价于使用 == 比较它们。 x is not y 产生 x is y 的逆。
布尔运算
or_test ::= and_test | or_test 'or' and_test
and_test ::= not_test | and_test 'and' not_test
not_test ::= 'bool' '(' or_expr ')' | comparison | 'not' not_test
用户定义的对象可以通过实现一个 __bool__ 方法来自定义其转换为 bool 。操作符 not 如果其操作数为假,则产生 True ,否则产生 False 。表达式 x 和 y 首先评估 x ;如果是 False ,则返回其值( False );否则,评估 y 并返回其值( False 或 True )。表达式 x 或 y 首先评估 x ;如果是 True ,则返回其值( True );否则,评估 y 并返回其值( False 或 True )。
条件表达式
conditional_expression ::= or_expr ['if' or_test 'else' conditional_expression]
expression ::= conditional_expression
表达式 x if c else y 首先评估条件 c ,而不是 x。如果 c 是 True ,则评估 x 并返回其值;否则,评估 y 并返回其值。与 if 语句一样, x 和 y 必须评估为相同类型的值。
表达式列表
expression_list ::= expression (',' expression)* [',']
starred_item ::= '*' primary
星号项只能出现在赋值语句的左侧,例如 a, *b, c = ... 。
简单语句
以下部分描述了在 TorchScript 中支持的简单语句的语法。它模仿了 Python 语言参考中的简单语句章节。
表达式语句 §
expression_stmt ::= starred_expression
starred_expression ::= expression | (starred_item ",")* [starred_item]
starred_item ::= assignment_expression | "*" or_expr
赋值语句 §
assignment_stmt ::= (target_list "=")+ (starred_expression)
target_list ::= target ("," target)* [","]
target ::= identifier
| "(" [target_list] ")"
| "[" [target_list] "]"
| attributeref
| subscription
| slicing
| "*" target
增量赋值语句 §
augmented_assignment_stmt ::= augtarget augop (expression_list)
augtarget ::= identifier | attributeref | subscription
augop ::= "+=" | "-=" | "*=" | "/=" | "//=" | "%=" |
"**="| ">>=" | "<<=" | "&=" | "^=" | "|="
注解的赋值语句 ¶
annotated_assignment_stmt ::= augtarget ":" expression
["=" (starred_expression)]
raise 语句 ¶
raise_stmt ::= "raise" [expression ["from" expression]]
在 TorchScript 中,抛出语句不支持 try\except\finally 。
assert 语句 ¶
assert_stmt ::= "assert" expression ["," expression]
TorchScript 中的断言语句不支持 try\except\finally 。
return 语句 ¶
return_stmt ::= "return" [expression_list]
TorchScript 中的返回语句不支持 try\except\finally 。
del 语句 ¶
del_stmt ::= "del" target_list
第 0#声明句
pass_stmt ::= "pass"
第 0#声明句
print_stmt ::= "print" "(" expression [, expression] [.format{expression_list}] ")"
第 0#声明句
break_stmt ::= "break"
第 0#声明句:
continue_stmt ::= "continue"
复合语句 §
以下部分描述了 TorchScript 支持的复合语句的语法。本节还突出了 Torchscript 与常规 Python 语句的区别。它模仿了 Python 语言参考中的复合语句章节。
if 语句 §
Torchscript 支持基本的 if/else 和三元 if/else 。
基本语句①
if_stmt ::= "if" assignment_expression ":" suite
("elif" assignment_expression ":" suite)
["else" ":" suite]
语句①可以重复任意次数,但必须在语句①之前。
三元语句①
if_stmt ::= return [expression_list] "if" assignment_expression "else" [expression_list]
示例 1
一个具有 1 维的 tensor 被提升为 bool :
import torch
@torch.jit.script
def fn(x: torch.Tensor):
if x: # The tensor gets promoted to bool
return True
return False
print(fn(torch.rand(1)))
上面的例子产生了以下输出:
True
示例 2
具有多维的 tensor 不会被提升为 bool :
import torch
# Multi dimensional Tensors error out.
@torch.jit.script
def fn():
if torch.rand(2):
print("Tensor is available")
if torch.rand(4,5,6):
print("Tensor is available")
print(fn())
运行上述代码将得到以下结果 RuntimeError 。
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
@torch.jit.script
def fn():
if torch.rand(2):
~~~~~~~~~~~~ <--- HERE
print("Tensor is available")
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
如果一个条件变量被标注为 final ,则根据条件变量的评估结果,将评估真或假的分支。
示例 3
在本例中,仅评估 True 分支,因为 a 被标注为 final 并设置为 True :
import torch
a : torch.jit.final[Bool] = True
if a:
return torch.empty(2,3)
else:
return []
while 声明 ¶
while_stmt ::= "while" assignment_expression ":" suite
在 Torchscript 中不支持 while…else 语句。这会导致出现一个 RuntimeError 。
for-in 声明 ¶
for_stmt ::= "for" target_list "in" expression_list ":" suite
["else" ":" suite]
在 Torchscript 中不支持 for...else 语句。这会导致出现一个 RuntimeError 。
示例 1
对于元组中的循环:这些循环会展开,为元组的每个成员生成一个体。该体必须对每个成员进行类型检查。
import torch
from typing import Tuple
@torch.jit.script
def fn():
tup = (3, torch.ones(4))
for x in tup:
print(x)
fn()
上面的示例将产生以下输出:
3
1
1
1
1
[ CPUFloatType{4} ]
示例 2
列表中的循环:对于 nn.ModuleList 的循环会在编译时展开循环体,每个模块列表的成员都会执行一次。
class SubModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.randn(2))
def forward(self, input):
return self.weight + input
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])
def forward(self, v):
for module in self.mods:
v = module(v)
return v
model = torch.jit.script(MyModule())
with 语句
with 语句用于将代码块执行封装在由上下文管理器定义的方法中。
with_stmt ::= "with" with_item ("," with_item) ":" suite
with_item ::= expression ["as" target]
如果在
with语句中包含了目标,上下文管理器的__enter__()的返回值将被分配给它。与 Python 不同,如果异常导致代码块退出,其类型、值和跟踪信息不会被作为参数传递给__exit__()。提供了三个None参数。try、except和finally语句不支持在with块内部。在
with块内抛出的异常无法被抑制。
tuple 语句 ¶
tuple_stmt ::= tuple([iterables])
TorchScript 中的可迭代类型包括
Tensors、lists、tuples、dictionaries、strings、torch.nn.ModuleList和torch.nn.ModuleDict。您无法通过这个内置函数将列表转换为元组。
将所有输出解包到元组中,由以下内容涵盖:
abc = func() # Function that returns a tuple
a,b = func()
getattr 语句 ¶
getattr_stmt ::= getattr(object, name[, default])
属性名称必须是字面字符串。
支持模块类型对象(例如,torch._C)。
不支持自定义类对象(例如,torch.classes.*)。
hasattr 声明 ¶
hasattr_stmt ::= hasattr(object, name)
属性名称必须是字面字符串。
支持模块类型对象(例如,torch._C)。
不支持自定义类对象(例如,torch.classes.*)。
zip 声明
zip_stmt ::= zip(iterable1, iterable2)
参数必须是可迭代对象。
支持相同外部容器类型但长度不同的两个可迭代对象。
示例 1
两个可迭代对象必须是同一容器类型:
a = [1, 2] # List
b = [2, 3, 4] # List
zip(a, b) # works
示例 2
这个例子失败,因为可迭代对象是不同容器类型:
a = (1, 2) # Tuple
b = [2, 3, 4] # List
zip(a, b) # Runtime error
运行上述代码将产生以下 RuntimeError 。
RuntimeError: Can not iterate over a module list or
tuple with a value that does not have a statically determinable length.
示例 3
支持相同容器类型但数据类型不同的两个可迭代对象:
a = [1.3, 2.4]
b = [2, 3, 4]
zip(a, b) # Works
TorchScript 中的可迭代类型包括 Tensors 、 lists 、 tuples 、 dictionaries 、 strings 、 torch.nn.ModuleList 和 torch.nn.ModuleDict 。
enumerate 语句 ¶
enumerate_stmt ::= enumerate([iterable])
参数必须是可迭代的。
TorchScript 中的可迭代类型包括
Tensors、lists、tuples、dictionaries、strings、torch.nn.ModuleList和torch.nn.ModuleDict。
Python 值 ¶
解析规则 ¶
当给定一个 Python 值时,TorchScript 会尝试以下五种方式来解析它:
- 可编译的 Python 实现:
当一个 Python 值由一个可以被 TorchScript 编译的 Python 实现支持时,TorchScript 会编译并使用其底层的 Python 实现。
示例:
torch.jit.Attribute
- Op Python 包装器:
当一个 Python 值是原生 PyTorch 操作的包装器时,TorchScript 会发出相应的操作符。
示例:
torch.jit._logging.add_stat_value
- Python 对象身份匹配:
对于 TorchScript 支持的有限集的
torch.*API 调用(以 Python 值的形式),TorchScript 会尝试将 Python 值与集合中的每个项目进行匹配。匹配成功时,TorchScript 会生成一个相应的
SugaredValue实例,该实例包含这些值的降级逻辑。示例:
torch.jit.isinstance()
- 名称匹配:
对于 Python 内置函数和常量,TorchScript 通过名称识别它们,并创建相应的
SugaredValue实例以实现其功能。示例:
all()
- 值快照:
对于来自未识别模块的 Python 值,TorchScript 会尝试对该值进行快照并将其转换为图中的常量,该图是正在编译的函数或方法。
示例:
math.pi
Python 内置函数支持 ¶
内置函数 |
支持级别 |
笔记 |
|---|---|---|
|
部分支持 |
仅支持 |
|
完整 |
|
|
完整 |
|
|
无 |
|
|
部分支持 |
仅支持 |
|
部分支持 |
仅支持 |
|
无 |
|
|
无 |
|
|
无 |
|
|
无 |
|
|
部分翻译 |
仅支持 ASCII 字符集。 |
|
完整翻译 |
|
|
无 |
|
|
无 |
|
|
无 |
|
|
全 |
|
|
无 |
|
|
全部 |
|
|
全部 |
|
|
无 |
|
|
None |
|
|
None |
|
|
部分翻译 |
不遵守 |
|
部分翻译 |
不支持手动索引指定。|不支持格式类型修饰符。 |
|
无 |
|
|
部分翻译 |
属性名称必须是字符串字面量。 |
|
无 |
|
|
部分内容 |
属性名称必须是字符串字面量。 |
|
全部 |
|
|
部分支持 |
仅支持 |
|
全部 |
仅支持 |
|
无 |
|
|
部分翻译 |
不支持 |
|
完整翻译 |
提供与容器类型等检查时的更好支持。 |
|
None |
|
|
None |
|
|
全部 |
|
|
全面 |
|
|
部分翻译 |
仅支持 ASCII 字符集。 |
|
完整翻译 |
|
|
部分翻译 |
|
|
None |
|
|
完整 |
|
|
None |
|
|
无 |
|
|
部分翻译 |
不支持该参数。 |
|
无 |
|
|
无 |
|
|
全部 |
|
|
部分翻译 |
不支持该参数。 |
|
全部 |
|
|
部分支持 |
不支持 |
|
全部 |
|
|
部分内容 |
只能在 |
|
无 |
|
|
无 |
|
|
全部 |
|
|
无 |
Python 内置值支持 ¶
内置值 |
支持级别 |
笔记 |
|---|---|---|
|
完全 |
|
|
完全 |
|
|
完整 |
|
|
无 |
|
|
完整 |
torch.* API ¶
远程过程调用
TorchScript 支持 RPC API 的子集,该 API 支持在指定的远程工作者上运行函数,而不是本地运行。
具体来说,以下 API 完全支持:
torch.distributed.rpc.rpc_sync()rpc_sync()对远程工作者运行函数的阻塞式 RPC 调用。RPC 消息在 Python 代码执行并行发送和接收。更多关于其使用方法和示例的详细信息可以在
rpc_sync()中找到。
torch.distributed.rpc.rpc_async()rpc_async()向远程工作进程执行非阻塞 RPC 调用以运行函数。RPC 消息在 Python 代码执行并行发送和接收。更多关于其使用方法和示例的详细信息可以在
rpc_async()中找到。
torch.distributed.rpc.remote()remote.()在工作进程中执行远程调用并返回一个远程引用RRef。更多关于其用法和示例的详细信息可以在
remote()中找到。
异步执行 ¶
TorchScript 允许您创建异步计算任务,以更好地利用计算资源。这是通过支持一系列仅在 TorchScript 中可用的 API 来实现的:
torch.jit.fork()创建一个执行 func 并引用此执行结果的值的异步任务。Fork 将立即返回。
相当于
torch.jit._fork(),仅为了向后兼容而保留。更多关于其用法和示例的详细信息可以在
fork()中找到。
torch.jit.wait()强制完成一个
torch.jit.Future[T]异步任务,返回任务的结果。相当于
torch.jit._wait(),仅为了向后兼容而保留。更多关于其用法和示例的详细信息可以在
wait()中找到。
类型注解 ¶
TorchScript 是静态类型的。它提供并支持一组工具来帮助注解变量和属性:
torch.jit.annotate()为 TorchScript 提供类型提示,在 Python 3 风格的类型提示不适用的情况下。
一个常见的例子是为类似
[]的表达式标注类型。[]默认被视为List[torch.Tensor]。当需要不同的类型时,可以使用此代码来提示 TorchScript:torch.jit.annotate(List[int], [])。更多详细信息可以在
annotate()中找到。
torch.jit.Attribute常见用例包括为
torch.nn.Module属性提供类型提示。因为它们的__init__方法不会被 TorchScript 解析,所以在模块的__init__方法中应使用torch.jit.Attribute代替torch.jit.annotate。更多详细信息可以在
Attribute()中找到。
torch.jit.FinalPython 的别名。
torch.jit.Final仅保留为向后兼容。
元编程 ¶
TorchScript 提供了一套元编程的实用工具:
torch.jit.is_scripting()返回一个布尔值,指示当前程序是否由
torch.jit.script编译。当用于
assert或if语句时,torch.jit.is_scripting()评估为False的作用域或分支不会被编译。其值可以在编译时静态评估,因此通常用于
if语句以阻止 TorchScript 编译其中一个分支。更多详细信息及示例可在
is_scripting()查找
torch.jit.is_tracing()返回一个布尔值,指示当前程序是否被
torch.jit.trace/torch.jit.trace_module跟踪。更多详细信息请参阅
is_tracing()
@torch.jit.ignore此装饰器指示编译器忽略函数或方法,并将其保留为 Python 函数。
这允许您在模型中留下尚未与 TorchScript 兼容的代码。
如果从 TorchScript 调用由
@torch.jit.ignore装饰的函数,则忽略的函数将调用 Python 解释器来处理调用。无法导出具有忽略函数的模型。
更多信息和示例可以在
ignore()中找到。
@torch.jit.unused此装饰器指示编译器忽略函数或方法,并用抛出异常来替换。
这允许您在模型中留下尚未与 TorchScript 兼容的代码,同时仍然导出模型。
如果从 TorchScript 调用由
@torch.jit.unused装饰的函数,将引发运行时错误。更多信息和示例可以在
unused()中找到。
类型细化 ¶
torch.jit.isinstance()返回一个布尔值,表示变量是否为指定的类型。
更多关于其用法和示例的详细信息可以在
isinstance()中找到。