torch.futures¶
此包提供了一个 Future
类型,用于封装异步执行和一组简化 Future
对象操作的实用函数。目前,该 Future
类型主要用于分布式 RPC 框架。
- class torch.futures.Future(*, devices=None)¶
包裹一个异步执行可调用的
torch._C.Future
,例如rpc_async()
。它还公开了一组 API 来添加回调函数和设置结果。警告
GPU 支持是测试版功能,可能发生变化。
- add_done_callback(callback)[source][source]¶
将给定的回调函数添加到这个
Future
中,当Future
完成时将运行。可以添加多个回调到同一个Future
中,但无法保证它们的执行顺序。回调函数必须接受一个参数,即对这个Future
的引用。回调函数可以使用value()
方法来获取值。注意,如果这个Future
已经完成,则将直接运行给定的回调。我们建议您使用
then()
方法,因为它提供了一种在回调完成后同步的方式。如果您的回调不返回任何内容,add_done_callback
可能更便宜。但then()
和add_done_callback
在底层都使用相同的回调注册 API。关于 GPU 张量,此方法的行为与
then()
相同。- 参数:
callback (
Future
) – 一个接受一个参数的Callable
,该参数是此Future
的引用。
注意
注意,如果回调函数抛出异常,无论是通过原始 future 完成并调用
fut.wait()
,还是通过回调中的其他代码,都必须仔细处理错误处理。例如,如果此回调后来完成其他 future,这些 future 不会被标记为带有错误的完成,并且用户负责独立处理这些 future 的完成/等待。- 示例::
>>> def callback(fut): ... print("This will run after the future has finished.") ... print(fut.wait()) >>> fut = torch.futures.Future() >>> fut.add_done_callback(callback) >>> fut.set_result(5) This will run after the future has finished. 5
- done()[source][source]¶
如果这个
Future
已经完成,则返回True
。一个Future
完成是指它有结果或异常。如果值包含位于 GPU 上的张量,
Future.done()
将返回True
,即使填充这些张量的异步内核尚未在设备上完成运行,因为在那个阶段结果已经可用,前提是执行适当的同步操作(见wait()
)。- 返回类型:
- set_exception(result)[source][source]¶
为此
Future
设置一个异常,这将标记此Future
为出错完成并触发所有附加回调。请注意,当在此Future
上调用 wait()/value() 时,此处设置的异常将立即抛出。- 参数:
结果(BaseException)- 此
Future
的异常。
- 示例::
>>> fut = torch.futures.Future() >>> fut.set_exception(ValueError("foo")) >>> fut.wait() Traceback (most recent call last): ... ValueError: foo
- set_result(result)[source][source]¶
为此
Future
设置结果,这将标记此Future
为完成并触发所有附加回调。请注意,一个Future
不能被标记完成两次。如果结果包含位于 GPU 上的张量,即使填充这些张量的异步内核尚未在设备上完成运行,也可以调用此方法,前提是当调用此方法时,那些内核入队的流被设置为当前流。简单来说,只要在之间不更改流,就可以在启动这些内核后立即调用此方法,无需任何额外的同步。此方法将在所有相关当前流上记录事件,并使用它们来确保所有此
Future
的消费者都能得到适当的调度。- 参数:
结果(对象)- 此
Future
的结果对象。
- 示例::
>>> import threading >>> import time >>> def slow_set_future(fut, value): ... time.sleep(0.5) ... fut.set_result(value) >>> fut = torch.futures.Future() >>> t = threading.Thread( ... target=slow_set_future, ... args=(fut, torch.ones(2) * 3) ... ) >>> t.start() >>> print(fut.wait()) tensor([3., 3.]) >>> t.join()
- then(callback)[source][source]
将给定的回调函数添加到这个
Future
,当Future
完成时将执行该回调。可以将多个回调添加到同一个Future
,但无法保证它们的执行顺序(要强制执行特定顺序,请考虑链式操作:fut.then(cb1).then(cb2)
)。回调函数必须接受一个参数,即对Future
的引用。回调函数可以使用value()
方法获取值。注意,如果Future
已经完成,则立即在行内运行给定的回调。如果
Future
的值包含位于 GPU 上的张量,则回调可能在填充这些张量的异步内核尚未在设备上完成执行时被调用。然而,回调将以一些专用流(从全局池中获取)作为当前流被调用,这些流将与那些内核同步。因此,回调对这些张量进行的任何操作都将安排在内核完成后在设备上执行。换句话说,只要回调不切换流,就可以安全地操作结果而无需任何额外的同步。这与wait()
的非阻塞行为类似。同样,如果回调返回的值包含位于 GPU 上的张量,即使生成这些张量的内核仍在设备上运行,它也可以这样做,只要回调在其执行过程中没有更改流。如果想要更改流,必须小心地将它们与原始流重新同步,即回调被调用时的当前流。
- 参数:
回调(
Callable
)- 一个接受此Callable
作为唯一参数的Future
。- 返回:
一个新的
Future
对象,它持有callback
的返回值,并且当给定的callback
完成时将被标记为完成。- 返回类型:
Future[S]
注意
注意,如果回调函数抛出异常,无论是通过原始 future 完成并调用
fut.wait()
,还是通过回调中的其他代码,then
返回的 future 将被适当地标记为遇到错误。然而,如果此回调后来完成其他 future,这些 future 不会被标记为带有错误完成,并且用户负责独立处理这些 future 的完成/等待。- 示例::
>>> def callback(fut): ... print(f"RPC return value is {fut.wait()}.") >>> fut = torch.futures.Future() >>> # The inserted callback will print the return value when >>> # receiving the response from "worker1" >>> cb_fut = fut.then(callback) >>> chain_cb_fut = cb_fut.then( ... lambda x : print(f"Chained cb done. {x.wait()}") ... ) >>> fut.set_result(5) RPC return value is 5. Chained cb done. None
- value()[来源][来源] ¶
获取已完成的 future 的值。
此方法应在调用
wait()
完成之后调用,或在内部传递给then()
的回调函数中调用。在其他情况下,Future
可能尚未持有值,调用value()
可能会失败。如果值包含位于 GPU 上的张量,则此方法不会执行任何额外的同步。这应该在之前单独通过调用
wait()
来完成(除非在回调中,因为then()
已经处理了)。- 返回:
此
Future
所持有的值。如果创建值的函数(回调或 RPC)抛出了错误,则此value()
方法也会抛出错误。- 返回类型:
T
- torch.futures.collect_all(futures)[source][source]
将提供的对象收集到单个组合对象中,当所有子未来都完成时,该组合对象完成。
- 参数:
futures (列表) - 一个对象列表。
- 返回:
将一个
Future
对象返回为传入的 Futures 列表。- 返回类型:
- 示例::
>>> fut0 = torch.futures.Future() >>> fut1 = torch.futures.Future() >>> fut = torch.futures.collect_all([fut0, fut1]) >>> fut0.set_result(0) >>> fut1.set_result(1) >>> fut_list = fut.wait() >>> print(f"fut0 result = {fut_list[0].wait()}") fut0 result = 0 >>> print(f"fut1 result = {fut_list[1].wait()}") fut1 result = 1