torch.distributed.tensor.experimental 的源代码
# 版权(c)Meta Platforms, Inc. 和其关联公司。从 collections.abc 导入 Iterator,从 contextlib 导入 contextmanager,从 torch.distributed.tensor._api 导入 DTensor,从 torch.distributed.tensor.experimental._attention 导入 context_parallel,从 torch.distributed.tensor.experimental._func_map 导入 local_map,从 torch.distributed.tensor.experimental._register_sharding 导入 register_sharding __all__ = ["context_parallel", "implicit_replication", "local_map", "register_sharding"] @contextmanager def implicit_replication() -> Iterator[None]: """ 此上下文管理器允许 :class:`DTensor` 在运算符计算期间隐式地处理程序中的所有非 DTensor(``torch.Tensor``),将其复制为 :class:`DTensor`。 .. warning:: 如果 ``torch.Tensor`` 在实际中未复制,则可能产生不正确的结果,请谨慎使用。 尝试:将 DTensor._op_dispatcher._allow_implicit_replication 设置为 True yield finally: 将 DTensor._op_dispatcher._allow_implicit_replication 设置为 False # 为暴露的私有名称设置命名空间 context_parallel.__module__ = "torch.distributed.tensor.experimental" implicit_replication.__module__ = "torch.distributed.tensor.experimental" local_map.__module__ = "torch.distributed.tensor.experimental" register_sharding.__module__ = "torch.distributed.tensor.experimental"