Kernels
wyldecat Claude Opus 4.6 commited on
Commit
67f7e11
·
1 Parent(s): a4d1f34

Replace toy PP tests with real-model-based pipeline tests [skip-build]

Browse files

Use Motif-2.6B (dense) and torchtitan Llama4 MoE (MoE) models with
realistic PP model splitting (deep copy → delete non-stage layers →
per-stage FSDP) matching actual torchtitan training pipeline. MoE test
uses torchtitan's parallelize_llama() directly. Both tests verify
correctness against sequential baseline with atol=0, rtol=0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (2) hide show
  1. test/test_muon.py +95 -102
  2. test/test_muon_moe.py +91 -135
test/test_muon.py CHANGED
@@ -12,8 +12,8 @@ from torch.distributed.tensor import (DTensor, Replicate, Shard,
12
  distribute_tensor)
13
  from torch.profiler import ProfilerActivity, profile
14
 
15
- from .utils import (ParallelDims, assert_params_equal, parallelize_motif,
16
- parallelize_qk_logits)
17
 
18
  logger = logging.getLogger(__name__)
19
  logging.basicConfig(level=logging.INFO)
@@ -393,126 +393,119 @@ def test_parallel_muon_uneven_shard(init_dist, uneven_dim):
393
  uneven_dim, rank)
394
 
395
 
396
- def test_pp_dp_replicate_no_deadlock(init_dist):
397
- """Regression: PP-like setup where different rank subsets call
398
- construct_shard_mesh for different parameters must not deadlock.
399
- Also verifies correctness (atol=0, rtol=0) against sequential baseline.
400
 
401
- Simulates PP=2 with dp_replicate=2, dp_shard=2. Each PP stage has
402
- 4 ranks with a (2,2) mesh and [Replicate, Shard(0)] placements
403
- (created via fully_shard, matching the real HSDP pattern).
404
- Stages create different numbers of layers, forcing
405
- construct_shard_mesh to be called independently per stage.
406
- Without use_local_synchronization=True in dist.new_group(),
407
- this would deadlock.
 
 
408
  """
 
 
 
409
  from optimizer.distributed.utils import _ranks_to_dist_cache
410
- from optimizer.newton_schulz import set_ns_compile
411
- from torch.distributed.fsdp import fully_shard
412
 
413
  rank = dist.get_rank()
414
- world_size = dist.get_world_size()
415
- assert world_size == 8
416
 
417
  set_ns_compile(False)
418
-
419
- # Clear cache to ensure dist.new_group is actually called
420
  _ranks_to_dist_cache.clear()
421
 
422
- # Create full mesh: PP=2, dp_replicate=2, dp_shard=2
 
 
 
 
 
 
 
 
423
  full_mesh = dist.init_device_mesh(
424
  "cuda",
425
  (2, 2, 2),
426
  mesh_dim_names=("pp", "dp_replicate", "dp_shard"),
427
  )
428
-
429
- stage_mesh = full_mesh["dp_replicate", "dp_shard"]
430
  pp_rank = full_mesh.get_local_rank("pp")
431
 
432
- # Asymmetric layer counts per stage (mimics PP)
433
- num_layers = 3 if pp_rank == 0 else 5
434
- hidden = 64
435
-
436
- # Same seed per stage so all ranks in a stage get identical init weights
437
- torch.manual_seed(42 + pp_rank)
438
-
439
- # Create model and save initial state for sequential baseline
440
- model = torch.nn.Sequential(*[
441
- torch.nn.Linear(hidden, hidden, bias=False) for _ in range(num_layers)
442
- ]).cuda()
443
-
444
- init_state = {n: p.data.clone() for n, p in model.named_parameters()}
445
- grads = {n: torch.randn_like(p) for n, p in model.named_parameters()}
446
-
447
- # Apply FSDP (creates proper DTensors with [Replicate, Shard(0)])
448
- for layer in model:
449
- fully_shard(layer, mesh=stage_mesh)
450
- fully_shard(model, mesh=stage_mesh)
451
- model.reshard()
452
-
453
- # Apply grads with proper DTensor redistribution
454
- for n, p in model.named_parameters():
455
- g = grads[n]
456
- if isinstance(p.data, DTensor):
457
- ug = DTensor.from_local(
458
- g,
459
- device_mesh=p.data.device_mesh,
460
- placements=[Replicate()] * p.data.device_mesh.ndim,
461
- )
462
- p.grad = ug.redistribute(device_mesh=p.data.device_mesh,
463
- placements=p.data.placements)
464
  else:
465
- p.grad = g
466
-
467
- # Parallel Muon step — must not deadlock
468
- muon_names = [n for n, _ in model.named_parameters()]
469
- muon_params = [p for _, p in model.named_parameters()]
470
- param_groups = [{
471
- "params": muon_params,
472
- "names": muon_names,
473
- "use_muon": True,
474
- "lr": 0.02,
475
- "weight_decay": 0.01,
476
- "momentum": 0.95,
477
- "nesterov": True,
478
- "ns_steps": 5,
479
- "none_grad": False,
480
- }]
481
- optim = Muon(params=param_groups, chunk_size=1, warmup_step=0)
482
- optim.step()
483
-
484
- # Sequential baseline (base path, no sharding)
485
- torch.manual_seed(42 + pp_rank)
486
- model_seq = torch.nn.Sequential(*[
487
- torch.nn.Linear(hidden, hidden, bias=False) for _ in range(num_layers)
488
- ]).cuda()
489
-
490
- for n, p in model_seq.named_parameters():
491
- p.grad = grads[n].clone()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
492
 
493
- seq_names = [n for n, _ in model_seq.named_parameters()]
494
- seq_params = [p for _, p in model_seq.named_parameters()]
495
- param_groups_seq = [{
496
- "params": seq_params,
497
- "names": seq_names,
498
- "use_muon": True,
499
- "lr": 0.02,
500
- "weight_decay": 0.01,
501
- "momentum": 0.95,
502
- "nesterov": True,
503
- "ns_steps": 5,
504
- "none_grad": False,
505
- }]
506
- optim_seq = Muon(params=param_groups_seq)
507
- optim_seq.step()
508
 
509
  # Correctness: parallel must match sequential exactly
510
- for (n_par, p_par), (n_seq, p_seq) in zip(model.named_parameters(),
511
- model_seq.named_parameters()):
512
- par_data = p_par.data
513
- if isinstance(par_data, DTensor):
514
- par_data = par_data.full_tensor()
515
- torch.testing.assert_close(par_data, p_seq.data, atol=0, rtol=0)
516
 
517
  set_ns_compile(True)
518
  logger.info(
 
12
  distribute_tensor)
13
  from torch.profiler import ProfilerActivity, profile
14
 
15
+ from .utils import (ParallelDims, _apply_fsdp, assert_params_equal,
16
+ parallelize_motif, parallelize_qk_logits)
17
 
18
  logger = logging.getLogger(__name__)
19
  logging.basicConfig(level=logging.INFO)
 
393
  uneven_dim, rank)
394
 
395
 
396
+ def test_pp_dp_replicate_no_deadlock(init_dist, inputs):
397
+ """PP regression test using real Motif model.
 
 
398
 
399
+ PP=2, dp_replicate=2, dp_shard=2 on 8 GPUs. Splits the
400
+ Motif-2.6B-4layer model across 2 pipeline stages following the
401
+ torchtitan pattern (deep copy delete non-stage layers per-stage
402
+ FSDP). Each stage independently runs Muon optimizer and the result
403
+ is verified against a sequential baseline (atol=0, rtol=0).
404
+
405
+ Without use_local_synchronization=True in construct_shard_mesh(),
406
+ different stages would deadlock on dist.new_group() because they
407
+ call it for different parameters.
408
  """
409
+ import re
410
+
411
+ import torch.nn as nn
412
  from optimizer.distributed.utils import _ranks_to_dist_cache
 
 
413
 
414
  rank = dist.get_rank()
415
+ assert dist.get_world_size() == 8
 
416
 
417
  set_ns_compile(False)
 
 
418
  _ranks_to_dist_cache.clear()
419
 
420
+ model_orig, grads_orig, _ = inputs
421
+
422
+ # Build name→grad mapping from original model
423
+ grad_dict = {
424
+ name: grad
425
+ for (name, _), grad in zip(model_orig.named_parameters(), grads_orig)
426
+ }
427
+
428
+ # Full mesh: PP=2, dp_replicate=2, dp_shard=2
429
  full_mesh = dist.init_device_mesh(
430
  "cuda",
431
  (2, 2, 2),
432
  mesh_dim_names=("pp", "dp_replicate", "dp_shard"),
433
  )
434
+ dp_mesh = full_mesh["dp_replicate", "dp_shard"]
 
435
  pp_rank = full_mesh.get_local_rank("pp")
436
 
437
+ # -- Helpers ----------------------------------------------------------
438
+ def _split_motif(model):
439
+ """Split Motif model per PP stage (torchtitan pattern).
440
+
441
+ Stage 0: embed_tokens + layers[0:2]
442
+ Stage 1: layers[2:4] + norm + output
443
+ Non-stage components replaced with nn.Identity (no params).
444
+ """
445
+ all_layers = list(model.model.layers)
446
+ if pp_rank == 0:
447
+ model.model.layers = nn.ModuleList(all_layers[:2])
448
+ model.model.norm = nn.Identity()
449
+ if hasattr(model, "output"):
450
+ model.output = nn.Identity()
451
+ if hasattr(model, "lm_head"):
452
+ model.lm_head = nn.Identity()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  else:
454
+ model.model.layers = nn.ModuleList(all_layers[2:])
455
+ model.model.embed_tokens = nn.Identity()
456
+ return model
457
+
458
+ layer_offset = 0 if pp_rank == 0 else 2
459
+
460
+ def _remap(name):
461
+ """Map stage param name → original param name (layer index offset).
462
+
463
+ Also handles weight tying: Motif ties lm_head.weight to
464
+ model.embed_tokens.weight, so named_parameters() only lists the
465
+ latter. After stage-split, stage 1 loses embed_tokens but keeps
466
+ lm_head, so we remap it back.
467
+ """
468
+ # Weight tying: lm_head.weight ↔ model.embed_tokens.weight
469
+ if name == "lm_head.weight":
470
+ return "model.embed_tokens.weight"
471
+
472
+ if layer_offset == 0:
473
+ return name
474
+
475
+ def _replace(m):
476
+ return f"layers.{int(m.group(1)) + layer_offset}."
477
+
478
+ return re.sub(r"layers\.(\d+)\.", _replace, name)
479
+
480
+ def _stage_grads(model):
481
+ """Build grads list aligned with stage model parameters."""
482
+ return [grad_dict[_remap(n)] for n, _ in model.named_parameters()]
483
+
484
+ # -- Parallel path: split → FSDP → Muon step -------------------------
485
+ par_model = _split_motif(copy.deepcopy(model_orig).cuda())
486
+ _apply_fsdp(par_model, dp_mesh)
487
+ par_model, _ = apply_muon_step(
488
+ model=par_model,
489
+ parallel_dims=None,
490
+ grads=_stage_grads(par_model),
491
+ warmup_step=5,
492
+ chunk_size=2,
493
+ qk_logits=None,
494
+ )
495
 
496
+ # -- Sequential baseline: split no FSDP → base Muon ----------------
497
+ seq_model = _split_motif(copy.deepcopy(model_orig).cuda())
498
+ seq_model, _ = apply_muon_step(
499
+ model=seq_model,
500
+ parallel_dims=None,
501
+ grads=_stage_grads(seq_model),
502
+ warmup_step=-1,
503
+ chunk_size=-1,
504
+ qk_logits=None,
505
+ )
 
 
 
 
 
506
 
507
  # Correctness: parallel must match sequential exactly
508
+ assert_params_equal(par_model, seq_model, atol=0, rtol=0)
 
 
 
 
 
509
 
510
  set_ns_compile(True)
511
  logger.info(
test/test_muon_moe.py CHANGED
@@ -404,157 +404,113 @@ def test_parallel_muon_moe_uneven_shard(init_dist, uneven_dim):
404
  uneven_dim, rank)
405
 
406
 
407
- def test_pp_dp_replicate_moe_no_deadlock(init_dist):
408
- """Regression: PP-like MoE setup where different stages have different
409
- parameter types must not deadlock in construct_shard_mesh.
410
- Also verifies correctness (atol=0, rtol=0) against sequential baseline.
411
-
412
- Simulates PP=2 with dp_replicate=2, dp_shard=2. Stage 0 has only
413
- non-expert 2D FSDP-sharded params; stage 1 has 2D FSDP-sharded params
414
- plus 3D expert plain-tensor params. This mirrors real PP+MoE where
415
- expert layers exist only in certain stages.
 
 
 
 
416
  """
417
  from optimizer.distributed.utils import _ranks_to_dist_cache
418
  from optimizer.newton_schulz import set_ns_compile
419
- from torch.distributed.fsdp import fully_shard
 
 
420
 
421
  rank = dist.get_rank()
422
- world_size = dist.get_world_size()
423
- assert world_size == 8
424
 
425
  set_ns_compile(False)
426
-
427
- # Clear cache to ensure dist.new_group is actually called
428
  _ranks_to_dist_cache.clear()
429
 
430
- # Create full mesh: PP=2, dp_replicate=2, dp_shard=2
431
- full_mesh = dist.init_device_mesh(
432
- "cuda",
433
- (2, 2, 2),
434
- mesh_dim_names=("pp", "dp_replicate", "dp_shard"),
 
 
 
 
 
 
 
 
 
 
 
 
 
435
  )
436
 
437
- stage_mesh = full_mesh["dp_replicate", "dp_shard"]
438
- pp_rank = full_mesh.get_local_rank("pp")
439
-
440
- num_dense = 2 if pp_rank == 0 else 3
441
- num_experts = 4
442
- hidden = 64
443
-
444
- torch.manual_seed(42 + pp_rank)
445
-
446
- # Create model with dense layers (+ expert param for stage 1)
447
- model = torch.nn.Sequential(*[
448
- torch.nn.Linear(hidden, hidden, bias=False) for _ in range(num_dense)
449
- ]).cuda()
450
-
451
- # Save init state and grads for sequential baseline
452
- init_state = {n: p.data.clone() for n, p in model.named_parameters()}
453
- dense_grads = {n: torch.randn_like(p) for n, p in model.named_parameters()}
454
-
455
- # Expert param (stage 1 only, plain tensor — not FSDP-sharded)
456
- expert_data = None
457
- expert_grad = None
458
- if pp_rank == 1:
459
- expert_data = torch.randn(num_experts, hidden, hidden, device="cuda")
460
- expert_grad = torch.randn(num_experts, hidden, hidden, device="cuda")
461
-
462
- # Apply FSDP to dense layers
463
- for layer in model:
464
- fully_shard(layer, mesh=stage_mesh)
465
- fully_shard(model, mesh=stage_mesh)
466
- model.reshard()
467
-
468
- # Apply dense grads with DTensor redistribution
469
- for n, p in model.named_parameters():
470
- g = dense_grads[n]
471
- if isinstance(p.data, DTensor):
472
- ug = DTensor.from_local(
473
- g,
474
- device_mesh=p.data.device_mesh,
475
- placements=[Replicate()] * p.data.device_mesh.ndim,
476
- )
477
- p.grad = ug.redistribute(device_mesh=p.data.device_mesh,
478
- placements=p.data.placements)
479
  else:
480
- p.grad = g
481
-
482
- # Build param groups: dense (FSDP DTensors) + expert (plain tensor)
483
- muon_names = [n for n, _ in model.named_parameters()]
484
- muon_params = list(model.parameters())
485
-
486
- if pp_rank == 1:
487
- expert_p = torch.nn.Parameter(expert_data.clone())
488
- expert_p.grad = expert_grad.clone()
489
- muon_params.append(expert_p)
490
- muon_names.append("experts.w1.weight")
491
-
492
- param_groups = [{
493
- "params": muon_params,
494
- "names": muon_names,
495
- "use_muon": True,
496
- "lr": 0.02,
497
- "weight_decay": 0.01,
498
- "momentum": 0.95,
499
- "nesterov": True,
500
- "ns_steps": 5,
501
- "none_grad": False,
502
- }]
503
-
504
- # Must not deadlock
505
- optim = Muon(params=param_groups,
506
- chunk_size=1,
507
- warmup_step=0,
508
- expert_keys=["experts"])
509
- optim.step()
510
-
511
- # Sequential baseline
512
- torch.manual_seed(42 + pp_rank)
513
- model_seq = torch.nn.Sequential(*[
514
- torch.nn.Linear(hidden, hidden, bias=False) for _ in range(num_dense)
515
- ]).cuda()
516
-
517
- seq_names = [n for n, _ in model_seq.named_parameters()]
518
- seq_params = list(model_seq.parameters())
519
-
520
- for n, p in model_seq.named_parameters():
521
- p.grad = dense_grads[n].clone()
522
 
523
- if pp_rank == 1:
524
- expert_p_seq = torch.nn.Parameter(expert_data.clone())
525
- expert_p_seq.grad = expert_grad.clone()
526
- seq_params.append(expert_p_seq)
527
- seq_names.append("experts.w1.weight")
528
 
529
- param_groups_seq = [{
530
- "params": seq_params,
531
- "names": seq_names,
532
- "use_muon": True,
533
- "lr": 0.02,
534
- "weight_decay": 0.01,
535
- "momentum": 0.95,
536
- "nesterov": True,
537
- "ns_steps": 5,
538
- "none_grad": False,
539
- }]
540
- optim_seq = Muon(params=param_groups_seq, expert_keys=["experts"])
541
- optim_seq.step()
542
 
543
  # Correctness: parallel must match sequential exactly
544
- # Dense params
545
- for (n_par, p_par), (n_seq, p_seq) in zip(model.named_parameters(),
546
- model_seq.named_parameters()):
547
- par_data = p_par.data
548
- if isinstance(par_data, DTensor):
549
- par_data = par_data.full_tensor()
550
- torch.testing.assert_close(par_data, p_seq.data, atol=0, rtol=0)
551
-
552
- # Expert params (stage 1 only)
553
- if pp_rank == 1:
554
- torch.testing.assert_close(muon_params[-1].data,
555
- seq_params[-1].data,
556
- atol=0,
557
- rtol=0)
558
 
559
  set_ns_compile(True)
560
  logger.info(
 
404
  uneven_dim, rank)
405
 
406
 
407
+ def test_pp_dp_replicate_moe_no_deadlock(init_dist, moe_inputs):
408
+ """PP regression test using real torchtitan Llama4 MoE model.
409
+
410
+ PP=2, dp_replicate=2, dp_shard=2 on 8 GPUs. Splits the Llama4 MoE
411
+ model (4 layers, 8 experts) across 2 pipeline stages following the
412
+ torchtitan pattern. Uses torchtitan's ``parallelize_llama`` for
413
+ realistic FSDP application (same function as real training).
414
+
415
+ Each stage independently runs Muon optimizer with expert_keys and
416
+ the result is verified against a sequential baseline (atol=0, rtol=0).
417
+
418
+ Without use_local_synchronization=True in construct_shard_mesh(),
419
+ different stages would deadlock on dist.new_group().
420
  """
421
  from optimizer.distributed.utils import _ranks_to_dist_cache
422
  from optimizer.newton_schulz import set_ns_compile
423
+ from torchtitan.config import JobConfig
424
+ from torchtitan.distributed import ParallelDims as TTParallelDims
425
+ from torchtitan.models.llama4.infra.parallelize import parallelize_llama
426
 
427
  rank = dist.get_rank()
428
+ assert dist.get_world_size() == 8
 
429
 
430
  set_ns_compile(False)
 
 
431
  _ranks_to_dist_cache.clear()
432
 
433
+ model_orig, grads_orig = moe_inputs
434
+
435
+ # Build name→grad mapping from original model
436
+ grad_dict = {
437
+ name: grad
438
+ for (name, _), grad in zip(model_orig.named_parameters(), grads_orig)
439
+ }
440
+
441
+ # torchtitan ParallelDims with PP=2 (same as real training config)
442
+ tt_dims = TTParallelDims(
443
+ dp_replicate=2,
444
+ dp_shard=2,
445
+ cp=1,
446
+ tp=1,
447
+ pp=2,
448
+ ep=1,
449
+ etp=1,
450
+ world_size=8,
451
  )
452
 
453
+ # Accessing world_mesh triggers build_mesh() (lazy init).
454
+ # All ranks participate in init_device_mesh (collective).
455
+ pp_rank = tt_dims.world_mesh.get_local_rank("pp")
456
+
457
+ job_config = JobConfig()
458
+ job_config.training.mixed_precision_param = "float32"
459
+ job_config.activation_checkpoint.mode = "none"
460
+ job_config.compile.enable = False
461
+ job_config.parallelism.disable_loss_parallel = True
462
+
463
+ # -- Helpers ----------------------------------------------------------
464
+ def _split_llama4(model):
465
+ """Split Llama4 MoE model per PP stage (torchtitan pattern).
466
+
467
+ Stage 0: tok_embeddings + layers["0"], ["1"]
468
+ Stage 1: layers["2"], ["3"] + norm + output
469
+ ModuleDict preserves keys param names unchanged.
470
+ torchtitan model natively supports None modules in forward().
471
+ """
472
+ if pp_rank == 0:
473
+ for key in ["2", "3"]:
474
+ if key in model.layers:
475
+ del model.layers[key]
476
+ model.norm = None
477
+ model.output = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
  else:
479
+ for key in ["0", "1"]:
480
+ if key in model.layers:
481
+ del model.layers[key]
482
+ model.tok_embeddings = None
483
+ return model
484
+
485
+ def _stage_grads(model):
486
+ """Build grads list aligned with stage model parameters."""
487
+ return [grad_dict[n] for n, _ in model.named_parameters()]
488
+
489
+ # -- Parallel path: split → parallelize_llama → Muon step -------------
490
+ par_model = _split_llama4(copy.deepcopy(model_orig).cuda())
491
+ parallelize_llama(par_model, tt_dims, job_config)
492
+
493
+ par_model, _ = apply_muon_step_moe(
494
+ model=par_model,
495
+ parallel_dims=None,
496
+ grads=_stage_grads(par_model),
497
+ warmup_step=5,
498
+ chunk_size=2,
499
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
 
501
+ # -- Sequential baseline: split → no parallelization → base Muon ------
502
+ seq_model = _split_llama4(copy.deepcopy(model_orig).cuda())
 
 
 
503
 
504
+ seq_model, _ = apply_muon_step_moe(
505
+ model=seq_model,
506
+ parallel_dims=None,
507
+ grads=_stage_grads(seq_model),
508
+ warmup_step=-1,
509
+ chunk_size=-1,
510
+ )
 
 
 
 
 
 
511
 
512
  # Correctness: parallel must match sequential exactly
513
+ assert_params_equal(par_model, seq_model, atol=0, rtol=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
514
 
515
  set_ns_compile(True)
516
  logger.info(