快捷键

图 ¶

class torch.cuda.graph(cuda_graph, pool=None, stream=None, capture_error_mode='global')[source][source]

捕获 CUDA 工作并将其存储到 torch.cuda.CUDAGraph 对象中,以便稍后回放。

请参阅 CUDA 图,以获取一般介绍、详细用法和限制。

参数:
  • cuda_graph (torch.cuda.CUDAGraph) – 用于捕获的图对象。

  • pool (可选) – 不透明令牌(由 graph_pool_handle()other_Graph_instance.pool() 调用返回),提示此图的捕获可能共享指定池的内存。请参阅图内存管理。

  • stream (torch.cuda.Stream, 可选) – 如果提供,将设置为上下文中的当前流。如果不提供, graph 将设置其内部侧流作为上下文中的当前流。

  • capture_error_mode (str, 可选) – 指定图捕获流的 cudaStreamCaptureMode。可以是“global”、“thread_local”或“relaxed”。在 cuda 图捕获期间,某些操作,如 cudaMalloc,可能是不安全的。“global”将在其他线程中的操作上引发错误,“thread_local”将仅对当前线程中的操作引发错误,“relaxed”将不对操作引发错误。除非您熟悉 cudaStreamCaptureMode,否则请不要更改此设置。

注意

为了有效共享内存,如果您传递了一个之前捕获中使用的 pool ,并且之前的捕获使用了显式的 stream 参数,那么您应该将相同的 stream 参数传递给这次捕获。

警告

此 API 处于测试阶段,未来版本中可能会有所变化。


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源