import copy import logging import time from contextlib import nullcontext import pytest import torch import torch.distributed as dist from optimizer.muon import Muon, get_default_muon_param_groups from optimizer.newton_schulz import set_ns_compile from torch.distributed.tensor import (DTensor, Replicate, Shard, distribute_tensor) from torch.profiler import ProfilerActivity, profile from .utils import (ParallelDims, _apply_fsdp, assert_params_equal, parallelize_motif, parallelize_qk_logits) logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) def apply_muon_step( model: torch.nn.Module, parallel_dims: ParallelDims | None, grads: list[torch.Tensor], warmup_step: int, chunk_size: int, qk_logits: dict[int, torch.Tensor] | None = None, use_distributed_muon: bool = False, measure_perf: bool = False, do_profile: bool = False, test_name: str | None = None, ) -> tuple[torch.nn.Module, tuple[float, float] | None]: """ apply single Muon step with optional QK clipping """ # 1. Apply gradients to model parameters assert len(grads) == len(list(model.parameters())) for grad, param in zip(grads, model.parameters()): grad = grad.to(param.device) if isinstance(param.data, DTensor): unsharded_grad = DTensor.from_local( grad, device_mesh=param.data.device_mesh, placements=[Replicate()] * param.data.device_mesh.ndim, ) sharded_grad = unsharded_grad.redistribute( device_mesh=param.data.device_mesh, placements=param.data.placements) param.grad = sharded_grad else: param.grad = grad # 2. Setup Muon optimizer params = get_default_muon_param_groups(model) clip_config = dict({ "q_indices": list(range(model.config.num_attention_heads)), "k_indices": list(range(model.config.num_attention_heads)), "head_dim": model.config.hidden_size // model.config.num_attention_heads, "threshold": 0.5 }) optim = Muon( params=params, clip_config=clip_config if qk_logits is not None else None, none_grad=False, warmup_step=warmup_step, chunk_size=chunk_size, use_distributed_muon=use_distributed_muon, ) optim.step(qk_logits=qk_logits) timing_result: tuple[float, float] | None = None if measure_perf: # extra warm up optim.step(qk_logits=qk_logits) start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) torch.cuda.reset_peak_memory_stats() start.record() num_iters = 20 if do_profile: context = profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) else: context = nullcontext() with context as prof: for _i in range(num_iters): optim.step(qk_logits=qk_logits) end.record() end.synchronize() if prof is not None: date = time.strftime("%Y%m%d_%H%M%S", time.localtime()) name = test_name or "trace" rank = dist.get_rank() prof.export_chrome_trace(f"{name}_{date}_rank{rank}.json") peak_memory = torch.cuda.max_memory_allocated() elapsed_time_ms = start.elapsed_time(end) / num_iters timing_result = (elapsed_time_ms, peak_memory) return model, timing_result @pytest.fixture(scope="session") def sequential_muon_result( skip_verify, # from conftest.py inputs # from conftest.py ) -> dict[tuple[bool, bool], torch.nn.Module]: """Run Muon optimizer to sequential model for baseline results. Returns dict keyed by ``(apply_qk_clip, use_compile)``. """ if skip_verify: logger.info("Skipping verification tests as per user request") return None model, grads, qk_logits = inputs results: dict[tuple[bool, bool], torch.nn.Module] = {} for use_compile in [False, True]: set_ns_compile(use_compile) results[(False, use_compile)] = apply_muon_step( model=copy.deepcopy(model).cuda(), parallel_dims=None, grads=grads, warmup_step=-1, chunk_size=-1, qk_logits=None, )[0].cpu() results[(True, use_compile)] = apply_muon_step( model=copy.deepcopy(model).cuda(), parallel_dims=None, grads=grads, warmup_step=-1, chunk_size=-1, qk_logits=qk_logits, )[0].cpu() set_ns_compile(True) # restore default return results OVERLAP_STEPS = [5] CHUNK_SIZES = [2] @pytest.mark.parametrize("parallel_dims", [ pytest.param(ParallelDims(8, 1, 1), id="base"), pytest.param(ParallelDims(1, 8, 1), id="fsdp"), pytest.param(ParallelDims(2, 4, 1), id="hsdp"), pytest.param(ParallelDims(1, 1, 8), id="tp"), pytest.param(ParallelDims(2, 2, 2), id="hsdp+tp"), pytest.param(ParallelDims(1, 2, 4), id="fsdp+tp"), ]) @pytest.mark.parametrize("apply_qk_clip", [False, True]) @pytest.mark.parametrize("use_distributed_muon", [False]) @pytest.mark.parametrize("warmup_step", OVERLAP_STEPS) @pytest.mark.parametrize("chunk_size", CHUNK_SIZES) @pytest.mark.parametrize("use_compile", [False, True]) def test_parallel_muon( request, sequential_muon_result: dict[tuple[bool, bool], torch.nn.Module], parallel_dims: ParallelDims, apply_qk_clip: bool, use_distributed_muon: bool, warmup_step: int, chunk_size: int, use_compile: bool, inputs: tuple[torch.nn.Module, list[torch.Tensor], dict[int, torch.Tensor]], # from conftest.py measure_perf, # from conftest.py do_profile, # from conftest.py ) -> None: if use_distributed_muon and chunk_size != CHUNK_SIZES[0]: pytest.skip("Distributed Muon does not effected by chunk size") if use_distributed_muon and warmup_step != OVERLAP_STEPS[0]: pytest.skip("Distributed Muon does not effected by warmup step") set_ns_compile(use_compile) model, grads, qk_logits = inputs if not apply_qk_clip: qk_logits = None # Deepcopy the model to avoid in-place modification model = copy.deepcopy(model).cuda() parallelized_model = parallelize_motif(model, parallel_dims) if qk_logits is not None: # Deepcopy the qk logits to avoid in-place modification qk_logits = copy.deepcopy(qk_logits) qk_logits = parallelize_qk_logits(qk_logits, parallel_dims) parallelized_model, timing_result = apply_muon_step( model=parallelized_model, parallel_dims=parallel_dims, grads=grads, warmup_step=warmup_step, chunk_size=chunk_size, qk_logits=qk_logits, use_distributed_muon=use_distributed_muon, measure_perf=measure_perf, do_profile=do_profile, test_name=request.node.name, ) if measure_perf: assert timing_result is not None avg_time_ms, peak_memory = timing_result logger.info( f"\nParallel dims: {parallel_dims}, " f"\nUse distributed Muon: {use_distributed_muon}, " f"\nApply QK clip: {apply_qk_clip} => " f"\nChunk Size, Warmup Step, Avg Time (ms), Peak Memory (MB):" f"\n{chunk_size}, {warmup_step}, {avg_time_ms:.2f}, {peak_memory / (1024**2):.2f}," ) if sequential_muon_result is None: logger.info("Skipping correctness check as sequential result is None") elif measure_perf: logger.info("Skipping correctness check as timing is enabled") else: atol = 1e-5 if use_compile else 0 rtol = 1e-2 if use_compile else 0 assert_params_equal(parallelized_model, sequential_muon_result[(apply_qk_clip, use_compile)], atol=atol, rtol=rtol) def test_parallel_muon_empty_shard(init_dist): """Regression: parallel Muon must handle chunks where some ranks have empty local shards (dim-0 < world_size). With 8-way Shard(0) and dim-0 of size 4, ranks 4-7 get 0-element local shards. Previously ``_launch_gather`` hit ``assert total_send > 0``. """ rank = dist.get_rank() world_size = dist.get_world_size() mesh = dist.init_device_mesh("cuda", (world_size, ), mesh_dim_names=("dp", )) set_ns_compile(False) # dim-0 = 4 < 8 ranks → ranks 4-7 have empty local shards with Shard(0) small_dim = 4 num_params = 4 torch.manual_seed(42) muon_params = [] muon_names = [] for i in range(num_params): full = torch.randn(small_dim, 64, device="cuda") dt = distribute_tensor(full, mesh, [Shard(0)]) p = torch.nn.Parameter(dt) grad_full = torch.randn(small_dim, 64, device="cuda") p.grad = distribute_tensor(grad_full, mesh, [Shard(0)]) muon_params.append(p) muon_names.append(f"layer.{i}.weight") param_groups = [{ "params": muon_params, "names": muon_names, "use_muon": True, "lr": 0.02, "weight_decay": 0.01, "momentum": 0.95, "nesterov": True, "ns_steps": 5, "none_grad": False, }] optim = Muon(params=param_groups, chunk_size=1, warmup_step=0) # Must not raise AssertionError: total_send > 0 optim.step() # Run a second step to verify cached path also works for p in muon_params: grad_full = torch.randn(small_dim, 64, device="cuda") p.grad = distribute_tensor(grad_full, mesh, [Shard(0)]) optim.step() set_ns_compile(True) logger.info("test_parallel_muon_empty_shard PASSED (rank %d)", rank) @pytest.mark.parametrize("uneven_dim", [ pytest.param(33, id="33"), pytest.param(19, id="19"), pytest.param(11, id="11"), ]) def test_parallel_muon_uneven_shard(init_dist, uneven_dim): """Test that parallel Muon produces correct results when parameter dimensions are not evenly divisible by the number of shard ranks. For example, dim=33 with 8 ranks gives 7 ranks with 4 rows and 1 rank with 5 rows. This exercises the remainder-handling logic in ``get_slices_of_dtensor`` and the all-to-all pipeline. """ rank = dist.get_rank() world_size = dist.get_world_size() mesh = dist.init_device_mesh("cuda", (world_size, ), mesh_dim_names=("dp", )) set_ns_compile(False) torch.manual_seed(42) other_dim = 64 num_params = 3 # --- Build sharded params + grads --- muon_params = [] muon_names = [] full_params_snapshot = [] full_grads = [] for i in range(num_params): full = torch.randn(uneven_dim, other_dim, device="cuda") full_params_snapshot.append(full.clone()) dt = distribute_tensor(full, mesh, [Shard(0)]) p = torch.nn.Parameter(dt) grad_full = torch.randn(uneven_dim, other_dim, device="cuda") full_grads.append(grad_full.clone()) p.grad = distribute_tensor(grad_full, mesh, [Shard(0)]) muon_params.append(p) muon_names.append(f"layer.{i}.weight") # --- Parallel path (all2all pipeline) --- param_groups_par = [{ "params": muon_params, "names": muon_names, "use_muon": True, "lr": 0.02, "weight_decay": 0.01, "momentum": 0.95, "nesterov": True, "ns_steps": 5, "none_grad": False, }] optim_par = Muon(params=param_groups_par, chunk_size=1, warmup_step=0) optim_par.step() # --- Sequential baseline (base path, no sharding) --- seq_params = [] seq_names = [] for i in range(num_params): p = torch.nn.Parameter(full_params_snapshot[i].clone()) p.grad = full_grads[i].clone() seq_params.append(p) seq_names.append(f"layer.{i}.weight") param_groups_seq = [{ "params": seq_params, "names": seq_names, "use_muon": True, "lr": 0.02, "weight_decay": 0.01, "momentum": 0.95, "nesterov": True, "ns_steps": 5, "none_grad": False, }] optim_seq = Muon(params=param_groups_seq) optim_seq.step() # --- Compare: parallel result (gathered) must match sequential --- for i in range(num_params): par_full = muon_params[i].data.full_tensor() seq_full = seq_params[i].data torch.testing.assert_close(par_full, seq_full, atol=0, rtol=0) set_ns_compile(True) logger.info("test_parallel_muon_uneven_shard (dim=%d) PASSED (rank %d)", uneven_dim, rank) def test_pp_dp_replicate_no_deadlock(init_dist, inputs): """PP regression test using real Motif model. PP=2, dp_replicate=2, dp_shard=2 on 8 GPUs. Splits the Motif-2.6B-4layer model across 2 pipeline stages following the torchtitan pattern (deep copy → delete non-stage layers → per-stage FSDP). Each stage independently runs Muon optimizer and the result is verified against a sequential baseline (atol=0, rtol=0). Without use_local_synchronization=True in construct_shard_mesh(), different stages would deadlock on dist.new_group() because they call it for different parameters. """ import re import torch.nn as nn from optimizer.distributed.utils import _ranks_to_dist_cache rank = dist.get_rank() assert dist.get_world_size() == 8 set_ns_compile(False) _ranks_to_dist_cache.clear() model_orig, grads_orig, _ = inputs # Build name→grad mapping from original model grad_dict = { name: grad for (name, _), grad in zip(model_orig.named_parameters(), grads_orig) } # Full mesh: PP=2, dp_replicate=2, dp_shard=2 full_mesh = dist.init_device_mesh( "cuda", (2, 2, 2), mesh_dim_names=("pp", "dp_replicate", "dp_shard"), ) dp_mesh = full_mesh["dp_replicate", "dp_shard"] pp_rank = full_mesh.get_local_rank("pp") # -- Helpers ---------------------------------------------------------- def _split_motif(model): """Split Motif model per PP stage (torchtitan pattern). Stage 0: embed_tokens + layers[0:2] Stage 1: layers[2:4] + norm + output Non-stage components replaced with nn.Identity (no params). """ all_layers = list(model.model.layers) if pp_rank == 0: model.model.layers = nn.ModuleList(all_layers[:2]) model.model.norm = nn.Identity() if hasattr(model, "output"): model.output = nn.Identity() if hasattr(model, "lm_head"): model.lm_head = nn.Identity() else: model.model.layers = nn.ModuleList(all_layers[2:]) model.model.embed_tokens = nn.Identity() return model layer_offset = 0 if pp_rank == 0 else 2 def _remap(name): """Map stage param name → original param name (layer index offset). Also handles weight tying: Motif ties lm_head.weight to model.embed_tokens.weight, so named_parameters() only lists the latter. After stage-split, stage 1 loses embed_tokens but keeps lm_head, so we remap it back. """ # Weight tying: lm_head.weight ↔ model.embed_tokens.weight if name == "lm_head.weight": return "model.embed_tokens.weight" if layer_offset == 0: return name def _replace(m): return f"layers.{int(m.group(1)) + layer_offset}." return re.sub(r"layers\.(\d+)\.", _replace, name) def _stage_grads(model): """Build grads list aligned with stage model parameters.""" return [grad_dict[_remap(n)] for n, _ in model.named_parameters()] # -- Parallel path: split → FSDP → Muon step ------------------------- par_model = _split_motif(copy.deepcopy(model_orig).cuda()) _apply_fsdp(par_model, dp_mesh) par_model, _ = apply_muon_step( model=par_model, parallel_dims=None, grads=_stage_grads(par_model), warmup_step=5, chunk_size=2, qk_logits=None, ) # -- Sequential baseline: split → no FSDP → base Muon ---------------- seq_model = _split_motif(copy.deepcopy(model_orig).cuda()) seq_model, _ = apply_muon_step( model=seq_model, parallel_dims=None, grads=_stage_grads(seq_model), warmup_step=-1, chunk_size=-1, qk_logits=None, ) # Correctness: parallel must match sequential exactly assert_params_equal(par_model, seq_model, atol=0, rtol=0) set_ns_compile(True) logger.info( "test_pp_dp_replicate_no_deadlock PASSED (rank %d, pp_rank %d)", rank, pp_rank)