Kernels
optimizer / test /test_muon.py
wyldecat's picture
Replace toy PP tests with real-model-based pipeline tests [skip-build]
67f7e11
import copy
import logging
import time
from contextlib import nullcontext
import pytest
import torch
import torch.distributed as dist
from optimizer.muon import Muon, get_default_muon_param_groups
from optimizer.newton_schulz import set_ns_compile
from torch.distributed.tensor import (DTensor, Replicate, Shard,
distribute_tensor)
from torch.profiler import ProfilerActivity, profile
from .utils import (ParallelDims, _apply_fsdp, assert_params_equal,
parallelize_motif, parallelize_qk_logits)
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def apply_muon_step(
model: torch.nn.Module,
parallel_dims: ParallelDims | None,
grads: list[torch.Tensor],
warmup_step: int,
chunk_size: int,
qk_logits: dict[int, torch.Tensor] | None = None,
use_distributed_muon: bool = False,
measure_perf: bool = False,
do_profile: bool = False,
test_name: str | None = None,
) -> tuple[torch.nn.Module, tuple[float, float] | None]:
""" apply single Muon step with optional QK clipping """
# 1. Apply gradients to model parameters
assert len(grads) == len(list(model.parameters()))
for grad, param in zip(grads, model.parameters()):
grad = grad.to(param.device)
if isinstance(param.data, DTensor):
unsharded_grad = DTensor.from_local(
grad,
device_mesh=param.data.device_mesh,
placements=[Replicate()] * param.data.device_mesh.ndim,
)
sharded_grad = unsharded_grad.redistribute(
device_mesh=param.data.device_mesh,
placements=param.data.placements)
param.grad = sharded_grad
else:
param.grad = grad
# 2. Setup Muon optimizer
params = get_default_muon_param_groups(model)
clip_config = dict({
"q_indices":
list(range(model.config.num_attention_heads)),
"k_indices":
list(range(model.config.num_attention_heads)),
"head_dim":
model.config.hidden_size // model.config.num_attention_heads,
"threshold":
0.5
})
optim = Muon(
params=params,
clip_config=clip_config if qk_logits is not None else None,
none_grad=False,
warmup_step=warmup_step,
chunk_size=chunk_size,
use_distributed_muon=use_distributed_muon,
)
optim.step(qk_logits=qk_logits)
timing_result: tuple[float, float] | None = None
if measure_perf:
# extra warm up
optim.step(qk_logits=qk_logits)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.reset_peak_memory_stats()
start.record()
num_iters = 20
if do_profile:
context = profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True)
else:
context = nullcontext()
with context as prof:
for _i in range(num_iters):
optim.step(qk_logits=qk_logits)
end.record()
end.synchronize()
if prof is not None:
date = time.strftime("%Y%m%d_%H%M%S", time.localtime())
name = test_name or "trace"
rank = dist.get_rank()
prof.export_chrome_trace(f"{name}_{date}_rank{rank}.json")
peak_memory = torch.cuda.max_memory_allocated()
elapsed_time_ms = start.elapsed_time(end) / num_iters
timing_result = (elapsed_time_ms, peak_memory)
return model, timing_result
@pytest.fixture(scope="session")
def sequential_muon_result(
skip_verify, # from conftest.py
inputs # from conftest.py
) -> dict[tuple[bool, bool], torch.nn.Module]:
"""Run Muon optimizer to sequential model for baseline results.
Returns dict keyed by ``(apply_qk_clip, use_compile)``.
"""
if skip_verify:
logger.info("Skipping verification tests as per user request")
return None
model, grads, qk_logits = inputs
results: dict[tuple[bool, bool], torch.nn.Module] = {}
for use_compile in [False, True]:
set_ns_compile(use_compile)
results[(False, use_compile)] = apply_muon_step(
model=copy.deepcopy(model).cuda(),
parallel_dims=None,
grads=grads,
warmup_step=-1,
chunk_size=-1,
qk_logits=None,
)[0].cpu()
results[(True, use_compile)] = apply_muon_step(
model=copy.deepcopy(model).cuda(),
parallel_dims=None,
grads=grads,
warmup_step=-1,
chunk_size=-1,
qk_logits=qk_logits,
)[0].cpu()
set_ns_compile(True) # restore default
return results
OVERLAP_STEPS = [5]
CHUNK_SIZES = [2]
@pytest.mark.parametrize("parallel_dims", [
pytest.param(ParallelDims(8, 1, 1), id="base"),
pytest.param(ParallelDims(1, 8, 1), id="fsdp"),
pytest.param(ParallelDims(2, 4, 1), id="hsdp"),
pytest.param(ParallelDims(1, 1, 8), id="tp"),
pytest.param(ParallelDims(2, 2, 2), id="hsdp+tp"),
pytest.param(ParallelDims(1, 2, 4), id="fsdp+tp"),
])
@pytest.mark.parametrize("apply_qk_clip", [False, True])
@pytest.mark.parametrize("use_distributed_muon", [False])
@pytest.mark.parametrize("warmup_step", OVERLAP_STEPS)
@pytest.mark.parametrize("chunk_size", CHUNK_SIZES)
@pytest.mark.parametrize("use_compile", [False, True])
def test_parallel_muon(
request,
sequential_muon_result: dict[tuple[bool, bool], torch.nn.Module],
parallel_dims: ParallelDims,
apply_qk_clip: bool,
use_distributed_muon: bool,
warmup_step: int,
chunk_size: int,
use_compile: bool,
inputs: tuple[torch.nn.Module, list[torch.Tensor],
dict[int, torch.Tensor]], # from conftest.py
measure_perf, # from conftest.py
do_profile, # from conftest.py
) -> None:
if use_distributed_muon and chunk_size != CHUNK_SIZES[0]:
pytest.skip("Distributed Muon does not effected by chunk size")
if use_distributed_muon and warmup_step != OVERLAP_STEPS[0]:
pytest.skip("Distributed Muon does not effected by warmup step")
set_ns_compile(use_compile)
model, grads, qk_logits = inputs
if not apply_qk_clip:
qk_logits = None
# Deepcopy the model to avoid in-place modification
model = copy.deepcopy(model).cuda()
parallelized_model = parallelize_motif(model, parallel_dims)
if qk_logits is not None:
# Deepcopy the qk logits to avoid in-place modification
qk_logits = copy.deepcopy(qk_logits)
qk_logits = parallelize_qk_logits(qk_logits, parallel_dims)
parallelized_model, timing_result = apply_muon_step(
model=parallelized_model,
parallel_dims=parallel_dims,
grads=grads,
warmup_step=warmup_step,
chunk_size=chunk_size,
qk_logits=qk_logits,
use_distributed_muon=use_distributed_muon,
measure_perf=measure_perf,
do_profile=do_profile,
test_name=request.node.name,
)
if measure_perf:
assert timing_result is not None
avg_time_ms, peak_memory = timing_result
logger.info(
f"\nParallel dims: {parallel_dims}, "
f"\nUse distributed Muon: {use_distributed_muon}, "
f"\nApply QK clip: {apply_qk_clip} => "
f"\nChunk Size, Warmup Step, Avg Time (ms), Peak Memory (MB):"
f"\n{chunk_size}, {warmup_step}, {avg_time_ms:.2f}, {peak_memory / (1024**2):.2f},"
)
if sequential_muon_result is None:
logger.info("Skipping correctness check as sequential result is None")
elif measure_perf:
logger.info("Skipping correctness check as timing is enabled")
else:
atol = 1e-5 if use_compile else 0
rtol = 1e-2 if use_compile else 0
assert_params_equal(parallelized_model,
sequential_muon_result[(apply_qk_clip,
use_compile)],
atol=atol,
rtol=rtol)
def test_parallel_muon_empty_shard(init_dist):
"""Regression: parallel Muon must handle chunks where some ranks have
empty local shards (dim-0 < world_size).
With 8-way Shard(0) and dim-0 of size 4, ranks 4-7 get 0-element local
shards. Previously ``_launch_gather`` hit ``assert total_send > 0``.
"""
rank = dist.get_rank()
world_size = dist.get_world_size()
mesh = dist.init_device_mesh("cuda", (world_size, ),
mesh_dim_names=("dp", ))
set_ns_compile(False)
# dim-0 = 4 < 8 ranks → ranks 4-7 have empty local shards with Shard(0)
small_dim = 4
num_params = 4
torch.manual_seed(42)
muon_params = []
muon_names = []
for i in range(num_params):
full = torch.randn(small_dim, 64, device="cuda")
dt = distribute_tensor(full, mesh, [Shard(0)])
p = torch.nn.Parameter(dt)
grad_full = torch.randn(small_dim, 64, device="cuda")
p.grad = distribute_tensor(grad_full, mesh, [Shard(0)])
muon_params.append(p)
muon_names.append(f"layer.{i}.weight")
param_groups = [{
"params": muon_params,
"names": muon_names,
"use_muon": True,
"lr": 0.02,
"weight_decay": 0.01,
"momentum": 0.95,
"nesterov": True,
"ns_steps": 5,
"none_grad": False,
}]
optim = Muon(params=param_groups, chunk_size=1, warmup_step=0)
# Must not raise AssertionError: total_send > 0
optim.step()
# Run a second step to verify cached path also works
for p in muon_params:
grad_full = torch.randn(small_dim, 64, device="cuda")
p.grad = distribute_tensor(grad_full, mesh, [Shard(0)])
optim.step()
set_ns_compile(True)
logger.info("test_parallel_muon_empty_shard PASSED (rank %d)", rank)
@pytest.mark.parametrize("uneven_dim", [
pytest.param(33, id="33"),
pytest.param(19, id="19"),
pytest.param(11, id="11"),
])
def test_parallel_muon_uneven_shard(init_dist, uneven_dim):
"""Test that parallel Muon produces correct results when parameter
dimensions are not evenly divisible by the number of shard ranks.
For example, dim=33 with 8 ranks gives 7 ranks with 4 rows and
1 rank with 5 rows. This exercises the remainder-handling logic
in ``get_slices_of_dtensor`` and the all-to-all pipeline.
"""
rank = dist.get_rank()
world_size = dist.get_world_size()
mesh = dist.init_device_mesh("cuda", (world_size, ),
mesh_dim_names=("dp", ))
set_ns_compile(False)
torch.manual_seed(42)
other_dim = 64
num_params = 3
# --- Build sharded params + grads ---
muon_params = []
muon_names = []
full_params_snapshot = []
full_grads = []
for i in range(num_params):
full = torch.randn(uneven_dim, other_dim, device="cuda")
full_params_snapshot.append(full.clone())
dt = distribute_tensor(full, mesh, [Shard(0)])
p = torch.nn.Parameter(dt)
grad_full = torch.randn(uneven_dim, other_dim, device="cuda")
full_grads.append(grad_full.clone())
p.grad = distribute_tensor(grad_full, mesh, [Shard(0)])
muon_params.append(p)
muon_names.append(f"layer.{i}.weight")
# --- Parallel path (all2all pipeline) ---
param_groups_par = [{
"params": muon_params,
"names": muon_names,
"use_muon": True,
"lr": 0.02,
"weight_decay": 0.01,
"momentum": 0.95,
"nesterov": True,
"ns_steps": 5,
"none_grad": False,
}]
optim_par = Muon(params=param_groups_par, chunk_size=1, warmup_step=0)
optim_par.step()
# --- Sequential baseline (base path, no sharding) ---
seq_params = []
seq_names = []
for i in range(num_params):
p = torch.nn.Parameter(full_params_snapshot[i].clone())
p.grad = full_grads[i].clone()
seq_params.append(p)
seq_names.append(f"layer.{i}.weight")
param_groups_seq = [{
"params": seq_params,
"names": seq_names,
"use_muon": True,
"lr": 0.02,
"weight_decay": 0.01,
"momentum": 0.95,
"nesterov": True,
"ns_steps": 5,
"none_grad": False,
}]
optim_seq = Muon(params=param_groups_seq)
optim_seq.step()
# --- Compare: parallel result (gathered) must match sequential ---
for i in range(num_params):
par_full = muon_params[i].data.full_tensor()
seq_full = seq_params[i].data
torch.testing.assert_close(par_full, seq_full, atol=0, rtol=0)
set_ns_compile(True)
logger.info("test_parallel_muon_uneven_shard (dim=%d) PASSED (rank %d)",
uneven_dim, rank)
def test_pp_dp_replicate_no_deadlock(init_dist, inputs):
"""PP regression test using real Motif model.
PP=2, dp_replicate=2, dp_shard=2 on 8 GPUs. Splits the
Motif-2.6B-4layer model across 2 pipeline stages following the
torchtitan pattern (deep copy → delete non-stage layers → per-stage
FSDP). Each stage independently runs Muon optimizer and the result
is verified against a sequential baseline (atol=0, rtol=0).
Without use_local_synchronization=True in construct_shard_mesh(),
different stages would deadlock on dist.new_group() because they
call it for different parameters.
"""
import re
import torch.nn as nn
from optimizer.distributed.utils import _ranks_to_dist_cache
rank = dist.get_rank()
assert dist.get_world_size() == 8
set_ns_compile(False)
_ranks_to_dist_cache.clear()
model_orig, grads_orig, _ = inputs
# Build name→grad mapping from original model
grad_dict = {
name: grad
for (name, _), grad in zip(model_orig.named_parameters(), grads_orig)
}
# Full mesh: PP=2, dp_replicate=2, dp_shard=2
full_mesh = dist.init_device_mesh(
"cuda",
(2, 2, 2),
mesh_dim_names=("pp", "dp_replicate", "dp_shard"),
)
dp_mesh = full_mesh["dp_replicate", "dp_shard"]
pp_rank = full_mesh.get_local_rank("pp")
# -- Helpers ----------------------------------------------------------
def _split_motif(model):
"""Split Motif model per PP stage (torchtitan pattern).
Stage 0: embed_tokens + layers[0:2]
Stage 1: layers[2:4] + norm + output
Non-stage components replaced with nn.Identity (no params).
"""
all_layers = list(model.model.layers)
if pp_rank == 0:
model.model.layers = nn.ModuleList(all_layers[:2])
model.model.norm = nn.Identity()
if hasattr(model, "output"):
model.output = nn.Identity()
if hasattr(model, "lm_head"):
model.lm_head = nn.Identity()
else:
model.model.layers = nn.ModuleList(all_layers[2:])
model.model.embed_tokens = nn.Identity()
return model
layer_offset = 0 if pp_rank == 0 else 2
def _remap(name):
"""Map stage param name → original param name (layer index offset).
Also handles weight tying: Motif ties lm_head.weight to
model.embed_tokens.weight, so named_parameters() only lists the
latter. After stage-split, stage 1 loses embed_tokens but keeps
lm_head, so we remap it back.
"""
# Weight tying: lm_head.weight ↔ model.embed_tokens.weight
if name == "lm_head.weight":
return "model.embed_tokens.weight"
if layer_offset == 0:
return name
def _replace(m):
return f"layers.{int(m.group(1)) + layer_offset}."
return re.sub(r"layers\.(\d+)\.", _replace, name)
def _stage_grads(model):
"""Build grads list aligned with stage model parameters."""
return [grad_dict[_remap(n)] for n, _ in model.named_parameters()]
# -- Parallel path: split → FSDP → Muon step -------------------------
par_model = _split_motif(copy.deepcopy(model_orig).cuda())
_apply_fsdp(par_model, dp_mesh)
par_model, _ = apply_muon_step(
model=par_model,
parallel_dims=None,
grads=_stage_grads(par_model),
warmup_step=5,
chunk_size=2,
qk_logits=None,
)
# -- Sequential baseline: split → no FSDP → base Muon ----------------
seq_model = _split_motif(copy.deepcopy(model_orig).cuda())
seq_model, _ = apply_muon_step(
model=seq_model,
parallel_dims=None,
grads=_stage_grads(seq_model),
warmup_step=-1,
chunk_size=-1,
qk_logits=None,
)
# Correctness: parallel must match sequential exactly
assert_params_equal(par_model, seq_model, atol=0, rtol=0)
set_ns_compile(True)
logger.info(
"test_pp_dp_replicate_no_deadlock PASSED (rank %d, pp_rank %d)", rank,
pp_rank)