torch.distributed.elastic.control_plane 源代码
导入
操作系统
来自 collections.abc
导入
生成器
来自 contextlib
导入
上下文管理器,
退出栈
来自 torch.distributed.elastic.multiprocessing.errors
导入
记录
全部 = [
"worker_main",
]
TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET"
@contextmanager
定义
工作服务器(
套接字路径:
字符串) ->
生成器[
无,
无,
无
]
来自 torch._C._distributed_c10d
导入 _WorkerServer
服务器 =
工作服务器(
套接字路径)
try:
产生
最后:
服务器.
关闭()
[文档]上下文管理器
@record
def 工作器主函数() -> 生成器[None, None, None]:
"""
这是一个包装您主入口函数的上下文管理器。这结合了
现有的 `errors.record` 逻辑以及一个新的 `_WorkerServer`,该服务器通过
由 `Torch_WORKER_SERVER_SOCKET` 指定的 Unix 套接字公开处理程序。
``Torch_WORKER_SERVER_SOCKET``.
示例
::
@worker_main()
def main():
pass
if __name__ == "__main__":
main()
"""
with ExitStack() as stack:
socket_path = os.environ.get(TORCH_WORKER_SERVER_SOCKET)
if socket_path is not None:
stack.enter_context(_worker_server(socket_path))
yield