ZeRO 究竟算数据并行还是模型并行?

 

Zero(ZeRO, Zero Redundancy Optimizer)总被问:它究竟是数据并行(DP)的技巧,还是模型并行(MP)的变体? 本文把 ZeRO 放回数据并行的语境,解释它如何削减显存冗余、为何仍维持 DP 语义,以及和真正模型并行的根本区别。

1 数据并行的痛点:冗余状态

传统 DP 在每张卡上复制 完整模型

  • 参数、梯度、优化器状态都各存一份;
  • Backward 结束后 All-Reduce 梯度,保持副本一致。

随着模型变大,显存中堆着 N 份(N=GPU 数)相同状态,浪费巨大。ZeRO 的目标就是保持 DP 训练语义不变,却让这些状态不再 N 份复制。

2 ZeRO 的核心:分区状态而非拆计算

ZeRO 把模型状态切成片,平均分到不同 GPU。分区力度随 stage 提升:

Stage 被分区的状态 额外通信 显存收益 能否突破单卡参数量
Stage 1 仅优化器状态 基本无 降低 optimizer 占用 否,参数仍完整驻留
Stage 2 优化器状态 + 梯度 backward 内部的 Reduce-Scatter/All-Gather(沿用原 DP) 再省去梯度副本 否,参数仍完整
Stage 3 参数 + 梯度 + 优化器状态 每次 forward/backward 前后都要 All-Gather / Reduce-Scatter 参数 几乎消除所有状态冗余 是,可按 sharding 数扩充模型

Stage 1/2 的节省上限固定在优化器 + 梯度,不改变参数驻留,因此无法突破单卡参数上限。只有 Stage 3 引入参数分区,才能继续扩展模型规模。

ZeRO 的计算路径仍是“每张卡拿到一批样本,执行完整模型的 forward/backward”,只是状态存储方式不同。

3 为什么 ZeRO 不是模型并行

模型并行的核心是计算图跨 GPU:不同层、张量或 token 由不同设备计算,并通过激活(activations)传递上下文。

ZeRO 即便到了 Stage 3,也只是“把参数临时聚合到本卡算完再还回去”,计算图从始至终没跨卡。对比如下:

比较项 普通 DP ZeRO(Stage 3) 模型并行
模型副本 每卡完整一份 参数/梯度/优化器按 shard 切开放在所有卡 每卡只负责部分层或张量
前向/反向计算 各卡独立完成全模型 依然完成全模型 计算图被拆分,激活跨卡流动
主要通信 梯度 All-Reduce 参数 All-Gather + Reduce-Scatter 激活 Send/Recv,在层间穿梭
是否改变计算图

把 ZeRO-3 和典型模型并行再拆:

比较项 ZeRO-3 模型并行
通信对象 参数 激活
通信时机 每次层计算前后聚合/拆散参数 层与层之间传递激活
通信目的 “借”到本卡一份参数来完成单卡计算 把计算图继续下放给下一张卡
类比 去别的 GPU 借书页(参数),读完还回去 每个 GPU 只写书的一章,下一章要读上一章的笔记(激活)

因此 ZeRO 属于 State-Sharded Data Parallelism:状态是分布式的,计算仍是 per-card 完整副本。

4 混合:ZeRO + 真实模型并行 = 3D 并行

在 Megatron-Deepspeed、DeepSeek 等系统里,常见组合是:

  1. DP + ZeRO Stage 1/2:用来分摊 optimizer/梯度;
  2. Tensor Parallel / Pipeline Parallel / Expert Parallel:拆模型计算本身;
  3. 必要时再叠 Sequence Parallel 等。

这种“3D 并行”让每条维度负责不同瓶颈:ZeRO 管显存,TP/PP/EP 管计算切分。ZeRO 不是模型并行的替代,而是 DP 的增强件,和 MP 完全可以叠加。

5 ZeRO-1 vs ZeRO-2:何时仍选 Stage 1?

很多人看到 Stage 2 通信成本接近 Stage 1,却多省了梯度显存,就疑惑 Stage 1 还有何意义。

  • Stage 2 把 All-Reduce 拆成 Reduce-Scatter + All-Gather,理论通信量与 DP 几乎一致,但会增加少量控制/内存管理开销。
  • 当启用梯度累积(Gradient Accumulation)且 micro-batch 很小时,Stage 2 需要频繁调度分片梯度,可能导致 dispatcher 成本反而更高。
  • Stage 1 因为只分 optimizer 状态,在高累积步数、长序列训练中更稳定;对老旧优化器(Adafactor 之类)也更容易集成。

实务建议:

  1. 需要突破参数显存 -> 直接上 Stage 3。
  2. 仍在单卡参数范围内,但希望减少 optimizer/梯度 -> Stage 2 默认首选。
  3. 大梯度累积、通信链路脆弱或定制优化器难以改写 -> Stage 1 反而更稳。

6 关键信息速览

  • ZeRO 解决的是 DP 的冗余状态,不是把计算图拆成模型并行。
  • Stage 1/2 只能节省 optimizer + 梯度,Stage 3 才能真正突破参数显存上限。
  • ZeRO-3 通信的对象是参数,而模型并行通信的是激活,两者作用完全不同。
  • 在 3D 并行体系里,ZeRO 与 TP/PP/EP 是互补关系,常同时启用。
  • Stage 1 仍有价值:高梯度累积、通信受限或优化器难以分片时,它的实现成本最低。

一句话总结:ZeRO 是数据并行的内存优化实现,让每张卡“借”到所需的模型状态,却从未改变每卡独立执行完整模型这一训练语义。