| import copy |
| import logging |
| import time |
| from contextlib import nullcontext |
|
|
| import pytest |
| import torch |
| import torch.distributed as dist |
| from optimizer.muon import Muon, get_default_muon_param_groups |
| from optimizer.newton_schulz import set_ns_compile |
| from torch.distributed.tensor import (DTensor, Replicate, Shard, |
| distribute_tensor) |
| from torch.profiler import ProfilerActivity, profile |
|
|
| from .utils import (ParallelDims, _apply_fsdp, assert_params_equal, |
| parallelize_motif, parallelize_qk_logits) |
|
|
| logger = logging.getLogger(__name__) |
| logging.basicConfig(level=logging.INFO) |
|
|
|
|
| def apply_muon_step( |
| model: torch.nn.Module, |
| parallel_dims: ParallelDims | None, |
| grads: list[torch.Tensor], |
| warmup_step: int, |
| chunk_size: int, |
| qk_logits: dict[int, torch.Tensor] | None = None, |
| use_distributed_muon: bool = False, |
| measure_perf: bool = False, |
| do_profile: bool = False, |
| test_name: str | None = None, |
| ) -> tuple[torch.nn.Module, tuple[float, float] | None]: |
| """ apply single Muon step with optional QK clipping """ |
|
|
| |
| assert len(grads) == len(list(model.parameters())) |
| for grad, param in zip(grads, model.parameters()): |
| grad = grad.to(param.device) |
| if isinstance(param.data, DTensor): |
| unsharded_grad = DTensor.from_local( |
| grad, |
| device_mesh=param.data.device_mesh, |
| placements=[Replicate()] * param.data.device_mesh.ndim, |
| ) |
| sharded_grad = unsharded_grad.redistribute( |
| device_mesh=param.data.device_mesh, |
| placements=param.data.placements) |
| param.grad = sharded_grad |
| else: |
| param.grad = grad |
|
|
| |
| params = get_default_muon_param_groups(model) |
| clip_config = dict({ |
| "q_indices": |
| list(range(model.config.num_attention_heads)), |
| "k_indices": |
| list(range(model.config.num_attention_heads)), |
| "head_dim": |
| model.config.hidden_size // model.config.num_attention_heads, |
| "threshold": |
| 0.5 |
| }) |
| optim = Muon( |
| params=params, |
| clip_config=clip_config if qk_logits is not None else None, |
| none_grad=False, |
| warmup_step=warmup_step, |
| chunk_size=chunk_size, |
| use_distributed_muon=use_distributed_muon, |
| ) |
|
|
| optim.step(qk_logits=qk_logits) |
|
|
| timing_result: tuple[float, float] | None = None |
|
|
| if measure_perf: |
| |
| optim.step(qk_logits=qk_logits) |
|
|
| start = torch.cuda.Event(enable_timing=True) |
| end = torch.cuda.Event(enable_timing=True) |
|
|
| torch.cuda.reset_peak_memory_stats() |
| start.record() |
| num_iters = 20 |
|
|
| if do_profile: |
| context = profile( |
| activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], |
| record_shapes=True) |
| else: |
| context = nullcontext() |
|
|
| with context as prof: |
| for _i in range(num_iters): |
| optim.step(qk_logits=qk_logits) |
|
|
| end.record() |
| end.synchronize() |
|
|
| if prof is not None: |
| date = time.strftime("%Y%m%d_%H%M%S", time.localtime()) |
| name = test_name or "trace" |
| rank = dist.get_rank() |
| prof.export_chrome_trace(f"{name}_{date}_rank{rank}.json") |
|
|
| peak_memory = torch.cuda.max_memory_allocated() |
|
|
| elapsed_time_ms = start.elapsed_time(end) / num_iters |
|
|
| timing_result = (elapsed_time_ms, peak_memory) |
|
|
| return model, timing_result |
|
|
|
|
| @pytest.fixture(scope="session") |
| def sequential_muon_result( |
| skip_verify, |
| inputs |
| ) -> dict[tuple[bool, bool], torch.nn.Module]: |
| """Run Muon optimizer to sequential model for baseline results. |
| |
| Returns dict keyed by ``(apply_qk_clip, use_compile)``. |
| """ |
| if skip_verify: |
| logger.info("Skipping verification tests as per user request") |
| return None |
|
|
| model, grads, qk_logits = inputs |
| results: dict[tuple[bool, bool], torch.nn.Module] = {} |
|
|
| for use_compile in [False, True]: |
| set_ns_compile(use_compile) |
|
|
| results[(False, use_compile)] = apply_muon_step( |
| model=copy.deepcopy(model).cuda(), |
| parallel_dims=None, |
| grads=grads, |
| warmup_step=-1, |
| chunk_size=-1, |
| qk_logits=None, |
| )[0].cpu() |
|
|
| results[(True, use_compile)] = apply_muon_step( |
| model=copy.deepcopy(model).cuda(), |
| parallel_dims=None, |
| grads=grads, |
| warmup_step=-1, |
| chunk_size=-1, |
| qk_logits=qk_logits, |
| )[0].cpu() |
|
|
| set_ns_compile(True) |
| return results |
|
|
|
|
| OVERLAP_STEPS = [5] |
| CHUNK_SIZES = [2] |
|
|
|
|
| @pytest.mark.parametrize("parallel_dims", [ |
| pytest.param(ParallelDims(8, 1, 1), id="base"), |
| pytest.param(ParallelDims(1, 8, 1), id="fsdp"), |
| pytest.param(ParallelDims(2, 4, 1), id="hsdp"), |
| pytest.param(ParallelDims(1, 1, 8), id="tp"), |
| pytest.param(ParallelDims(2, 2, 2), id="hsdp+tp"), |
| pytest.param(ParallelDims(1, 2, 4), id="fsdp+tp"), |
| ]) |
| @pytest.mark.parametrize("apply_qk_clip", [False, True]) |
| @pytest.mark.parametrize("use_distributed_muon", [False]) |
| @pytest.mark.parametrize("warmup_step", OVERLAP_STEPS) |
| @pytest.mark.parametrize("chunk_size", CHUNK_SIZES) |
| @pytest.mark.parametrize("use_compile", [False, True]) |
| def test_parallel_muon( |
| request, |
| sequential_muon_result: dict[tuple[bool, bool], torch.nn.Module], |
| parallel_dims: ParallelDims, |
| apply_qk_clip: bool, |
| use_distributed_muon: bool, |
| warmup_step: int, |
| chunk_size: int, |
| use_compile: bool, |
| inputs: tuple[torch.nn.Module, list[torch.Tensor], |
| dict[int, torch.Tensor]], |
| measure_perf, |
| do_profile, |
| ) -> None: |
| if use_distributed_muon and chunk_size != CHUNK_SIZES[0]: |
| pytest.skip("Distributed Muon does not effected by chunk size") |
| if use_distributed_muon and warmup_step != OVERLAP_STEPS[0]: |
| pytest.skip("Distributed Muon does not effected by warmup step") |
|
|
| set_ns_compile(use_compile) |
|
|
| model, grads, qk_logits = inputs |
|
|
| if not apply_qk_clip: |
| qk_logits = None |
|
|
| |
| model = copy.deepcopy(model).cuda() |
|
|
| parallelized_model = parallelize_motif(model, parallel_dims) |
|
|
| if qk_logits is not None: |
| |
| qk_logits = copy.deepcopy(qk_logits) |
| qk_logits = parallelize_qk_logits(qk_logits, parallel_dims) |
|
|
| parallelized_model, timing_result = apply_muon_step( |
| model=parallelized_model, |
| parallel_dims=parallel_dims, |
| grads=grads, |
| warmup_step=warmup_step, |
| chunk_size=chunk_size, |
| qk_logits=qk_logits, |
| use_distributed_muon=use_distributed_muon, |
| measure_perf=measure_perf, |
| do_profile=do_profile, |
| test_name=request.node.name, |
| ) |
|
|
| if measure_perf: |
| assert timing_result is not None |
| avg_time_ms, peak_memory = timing_result |
| logger.info( |
| f"\nParallel dims: {parallel_dims}, " |
| f"\nUse distributed Muon: {use_distributed_muon}, " |
| f"\nApply QK clip: {apply_qk_clip} => " |
| f"\nChunk Size, Warmup Step, Avg Time (ms), Peak Memory (MB):" |
| f"\n{chunk_size}, {warmup_step}, {avg_time_ms:.2f}, {peak_memory / (1024**2):.2f}," |
| ) |
|
|
| if sequential_muon_result is None: |
| logger.info("Skipping correctness check as sequential result is None") |
| elif measure_perf: |
| logger.info("Skipping correctness check as timing is enabled") |
| else: |
| atol = 1e-5 if use_compile else 0 |
| rtol = 1e-2 if use_compile else 0 |
| assert_params_equal(parallelized_model, |
| sequential_muon_result[(apply_qk_clip, |
| use_compile)], |
| atol=atol, |
| rtol=rtol) |
|
|
|
|
| def test_parallel_muon_empty_shard(init_dist): |
| """Regression: parallel Muon must handle chunks where some ranks have |
| empty local shards (dim-0 < world_size). |
| |
| With 8-way Shard(0) and dim-0 of size 4, ranks 4-7 get 0-element local |
| shards. Previously ``_launch_gather`` hit ``assert total_send > 0``. |
| """ |
| rank = dist.get_rank() |
| world_size = dist.get_world_size() |
| mesh = dist.init_device_mesh("cuda", (world_size, ), |
| mesh_dim_names=("dp", )) |
|
|
| set_ns_compile(False) |
|
|
| |
| small_dim = 4 |
| num_params = 4 |
| torch.manual_seed(42) |
|
|
| muon_params = [] |
| muon_names = [] |
| for i in range(num_params): |
| full = torch.randn(small_dim, 64, device="cuda") |
| dt = distribute_tensor(full, mesh, [Shard(0)]) |
| p = torch.nn.Parameter(dt) |
| grad_full = torch.randn(small_dim, 64, device="cuda") |
| p.grad = distribute_tensor(grad_full, mesh, [Shard(0)]) |
| muon_params.append(p) |
| muon_names.append(f"layer.{i}.weight") |
|
|
| param_groups = [{ |
| "params": muon_params, |
| "names": muon_names, |
| "use_muon": True, |
| "lr": 0.02, |
| "weight_decay": 0.01, |
| "momentum": 0.95, |
| "nesterov": True, |
| "ns_steps": 5, |
| "none_grad": False, |
| }] |
|
|
| optim = Muon(params=param_groups, chunk_size=1, warmup_step=0) |
| |
| optim.step() |
|
|
| |
| for p in muon_params: |
| grad_full = torch.randn(small_dim, 64, device="cuda") |
| p.grad = distribute_tensor(grad_full, mesh, [Shard(0)]) |
| optim.step() |
|
|
| set_ns_compile(True) |
| logger.info("test_parallel_muon_empty_shard PASSED (rank %d)", rank) |
|
|
|
|
| @pytest.mark.parametrize("uneven_dim", [ |
| pytest.param(33, id="33"), |
| pytest.param(19, id="19"), |
| pytest.param(11, id="11"), |
| ]) |
| def test_parallel_muon_uneven_shard(init_dist, uneven_dim): |
| """Test that parallel Muon produces correct results when parameter |
| dimensions are not evenly divisible by the number of shard ranks. |
| |
| For example, dim=33 with 8 ranks gives 7 ranks with 4 rows and |
| 1 rank with 5 rows. This exercises the remainder-handling logic |
| in ``get_slices_of_dtensor`` and the all-to-all pipeline. |
| """ |
| rank = dist.get_rank() |
| world_size = dist.get_world_size() |
| mesh = dist.init_device_mesh("cuda", (world_size, ), |
| mesh_dim_names=("dp", )) |
|
|
| set_ns_compile(False) |
| torch.manual_seed(42) |
|
|
| other_dim = 64 |
| num_params = 3 |
|
|
| |
| muon_params = [] |
| muon_names = [] |
| full_params_snapshot = [] |
| full_grads = [] |
|
|
| for i in range(num_params): |
| full = torch.randn(uneven_dim, other_dim, device="cuda") |
| full_params_snapshot.append(full.clone()) |
| dt = distribute_tensor(full, mesh, [Shard(0)]) |
| p = torch.nn.Parameter(dt) |
| grad_full = torch.randn(uneven_dim, other_dim, device="cuda") |
| full_grads.append(grad_full.clone()) |
| p.grad = distribute_tensor(grad_full, mesh, [Shard(0)]) |
| muon_params.append(p) |
| muon_names.append(f"layer.{i}.weight") |
|
|
| |
| param_groups_par = [{ |
| "params": muon_params, |
| "names": muon_names, |
| "use_muon": True, |
| "lr": 0.02, |
| "weight_decay": 0.01, |
| "momentum": 0.95, |
| "nesterov": True, |
| "ns_steps": 5, |
| "none_grad": False, |
| }] |
| optim_par = Muon(params=param_groups_par, chunk_size=1, warmup_step=0) |
| optim_par.step() |
|
|
| |
| seq_params = [] |
| seq_names = [] |
| for i in range(num_params): |
| p = torch.nn.Parameter(full_params_snapshot[i].clone()) |
| p.grad = full_grads[i].clone() |
| seq_params.append(p) |
| seq_names.append(f"layer.{i}.weight") |
|
|
| param_groups_seq = [{ |
| "params": seq_params, |
| "names": seq_names, |
| "use_muon": True, |
| "lr": 0.02, |
| "weight_decay": 0.01, |
| "momentum": 0.95, |
| "nesterov": True, |
| "ns_steps": 5, |
| "none_grad": False, |
| }] |
| optim_seq = Muon(params=param_groups_seq) |
| optim_seq.step() |
|
|
| |
| for i in range(num_params): |
| par_full = muon_params[i].data.full_tensor() |
| seq_full = seq_params[i].data |
| torch.testing.assert_close(par_full, seq_full, atol=0, rtol=0) |
|
|
| set_ns_compile(True) |
| logger.info("test_parallel_muon_uneven_shard (dim=%d) PASSED (rank %d)", |
| uneven_dim, rank) |
|
|
|
|
| def test_pp_dp_replicate_no_deadlock(init_dist, inputs): |
| """PP regression test using real Motif model. |
| |
| PP=2, dp_replicate=2, dp_shard=2 on 8 GPUs. Splits the |
| Motif-2.6B-4layer model across 2 pipeline stages following the |
| torchtitan pattern (deep copy → delete non-stage layers → per-stage |
| FSDP). Each stage independently runs Muon optimizer and the result |
| is verified against a sequential baseline (atol=0, rtol=0). |
| |
| Without use_local_synchronization=True in construct_shard_mesh(), |
| different stages would deadlock on dist.new_group() because they |
| call it for different parameters. |
| """ |
| import re |
|
|
| import torch.nn as nn |
| from optimizer.distributed.utils import _ranks_to_dist_cache |
|
|
| rank = dist.get_rank() |
| assert dist.get_world_size() == 8 |
|
|
| set_ns_compile(False) |
| _ranks_to_dist_cache.clear() |
|
|
| model_orig, grads_orig, _ = inputs |
|
|
| |
| grad_dict = { |
| name: grad |
| for (name, _), grad in zip(model_orig.named_parameters(), grads_orig) |
| } |
|
|
| |
| full_mesh = dist.init_device_mesh( |
| "cuda", |
| (2, 2, 2), |
| mesh_dim_names=("pp", "dp_replicate", "dp_shard"), |
| ) |
| dp_mesh = full_mesh["dp_replicate", "dp_shard"] |
| pp_rank = full_mesh.get_local_rank("pp") |
|
|
| |
| def _split_motif(model): |
| """Split Motif model per PP stage (torchtitan pattern). |
| |
| Stage 0: embed_tokens + layers[0:2] |
| Stage 1: layers[2:4] + norm + output |
| Non-stage components replaced with nn.Identity (no params). |
| """ |
| all_layers = list(model.model.layers) |
| if pp_rank == 0: |
| model.model.layers = nn.ModuleList(all_layers[:2]) |
| model.model.norm = nn.Identity() |
| if hasattr(model, "output"): |
| model.output = nn.Identity() |
| if hasattr(model, "lm_head"): |
| model.lm_head = nn.Identity() |
| else: |
| model.model.layers = nn.ModuleList(all_layers[2:]) |
| model.model.embed_tokens = nn.Identity() |
| return model |
|
|
| layer_offset = 0 if pp_rank == 0 else 2 |
|
|
| def _remap(name): |
| """Map stage param name → original param name (layer index offset). |
| |
| Also handles weight tying: Motif ties lm_head.weight to |
| model.embed_tokens.weight, so named_parameters() only lists the |
| latter. After stage-split, stage 1 loses embed_tokens but keeps |
| lm_head, so we remap it back. |
| """ |
| |
| if name == "lm_head.weight": |
| return "model.embed_tokens.weight" |
|
|
| if layer_offset == 0: |
| return name |
|
|
| def _replace(m): |
| return f"layers.{int(m.group(1)) + layer_offset}." |
|
|
| return re.sub(r"layers\.(\d+)\.", _replace, name) |
|
|
| def _stage_grads(model): |
| """Build grads list aligned with stage model parameters.""" |
| return [grad_dict[_remap(n)] for n, _ in model.named_parameters()] |
|
|
| |
| par_model = _split_motif(copy.deepcopy(model_orig).cuda()) |
| _apply_fsdp(par_model, dp_mesh) |
| par_model, _ = apply_muon_step( |
| model=par_model, |
| parallel_dims=None, |
| grads=_stage_grads(par_model), |
| warmup_step=5, |
| chunk_size=2, |
| qk_logits=None, |
| ) |
|
|
| |
| seq_model = _split_motif(copy.deepcopy(model_orig).cuda()) |
| seq_model, _ = apply_muon_step( |
| model=seq_model, |
| parallel_dims=None, |
| grads=_stage_grads(seq_model), |
| warmup_step=-1, |
| chunk_size=-1, |
| qk_logits=None, |
| ) |
|
|
| |
| assert_params_equal(par_model, seq_model, atol=0, rtol=0) |
|
|
| set_ns_compile(True) |
| logger.info( |
| "test_pp_dp_replicate_no_deadlock PASSED (rank %d, pp_rank %d)", rank, |
| pp_rank) |
|
|