跳转至

RL Data Flow

导言

这篇文章只回答一个问题:一条 RL 样本从 prompt 进入系统,到 rollout、reward、logprob、advantage、loss、backward,最后回到下一轮训练时,数据到底怎么流、shape 怎么变、显存为什么涨。

先建立三张表

写这篇文章时,先不要急着解释算法名词,优先建立三张表:

  1. shape 表:logical shape、physical shard shape、mask shape。
  2. memory 表:参数、梯度、优化器状态、激活值、KV cache、临时 buffer。
  3. lifetime 表:每个 tensor 在哪个阶段生成、在哪个阶段被消费、在哪个阶段释放。

1. 为什么要先讲数据流

  • shape mismatch 的根源通常不是公式错,而是没有把“逻辑 batch”和“物理 shard”分开。
  • OOM 不是单纯显存不够,而是没有把峰值阶段和张量生命周期拆开。
  • 多卡不均衡 不是一句“通信慢”能解释的,而要看 token 分布、micro batch 切分和 stage 时间。

2. RL 的端到端链路

2.1 主链路

prompt batch -> rollout / generation -> reward -> old/ref logprob -> advantage / return -> actor update -> next step

2.2 需要在文中讲清楚的语义

  • sample 维度:一条 prompt 对应多少个 response / candidate。
  • token 维度:每个阶段是按 token 处理,还是按 sample 聚合。
  • group 维度:GRPO / group sampling 中一个 prompt 下面的多个 response 如何展开。
  • micro batch 维度:训练时实际送进 forward/backward 的最小批。
  • dynamic batch 维度:按 token 预算、max length、显存预算动态变化的 batch。

3. 推理输入到推理输出

3.1 输入侧

  • prompt_ids
  • attention_mask
  • position_ids
  • 可能还有:prompt_lengthsample_idgroup_id

3.2 输出侧

  • response_ids
  • response_mask
  • full_sequence
  • generation metadata
  • logits / logprobs(如果 rollout 后端需要返回)
  • abort / timeout 状态

3.3 典型 shape

下面是建议的“典型形式”,具体实现要回到代码确认。

  • prompt_ids: [B_prompt, S_prompt]
  • response_ids: [B_prompt * G, S_resp]
  • full_sequence: [B_prompt * G, S_prompt + S_resp]
  • response_mask: [B_prompt * G, S_resp]
  • logprobs: [B_prompt * G, S_resp]

3.4 关键解释点

  • 动态 bs 不是固定样本数,而是受 token 数与调度器共同影响。
  • padding 会把逻辑 shape 拉大,但不一定增加有效 token 计算。
  • mask 决定了哪些 token 参与 loss、哪些 token 只是上下文。

4. reward 与 advantage 的数据流

4.1 reward 的来源

  • reward model
  • rule-based reward
  • function reward
  • 多 reward 融合

4.2 需要说明的 shape

  • sample-level reward:[B_prompt, G][B_prompt * G]
  • token-level reward:[B_prompt * G, S_resp] 或 broadcast 后的同形矩阵
  • advantage / return:常见为按 sample 计算后广播到 response token

4.3 关键解释点

  • 为什么有些 reward 是标量,有些是 token 级。
  • 为什么 advantage 常常先按 sample 算,再扩展到 token 维度。
  • response_mask 如何屏蔽 prompt token。

5. 训练输入到训练输出

5.1 训练阶段输入

  • old_logprob
  • ref_logprob
  • new_logprob
  • advantages
  • returns
  • response_mask
  • loss_mask

5.2 训练阶段输出

  • policy loss
  • kl loss
  • entropy loss
  • grad norm
  • updated params

5.3 需要在文中明确的 shape 关系

  • old_logprob: [B, S_resp]
  • ref_logprob: [B, S_resp]
  • new_logprob: [B, S_resp]
  • advantages: [B][B, S_resp]
  • loss_mat: [B, S_resp]
  • loss_mask: [B, S_resp]

5.4 文章里要解释的点

  • token-level loss 与 sample-level reward 的映射关系。
  • 为什么 old_logprob 要先保存下来。
  • 为什么 KLentropy 是稳定性指标,不只是 loss 的附属项。

6. Shape ledger

6.1 这一节的目标

把每个阶段的 tensor 记录成统一表格,避免只靠脑补推 shape。

6.2 建议表头

stage tensor logical shape local shard shape mask / broadcast owner rank lifetime common bug

6.3 建议至少覆盖的张量

  • prompt / response / full_sequence
  • attention mask / response mask
  • old / ref / new logprob
  • reward / advantage / return
  • loss matrix / loss mask
  • hidden states / activations
  • KV cache

7. Memory ledger

7.1 显存拆分

  • 参数:model weights。
  • 梯度:backward 期间的梯度张量。
  • 优化器状态:moment / variance 等。
  • 激活值:forward 保存的中间状态。
  • KV cache:rollout / inference 阶段常见大头。
  • 通信 buffer:all-gather / reduce-scatter / ring 通信临时缓冲。
  • 临时 tensor:loss、mask、拼接、索引等短生命周期张量。

7.2 估算原则

  • activation_bytes ≈ layers × local_tokens × hidden_size × bytes × factor
  • kv_bytes ≈ layers × batch_local × seq_local × kv_heads × head_dim × 2 × bytes
  • peak_memory ≈ params + grads + optimizer + activations + kv_cache + buffers

7.3 文章里要解释的点

  • 哪一阶段最可能成为峰值。
  • allocatedreserved 的差异意味着什么。
  • 动态 batch 如何改变实际峰值。

8. 并行切分总图

8.1 需要对齐的并行维度

  • DP:切 batch。
  • TP:切 hidden / head / linear 权重。
  • PP:切 layer。
  • SP:切 sequence 相关激活或中间状态。
  • CP:切 context / sequence,并在 attention 语义上做通信。

8.2 这篇文章要强调的核心区别

  • batch 切分 影响的是“有多少条样本同时跑”。
  • sequence 切分 影响的是“每条样本内部怎么切 token”。
  • shape mismatch 常常就是这两个切分维度混淆了。

9. SP 与 CP 的逻辑差异

9.1 SP:更偏激活切分

  • 重点是降低激活显存。
  • 目标是让 sequence 相关的中间状态在多个 rank 之间分片。
  • 适合在 MLP / norm / residual 等可分块区域做通信隐藏。

9.2 CP:更偏上下文切分

  • 重点是让长上下文 attention 能在更长序列上跑起来。
  • 语义上更接近“把上下文拆开再重组”。
  • 通常通信模式比 SP 更重,也更依赖 attention 的具体实现。

9.3 文中必须回答的三个问题

  1. 逻辑 shape 是什么:完整 [B, S, H] 如何映射到 local shard。
  2. 哪些算子能本地算:哪些算子可以完全在 local shard 上执行。
  3. 通信怎样被掩盖:什么时候 T_compute(chunk) >= T_comm(chunk),能把通信藏进计算里。

9.4 判断准则

  • 当 chunk 粒度足够大、通信可流水化、算子可以分块时,更容易实现 compute / comm overlap。
  • 如果 attention 需要全局上下文而本地 shard 不足,则通信更难完全掩盖。
  • 文中要明确:SP 和 CP 都可能切 sequence,但切分目的、通信位置、mask / position / KV 语义不同。

10. DFX 设计

10.1 设计目标

DFX 不只是日志,而是要同时回答:

  • 正确性:shape 对不对、mask 对不对、token 对不对。
  • 稳定性:KL、entropy、grad norm、abort ratio 是否健康。
  • 性能:step time、tokens/s、stage time、MFU / SMA。
  • 显存:allocated / reserved、峰值、碎片率。
  • 负载:rank 间 token 与时间是否均衡。
  • 数据质量:prompt length、response length、clip ratio、reward 分布。

10.2 建议的指标层次

  • E2E 指标:step time、throughput、total tokens。
  • 阶段指标:rollout / reward / logprob / ref / update 的耗时。
  • 张量指标:shape、bytes、mask ratio、lifetime。
  • 并行指标:rank 负载、通信时长、bubble、queue depth。
  • 稳定性指标:KL、entropy、grad norm、clipfrac。

10.3 建议的日志字段

  • stage
  • rank
  • tensor_name
  • logical_shape
  • local_shape
  • dtype
  • numel
  • bytes
  • mask_ratio
  • lifetime
  • comm_type
  • latency_ms
  • owner

10.4 告警规则草案

  • shape mismatch:shape checksum 不一致。
  • OOM:reserved 快接近物理上限,且碎片率升高。
  • load imbalance:per-rank token 或 step time 偏差过大。
  • training instability:KL / grad norm / entropy 同时异常波动。

11. 调试顺序

11.1 shape mismatch

  1. 先看 logical shape。
  2. 再看 local shard shape。
  3. 再看 mask / broadcast / reshape。
  4. 最后才看算子实现和并行配置。

11.2 OOM

  1. 先找峰值阶段。
  2. 再拆参数、梯度、优化器、激活、KV、buffer。
  3. 再看 dynamic batch / sequence split 是否改变了峰值。

11.3 多卡不均衡

  1. 先看 token 分布。
  2. 再看动态 batch 和 micro batch 切分。
  3. 再看 SP / CP / PP / TP 的通信等待。
  4. 最后看异步队列与 straggler。

12. 收束

  • RL Infra 的核心不是某一个公式,而是 数据流 + shape + 生命周期 + 并行切分 + DFX 的统一视角。
  • 只要把这五件事统一起来,后面再看 verl 的训练、推理、异步和 checkpoint,很多问题都会变得可解释。
  • 第一篇的最终目标,是给后面所有文章提供一张共同的“坐标系”。

评论