• 文档 >
  • 模块代码 >
  • torch >
  • torch.jit
快捷键

torch.jit 的源代码

# mypy: 允许未类型化定义
导入 警告
from collections.abc 导入 迭代器
from contextlib 导入 contextmanager
from 打字 导入 任何

导入 torch._C

# 这些导入是为了用户可以从`torch.jit`模块访问它们
from torch._jit_internal 导入 (
    _Await,
    _drop,
    _IgnoreContextManager,
    isinstance,
    超载_,
    重载方法,
    导出,
    最终,
    未来,
    忽略,
    是否正在脚本化,
    未使用,
)
from torch.jit._async 导入 分叉, 等待
from torch.jit._await 导入 _awaitable, _awaitable_nowait, 可等待的等待
from torch.jit._decomposition_utils 导入 _register_decomposition
from torch.jit._freeze 导入 冻结, 优化推理, 运行冻结优化
from torch.jit._fuser 导入 (
    燃烧器,
    最后执行优化的图,
    优化执行,
    设置融合策略,
)
from torch.jit._ir_utils 导入 _InsertPoint
from torch.jit._script 导入 (
    _ScriptProfile,
    解包可选,
    属性,
    编译单元,
    接口,
    递归脚本类,
    递归脚本模块,
    脚本,
    脚本方法,
    脚本函数,
    脚本模块,
    脚本警告,
)
from torch.jit._serialization 导入 (
    从 flatbuffer 生成 jit 模块,
    加载,
    保存,
    将 jit 模块保存到 flatbuffer,
)
from torch.jit._trace 导入 (
    _flatten,
    _get_trace_graph,
    _script_if_tracing,
    唯一状态字典,
    is_tracing,
    ONNX 跟踪模块,
    顶级跟踪模块,
    跟踪,
    跟踪模块,
    跟踪模块,
    跟踪警告,
    追踪检查错误,
)
from torch.utils 导入 设置模块


全部 = [
    属性,
    编译单元,
    "错误",
    未来,
    "脚本函数",
    "脚本模块",
    "标注",
    启用 ONEDNN 融合,
    "导出",
    导出操作名称,
    分叉,
    冻结,
    "界面",
    "忽略",
    "isinstance",
    加载,
    "onednn 融合启用",
    "优化推理",
    保存,
    "脚本",
    "脚本如果跟踪",
    "设置融合策略",
    "严格融合",
    "跟踪",
    trace_module,
    未使用,
    等待,
]

# 为向后兼容
分支 = 分支
等待 = 等待
设置融合策略 = 设置融合策略


定义 导出操作名称(m):
    r""
为脚本模块生成新的字节码。

返回基于当前代码库的脚本模块的 op 列表。

如果你有 LiteScriptModule 并且想要获取当前存在的
列表,请调用_export_operator_list。
    """
    返回 火炬._C._export_opnames(m._c)


torch.jit.Error
错误 = 火炬._C.JIT 异常
设置模块(错误, torch.jit)
# 这并不完美,但在常见情况下是可行的
错误.__name__ = "错误"
错误.__qualname__ = "错误"


# 用于 Python 中的注释
[文档]def annotate(the_type, the_value): """用于在 TorchScript 编译器中指定`the_value`的类型。 这种方法是一个透传函数,用于返回 `the_value`,用于向 TorchScript 编译器提示`the_value`的类型。当在 TorchScript 外部运行时,它是一个空操作(no-op)。 虽然 TorchScript 可以推断大多数 Python 表达式的正确类型,但有些情况下类型推断可能会出错,包括: 尽管 TorchScript 可以推断出大多数 Python 表达式的正确类型,但有些情况下类型推断可能会出错,包括: 类型推断可能会出错,包括: 空容器如 `[]` 和 `{}`,TorchScript 会假设它们是 `Tensor` 的容器 可选类型如 `Optional[T]`,但被赋予了有效的 `T` 类型值,TorchScript 会假设它是 `T` 类型而不是 `Optional[T]` 注意,`annotate()` 方法在 `torch.nn.Module` 子类的 `__init__` 方法中不起作用 因为 `annotate()` 方法不会在 `torch.nn.Module` 子类的 `__init__` 方法中提供帮助 在急切模式下执行。要注释 `torch.nn.Module` 属性的类型, 请使用 :meth:`~torch.jit.Attribute`。 示例: .. testcode:: import torch from typing import Dict @torch.jit.script def fn(): 告诉 TorchScript 这个空字典是一个(str -> int)字典 # 默认字典类型为(str -> Tensor)。 d = torch.jit.annotate(Dict[str, int], {}) # 如果没有上面的 torch.jit.annotate,下面的语句会因为 类型不匹配。 d["名称"] = 20 .. 清理测试:: 删除 fn Args: 类型:传递给 TorchScript 编译器的 Python 类型,作为`the_value`的类型提示 the_value:用于提示类型的值或表达式。 Returns: 返回值作为返回值传递。 """ 返回 the_value
[文档]def 跟踪脚本(fn): """ 编译 `fn` 时,在追踪期间首次调用。 `torch.jit.script` 在首次调用时由于许多编译器内置函数的懒加载,会有一个不可忽视的启动时间。因此,您不应该使用 懒加载的许多编译器内置函数。因此您不应该使用 在库代码中,然而,你可能希望库的部分功能在跟踪时也能工作。 即使它们使用控制流,在跟踪时也应使用。 使用 `@torch.jit.script_if_tracing` 来替代 `torch.jit.script`。 替换为 `torch.jit.script`。 Args: fn: 一个用于编译的函数。 Returns: 如果在追踪期间调用,则返回由 `torch.jit.script` 创建的 :class:`ScriptFunction`。 否则,返回原始函数 `fn`。 """ 返回 _script_if_tracing(fn)
# for torch.jit.isinstance
[文档]def isinstance(obj, target_type): """ 在 TorchScript 中提供容器类型细化。 它可以细化 List、Dict、Tuple 和 Optional 类型的参数化容器。例如 ``List[str]``, `Dict[str, List[torch.Tensor]]`, `Optional[Tuple[int,str,int]]`。它还可以 精炼基本类型,如 bools 和 ints,这些类型在 TorchScript 中可用。 Args: obj: 要精炼类型的对象 尝试将对象精炼到目标类型的目标类型 返回: ``bool``: 如果对象成功精炼到 target_type 类型,则为 True, 否则返回 False,没有新的类型精炼 使用 `torch.jit.isinstance` 进行类型细化示例: .. testcode:: 导入 torch 从 typing 模块导入 Any, Dict, List class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, input: Any): # note the Any type if torch.jit.isinstance(input, List[torch.Tensor]): for t in input: y = t.clamp(0, 0.5) elif torch.jit.isinstance(input, Dict[str, str]): for val in input.values(): print(val) m = torch.jit.script(MyModule()) x = [torch.rand(3,3), torch.rand(4,3)] m(x) y = {"key1":"val1","key2":"val2"} m(y) """ return _isinstance(obj, target_type)
[文档]class strict_fusion: """ 如果在推理中未对所有节点进行融合,或在训练中进行符号微分,则给出错误。 示例: 强制融合添加项。 .. 代码块 :: python @torch.jit.script def foo(x): with torch.jit.strict_fusion(): return x + x + x """ def __init__(self) -> None: 如果不是 torch._jit_internal.is_scripting(): 警告:仅在脚本模式下有效 def __enter__(self): pass def __exit__(self, type: Any, value: Any, tb: Any) -> None: pass
# 全局隐藏打印图形时源范围的上下文管理器。 # 注意,这些函数作为静态成员暴露给 Python。 # 图类,因此需要跳过 mypy 检查。 @contextmanager 定义
_隐藏源范围() -> 迭代器[] 旧版启用源范围 = 火炬._C..全局打印源范围 # 类型:忽略[已定义] 尝试: 火炬._C..设置全局打印源范围(False) # 类型:忽略[已定义] 产生 最后: 火炬._C..设置全局打印源范围(旧启用源范围) # 类型:忽略[已定义]
[文档]def 启用 onednn 融合(启用: bool): 根据参数 `enabled` 启用或禁用 onednn JIT 融合。 torch._C._jit_set_llga_enabled(enabled)
[文档]def onednn_fusion_enabled(): 返回 onednn JIT 融合是否启用。 return torch._C._jit_llga_enabled()
删除
任何 如果 not 火炬._C._jit_init(): 抛出 RuntimeError("JIT 初始化失败")

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源,获取您的疑问解答

查看资源