torch.distributed.tensor.debug 的源代码
# mypy: 允许未类型注解的导入 from torch.distributed.tensor.debug._comm_mode import CommDebugMode from torch.distributed.tensor.debug._visualize_sharding import visualize_sharding __all__ = ["CommDebugMode", "visualize_sharding"] def _get_sharding_prop_cache_info(): """ 获取分片传播缓存的缓存信息,仅用于调试目的。 这将返回一个命名元组,显示分片传播器的缓存中的命中、未命中、最大大小和当前大小。 """ from torch.distributed.tensor._api import DTensor return ( DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding.cache_info() # type:ignore[attr-defined] ) # 为公开的私有名称设置命名空间 CommDebugMode.__module__ = "torch.distributed.tensor.debug" visualize_sharding.__module__ = "torch.distributed.tensor.debug"