1 训练过程中显存的组成
在大模型训练中,显存大致被三类数据占用:
- P (Parameters):模型参数
- G (Gradients):反向传播时的梯度
- OS (Optimizer States):优化器的状态(如 Adam 的动量和方差信息)
典型的 Adam 优化器中,这三者的比例大致为 1 : 1 : 6
之所以 Optimizer States 占比高,是因为 Adam 需要为每个参数维护一阶矩(动量)和二阶矩(方差)的额外信息
例如,在 PyTorch 中优化器的典型构造如下:
optimizer = optim.Adam(model.parameters(), lr=0.001)
反向传播与参数更新的过程:
loss.backward()
→ 生成parameters.grad
optimizer.step()
→ 更新优化器状态(OS)
PyTorch 内部大致逻辑如下:
for group in optimizer.param_groups:
for p in group['params']:
state = optimizer.state[p]
# Exponential moving average of gradient values
m = state['exp_avg'] # 动量参数
# Exponential moving average of squared gradient values
v = state['exp_avg_sq'] # 方差参数
2 混合精度训练下的显存占用
在混合精度(前向 FP16,反向和优化器 FP32)下,如果模型有 x
个参数,显存占用大致如下:
类型 | 占用 |
---|---|
Parameters(FP16 + FP32 拷贝) | 2x |
Gradients(FP32) | 2x |
Optimizer States(Adam,全 FP32) | 12x(4x 参数副本 + 4x momentum 动量 + 4x variance 方差) |
参考:Reducing Activation Recomputation in Large Transformer Models
下图是论文中针对不同优化器的显存占用对比(K 表示倍数, adam的时候K=12):
不同优化器的 K 值不同,而 DeepSpeed 的 ZeRO 优化器就是通过对 P/G/OS 进行分片,来降低显存占用的:
- ZeRO-1:分片 OS
- ZeRO-2:分片 OS + G
- ZeRO-3:分片 OS + G + P
更多关于 Adam 优化器的细节可参考官方文档:
3 DDP(Data Parallelism)数据并行
在 DDP 模式下,每张 GPU 保存完整的模型副本,各自独立执行前向与反向,然后通过 All-Reduce 汇聚梯度。
流程如下:
- 每张卡独立执行前向计算
- 各自计算梯度
- 使用 All-Reduce 对所有梯度求和,并广播给所有 GPU
- 每张卡使用本地优化器独立更新参数
4 PP(Pipeline Parallelism)流水线并行
PP 将模型按层拆分到不同 GPU 上,形成流水线。PyTorch 提供了 Pipe
API:
from torch.distributed.pipeline.sync import Pipe
# 初始化 RPC 框架
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
torch.distributed.rpc.init_rpc('worker', rank=0, world_size=1)
# 构建流水线
fc1 = nn.Linear(16, 8).cuda(0)
fc2 = nn.Linear(8, 4).cuda(1)
model = nn.Sequential(fc1, fc2)
# chunks: number of micro-batches (default: 1)
model = Pipe(model, chunks=8)
input = torch.rand(16, 16).cuda(0)
output_rref = model(input)
5 FSDP(Fully Sharded Data Parallel)并行
论文:Fully Sharded Data Parallel
FSDP 是一种“全分片”方案,相比 DDP,它能极大地降低显存峰值占用。
5.1 核心步骤
- 定义 FSDP Unit:确定按 layer / module / stage 的垂直切分单元
- Sharding:对 P / G / OS 进行水平切分
- All Gather:前向传播前收集参数
- Reduce Scatter:反向传播后梯度聚合并分发
例如,一个 6 层的模型,可以划分成 3 个 unit(如图 layer0+3、layer1+2、layer4+5)。共享参数的 layer 需放在同一 unit 中
5.2 FlatParameter 与 Sharding
FSDP 会把每个 unit 的权重和 bias 合并为一个 FlatParameter,然后在不同 GPU 上进行分片存储:
上图描述了sharding过程,
首先把 weight 和bias 都存成FlatParameter,可能会存在一定的padding
FlatParameter 存好之后,每张卡分到一份FlatParameter
- construct units 构造 units
- unit0
- unit1
- unit2
- sharding:
- 把unit 存成 FlatParameter
- split FlatParameter 到多个node/gpu
- torch.distributed.fsdp.FullyShardedDataParallel
- sharding_strategy
- FULL_SHARD: os + g + p
- shard_grad_OP: os + G
- sharding_strategy
5.3 ALL gather
上图描述了NCCL 的all gather 原语,每张卡存了不同的shard,主要是先concat ,再做广播
前向传播时,需要从不同 GPU 收集分片 → 拼接成完整参数 → 计算完后释放。 这个过程与计算可 overlap,即在计算 unit0 的前向时,可以并行 gather unit1 的参数,从而提升效率。
上图是fsdp ,分成了4份, 前向的时候gather起来,再广播, 算完forward backward再释放
每张卡继续保留部分权重
其中,这种前向反向的过程,是有一定的overlap处理的
首先,对unit0 gather, gather后算前向,算unit0前向的同时,可以对 unit1 做 gather,通信和计算可以同时进行,就是overlap
5.4 reduce-scatter
reduce 默认操作是加和,
反向传播后,各卡计算的梯度会通过 Reduce-Scatter 聚合并分发给各自负责的分片。
上图,gather完之后,算完不同的梯度,4卡的梯度加起来,不同的部分再分发到不同的卡上
5.5 DDP 和FSDP的区别
- DDP:每张 GPU 上存放完整模型,显存占用高
- FSDP:每张 GPU 只存部分模型,剩余显存可用于更大的 batch size
QA: FSDP 下是否仍会出现“峰值显存 = 整个模型大小”?
是的,在 计算某一层时,该层的完整参数会被 gather 到所有参与 GPU 上,所以峰值显存会瞬时包含该层的完整参数。但这是逐层 gather,不会像 DDP 那样同时存下整个模型的所有参数。