• 文档 >
  • 火炬分布式弹性 >
  • 训练脚本
快捷键

训练脚本 ¶

如果您的训练脚本与 torch.distributed.launch 兼容,则它将使用以下差异与 torchrun 兼容:

  1. 无需手动传递 RANKWORLD_SIZEMASTER_ADDRMASTER_PORT

  2. rdzv_backendrdzv_endpoint 可以提供。对于大多数用户来说,这将被设置为 c10d (参见 rendezvous)。默认的 rdzv_backend 创建了一个非弹性的 rendezvous,其中 rdzv_endpoint 持有主地址。

  3. 确保你的脚本中有 load_checkpoint(path)save_checkpoint(path) 逻辑。当任何数量的工作进程失败时,我们将使用相同的程序参数重新启动所有工作进程,因此你将丢失最近检查点之前的进度(参见弹性启动)。

  4. use_env 标志已被删除。如果你是通过解析 --local-rank 选项来解析本地排名,你需要从环境变量 LOCAL_RANK (例如 int(os.environ["LOCAL_RANK"]) )中获取本地排名。

下面是一个示例训练脚本,它在每个纪元上进行检查点,因此失败时的最坏情况进度丢失是一个完整纪元的训练。

def main():
     args = parse_args(sys.argv[1:])
     state = load_checkpoint(args.checkpoint_path)
     initialize(state)

     # torch.distributed.run ensures that this will work
     # by exporting all the env vars needed to initialize the process group
     torch.distributed.init_process_group(backend=args.backend)

     for i in range(state.epoch, state.total_num_epochs)
          for batch in iter(state.dataset)
              train(batch, state.model)

          state.epoch += 1
          save_checkpoint(state)

关于 torchelastic 兼容的训练脚本的具体示例,请访问我们的示例页面。


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

查看 PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源