快捷键

torch.jit.fork

torch.jit.fork(func, *args, **kwargs)[source][source]

创建一个异步任务,执行 func 并引用此执行的结果值。

fork 将立即返回,因此 func 的返回值可能尚未计算。要强制完成任务并访问返回值,请在 Future 上调用 torch.jit.wait。使用返回类型为 T 的 func 调用的 fork 被类型化为 torch.jit.Future[T]。fork 调用可以是任意嵌套的,并且可以使用位置参数和关键字参数调用。异步执行仅在运行在 TorchScript 中时才会发生。如果在纯 Python 中运行,fork 不会并行执行。当在跟踪时调用,fork 也不会并行执行,但是 fork 和 wait 调用将被捕获在导出的 IR 图中。

警告

任务将非确定性地执行。我们建议只为不修改其输入、模块属性或全局状态的纯函数并行创建 fork 任务。

参数:
  • func (可调用或 torch.nn.Module) – 要调用的 Python 函数或 torch.nn.Module。如果在 TorchScript 中执行,它将异步执行,否则不会。fork 的跟踪调用将被捕获在 IR 中。

  • *args – 调用 func 时使用的参数。

  • **kwargs – 调用 func 时使用的参数。

返回值:

对 func 执行引用。值 T 只能通过强制 func 通过 torch.jit.wait 完成来访问。

返回类型:

torch.jit.Future[T]

示例(分叉一个自由函数):

import torch
from torch import Tensor


def foo(a: Tensor, b: int) -> Tensor:
    return a + b


def bar(a):
    fut: torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2)
    return torch.jit.wait(fut)


script_bar = torch.jit.script(bar)
input = torch.tensor(2)
# only the scripted version executes asynchronously
assert script_bar(input) == bar(input)
# trace is not run asynchronously, but fork is captured in IR
graph = torch.jit.trace(bar, (input,)).graph
assert "fork" in str(graph)

示例(分叉一个模块方法):

import torch
from torch import Tensor


class AddMod(torch.nn.Module):
    def forward(self, a: Tensor, b: int):
        return a + b


class Mod(torch.nn.Module):
    def __init__(self) -> None:
        super(self).__init__()
        self.mod = AddMod()

    def forward(self, input):
        fut = torch.jit.fork(self.mod, a, b=2)
        return torch.jit.wait(fut)


input = torch.tensor(2)
mod = Mod()
assert mod(input) == torch.jit.script(mod).forward(input)

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源