• 文档 >
  • torch.utils.module_tracker
快捷键

torch.utils.module_tracker ¬

此实用程序可用于跟踪 torch.nn.Module 层次结构中的当前位置。它可以在其他跟踪工具中使用,以便能够轻松地将测量的量与用户友好的名称关联起来。这在 FlopCounterMode 中特别有用。

class torch.utils.module_tracker.ModuleTracker[source][source] ¬

ModuleTracker 是一个上下文管理器,用于在执行过程中跟踪 nn.Module 层次结构,以便其他系统可以查询当前正在执行哪个 Module(或其反向操作正在执行)。

您可以通过访问此上下文管理器的 parents 属性来获取当前正在执行的所有 Module 的集合,通过它们的 fqn(完全限定名,也用作 state_dict 中的键)。您可以通过访问 is_bw 属性来了解您是否正在执行反向操作。

注意, parents 从不为空,并且始终包含“Global”键。 is_bw 标志将在执行另一个 Module 之前保持 True 状态。如果您需要它更准确,请提交一个问题请求。从 fqn 到模块实例的映射是可能的,但尚未实现,如果您需要它,请提交一个问题请求。

演示用法

mod = torch.nn.Linear(2, 2)

with ModuleTracker() as tracker:
    # Access anything during the forward pass
    def my_linear(m1, m2, bias):
        print(f"Current modules: {tracker.parents}")
        return torch.mm(m1, m2.t()) + bias
    torch.nn.functional.linear = my_linear

    mod(torch.rand(2, 2))

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

PyTorch 开发者文档全面访问

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源