torch.jit._async 的源代码
# mypy: 允许未类型化定义
异步 API。
此模块包含 TorchScript 中的并行 API,特别是:
* torch.jit.fork
* torch.jit.wait
这不是直接导入的意图;请使用 `torch.jit` 中公开的功能。
请勿直接导入,请使用 `torch.jit` 中提供的功能。
""
导入
火炬
from torch._jit_internal 导入
未来
from torch.jit._builtins 导入 _register_builtin
from torch.utils 导入
设置模块
设置模块(
未来,
torch.jit)
[文档]def
分叉(
函数, *
参数, **kwargs):
r""
创建一个异步任务执行 `func` 并返回执行结果的引用。
`fork` 将立即返回,因此 `func` 的返回值可能尚未计算完成。为了强制完成,请...
任务完成后访问返回值,在 Future 上调用`torch.jit.wait`。调用`fork`
使用返回 `T` 的 `func` 被类型化为 `torch.jit.Future[T]`。`fork` 调用可以是任意数量的。
嵌套的,并且可以用位置参数和关键字参数调用。
异步执行仅在以 TorchScript 运行时才会发生。如果以纯 Python 运行,
`fork` 不会并行执行。当调用时,`fork` 也不会并行执行。
而在跟踪时,然而 `fork` 和 `wait` 调用将被捕获在导出的 IR 图中。
.. 警告::
`fork` 任务将非确定性地执行。我们建议仅为主函数(纯函数)创建并行 `fork` 任务。
仅对不修改其输入的纯函数执行并行 `fork` 任务。
模块属性或全局状态。
参数:
func (可调用或 torch.nn.Module):一个 Python 函数或`torch.nn.Module`
将被调用的。如果在 TorchScript 中执行,它将异步执行,否则不会。fork 的跟踪调用将被捕获在 IR 中。
否则它将不会。Traced invocations of fork will be captured in the IR.
``*args``,``**kwargs``:调用 `func` 时的参数。
返回:
`torch.jit.Future[T]`:对 `func` 执行的引用。值 `T`
只能通过通过 `torch.jit.wait` 强制完成 `func` 来访问。
示例(分叉一个自由函数):
.. 代码块 :: python
导入 torch
从 torch 导入 Tensor
定义 foo(a: Tensor, b: int) -> Tensor
返回 a + b
定义 bar(a):
fut: torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2)
return torch.jit.wait(fut)
script_bar = torch.jit.script(bar)
input = torch.tensor(2)
# 仅脚本版本异步执行
assert script_bar(input) == bar(input)
# 跟踪不是异步运行,但分叉在 IR 中被捕获
graph = torch.jit.trace(bar, (input,)).graph
assert "fork" in str(graph)
示例(分叉模块方法):
.. 代码块 :: python
导入 torch
from torch import Tensor
class AddMod(torch.nn.Module):
def forward(self, a: Tensor, b: int):
return a + b
class Mod(torch.nn.Module):
def __init__(self) -> None:
super(self).__init__()
self.mod = AddMod()
def forward(self, input):
fut = torch.jit.fork(self.mod, a, b=2)
return torch.jit.wait(fut)
input = torch.tensor(2)
mod = Mod()
assert mod(input) == torch.jit.script(mod).forward(input)
"文档"
返回
火炬._C.
分叉(
函数, *
参数, **kwargs)
[文档]def wait(future):
r"""
强制完成一个 `torch.jit.Future[T]` 异步任务,返回任务的结果。
请参阅 :func:`~fork` 以获取文档和示例。
参数:
future (torch.jit.Future[T]): 通过 `torch.jit.fork` 创建的异步任务引用。
返回值:
`T`:完成任务的返回值
"""
return torch._C.wait(future)
注册内置函数(
等待,
aten::等待)