训练脚本 ¶
如果您的训练脚本与 torch.distributed.launch 兼容,则它将使用以下差异与 torchrun 兼容:
无需手动传递
RANK,WORLD_SIZE,MASTER_ADDR和MASTER_PORT。rdzv_backend和rdzv_endpoint可以提供。对于大多数用户来说,这将被设置为c10d(参见 rendezvous)。默认的rdzv_backend创建了一个非弹性的 rendezvous,其中rdzv_endpoint持有主地址。确保你的脚本中有
load_checkpoint(path)和save_checkpoint(path)逻辑。当任何数量的工作进程失败时,我们将使用相同的程序参数重新启动所有工作进程,因此你将丢失最近检查点之前的进度(参见弹性启动)。use_env标志已被删除。如果你是通过解析--local-rank选项来解析本地排名,你需要从环境变量LOCAL_RANK(例如int(os.environ["LOCAL_RANK"]))中获取本地排名。
下面是一个示例训练脚本,它在每个纪元上进行检查点,因此失败时的最坏情况进度丢失是一个完整纪元的训练。
def main():
args = parse_args(sys.argv[1:])
state = load_checkpoint(args.checkpoint_path)
initialize(state)
# torch.distributed.run ensures that this will work
# by exporting all the env vars needed to initialize the process group
torch.distributed.init_process_group(backend=args.backend)
for i in range(state.epoch, state.total_num_epochs)
for batch in iter(state.dataset)
train(batch, state.model)
state.epoch += 1
save_checkpoint(state)
关于 torchelastic 兼容的训练脚本的具体示例,请访问我们的示例页面。