FSDP 并行策略解析

 

1 训练过程中显存的组成

在大模型训练中,显存大致被三类数据占用:

  • P (Parameters):模型参数
  • G (Gradients):反向传播时的梯度
  • OS (Optimizer States):优化器的状态(如 Adam 的动量和方差信息)

典型的 Adam 优化器中,这三者的比例大致为 1 : 1 : 6

之所以 Optimizer States 占比高,是因为 Adam 需要为每个参数维护一阶矩(动量)和二阶矩(方差)的额外信息

例如,在 PyTorch 中优化器的典型构造如下:

optimizer = optim.Adam(model.parameters(), lr=0.001)

反向传播与参数更新的过程:

  • loss.backward() → 生成 parameters.grad
  • optimizer.step() → 更新优化器状态(OS)

PyTorch 内部大致逻辑如下:

for group in optimizer.param_groups:
    for p in group['params']:
        state = optimizer.state[p]

        # Exponential moving average of gradient values
        m = state['exp_avg']      # 动量参数

        # Exponential moving average of squared gradient values
        v = state['exp_avg_sq']   # 方差参数

2 混合精度训练下的显存占用

在混合精度(前向 FP16,反向和优化器 FP32)下,如果模型有 x 个参数,显存占用大致如下:

类型 占用
Parameters(FP16 + FP32 拷贝) 2x
Gradients(FP32) 2x
Optimizer States(Adam,全 FP32) 12x(4x 参数副本 + 4x momentum 动量 + 4x variance 方差)

参考:Reducing Activation Recomputation in Large Transformer Models

下图是论文中针对不同优化器的显存占用对比(K 表示倍数, adam的时候K=12):

Baseline

不同优化器的 K 值不同,而 DeepSpeed 的 ZeRO 优化器就是通过对 P/G/OS 进行分片,来降低显存占用的:

  • ZeRO-1:分片 OS
  • ZeRO-2:分片 OS + G
  • ZeRO-3:分片 OS + G + P

Adam 参数说明

更多关于 Adam 优化器的细节可参考官方文档:

torch.optim.Adam

3 DDP(Data Parallelism)数据并行

在 DDP 模式下,每张 GPU 保存完整的模型副本,各自独立执行前向与反向,然后通过 All-Reduce 汇聚梯度。

DDP AllReduce

流程如下:

  1. 每张卡独立执行前向计算
  2. 各自计算梯度
  3. 使用 All-Reduce 对所有梯度求和,并广播给所有 GPU
  4. 每张卡使用本地优化器独立更新参数

4 PP(Pipeline Parallelism)流水线并行

PP 将模型按层拆分到不同 GPU 上,形成流水线。PyTorch 提供了 Pipe API:


from torch.distributed.pipeline.sync import Pipe

# 初始化 RPC 框架
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
torch.distributed.rpc.init_rpc('worker', rank=0, world_size=1)

# 构建流水线
fc1 = nn.Linear(16, 8).cuda(0)
fc2 = nn.Linear(8, 4).cuda(1)
model = nn.Sequential(fc1, fc2)
# chunks: number of micro-batches (default: 1)
model = Pipe(model, chunks=8)

input = torch.rand(16, 16).cuda(0)
output_rref = model(input)

5 FSDP(Fully Sharded Data Parallel)并行

论文:Fully Sharded Data Parallel

FSDP 是一种“全分片”方案,相比 DDP,它能极大地降低显存峰值占用。

5.1 核心步骤

  1. 定义 FSDP Unit:确定按 layer / module / stage 的垂直切分单元
  2. Sharding:对 P / G / OS 进行水平切分
  3. All Gather:前向传播前收集参数
  4. Reduce Scatter:反向传播后梯度聚合并分发

fsdp_overall

例如,一个 6 层的模型,可以划分成 3 个 unit(如图 layer0+3、layer1+2、layer4+5)。共享参数的 layer 需放在同一 unit 中

5.2 FlatParameter 与 Sharding

FSDP 会把每个 unit 的权重和 bias 合并为一个 FlatParameter,然后在不同 GPU 上进行分片存储:

unit_sharding

上图描述了sharding过程,

首先把 weight 和bias 都存成FlatParameter,可能会存在一定的padding

FlatParameter 存好之后,每张卡分到一份FlatParameter

  • construct units 构造 units
    • unit0
    • unit1
    • unit2
  • sharding:
    • 把unit 存成 FlatParameter
    • split FlatParameter 到多个node/gpu
    • torch.distributed.fsdp.FullyShardedDataParallel
      • sharding_strategy
        • FULL_SHARD: os + g + p
        • shard_grad_OP: os + G

5.3 ALL gather

allgather

上图描述了NCCL 的all gather 原语,每张卡存了不同的shard,主要是先concat ,再做广播

前向传播时,需要从不同 GPU 收集分片 → 拼接成完整参数 → 计算完后释放。 这个过程与计算可 overlap,即在计算 unit0 的前向时,可以并行 gather unit1 的参数,从而提升效率。

FSDP allgather

上图是fsdp ,分成了4份, 前向的时候gather起来,再广播, 算完forward backward再释放

每张卡继续保留部分权重

overlap_comm_comp

其中,这种前向反向的过程,是有一定的overlap处理的

首先,对unit0 gather, gather后算前向,算unit0前向的同时,可以对 unit1 做 gather,通信和计算可以同时进行,就是overlap

5.4 reduce-scatter

reduce scatter

reduce 默认操作是加和,

反向传播后,各卡计算的梯度会通过 Reduce-Scatter 聚合并分发给各自负责的分片。

fsdp_sharding

fsdp_red_scatter

上图,gather完之后,算完不同的梯度,4卡的梯度加起来,不同的部分再分发到不同的卡上

5.5 DDP 和FSDP的区别

  • DDP:每张 GPU 上存放完整模型,显存占用高
  • FSDP:每张 GPU 只存部分模型,剩余显存可用于更大的 batch size

DDP vs FSDP

QA: FSDP 下是否仍会出现“峰值显存 = 整个模型大小”?

是的,在 计算某一层时,该层的完整参数会被 gather 到所有参与 GPU 上,所以峰值显存会瞬时包含该层的完整参数。但这是逐层 gather,不会像 DDP 那样同时存下整个模型的所有参数。

参考: