• 教程 >
  • 使用全局分片数据并行(FSDP)的高级模型训练
快捷键

使用完全分片数据并行(FSDP)的高级模型训练 ¶

创建于:2025 年 4 月 1 日 | 最后更新:2025 年 4 月 1 日 | 最后验证:2024 年 11 月 5 日

作者:Hamid Shojanazeri, Less Wright, Rohan Varma, Yanli Zhao

你将学到什么
  • PyTorch 的完全分片数据并行模块:跨分片模块参数的包装器

数据并行工作者。

前提条件
  • PyTorch 1.12 或更高版本

  • 了解 FSDP API。

本教程介绍了 PyTorch 1.12 版本中 FSDP(完全分片数据并行)的更多高级功能。要熟悉 FSDP,请参阅 FSDP 入门教程。

在本教程中,我们将以文本摘要为例,使用 FSDP 微调 HuggingFace (HF) T5 模型,作为一个工作示例。

示例使用 Wikihow,为了简单起见,我们将展示在单个节点、P4dn 实例上使用 8 个 A100 GPU 进行训练。我们现在有几篇关于大规模 FSDP 在多节点集群上训练的博客文章((链接 1),(链接 2))和一篇论文。

FSDP 是一个面向生产的软件包,注重易用性、性能和长期支持。FSDP 的主要优势之一是减少每个 GPU 的内存占用。这使得与 DDP 相比,可以以更低的总体内存训练更大的模型,并利用计算和通信的重叠来高效地训练模型。这种降低的内存压力可以用来训练更大的模型或增加批大小,从而可能有助于整体训练吞吐量。您可以在 PyTorch FSDP 这里了解更多信息。

本教程中的 FSDP 功能

  • 变压器自动包装策略

  • 混合精度

  • 在设备上初始化 FSDP 模型

  • 分片策略

  • 向后预取

  • 通过流式传输到 CPU 进行模型检查点保存

FSDP 工作原理概述 ¶

从高层次来看,FDSP 的工作原理如下:

在构造函数中

  • 分片模型参数,每个 rank 只保留自己的分片

在前向传播过程中

  • 执行 all_gather 以收集所有 rank 的所有 shard,以恢复此 FSDP 单元的完整参数并执行前向计算

  • 释放内存,丢弃它刚刚收集的非拥有参数碎片

在反向传播过程中

  • 在此 FSDP 单元中运行 all_gather 以收集所有 rank 的所有碎片以恢复完整参数并运行反向计算

  • 丢弃非拥有参数以释放内存。

  • 运行 reduce_scatter 以同步梯度

HF T5 微调教程

HF T5 预训练模型有四种不同的大小,从小型 6000 万参数到 XXL 的 110 亿参数不等。在本教程中,我们将使用 WikiHow 数据集演示 T5 3B 模型的微调,并使用 FSDP 进行文本摘要。本教程的重点是突出 FSDP 中对训练超过 30 亿参数的大型模型有帮助的不同功能。同时,我们还涵盖了针对基于 Transformer 的模型的具体功能。本教程的代码可在 Pytorch 示例中找到。

设置

1.1 安装最新的 PyTorch

pip3 install torch torchvision torchaudio

1.2 数据集设置

请创建一个数据文件夹,从 wikihowAll.csv 和 wikihowSep.cs 下载 WikiHow 数据集,并将它们放置在数据文件夹中。我们将使用 summarization_dataset 中的 wikihow 数据集。

接下来,我们将以下代码片段添加到 Python 脚本“T5_training.py”中。

备注

本教程的完整源代码可在 PyTorch 示例中找到。

1.3 导入必要的包:

import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, GPT2TokenizerFast
from transformers import T5Tokenizer, T5ForConditionalGeneration
import functools
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from transformers.models.t5.modeling_t5 import T5Block

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
 checkpoint_wrapper,
 CheckpointImpl,
 apply_activation_checkpointing_wrapper)

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    BackwardPrefetch,
    ShardingStrategy,
    FullStateDictConfig,
    StateDictType,
)
from torch.distributed.fsdp.wrap import (
    transformer_auto_wrap_policy,
    enable_wrap,
    wrap,
)
from functools import partial
from torch.utils.data import DataLoader
from pathlib import Path
from summarization_dataset import *
from transformers.models.t5.modeling_t5 import T5Block
from typing import Type
import time
import tqdm
from datetime import datetime

1.4 分布式训练设置。在这里,我们使用两个辅助函数来初始化分布式训练的进程,并在训练完成后进行清理。在本教程中,我们将使用 torch elastic,通过 torchrun 自动设置 worker RANK 和 WORLD_SIZE。

def setup():
    # initialize the process group
    dist.init_process_group("nccl")

def cleanup():
    dist.destroy_process_group()

2.1 设置 HuggingFace T5 模型:

def setup_model(model_name):
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    tokenizer =  T5Tokenizer.from_pretrained(model_name)
    return model, tokenizer

我们还添加了几个用于日期和格式化内存指标的帮助函数。

def get_date_of_run():
    """create date and time for file save uniqueness
    example: 2022-05-07-08:31:12_PM'
    """
    date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p")
    print(f"--> current date and time of run = {date_of_run}")
    return date_of_run

def format_metrics_to_gb(item):
    """quick function to format numbers to gigabyte and round to 4 digit precision"""
    metric_num = item / g_gigabyte
    metric_num = round(metric_num, ndigits=4)
    return metric_num

2.2 定义训练函数:

def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
    model.train()
    local_rank = int(os.environ['LOCAL_RANK'])
    fsdp_loss = torch.zeros(2).to(local_rank)

    if sampler:
        sampler.set_epoch(epoch)
    if rank==0:
        inner_pbar = tqdm.tqdm(
            range(len(train_loader)), colour="blue", desc="r0 Training Epoch"
        )
    for batch in train_loader:
        for key in batch.keys():
            batch[key] = batch[key].to(local_rank)
        optimizer.zero_grad()
        output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] )
        loss = output["loss"]
        loss.backward()
        optimizer.step()
        fsdp_loss[0] += loss.item()
        fsdp_loss[1] += len(batch)
        if rank==0:
            inner_pbar.update(1)

    dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
    train_accuracy = fsdp_loss[0] / fsdp_loss[1]


    if rank == 0:
        inner_pbar.close()
        print(
                f"Train Epoch: \t{epoch}, Loss: \t{train_accuracy:.4f}"
            )
    return train_accuracy

2.3 定义验证函数:

def validation(model, rank, world_size, val_loader):
    model.eval()
    correct = 0
    local_rank = int(os.environ['LOCAL_RANK'])
    fsdp_loss = torch.zeros(3).to(local_rank)
    if rank == 0:
        inner_pbar = tqdm.tqdm(
            range(len(val_loader)), colour="green", desc="Validation Epoch"
        )
    with torch.no_grad():
        for batch in val_loader:
            for key in batch.keys():
                batch[key] = batch[key].to(local_rank)
            output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"])
            fsdp_loss[0] += output["loss"].item()  # sum up batch loss
            fsdp_loss[1] += len(batch)

            if rank==0:
                inner_pbar.update(1)

    dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
    val_loss = fsdp_loss[0] / fsdp_loss[1]
    if rank == 0:
        inner_pbar.close()
        print(f"Validation Loss: {val_loss:.4f}")
    return val_loss

2.4 定义一个分布式训练函数,该函数使用 FSDP 包装模型:

def fsdp_main(args):

    model, tokenizer = setup_model("t5-base")

    local_rank = int(os.environ['LOCAL_RANK'])
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])


    dataset = load_dataset('wikihow', 'all', data_dir='data/')
    print(dataset.keys())
    print("Size of train dataset: ", dataset['train'].shape)
    print("Size of Validation dataset: ", dataset['validation'].shape)


    #wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False)
    train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False)
    val_dataset = wikihow(tokenizer, 'validation', 300, 512, 150, False)

    sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True)
    sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size)

    setup()


    train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
    test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
    cuda_kwargs = {'num_workers': 2,
                    'pin_memory': True,
                    'shuffle': False}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

    train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
    val_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)

    t5_auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            T5Block,
        },
    )
    sharding_strategy: ShardingStrategy = ShardingStrategy.SHARD_GRAD_OP #for Zero2 and FULL_SHARD for Zero3
    torch.cuda.set_device(local_rank)


    #init_start_event = torch.cuda.Event(enable_timing=True)
    #init_end_event = torch.cuda.Event(enable_timing=True)

    #init_start_event.record()

    bf16_ready = (
    torch.version.cuda
    and torch.cuda.is_bf16_supported()
    and LooseVersion(torch.version.cuda) >= "11.0"
    and dist.is_nccl_available()
    and nccl.version() >= (2, 10)
    )

    if bf16_ready:
        mp_policy = bfSixteen
    else:
        mp_policy = None # defaults to fp32

    # model is on CPU before input to FSDP
    model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=mp_policy,
        #sharding_strategy=sharding_strategy,
        device_id=torch.cuda.current_device())

    optimizer = optim.AdamW(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    best_val_loss = float("inf")
    curr_val_loss = float("inf")
    file_save_name = "T5-model-"

    if rank == 0:
        time_of_run = get_date_of_run()
        dur = []
        train_acc_tracking = []
        val_acc_tracking = []
        training_start_time = time.time()

    if rank == 0 and args.track_memory:
        mem_alloc_tracker = []
        mem_reserved_tracker = []

    for epoch in range(1, args.epochs + 1):
        t0 = time.time()
        train_accuracy = train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
        if args.run_validation:
            curr_val_loss = validation(model, rank, world_size, val_loader)
        scheduler.step()

        if rank == 0:

            print(f"--> epoch {epoch} completed...entering save and stats zone")

            dur.append(time.time() - t0)
            train_acc_tracking.append(train_accuracy.item())

            if args.run_validation:
                val_acc_tracking.append(curr_val_loss.item())

            if args.track_memory:
                mem_alloc_tracker.append(
                    format_metrics_to_gb(torch.cuda.memory_allocated())
                )
                mem_reserved_tracker.append(
                    format_metrics_to_gb(torch.cuda.memory_reserved())
                )
            print(f"completed save and stats zone...")

        if args.save_model and curr_val_loss < best_val_loss:

            # save
            if rank == 0:
                print(f"--> entering save model state")

            save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
            with FSDP.state_dict_type(
                model, StateDictType.FULL_STATE_DICT, save_policy
            ):
                cpu_state = model.state_dict()
            #print(f"saving process: rank {rank}  done w state_dict")


            if rank == 0:
                print(f"--> saving model ...")
                currEpoch = (
                    "-" + str(epoch) + "-" + str(round(curr_val_loss.item(), 4)) + ".pt"
                )
                print(f"--> attempting to save model prefix {currEpoch}")
                save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
                print(f"--> saving as model name {save_name}")

                torch.save(cpu_state, save_name)

        if curr_val_loss < best_val_loss:

            best_val_loss = curr_val_loss
            if rank==0:
                print(f"-->>>> New Val Loss Record: {best_val_loss}")

    dist.barrier()
    cleanup()

解析参数并设置主函数:

if __name__ == '__main__':
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch T5 FSDP Example')
    parser.add_argument('--batch-size', type=int, default=4, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=4, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=2, metavar='N',
                        help='number of epochs to train (default: 3)')
    parser.add_argument('--lr', type=float, default=.002, metavar='LR',
                        help='learning rate (default: .002)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--track_memory', action='store_false', default=True,
                        help='track the gpu memory')
    parser.add_argument('--run_validation', action='store_false', default=True,
                        help='running the validation')
    parser.add_argument('--save-model', action='store_false', default=True,
                        help='For Saving the current Model')
    args = parser.parse_args()

    torch.manual_seed(args.seed)

    fsdp_main(args)

使用 torchrun 运行训练:

torchrun --nnodes 1 --nproc_per_node 4  T5_training.py

Transformer 包装策略¶

如前一个教程中所述,auto_wrap_policy 是 FSDP 功能之一,它使得自动分片给定模型并将模型、优化器和梯度分片放入不同的 FSDP 单元变得容易。

对于像 Transformer 编码器-解码器这样的某些架构,模型的一些部分,如嵌入表,被编码器和解码器共享。在这种情况下,我们需要将嵌入表放置在外部 FSDP 单元中,以便可以从编码器和解码器中访问它。此外,通过为 transformer 注册层类,可以使分片计划更加通信高效。在 PyTorch 1.12 中,FSDP 增加了这项支持,现在我们有了 transformer 的包装策略。

可以按照以下方式创建,其中 T5Block 代表 T5 Transformer 层类(包含 MHSA 和 FFN)。

t5_auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            T5Block,
        },
    )
torch.cuda.set_device(local_rank)


model = FSDP(model,
    auto_wrap_policy=t5_auto_wrap_policy)

要查看包装后的模型,您可以轻松地打印模型,并直观地检查分片和 FSDP 单元。

混合精度

FSDP 支持灵活的混合精度训练,允许使用任意降低精度的类型(例如 fp16 或 bfloat16)。目前 BFloat16 仅在 Ampere GPU 上可用,因此在使用之前您需要确认其原生支持。例如,在 V100 上,BFloat16 仍然可以运行,但由于它不是原生运行,可能会导致显著的性能下降。

要检查 BFloat16 是否原生支持,您可以使用以下方法:

bf16_ready = (
    torch.version.cuda
    and torch.cuda.is_bf16_supported()
    and LooseVersion(torch.version.cuda) >= "11.0"
    and dist.is_nccl_available()
    and nccl.version() >= (2, 10)
)

FSDP 中混合精度的优点之一是提供对参数、梯度和缓冲区不同精度级别的细粒度控制,如下所示:

fpSixteen = MixedPrecision(
    param_dtype=torch.float16,
    # Gradient communication precision.
    reduce_dtype=torch.float16,
    # Buffer precision.
    buffer_dtype=torch.float16,
)

bfSixteen = MixedPrecision(
    param_dtype=torch.bfloat16,
    # Gradient communication precision.
    reduce_dtype=torch.bfloat16,
    # Buffer precision.
    buffer_dtype=torch.bfloat16,
)

fp32_policy = MixedPrecision(
    param_dtype=torch.float32,
    # Gradient communication precision.
    reduce_dtype=torch.float32,
    # Buffer precision.
    buffer_dtype=torch.float32,
)

注意,如果未指定某些类型(参数、归约、缓冲区),则不会进行任何类型转换。

这种灵活性允许用户进行细粒度控制,例如仅将梯度通信设置为以降低精度进行,所有参数/缓冲区计算均在全精度下完成。这在节点间通信是主要瓶颈且参数/缓冲区必须以全精度进行以避免精度问题时可能很有用。这可以通过以下策略实现:

grad_bf16 = MixedPrecision(reduce_dtype=torch.bfloat16)

在 2.4 版本中,我们只是将相关的混合精度策略添加到了 FSDP 包装器中:

model = FSDP(model,
       auto_wrap_policy=t5_auto_wrap_policy,
       mixed_precision=bfSixteen)

在我们的实验中,我们观察到使用 BFloat16 进行训练可以带来高达 4 倍的速度提升,在某些实验中内存减少约 30%,这可以用于增加批处理大小。

在设备上初始化 FSDP 模型 ¶

在 1.12 版本中,FSDP 支持一个 device_id 参数,用于在指定的设备上初始化输入 CPU 模块。这对于整个模型无法适应单个 GPU,但可以适应主机 CPU 内存的情况非常有用。当指定 device_id 时,FSDP 将根据每个 FSDP 单元将模型移动到指定的设备,避免初始化时的 GPU 内存溢出问题,并且比基于 CPU 的初始化快得多:

torch.cuda.set_device(local_rank)

 model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=bfSixteen,
        device_id=torch.cuda.current_device())

分片策略 ¶

FSDP 的分片策略默认设置为完全分片模型参数,梯度优化器状态将在所有 rank 之间分片。(也称为 Zero3 分片)。如果您对 Zero2 分片策略感兴趣,其中只有优化器状态和梯度被分片,FSDP 通过使用“ShardingStrategy.SHARD_GRAD_OP”而不是“ShardingStrategy.FULL_SHARD”作为 FSDP 初始化的参数来支持此功能,如下所示:

torch.cuda.set_device(local_rank)

 model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=bfSixteen,
        device_id=torch.cuda.current_device(),
        sharding_strategy=ShardingStrategy.SHARD_GRAD_OP # ZERO2)

这将减少 FSDP 中的通信开销,在这种情况下,它会在前向和反向传播后保留全部参数。

这在反向传播期间节省了一个 all_gather 操作,从而减少了通信量,但代价是更高的内存占用。请注意,在反向传播结束时释放全部模型参数,并在下一个前向传播中进行 all_gather 操作。

反向预取

反向预取设置控制了请求下一个 FSDP 单元参数的时间。通过将其设置为 BACKWARD_PRE,下一个 FSDP 单元的参数可以开始请求并提前到达,在当前单元的计算开始之前。这可以重叠 all_gather 通信和梯度计算,从而提高训练速度,但会略微增加内存消耗。它可以在 2.4 版本的 FSDP 包装器中如下使用:

torch.cuda.set_device(local_rank)

 model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=bfSixteen,
        device_id=torch.cuda.current_device(),
        backward_prefetch = BackwardPrefetch.BACKWARD_PRE)

backward_prefetch 有两种模式,BACKWARD_PRE 和 BACKWARD_POST。BACKWARD_POST 表示在当前 FSDP 单元处理完成之前,不会请求下一个 FSDP 单元的参数,从而最小化内存开销。在某些情况下,使用 BACKWARD_PRE 可以将模型训练速度提高 2-10%,对于更大的模型,速度提升更为显著。

模型检查点保存,通过流式传输到 Rank0 CPU

要使用 FULL_STATE_DICT 保存方式保存模型检查点,该方式与本地模型保存方式相同,PyTorch 1.12 提供了一些实用工具来支持保存更大的模型。

首先,可以指定 FullStateDictConfig,允许仅在 rank 0 上填充 state_dict 并将其卸载到 CPU。

当使用此配置时,FSDP 将逐个将模型参数 allgather 到 CPU 上,仅在 rank 0 上执行。当最终保存 state_dict 时,它将仅在 rank 0 上填充,并包含 CPU 张量。这避免了对于大于单个 GPU 内存的模型潜在的 OOM 问题,并允许用户检查点大小大致等于用户机器上可用 CPU RAM 的模型。

此功能可以按以下方式运行:

save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
            model, StateDictType.FULL_STATE_DICT, save_policy
        ):
            cpu_state = model.state_dict()
if rank == 0:
 save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
 torch.save(cpu_state, save_name)

摘要

在本教程中,我们介绍了 Pytorch 1.12 中 FSDP 的许多新功能,并使用 HF T5 作为运行示例。使用适当的包装策略,特别是对于 transformer 模型,以及混合精度和向后预取,可以加快您的训练运行。此外,将模型初始化在设备上,以及通过流式传输到 CPU 保存检查点等功能,有助于避免处理大型模型时的 OOM 错误。

我们正在积极开发 FSDP 的新功能,以便在下一个版本中发布。如果您有任何反馈、功能请求、问题或在使用 FSDP 时遇到问题,请随时通过在 PyTorch GitHub 仓库中创建问题与我们联系。


评分这个教程

© 版权所有 2024,PyTorch。

使用 Sphinx 构建,主题由 Read the Docs 提供。
//暂时添加调查链接

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源