| # Expert Parallelism in torchtitan |
|
|
| torchtitan (0.2.0)의 expert parallelism 구현을 정리한 문서. |
| Muon optimizer의 MoE 지원에 필요한 배경 지식. |
|
|
| Reference: `torchtitan/distributed/expert_parallel.py`, `torchtitan/distributed/parallel_dims.py` |
|
|
| ## Overview |
|
|
| torchtitan은 MoE expert weights에 대해 4가지 parallelism 전략을 제공: |
|
|
| | Config | TP | EP | ETP | Expert Weight Placements | Token Dispatch | |
| |--------|----|----|-----|--------------------------|----------------| |
| | TP Only | >1 | 1 | - | `[Shard(1/2)]` on TP mesh | None | |
| | EP Only | 1 | >1 | - | `[Shard(0)]` on EP mesh | All-to-all | |
| | EP+ETP (etp=tp) | >1 | >1 | =tp | `[Shard(0), Shard(1/2)]` on [EP, TP] mesh | All-to-all on EP | |
| | EP+ETP (etp=1) | >1 | >1 | 1 | `[Shard(0)]` on EP mesh | Sequence parallel on TP | |
|
|
| Expert weights shape: `(num_experts, out_dim, in_dim)` (w1, w3) / `(num_experts, in_dim, out_dim)` (w2). |
|
|
| ## EP가 dp_shard를 빌리는 구조 |
| |
| EP는 새로운 물리적 차원이 아니라 `dp_shard`를 분해해서 사용: |
|
|
| ``` |
| dp_shard = dp_shard_mod_ep * dp_shard_in_ep |
| |
| ETP=TP일 때: ep = dp_shard_in_ep * cp |
| ETP=1일 때: ep = dp_shard_in_ep * cp * tp |
| ``` |
|
|
| 기존 mesh `[pp, dp_replicate, dp_shard, cp, tp]`가 EP 활성화 시: |
|
|
| ``` |
| [pp, dp_replicate, dp_shard_mod_ep, dp_shard_in_ep, cp, tp] |
| ``` |
|
|
| 로 확장됨. `dp_shard_mod_ep`는 값이 1이어도 mesh에 유지 (FSDP wrapping 일관성). |
|
|
| ### 예시: 8 GPUs, ep=4, dp_shard=8, tp=1, cp=1 |
| |
| ``` |
| dp_shard_in_ep = ep / cp = 4 |
| dp_shard_mod_ep = dp_shard * cp / ep = 2 |
|
|
| mesh: [dp_shard_mod_ep=2, dp_shard_in_ep=4] |
| EP mesh: [dp_shard_in_ep=4] → expert들을 4-way로 분배 |
| FSDP mesh: [dp_shard_mod_ep=2] → expert FSDP는 2-way로 shard |
| ``` |
| |
| ## Submesh 매핑 |
| |
| ```python |
| # Data loading (no communication) |
| dp = [dp_replicate, dp_shard_mod_ep, dp_shard_in_ep] |
| |
| # Non-expert parameter sharding (FSDP) |
| dp_shard_cp = [dp_shard_mod_ep, dp_shard_in_ep, cp] |
| |
| # Expert parameter sharding (EFSDP) — dp_shard_in_ep 제외 |
| dp_mod_ep = [dp_replicate?, dp_shard_mod_ep] |
|
|
| # Expert parallelism mesh |
| ep = [dp_shard_in_ep, cp, (tp if etp==1)] |
| |
| # Loss all-reduce |
| dp_cp = [dp_replicate, dp_shard_mod_ep, dp_shard_in_ep, cp] |
| ``` |
| |
| ## 4가지 전략 상세 |
| |
| ### 1. TensorParallel (TP Only, EP=1) |
| |
| EP 없이 TP만 사용. Expert weights를 TP mesh에서 column/row-wise sharding: |
| |
| ```python |
| # expert_parallel.py: TensorParallel |
| w1: [Shard(1)] on TP mesh # column-wise (out_dim) |
| w2: [Shard(2)] on TP mesh # row-wise (out_dim, 3D에서 dim 2) |
| w3: [Shard(1)] on TP mesh # column-wise (out_dim) |
| ``` |
| |
| Token dispatch 없음. 일반 TP와 동일하게 동작. |
| |
| ### 2. ExpertParallel (EP Only, TP=1) |
| |
| Expert dim (dim 0)으로 sharding. Token all-to-all dispatch: |
| |
| ```python |
| # expert_parallel.py: ExpertParallel |
| w1, w2, w3: [Shard(0)] on EP mesh # expert dim으로 분배 |
| ``` |
| |
| Forward pass: |
| 1. Router가 각 token을 expert에 할당 |
| 2. `all_to_all_single`으로 token을 해당 expert의 rank로 dispatch |
| 3. 각 rank가 local expert에서 compute |
| 4. `all_to_all_single`으로 결과를 원래 rank로 combine |
| |
| ### 3. ExpertTensorParallel (EP+TP, ETP=TP) |
| |
| EP와 TP를 동시에 2D로 적용: |
| |
| ```python |
| # expert_parallel.py: ExpertTensorParallel (extends ExpertParallel) |
| w1: [Shard(0), Shard(1)] on [EP, TP] mesh # expert + column |
| w2: [Shard(0), Shard(2)] on [EP, TP] mesh # expert + row |
| w3: [Shard(0), Shard(1)] on [EP, TP] mesh # expert + column |
| ``` |
| |
| Token dispatch: |
| 1. TP mesh에서 input을 Replicate (gradient는 Partial) |
| 2. EP mesh에서 all-to-all dispatch (ExpertParallel과 동일) |
| 3. All-to-all은 EP mesh에서만 발생, TP 통신은 weight sharding으로 처리 |
| |
| ### 4. ReordererSequenceParallel (EP+TP, ETP=1) |
| |
| TP hardware를 EP에 빌려줌. TP mesh가 sequence parallel로 동작: |
| |
| ```python |
| # expert_parallel.py: ReordererSequenceParallel |
| # Expert weights: [Shard(0)] on EP mesh (TP 안 씀) |
| # Token split: batch*seq_len을 TP rank 수로 나눠서 분배 |
| |
| # EP mesh = [dp_shard_in_ep, cp, tp] ← tp가 EP에 포함됨 |
| ``` |
| |
| TP rank들이 token을 나눠 처리 (sequence parallel). Expert weight에는 TP sharding 없음. |
| |
| ## EFSDP (Expert FSDP) |
| |
| Expert parameter에 대한 FSDP는 non-expert parameter와 **다른 mesh**를 사용: |
| |
| ```python |
| # parallelize.py: apply_fsdp |
| # Non-expert: dp_shard_cp mesh 전체로 shard |
| fully_shard(transformer_block, mesh=dp_shard_cp_mesh) |
| |
| # Expert (EP 활성화 시): dp_mod_ep mesh로만 shard |
| # dp_shard_in_ep는 이미 EP에서 사용 중이므로 제외 |
| fully_shard(transformer_block.moe.experts, mesh=dp_mod_ep_mesh) |
| ``` |
| |
| ### Dynamic shard placement |
| |
| Expert 수보다 `dp_mod_ep * ep`가 클 때 (expert dim으로 더 쪼갤 수 없을 때), |
| dim 0 대신 dim 1로 shard. |
|
|
| **torchtitan 코드** (`torchtitan/models/llama4/infra/parallelize.py:339-359`): |
|
|
| ```python |
| # NOTE: EP alreadys shards the routed experts on dim 0 (num_experts). |
| # When dp_mod_ep * ep > num_experts, FSDP default dim-0 sharding |
| # causes inefficiency, so we choose to do FSDP sharding on dim-1. |
| _experts_shard_placement_fn = None |
| if ( |
| dp_mod_ep_mesh.size() * ep_degree |
| > transformer_block.moe.experts.num_experts |
| ): |
| _experts_shard_placement_fn = lambda param: Shard(1) |
| |
| fully_shard( |
| transformer_block.moe.experts, |
| **fsdp_mod_ep_config, # mesh=dp_mod_ep_mesh |
| reshard_after_forward=reshard_after_forward, |
| shard_placement_fn=_experts_shard_placement_fn, |
| ) |
| ``` |
|
|
| `dp_mod_ep_mesh` 구성 (`parallelize.py:140-159`): |
|
|
| ```python |
| dp_mod_ep_mesh_dim_names = [] |
| if parallel_dims.ep_enabled: |
| if parallel_dims.dp_replicate_enabled: |
| dp_mod_ep_mesh_dim_names.append("dp_replicate") |
| dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") |
| # → dp_mod_ep_mesh = world_mesh[tuple(dp_mod_ep_mesh_dim_names)] |
| ``` |
|
|
| ### 실제 placement 검증 결과 |
|
|
| 8 GPUs, `num_experts=2`, `etp=1` 기준: |
|
|
| #### num_experts=8 (기본) |
| |
| 모든 config에서 expert weights는 **dim 0 (expert dim)으로만 shard**: |
| |
| | Config | Expert Placements | Mesh | |
| |--------|-------------------|------| |
| | ep=8 | `[Shard(0)]` | `[ep=8]` | |
| | ep=4, fsdp=2 | `[_StridedShard(0), Shard(0)]` | `[dp_shard_mod_ep=2, ep=4]` | |
| | ep=2, fsdp=4 | `[_StridedShard(0), Shard(0)]` | `[dp_shard_mod_ep=4, ep=2]` | |
| | ep=2, hsdp=2+2 | `[Replicate(), _StridedShard(0), Shard(0)]` | `[dp_rep=2, dp_shard_mod_ep=2, ep=2]` | |
|
|
| EFSDP는 `_StridedShard(dim=0)`, EP는 `Shard(dim=0)`. 비-dim-0 shard 없음. |
|
|
| #### num_experts=2 (expert 수 < EFSDP shard count) |
| |
| `dp_mod_ep * ep > num_experts` 조건 충족 시 **EFSDP가 Shard(1)로 전환**: |
|
|
| | Config | 조건 | Expert Placements | Mesh | |
| |--------|------|-------------------|------| |
| | ep=2, fsdp=4 | 4*2=8 > 2 | `[Shard(1), Shard(0)]` | `[dp_shard_mod_ep=4, ep=2]` | |
| | ep=2, hsdp=2+2 | 2*2=4 > 2 | `[Replicate(), Shard(1), Shard(0)]` | `[dp_rep=2, dp_shard_mod_ep=2, ep=2]` | |
|
|
| - EFSDP: `Shard(1)` on `dp_shard_mod_ep` → out_dim을 shard (w1: 2816/4=704) |
| - EP: `Shard(0)` on `ep` → expert dim을 shard (2/2=1) |
| - `_StridedShard`가 아닌 일반 `Shard` 사용 |
|
|
| ## Gradient Clipping with EP |
|
|
| EP parameter와 non-EP parameter의 gradient norm을 별도로 계산 후 합산: |
|
|
| ```python |
| # distributed/utils.py: _clip_grad_norm_with_ep |
| ep_norm = get_total_norm(ep_grads, ...) |
| non_ep_norm = get_total_norm(non_ep_grads, ...) |
| total_norm = (ep_norm**p + non_ep_norm**p) ** (1/p) |
| ``` |
|
|
| EP parameter 판별: `device_mesh.mesh_dim_names`에 "ep" 포함 여부. |
|
|
| ## Muon optimizer에서의 처리 |
|
|
| 현재 Muon optimizer의 MoE 지원: |
|
|
| 1. **`_expand_expert_params`**: 3D expert weight를 expert dim (dim 0)으로 split하여 2D param으로 확장 |
| 2. **TP가 있을 때**: non-dim-0 shard (TP)를 TP submesh에 DTensor로 wrap |
| - 3D `(Shard(0), Shard(1))` → 2D `(Shard(0),)` on TP submesh |
| 3. **`construct_shard_mesh` fast path**: 1D submesh에서 `dist.new_group()` deadlock 방지 |
| |
| ### Muon이 지원하는 config |
| |
| | Config | 지원 | 비고 | |
| |--------|------|------| |
| | TP Only (EP=1) | O | expert를 TP submesh DTensor로 처리 | |
| | EP Only (TP=1) | O | expert를 plain tensor로 처리 (base mode) | |
| | FSDP + TP | O | FSDP는 expert dim, TP는 out/in dim | |
| | HSDP + TP | O | Replicate + FSDP + TP | |
| | EP Only (많은 experts) | O | EFSDP `Shard(0)` → plain tensor | |
| | EP + FSDP (적은 experts) | 미테스트 | EFSDP `Shard(1)` → 아래 참조 | |
| | EP + TP (ETP=TP) | 미테스트 | 2D expert DTensor `[Shard(0), Shard(1/2)]` | |
| | EP + TP (ETP=1) | 미테스트 | EP mesh에 TP가 포함된 경우 | |
| |
| ### EFSDP Shard(1)과 Muon의 호환성 |
| |
| Muon은 placement-agnostic. `_expand_expert_params`의 non-dim-0 shard 처리 로직이 |
| TP뿐 아니라 EFSDP `Shard(1)`에도 동일하게 적용됨 (변수명만 `tp_*`일 뿐 로직은 generic): |
| |
| ``` |
| 3D: (Shard(1), Shard(0)) on [dp_shard_mod_ep=4, ep=2] |
| local shape: (1, 704, 2048) |
| |
| _expand_expert_params: |
| 1. non-dim-0 shard 탐색 → Shard(1) on dp_shard_mod_ep |
| 2. submesh 추출 → dp_shard_mod_ep (1D, size 4) |
| 3. dim 0 split → (704, 2048) |
| 4. DTensor wrap → Shard(0) on dp_shard_mod_ep |
| = 일반 FSDP sharded 2D 텐서와 동일 |
| |
| → parallel()/distributed_muon()이 all-gather → Newton-Schulz → scatter 처리. |
| construct_shard_mesh fast path 적용 (1D submesh, deadlock 없음). |
| ``` |
| |