跳转至

Muon Optimizer + FSDP

导言

muon 优化器在FSDP场景下 ,xtuner以及业界先进方法是如何实现的。

xtuner的实现是根据相同shape串性把tensor来all2all,专家tensor不够fsdp_size时,还需要padding。并且内存快照时发现all2all要申请一个大buffer 35B 128卡 sp4 256k,好像有10GB左右。

这份文档围绕两个问题展开:

  1. FSDP + Muon 的 All2All 在内存快照中申请了约 10GB 大 buffer,该如何避免?是否可以分块通信?
  2. veScale-FSDP 中关于 RaggedShard + Muon 的设计有什么值得学习的地方?

先给结论:

  • 10GB All2All buffer 大概率不是 Muon 必然开销,而是 bucket 粒度过大、padding、临时 pack/unpack buffer、Newton-Schulz workspace 叠加导致的峰值。
  • 可以分块通信,但应优先“按矩阵组分块”,不要把单个 Muon 矩阵切碎后分别正交化,否则会改变优化器语义。
  • exact Muon 的下界是:某个 rank 至少要持有一个完整 2D 矩阵及其 Newton-Schulz 临时 workspace。这个下界无法通过普通 All2All 消除。
  • veScale-FSDP 的关键思想不是简单换一个 collective,而是把“结构感知布局”做进 FSDP:RaggedShard 表示不规则 sharding,Planner 减少 padding,DBuffer 做持久化零拷贝通信 buffer。

背景:FSDP 与 Muon 的冲突

Muon 的核心更新不是逐元素操作,而是对 2D 矩阵的 momentum 做近似正交化:

M_t = momentum_update(grad)
U_t = NewtonSchulz(M_t)
W_t = W_t - lr * U_t

AdamW 可以在 FSDP local shard 上直接做,因为它是 element-wise update。 但 Muon 不行,因为:

Orthogonalize([M_0; M_1; ...; M_{S-1}])_r
!=
Orthogonalize(M_r)

也就是说,在每个 FSDP rank 的 shard 上独立跑 Newton-Schulz,不等价于对完整矩阵跑 Muon

因此 exact FSDP + Muon 通常需要:

  1. 每个 rank 持有矩阵的一个 shard;
  2. 通过 All2All / gather 把完整矩阵重组到某些 rank;
  3. 这些 rank 执行 Newton-Schulz;
  4. 再把更新后的 shard 发回原 FSDP layout。

XTuner 那种“同 shape tensor 分桶,然后 All2All;数量不够 FSDP size 时 padding”的方案,本质上是在实现这个 exact 语义。


问题一:为什么 All2All buffer 会到 10GB

假设:

FSDP size = S
同 shape 矩阵数量 = B
单个完整矩阵大小 = P bytes
每个 rank 最终负责 C 个完整矩阵

如果实现一次性处理一个大 bucket,那么每个 rank 至少可能需要:

input pack buffer      ≈ C * P
output full buffer     ≈ C * P
Newton-Schulz workspace ≈ alpha * C * P
reverse All2All buffer  ≈ C * P
padding/alignment       ≈ extra

所以峰值近似是:

peak ≈ (2 + alpha + extra) * C * P

其中 alpha 取决于 Newton-Schulz 实现,可能是 1 到 3 甚至更高。

如果一个 35B 模型里某些 MLP 矩阵本身就有数百 MB,而一个 bucket 让每个 rank 同时接收多个完整矩阵,那么 10GB 峰值并不意外。

还要注意:SP4、256k sequence length 本身通常不直接决定 Muon All2All buffer 大小。Muon buffer 主要由模型参数形状、FSDP group、bucket 策略、dtype 决定。但 256k 长序列会让 activation、grad bucket、allocator reserved memory 压力更大,导致 optimizer step 的临时 buffer 更容易成为 OOM 触发点。


分块通信:可以,但要按矩阵分块

推荐分块方式

应该把一个大 Muon bucket 拆成多个 micro-bucket:

big bucket:
  [W0, W1, W2, ..., W127]

micro-bucket 0:
  [W0, W1, ..., W15]

micro-bucket 1:
  [W16, W17, ..., W31]

...

每个 micro-bucket 执行:

1. pack local momentum shards
2. All2All: shard layout -> full matrix layout
3. Newton-Schulz on full matrices
4. All2All: full update layout -> original shard layout
5. apply update
6. reuse buffer for next micro-bucket

关键是:每个 Muon 单元仍然是完整矩阵

不推荐的分块方式

不要把单个矩阵切成多个 row block 分别做 Muon:

W = [W_top; W_bottom]

NewtonSchulz(W_top) 和 NewtonSchulz(W_bottom)

这不再是原始 Muon,而是 block-wise Muon / approximate Muon。它可能可用,但优化器语义已经变了,loss 曲线和稳定性都需要重新验证。

例外情况是 fused tensor:

qkv_proj.weight = [Q; K; V]
experts.weight = [E, out, in]

这类 tensor 可以先按语义拆成多个逻辑矩阵,再分别做 Muon。这个拆分是合理的,因为 Muon 单元本来就应该是每个逻辑矩阵,而不是整个 fused 大 tensor。


内存控制方案

1. 给 Muon 设置独立 bucket cap

不要沿用 FSDP 的通信 bucket 大小,也不要把所有 same-shape tensor 一次性 All2All。

建议引入:

muon:
  comm_bucket_cap_mb: 512        # 起步建议 512MB 或 1GB
  max_inflight_buckets: 1        # 内存紧张时先设 1
  preallocate_workspace: true
  dtype: bf16

估算 micro-bucket 大小:

def estimate_muon_peak(full_bytes, ns_alpha=2.0, reverse_buffer=True):
    # full_bytes: 当前 micro-bucket 中每个 rank 需要持有的完整矩阵总 bytes
    base = 2 * full_bytes          # input + output
    ns = ns_alpha * full_bytes
    reverse = full_bytes if reverse_buffer else 0
    return base + ns + reverse

选择 micro-bucket 时应满足:

estimate_muon_peak(bucket) <= muon_workspace_cap

如果当前看到 10GB buffer,可以先把 cap 降到:

512MB ~ 1GB full-matrix bytes per rank

然后观察 optimizer step time 的变化。通常这会增加 All2All 次数,但能显著降低峰值显存。


2. 预分配并复用 workspace

不要每个 bucket 都:

tmp = torch.empty(...)

而应该在 optimizer 初始化时预分配固定 workspace:

workspace = MuonWorkspace(
    in_buf=torch.empty(max_bytes, dtype=torch.bfloat16, device="cuda"),
    out_buf=torch.empty(max_bytes, dtype=torch.bfloat16, device="cuda"),
    ns_buf=torch.empty(max_ns_bytes, dtype=torch.bfloat16, device="cuda"),
)

每个 micro-bucket 只使用 narrow/view

in_view = workspace.in_buf[:needed_numel]
out_view = workspace.out_buf[:needed_numel]

这样可以避免:

  • PyTorch caching allocator 反复申请大块内存;
  • stream lifetime 导致旧 buffer 不能及时复用;
  • memory fragmentation;
  • snapshot 中出现多个临时大块并存。

veScale-FSDP 的 DBuffer 思想也类似:用持久化 distributed buffer 和地址映射来避免反复 copy/alloc。


3. 限制 in-flight bucket 数量

为了 overlap,很多实现会同时挂多个异步 All2All:

chunk 0 communicating
chunk 1 packing
chunk 2 Newton-Schulz

这有利于性能,但会增加峰值显存。

在 35B、128 卡、SP4、256k 这种 activation 压力很高的场景,建议先关闭 aggressive overlap:

muon:
  max_inflight_buckets: 1

稳定后再尝试:

muon:
  max_inflight_buckets: 2

不要一开始就让 3 到 4 个 Muon chunks 同时在飞。


4. 尽量使用 BF16 通信和计算 buffer

检查 All2All buffer 的 dtype。

如果 momentum 或 update buffer 是 FP32:

10GB FP32 -> 5GB BF16

Muon 的大规模实现通常会尽量让通信和 Newton-Schulz 主路径使用 BF16 / FP16 / Tensor Core 友好的格式。 但需要注意:这可能影响数值稳定性。建议至少记录:

update RMS
grad norm
attention logits max
loss spike

如果 BF16 Muon 不稳定,可以只对 Newton-Schulz 的某些归一化标量保留 FP32,而不是让整个 All2All buffer 变成 FP32。


5. 用 all_to_all_single 的 split sizes 或 ragged all-to-all 减少 padding

XTuner 的 same-shape bucket 通常要求:

real_count padded 到 fsdp_size 的倍数

如果 expert tensor 数量少于 FSDP size,会产生严重 padding:

experts = 8
fsdp_size = 64
padding ratio = 87.5%

可以改成两条路径:

dense/common shape:
  equal-size all_to_all_single,走高带宽路径

expert/small/irregular shape:
  ragged all_to_all / all_to_all_single(split_sizes) / gather-to-root

如果后端支持 variable split sizes,优先避免 dummy tensor padding。 如果 variable all-to-all 性能不理想,小 bucket 可以退回 P2P gather/scatter 或 AdamW。


6. 跨 layer 合并 expert bucket

不要按 layer 做 expert bucket:

layer0: 8 experts -> pad to 64
layer1: 8 experts -> pad to 64
...

应该跨 layer 合并:

all_layers.experts.up_proj
all_layers.experts.gate_proj
all_layers.experts.down_proj

例如:

num_layers = 32
experts_per_layer = 8
fsdp_size = 64

按层 bucket:
  每层 8 个,padding 87.5%

跨层 bucket:
  32 * 8 = 256 个
  可切成 4 个 64-bucket,几乎无 padding

这是 MoE + FSDP + Muon 中非常关键的优化。


7. 对不划算的参数退回 AdamW

你前面已经观察到:在 Qwen 35B SFT 中,Muon 每步 loss 下降未必比 AdamW 快。

因此在 SFT 场景没必要追求 Muon 覆盖率 100%。建议:

Muon:
  attention q/k/v/o projection
  dense MLP gate/up/down
  大多数 routed expert MLP

AdamW:
  embedding
  lm_head
  norm
  bias
  router/gate
  LoRA 参数
  padding ratio 过高的小 expert bucket
  极大且导致显存峰值的个别矩阵

可以加策略:

if bucket.padding_ratio > 0.5:
    use_adamw(bucket)

if estimated_muon_peak(bucket) > hard_cap:
    split_bucket_or_use_adamw(bucket)

if sft and eval_not_improved_by_muon:
    reduce_muon_coverage()

8. HSDP 降低 FSDP group size

veScale-FSDP 论文也给出类似经验:不要盲目扩大 FSDP group size,必要时用 HSDP 控制 collective group。

例如总共 128 卡:

方案 A:
  fsdp_size = 128
  dp_replicas = 1

方案 B:
  fsdp_size = 64
  dp_replicas = 2

方案 C:
  fsdp_size = 32
  dp_replicas = 4

较小的 fsdp_size 通常可以:

  • 减少 expert padding;
  • 降低 collective group 复杂度;
  • 改善 NCCL latency 和 LCM rounding;
  • 让 bucket 更容易规划。

代价是:

  • 每卡参数 shard、grad shard、optimizer state 变大;
  • DP replica 之间还需要同步梯度或 optimizer state;
  • 总显存不一定下降,需要实测。

对于 35B SFT,如果 activation 才是主要压力,HSDP 未必能直接省显存;但如果 All2All padding 和大 bucket 是主要问题,HSDP 很值得试。


一个推荐的分块执行伪代码

class MuonFSDPExecutor:
    def __init__(self, fsdp_group, workspace_cap_bytes, max_inflight=1):
        self.fsdp_group = fsdp_group
        self.workspace = preallocate_workspace(workspace_cap_bytes)
        self.max_inflight = max_inflight

    def step_bucket(self, bucket):
        chunks = plan_micro_buckets(
            bucket,
            cap_bytes=self.workspace.cap_bytes,
            cost_fn=estimate_muon_peak,
        )

        for chunk in chunks:
            # 1. 本地 momentum update
            local_m_shards = update_momentum_local(chunk)

            # 2. pack 到持久化 input buffer
            in_view = self.workspace.pack(local_m_shards)

            # 3. shard layout -> full matrix layout
            out_view = self.workspace.alloc_output(chunk)
            dist.all_to_all_single(
                output=out_view,
                input=in_view,
                group=self.fsdp_group,
            )

            # 4. 每个 rank 对自己负责的完整矩阵跑 Newton-Schulz
            full_mats = unpack_full_matrices(out_view, chunk)
            updates = []
            for mat in full_mats:
                updates.append(newton_schulz(mat))

            # 5. pack update,反向 All2All 回原 FSDP shard layout
            update_view = self.workspace.pack(updates)
            shard_update_view = self.workspace.alloc_shard_output(chunk)
            dist.all_to_all_single(
                output=shard_update_view,
                input=update_view,
                group=self.fsdp_group,
            )

            # 6. apply local shard update
            apply_update_local(chunk.params, shard_update_view)

            # 7. workspace 逻辑释放,下一 chunk 复用
            self.workspace.reset()

重点:

分块对象:多个完整矩阵组成的 micro-bucket
不要分块对象:单个矩阵的 row shard

veScale-FSDP 的核心思想

veScale-FSDP 论文认为,传统 FSDP 的 element-wise 或 row-wise fixed sharding 难以支持结构感知训练,例如 Muon、Shampoo、block-wise quantization。它提出三个关键组件:

RaggedShard
Structure-aware Planner
DBuffer

RaggedShard

RaggedShard 是一种 DTensor placement,用来表达不规则 sharding。

传统 Shard(0) 通常要求均匀切分:

rank0: same size
rank1: same size
rank2: same size
...

RaggedShard 允许:

rank0: 0 block
rank1: 0 block
rank2: full tensor
rank3: 0 block

也允许:

rank0: 1 unit
rank1: 2 units
rank2: 1 unit
rank3: 1 unit

这对于 Muon 很有用,因为可以把某个完整矩阵重分布到一个 root rank:

original FSDP placement:
  each rank owns a shard

RaggedShard(root):
  only root owns the full 2D matrix
  other ranks own empty tensor

veScale 文档中也明确提到,Muon 的 Newton-Schulz 需要完整 2D 参数矩阵,RaggedShard 可以通过 DTensor.redistribute 表达 gather -> compute -> scatter 这个过程。


Structure-aware Planner

如果只是把 RaggedShard tensor 简单拼起来,可能出现:

block 被切碎
tensor 内部插 padding
每个 rank buffer 不均衡
通信 buffer 非连续

veScale 的 planner 目标是:

1. 不切碎结构块
2. 保持 tensor contiguous
3. 平衡每个设备的通信负载
4. 尽量把 padding 放在 tensor 之间,而不是 tensor 内部

这点对 Muon 和 MoE 都很重要。

对于你看到的 expert tensor padding,veScale 的思路不是简单“补 dummy tensor 到 fsdp_size”,而是做全局 layout planning,尽量减少 padding 和 LCM rounding。

论文中还给出经验:不要使用过大的 FSDP group size,可以通过 HSDP 控制 shard group,并通过离线模拟选择 padding 最小的 FSDP size。


DBuffer

DBuffer 是 veScale-FSDP 的通信 buffer 抽象。

它的目标是:

1. 持久化分配通信 buffer
2. 多 tensor group-level 操作
3. zero-copy access
4. in-place communication/computation
5. 降低 PyTorch allocator fragmentation

这正好对应你看到的 10GB 临时 All2All buffer 问题。

如果没有 DBuffer,一个朴素实现通常会反复:

torch.empty(...)
torch.cat(...)
torch.stack(...)
contiguous()
all_to_all(...)
unpack(...)

这会在 memory snapshot 中出现大量临时大块。

借鉴 DBuffer 后,应把 Muon 的通信区改成:

初始化时规划地址
初始化时分配最大 workspace
每个 step 使用 view/narrow
不在热路径中频繁申请大 tensor

veScale 的 Distributed Muon 流程

veScale-FSDP 论文中的 Muon 逻辑可以概括为:

for each 2D parameter w:
    g = grad(w)
    u = MomentumUpdate(g, m)
    p = original placement(u)

    r = SelectRoot()                      # 负载均衡选择 root
    o = Redistribute(u, RaggedShard(r))   # root 持有完整矩阵

    o = NewtonSchulz(o)                   # 只有 root 真正计算

    o = Redistribute(o, p)                # 回到原 FSDP shard
    w = w - lr * o

这个设计和 XTuner same-shape All2All 的目标类似,都是 exact Muon。 但抽象层次不同:

方案 核心思路 优点 风险
XTuner-style same-shape All2All 同 shape tensor 批量重排 简单,高带宽,容易实现 padding、大 bucket、大 buffer
veScale RaggedShard 用 placement 表达不规则 gather/scatter 语义清晰,减少 padding,适合结构感知 optimizer 需要 DTensor/RaggedShard/Planner 支撑
DBuffer 持久化通信 buffer 和地址映射 降低 allocator 峰值和 copy 工程复杂度更高
Rooted gather 每个矩阵选 root rank 避免同 shape 数量不足 padding 需要负载均衡和异步 overlap

推荐架构

如果你要从当前 XTuner-style FSDP + Muon 演进,我建议分三层做。

第一层:保留 exact All2All,但加 micro-bucket

这是最容易落地的改造。

目标:
  把 10GB 临时 buffer 降到 1GB ~ 2GB 可控范围

做法:
  1. same-shape bucket 保留
  2. bucket 内按 workspace_cap 切 micro-bucket
  3. max_inflight 先设为 1
  4. 所有临时 buffer 预分配并复用

建议配置:

muon:
  exact: true
  comm_dtype: bf16
  workspace_cap_mb: 1024
  max_inflight_buckets: 1
  preallocate_workspace: true
  fallback_min_fill_ratio: 0.5

第二层:MoE expert 改成全局规划

针对 expert tensor:

1. 跨 layer 合并 same-shape experts
2. padding ratio > 50% 的 bucket 退回 AdamW 或走 ragged path
3. 如果 expert 是 [E, out, in] fused layout,并且 shard 在 E 维,则本地 per-expert Muon,不需要 All2All

判断逻辑:

if is_expert_batch_sharded(param):
    # local shard already owns complete expert matrices
    run_local_per_expert_muon(param)
elif bucket.padding_ratio <= 0.5:
    run_all2all_muon(bucket)
else:
    run_adamw(bucket)

第三层:引入 RaggedShard / Rooted Muon

当 same-shape + padding 已经成为主要瓶颈时,再引入 veScale 风格设计:

1. 每个 Muon 矩阵选择一个 root rank
2. 使用 ragged placement 表示 root 持有完整矩阵
3. redistribute 到 root
4. root 上 Newton-Schulz
5. redistribute 回原 placement
6. 使用 planner 控制 root 负载和 buffer cap

root 选择可以按 estimated cost 做负载均衡:

def select_root(matrix, rank_load):
    cost = matrix.numel() * matrix.dtype.itemsize
    root = min(rank_load, key=rank_load.get)
    rank_load[root] += cost
    return root

不要简单 round-robin,因为不同矩阵大小差异很大。


针对 35B / 128 卡 / SP4 / 256k 的建议

优先级如下:

  1. 先确认 All2All buffer dtype 如果是 FP32,优先改成 BF16。

  2. 把 Muon bucket cap 降到 512MB 或 1GB 先牺牲一点 optimizer step time,换取显存稳定。

  3. 关闭多 chunk overlap max_inflight_buckets=1,稳定后再开到 2。

  4. 预分配 Muon workspace 不要在每个 step、每个 bucket 动态 torch.empty 大 tensor。

  5. 跨 layer 合并 expert bucket 避免每层 expert 数量小于 FSDP size 导致巨量 padding。

  6. padding ratio 高的 expert bucket 退回 AdamW SFT 中 Muon 未必带来收益,不值得为这些参数付出 10GB buffer。

  7. 评估 HSDP 比如:

fsdp_size = 64, dp = 2
fsdp_size = 32, dp = 4

用离线脚本估算 padding 和 per-rank state,再实测。

  1. 检查 optimizer step 前 activation 是否真正释放 256k 下 activation 压力极大。可以诊断性地在 optimizer 前插入同步,确认是否是 stream lifetime 导致临时 buffer 共存:
del loss, outputs
torch.cuda.synchronize()
optimizer.step()

这不是最终性能方案,但可以帮助定位峰值来源。


需要记录的指标

建议每个 Muon bucket 打印:

bucket_name
logical_shape
num_real_tensors
num_padded_tensors
padding_ratio
full_bytes_per_rank
estimated_comm_buffer_bytes
estimated_ns_workspace_bytes
actual_allocated_before
actual_allocated_after
actual_peak_allocated
all2all_in_time_ms
newton_schulz_time_ms
all2all_out_time_ms

示例:

logger.info(
    "[muon_bucket] key=%s real=%d padded=%d pad=%.2f "
    "full_rank=%.2fGB comm=%.2fGB ns=%.2fGB "
    "t_a2a_in=%.2fms t_ns=%.2fms t_a2a_out=%.2fms",
    bucket.key,
    bucket.real_count,
    bucket.padded_count,
    bucket.padding_ratio,
    full_bytes_per_rank / 2**30,
    comm_bytes / 2**30,
    ns_bytes / 2**30,
    t_in,
    t_ns,
    t_out,
)

没有这些指标,很难判断 10GB 是来自:

bucket 太大
padding 太多
dtype 不对
NS workspace 太大
多 chunk overlap
allocator fragmentation
activation 未释放

结论

对于你的场景,最实际的路线是:

短期:
  XTuner-style same-shape All2All 保留
  加 micro-bucket + workspace cap + BF16 + max_inflight=1
  padding 高的 expert bucket 退回 AdamW

中期:
  跨 layer expert bucket
  expert-batch-sharded fast path
  HSDP 调整 fsdp_size

长期:
  学 veScale-FSDP
  引入 RaggedShard / rooted redistribution / planner / persistent DBuffer

一句话总结:

FSDP + Muon 的核心不是“能不能 All2All”,而是“能否在保持完整矩阵 Muon 语义的同时,把重分布、padding、workspace 和 allocator 生命周期都纳入统一规划”。XTuner 的实现解决了 correctness,veScale-FSDP 的设计进一步解决了 layout、padding 和 buffer 生命周期问题。


参考资料

  • veScale-FSDP paper: veScale-FSDP: Flexible and High-Performance FSDP at Scale 1
  • veScale RaggedShard 文档: RaggedShard Placement 2
  • veScale GitHub: volcengine/veScale 3
  • PyTorch / TorchTitan FSDP2 notes: torchtitan FSDP documentation 4
  • Microsoft Dion / Muon distributed implementation: microsoft/dion 5

评论