ZeRO++: Extremely Efficient Collective Communication for Giant Model Training

本论文是Guanhua Wang, Heyang Qin, Sam Ade Jacobs, Connor Holmes, Samyam Rajbhandari, Olatunji Ruwase, Feng Yan, Lei Yang, Yuxiong He发表在 ICML'24上的工作。

本工作主要通过量化和分级通信,分别优化了训练中 ZeRO3 中的三种不同集合通信过程,使跨节点通信的总量从 3M 降低到 0.75M.

Entry:Zotero link URL link

Motivation

ZeRO 尽管能够高效的训练大规模模型,然而其在某些场景下依然面临着性能不高效的问题:

  • cluster with low bandwidth:如今的硬件在 intra-node 和 inter-node 的 bandwidth 相差巨大,因此,对于具有较多通信的训练场景下,跨节点通信很容易成为训练瓶颈
  • very small batch sizes per GPU:当 global batch size 保持不变时,随着 GPU cluster 的规模增大,每个 GPU 能够处理的 batch size 会逐渐减少。而这会使得通信和计算的比例逐渐增大,从而造成通信瓶颈

在 ZeRO3 的过程中,一次forward+backward迭代需要进行三次集合通信:

  1. all-gather paramters in forward
  2. all-gather parameters in backward
  3. reduce-scatter gradients in backward

因此文章提出ZeRO++进一步减少这些通信的通信量,主要包括:

  1. Quantized Weight Communication for ZeRO(qwZ):在forward做weight all-gather时,将weight进行量化以此来减少通信量
  2. Hierarchical Partitioning for ZeRO(hpZ):在backward做weight all-gather时,将weight的切分限制在sub-group中(一般为节点内),从而减少节点间的通信
  3. Quantized Gradients Communication for ZeRO(qgZ):在做gradients reduce-scatter时,将gradients进行量化以此来减少通信量

Design

Quantized Weight Communication for ZeRO

使用分组量化策略,对于一个 weight 矩阵通过按行/列分组量化的形式,减少量化误差

在前向传播做weight all-gather之前,将 weight 从 fp16 量化为 int8 类型,做完 all-gather 通信后,又将 weight 从 int8 还原为 fp16 类型

Hierarchical Partitioning for ZeRO

在ZeRO中,参数的切分发生在所有的 GPU 之间,即被切分成 world size 份,那么这些参数所涉及到通信自然也在所有 GPU 之间发生。hpZ 则主要优化的是反向传播中的 all-gather 通信,通过将参数切分限制在 secondary-group 内,通常而言是限制在单个节点内,节点之间的参数则是 replica,通过牺牲显存的方式来避免在所有 GPU 之间做通信,从而提升通信效率。

如图所示,在ZeRO3中,paramter 被切分成4份,每个 GPU 拥有一份,自然 parameter 的通信发生在四个 GPU 之间(这里会涉及到跨节点通信);在 hpZ 中,paramter 被切分成两份,G0 和 G2 拥有的 paramter 是相同的,这样 paramter 的通信只会发生在节点内,而这需要更大的显存来确保 GPU 能够放下 paramter 的切片。

一次iteration过程:

  1. weight在primary group(指所有的gpu)内被切分
  2. all-gather weight in primary group before forward
  3. forward
  4. partition weight in secondary group
  5. all-gather weight in secondary group
  6. backward

Quantized Gradients Communication for ZeRO

做 reduce-scatter 之前将 gradients 从 fp16 转化为 int4 类型,做完r educe-scatter 通信后,将 gradients 从 int4 类型还原为 fp16 类型

为了实现这一目标,主要做了三件事情:

  1. 使用 all2all 通信代替 reduce-scatter 来减少量化误差和量化时间
  2. 设计了分级的 all2all 通信来进减少节点间通信量
  3. 使用 tensor slice reordering 算法来保证通信结果的正确性

Implement

实现过程中,通过把数据切成多个 data chunk,可以进行 intra-node 和 inter-node 通信的 overlap:


Last modified on 2024-11-13