torch.cuda.make_graphed_callables
- torch.cuda.make_graphed_callables(callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None)[source][source]
接受可调用对象(函数或
nn.Module
),并返回图形化版本。每个已图形化的可调用对象的正向传递运行其源可调用对象的前向 CUDA 工作,作为一个 CUDA 图在单个 autograd 节点内部。
可视化图中的可调用函数的前向传递还会向自动微分图中添加一个反向节点。在反向传播过程中,该节点以 CUDA 图的形式执行可调用函数的反向操作。
因此,每个可视化的可调用函数都应该是其源可调用函数在自动微分训练循环中的直接替换。
请参阅部分网络捕获以获取详细用法和限制。
如果您传递了多个可调用函数的元组,它们的捕获将使用相同的内存池。请参阅图内存管理,了解何时适用此操作。
- 参数:
可调用对象(torch.nn.Module 或 Python 函数,或这些对象的元组)- 可调用对象或可调用对象列表,用于构建图。有关何时使用可调用对象元组进行传递的图内存管理,请参阅相关文档。如果传递了可调用对象元组,则元组中元素的顺序必须与它们在实时工作负载中的运行顺序相同。
sample_args(张量元组,或张量元组的元组)- 为每个可调用对象采样参数。如果传递了一个单独的可调用对象,则
sample_args
必须是一个包含参数张量的单个元组。如果传递了可调用对象元组,则sample_args
必须是一个参数张量元组的元组。num_warmup_iters(整数)- 预热迭代次数。目前,
DataDistributedParallel
需要 11 次迭代进行预热。默认值:3
。allow_unused_input(布尔值)- 如果为 False,指定在计算输出时未使用(因此其梯度始终为零)的输入是一个错误。默认值为 False。
pool(可选)- 由
graph_pool_handle()
或other_Graph_instance.pool()
返回的令牌,提示此图可能与指定的池共享内存。请参阅图内存管理。
注意
每个 Tensor 在
sample_args
中的状态必须与训练循环中对应真实输入期望的状态相匹配。警告
此 API 处于测试阶段,未来版本中可能会有所变化。
警告
sample_args
中的每个可调用对象必须只包含 Tensor。不允许其他类型。警告
返回的可调用对象不支持高阶微分(例如,双重反向)。
警告
在传递给
make_graphed_callables()
的任何Module
中,只有参数可以是可训练的。缓冲区必须具有requires_grad=False
。警告
在通过
torch.nn.Module
和make_graphed_callables()
后,您不得添加或删除该模块的任何参数或缓冲区。警告
在将
torch.nn.Module
传递给make_graphed_callables()
时,这些模块不得注册钩子。然而,在通过make_graphed_callables()
传递模块后注册钩子是允许的。警告
当运行图调用时,必须以与调用中出现的顺序和格式相同的顺序传递其参数。
警告
自动混合精度仅在
make_graphed_callables()
中支持,并且缓存已禁用。torch.cuda.amp.autocast()上下文管理器必须具有 cache_enabled=False。