Replace toy PP tests with real-model-based pipeline tests [skip-build]
Browse filesUse Motif-2.6B (dense) and torchtitan Llama4 MoE (MoE) models with
realistic PP model splitting (deep copy → delete non-stage layers →
per-stage FSDP) matching actual torchtitan training pipeline. MoE test
uses torchtitan's parallelize_llama() directly. Both tests verify
correctness against sequential baseline with atol=0, rtol=0.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- test/test_muon.py +95 -102
- test/test_muon_moe.py +91 -135
test/test_muon.py
CHANGED
|
@@ -12,8 +12,8 @@ from torch.distributed.tensor import (DTensor, Replicate, Shard,
|
|
| 12 |
distribute_tensor)
|
| 13 |
from torch.profiler import ProfilerActivity, profile
|
| 14 |
|
| 15 |
-
from .utils import (ParallelDims,
|
| 16 |
-
parallelize_qk_logits)
|
| 17 |
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -393,126 +393,119 @@ def test_parallel_muon_uneven_shard(init_dist, uneven_dim):
|
|
| 393 |
uneven_dim, rank)
|
| 394 |
|
| 395 |
|
| 396 |
-
def test_pp_dp_replicate_no_deadlock(init_dist):
|
| 397 |
-
"""
|
| 398 |
-
construct_shard_mesh for different parameters must not deadlock.
|
| 399 |
-
Also verifies correctness (atol=0, rtol=0) against sequential baseline.
|
| 400 |
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
(
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
|
|
|
|
|
|
| 408 |
"""
|
|
|
|
|
|
|
|
|
|
| 409 |
from optimizer.distributed.utils import _ranks_to_dist_cache
|
| 410 |
-
from optimizer.newton_schulz import set_ns_compile
|
| 411 |
-
from torch.distributed.fsdp import fully_shard
|
| 412 |
|
| 413 |
rank = dist.get_rank()
|
| 414 |
-
|
| 415 |
-
assert world_size == 8
|
| 416 |
|
| 417 |
set_ns_compile(False)
|
| 418 |
-
|
| 419 |
-
# Clear cache to ensure dist.new_group is actually called
|
| 420 |
_ranks_to_dist_cache.clear()
|
| 421 |
|
| 422 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
full_mesh = dist.init_device_mesh(
|
| 424 |
"cuda",
|
| 425 |
(2, 2, 2),
|
| 426 |
mesh_dim_names=("pp", "dp_replicate", "dp_shard"),
|
| 427 |
)
|
| 428 |
-
|
| 429 |
-
stage_mesh = full_mesh["dp_replicate", "dp_shard"]
|
| 430 |
pp_rank = full_mesh.get_local_rank("pp")
|
| 431 |
|
| 432 |
-
#
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
for layer in model:
|
| 449 |
-
fully_shard(layer, mesh=stage_mesh)
|
| 450 |
-
fully_shard(model, mesh=stage_mesh)
|
| 451 |
-
model.reshard()
|
| 452 |
-
|
| 453 |
-
# Apply grads with proper DTensor redistribution
|
| 454 |
-
for n, p in model.named_parameters():
|
| 455 |
-
g = grads[n]
|
| 456 |
-
if isinstance(p.data, DTensor):
|
| 457 |
-
ug = DTensor.from_local(
|
| 458 |
-
g,
|
| 459 |
-
device_mesh=p.data.device_mesh,
|
| 460 |
-
placements=[Replicate()] * p.data.device_mesh.ndim,
|
| 461 |
-
)
|
| 462 |
-
p.grad = ug.redistribute(device_mesh=p.data.device_mesh,
|
| 463 |
-
placements=p.data.placements)
|
| 464 |
else:
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
"
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
"
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 492 |
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
"ns_steps": 5,
|
| 504 |
-
"none_grad": False,
|
| 505 |
-
}]
|
| 506 |
-
optim_seq = Muon(params=param_groups_seq)
|
| 507 |
-
optim_seq.step()
|
| 508 |
|
| 509 |
# Correctness: parallel must match sequential exactly
|
| 510 |
-
|
| 511 |
-
model_seq.named_parameters()):
|
| 512 |
-
par_data = p_par.data
|
| 513 |
-
if isinstance(par_data, DTensor):
|
| 514 |
-
par_data = par_data.full_tensor()
|
| 515 |
-
torch.testing.assert_close(par_data, p_seq.data, atol=0, rtol=0)
|
| 516 |
|
| 517 |
set_ns_compile(True)
|
| 518 |
logger.info(
|
|
|
|
| 12 |
distribute_tensor)
|
| 13 |
from torch.profiler import ProfilerActivity, profile
|
| 14 |
|
| 15 |
+
from .utils import (ParallelDims, _apply_fsdp, assert_params_equal,
|
| 16 |
+
parallelize_motif, parallelize_qk_logits)
|
| 17 |
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 393 |
uneven_dim, rank)
|
| 394 |
|
| 395 |
|
| 396 |
+
def test_pp_dp_replicate_no_deadlock(init_dist, inputs):
|
| 397 |
+
"""PP regression test using real Motif model.
|
|
|
|
|
|
|
| 398 |
|
| 399 |
+
PP=2, dp_replicate=2, dp_shard=2 on 8 GPUs. Splits the
|
| 400 |
+
Motif-2.6B-4layer model across 2 pipeline stages following the
|
| 401 |
+
torchtitan pattern (deep copy → delete non-stage layers → per-stage
|
| 402 |
+
FSDP). Each stage independently runs Muon optimizer and the result
|
| 403 |
+
is verified against a sequential baseline (atol=0, rtol=0).
|
| 404 |
+
|
| 405 |
+
Without use_local_synchronization=True in construct_shard_mesh(),
|
| 406 |
+
different stages would deadlock on dist.new_group() because they
|
| 407 |
+
call it for different parameters.
|
| 408 |
"""
|
| 409 |
+
import re
|
| 410 |
+
|
| 411 |
+
import torch.nn as nn
|
| 412 |
from optimizer.distributed.utils import _ranks_to_dist_cache
|
|
|
|
|
|
|
| 413 |
|
| 414 |
rank = dist.get_rank()
|
| 415 |
+
assert dist.get_world_size() == 8
|
|
|
|
| 416 |
|
| 417 |
set_ns_compile(False)
|
|
|
|
|
|
|
| 418 |
_ranks_to_dist_cache.clear()
|
| 419 |
|
| 420 |
+
model_orig, grads_orig, _ = inputs
|
| 421 |
+
|
| 422 |
+
# Build name→grad mapping from original model
|
| 423 |
+
grad_dict = {
|
| 424 |
+
name: grad
|
| 425 |
+
for (name, _), grad in zip(model_orig.named_parameters(), grads_orig)
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
# Full mesh: PP=2, dp_replicate=2, dp_shard=2
|
| 429 |
full_mesh = dist.init_device_mesh(
|
| 430 |
"cuda",
|
| 431 |
(2, 2, 2),
|
| 432 |
mesh_dim_names=("pp", "dp_replicate", "dp_shard"),
|
| 433 |
)
|
| 434 |
+
dp_mesh = full_mesh["dp_replicate", "dp_shard"]
|
|
|
|
| 435 |
pp_rank = full_mesh.get_local_rank("pp")
|
| 436 |
|
| 437 |
+
# -- Helpers ----------------------------------------------------------
|
| 438 |
+
def _split_motif(model):
|
| 439 |
+
"""Split Motif model per PP stage (torchtitan pattern).
|
| 440 |
+
|
| 441 |
+
Stage 0: embed_tokens + layers[0:2]
|
| 442 |
+
Stage 1: layers[2:4] + norm + output
|
| 443 |
+
Non-stage components replaced with nn.Identity (no params).
|
| 444 |
+
"""
|
| 445 |
+
all_layers = list(model.model.layers)
|
| 446 |
+
if pp_rank == 0:
|
| 447 |
+
model.model.layers = nn.ModuleList(all_layers[:2])
|
| 448 |
+
model.model.norm = nn.Identity()
|
| 449 |
+
if hasattr(model, "output"):
|
| 450 |
+
model.output = nn.Identity()
|
| 451 |
+
if hasattr(model, "lm_head"):
|
| 452 |
+
model.lm_head = nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
else:
|
| 454 |
+
model.model.layers = nn.ModuleList(all_layers[2:])
|
| 455 |
+
model.model.embed_tokens = nn.Identity()
|
| 456 |
+
return model
|
| 457 |
+
|
| 458 |
+
layer_offset = 0 if pp_rank == 0 else 2
|
| 459 |
+
|
| 460 |
+
def _remap(name):
|
| 461 |
+
"""Map stage param name → original param name (layer index offset).
|
| 462 |
+
|
| 463 |
+
Also handles weight tying: Motif ties lm_head.weight to
|
| 464 |
+
model.embed_tokens.weight, so named_parameters() only lists the
|
| 465 |
+
latter. After stage-split, stage 1 loses embed_tokens but keeps
|
| 466 |
+
lm_head, so we remap it back.
|
| 467 |
+
"""
|
| 468 |
+
# Weight tying: lm_head.weight ↔ model.embed_tokens.weight
|
| 469 |
+
if name == "lm_head.weight":
|
| 470 |
+
return "model.embed_tokens.weight"
|
| 471 |
+
|
| 472 |
+
if layer_offset == 0:
|
| 473 |
+
return name
|
| 474 |
+
|
| 475 |
+
def _replace(m):
|
| 476 |
+
return f"layers.{int(m.group(1)) + layer_offset}."
|
| 477 |
+
|
| 478 |
+
return re.sub(r"layers\.(\d+)\.", _replace, name)
|
| 479 |
+
|
| 480 |
+
def _stage_grads(model):
|
| 481 |
+
"""Build grads list aligned with stage model parameters."""
|
| 482 |
+
return [grad_dict[_remap(n)] for n, _ in model.named_parameters()]
|
| 483 |
+
|
| 484 |
+
# -- Parallel path: split → FSDP → Muon step -------------------------
|
| 485 |
+
par_model = _split_motif(copy.deepcopy(model_orig).cuda())
|
| 486 |
+
_apply_fsdp(par_model, dp_mesh)
|
| 487 |
+
par_model, _ = apply_muon_step(
|
| 488 |
+
model=par_model,
|
| 489 |
+
parallel_dims=None,
|
| 490 |
+
grads=_stage_grads(par_model),
|
| 491 |
+
warmup_step=5,
|
| 492 |
+
chunk_size=2,
|
| 493 |
+
qk_logits=None,
|
| 494 |
+
)
|
| 495 |
|
| 496 |
+
# -- Sequential baseline: split → no FSDP → base Muon ----------------
|
| 497 |
+
seq_model = _split_motif(copy.deepcopy(model_orig).cuda())
|
| 498 |
+
seq_model, _ = apply_muon_step(
|
| 499 |
+
model=seq_model,
|
| 500 |
+
parallel_dims=None,
|
| 501 |
+
grads=_stage_grads(seq_model),
|
| 502 |
+
warmup_step=-1,
|
| 503 |
+
chunk_size=-1,
|
| 504 |
+
qk_logits=None,
|
| 505 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 506 |
|
| 507 |
# Correctness: parallel must match sequential exactly
|
| 508 |
+
assert_params_equal(par_model, seq_model, atol=0, rtol=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 509 |
|
| 510 |
set_ns_compile(True)
|
| 511 |
logger.info(
|
test/test_muon_moe.py
CHANGED
|
@@ -404,157 +404,113 @@ def test_parallel_muon_moe_uneven_shard(init_dist, uneven_dim):
|
|
| 404 |
uneven_dim, rank)
|
| 405 |
|
| 406 |
|
| 407 |
-
def test_pp_dp_replicate_moe_no_deadlock(init_dist):
|
| 408 |
-
"""
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
"""
|
| 417 |
from optimizer.distributed.utils import _ranks_to_dist_cache
|
| 418 |
from optimizer.newton_schulz import set_ns_compile
|
| 419 |
-
from
|
|
|
|
|
|
|
| 420 |
|
| 421 |
rank = dist.get_rank()
|
| 422 |
-
|
| 423 |
-
assert world_size == 8
|
| 424 |
|
| 425 |
set_ns_compile(False)
|
| 426 |
-
|
| 427 |
-
# Clear cache to ensure dist.new_group is actually called
|
| 428 |
_ranks_to_dist_cache.clear()
|
| 429 |
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
)
|
| 436 |
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
# Apply FSDP to dense layers
|
| 463 |
-
for layer in model:
|
| 464 |
-
fully_shard(layer, mesh=stage_mesh)
|
| 465 |
-
fully_shard(model, mesh=stage_mesh)
|
| 466 |
-
model.reshard()
|
| 467 |
-
|
| 468 |
-
# Apply dense grads with DTensor redistribution
|
| 469 |
-
for n, p in model.named_parameters():
|
| 470 |
-
g = dense_grads[n]
|
| 471 |
-
if isinstance(p.data, DTensor):
|
| 472 |
-
ug = DTensor.from_local(
|
| 473 |
-
g,
|
| 474 |
-
device_mesh=p.data.device_mesh,
|
| 475 |
-
placements=[Replicate()] * p.data.device_mesh.ndim,
|
| 476 |
-
)
|
| 477 |
-
p.grad = ug.redistribute(device_mesh=p.data.device_mesh,
|
| 478 |
-
placements=p.data.placements)
|
| 479 |
else:
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
"none_grad": False,
|
| 502 |
-
}]
|
| 503 |
-
|
| 504 |
-
# Must not deadlock
|
| 505 |
-
optim = Muon(params=param_groups,
|
| 506 |
-
chunk_size=1,
|
| 507 |
-
warmup_step=0,
|
| 508 |
-
expert_keys=["experts"])
|
| 509 |
-
optim.step()
|
| 510 |
-
|
| 511 |
-
# Sequential baseline
|
| 512 |
-
torch.manual_seed(42 + pp_rank)
|
| 513 |
-
model_seq = torch.nn.Sequential(*[
|
| 514 |
-
torch.nn.Linear(hidden, hidden, bias=False) for _ in range(num_dense)
|
| 515 |
-
]).cuda()
|
| 516 |
-
|
| 517 |
-
seq_names = [n for n, _ in model_seq.named_parameters()]
|
| 518 |
-
seq_params = list(model_seq.parameters())
|
| 519 |
-
|
| 520 |
-
for n, p in model_seq.named_parameters():
|
| 521 |
-
p.grad = dense_grads[n].clone()
|
| 522 |
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
expert_p_seq.grad = expert_grad.clone()
|
| 526 |
-
seq_params.append(expert_p_seq)
|
| 527 |
-
seq_names.append("experts.w1.weight")
|
| 528 |
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
"nesterov": True,
|
| 537 |
-
"ns_steps": 5,
|
| 538 |
-
"none_grad": False,
|
| 539 |
-
}]
|
| 540 |
-
optim_seq = Muon(params=param_groups_seq, expert_keys=["experts"])
|
| 541 |
-
optim_seq.step()
|
| 542 |
|
| 543 |
# Correctness: parallel must match sequential exactly
|
| 544 |
-
|
| 545 |
-
for (n_par, p_par), (n_seq, p_seq) in zip(model.named_parameters(),
|
| 546 |
-
model_seq.named_parameters()):
|
| 547 |
-
par_data = p_par.data
|
| 548 |
-
if isinstance(par_data, DTensor):
|
| 549 |
-
par_data = par_data.full_tensor()
|
| 550 |
-
torch.testing.assert_close(par_data, p_seq.data, atol=0, rtol=0)
|
| 551 |
-
|
| 552 |
-
# Expert params (stage 1 only)
|
| 553 |
-
if pp_rank == 1:
|
| 554 |
-
torch.testing.assert_close(muon_params[-1].data,
|
| 555 |
-
seq_params[-1].data,
|
| 556 |
-
atol=0,
|
| 557 |
-
rtol=0)
|
| 558 |
|
| 559 |
set_ns_compile(True)
|
| 560 |
logger.info(
|
|
|
|
| 404 |
uneven_dim, rank)
|
| 405 |
|
| 406 |
|
| 407 |
+
def test_pp_dp_replicate_moe_no_deadlock(init_dist, moe_inputs):
|
| 408 |
+
"""PP regression test using real torchtitan Llama4 MoE model.
|
| 409 |
+
|
| 410 |
+
PP=2, dp_replicate=2, dp_shard=2 on 8 GPUs. Splits the Llama4 MoE
|
| 411 |
+
model (4 layers, 8 experts) across 2 pipeline stages following the
|
| 412 |
+
torchtitan pattern. Uses torchtitan's ``parallelize_llama`` for
|
| 413 |
+
realistic FSDP application (same function as real training).
|
| 414 |
+
|
| 415 |
+
Each stage independently runs Muon optimizer with expert_keys and
|
| 416 |
+
the result is verified against a sequential baseline (atol=0, rtol=0).
|
| 417 |
+
|
| 418 |
+
Without use_local_synchronization=True in construct_shard_mesh(),
|
| 419 |
+
different stages would deadlock on dist.new_group().
|
| 420 |
"""
|
| 421 |
from optimizer.distributed.utils import _ranks_to_dist_cache
|
| 422 |
from optimizer.newton_schulz import set_ns_compile
|
| 423 |
+
from torchtitan.config import JobConfig
|
| 424 |
+
from torchtitan.distributed import ParallelDims as TTParallelDims
|
| 425 |
+
from torchtitan.models.llama4.infra.parallelize import parallelize_llama
|
| 426 |
|
| 427 |
rank = dist.get_rank()
|
| 428 |
+
assert dist.get_world_size() == 8
|
|
|
|
| 429 |
|
| 430 |
set_ns_compile(False)
|
|
|
|
|
|
|
| 431 |
_ranks_to_dist_cache.clear()
|
| 432 |
|
| 433 |
+
model_orig, grads_orig = moe_inputs
|
| 434 |
+
|
| 435 |
+
# Build name→grad mapping from original model
|
| 436 |
+
grad_dict = {
|
| 437 |
+
name: grad
|
| 438 |
+
for (name, _), grad in zip(model_orig.named_parameters(), grads_orig)
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
# torchtitan ParallelDims with PP=2 (same as real training config)
|
| 442 |
+
tt_dims = TTParallelDims(
|
| 443 |
+
dp_replicate=2,
|
| 444 |
+
dp_shard=2,
|
| 445 |
+
cp=1,
|
| 446 |
+
tp=1,
|
| 447 |
+
pp=2,
|
| 448 |
+
ep=1,
|
| 449 |
+
etp=1,
|
| 450 |
+
world_size=8,
|
| 451 |
)
|
| 452 |
|
| 453 |
+
# Accessing world_mesh triggers build_mesh() (lazy init).
|
| 454 |
+
# All ranks participate in init_device_mesh (collective).
|
| 455 |
+
pp_rank = tt_dims.world_mesh.get_local_rank("pp")
|
| 456 |
+
|
| 457 |
+
job_config = JobConfig()
|
| 458 |
+
job_config.training.mixed_precision_param = "float32"
|
| 459 |
+
job_config.activation_checkpoint.mode = "none"
|
| 460 |
+
job_config.compile.enable = False
|
| 461 |
+
job_config.parallelism.disable_loss_parallel = True
|
| 462 |
+
|
| 463 |
+
# -- Helpers ----------------------------------------------------------
|
| 464 |
+
def _split_llama4(model):
|
| 465 |
+
"""Split Llama4 MoE model per PP stage (torchtitan pattern).
|
| 466 |
+
|
| 467 |
+
Stage 0: tok_embeddings + layers["0"], ["1"]
|
| 468 |
+
Stage 1: layers["2"], ["3"] + norm + output
|
| 469 |
+
ModuleDict preserves keys → param names unchanged.
|
| 470 |
+
torchtitan model natively supports None modules in forward().
|
| 471 |
+
"""
|
| 472 |
+
if pp_rank == 0:
|
| 473 |
+
for key in ["2", "3"]:
|
| 474 |
+
if key in model.layers:
|
| 475 |
+
del model.layers[key]
|
| 476 |
+
model.norm = None
|
| 477 |
+
model.output = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 478 |
else:
|
| 479 |
+
for key in ["0", "1"]:
|
| 480 |
+
if key in model.layers:
|
| 481 |
+
del model.layers[key]
|
| 482 |
+
model.tok_embeddings = None
|
| 483 |
+
return model
|
| 484 |
+
|
| 485 |
+
def _stage_grads(model):
|
| 486 |
+
"""Build grads list aligned with stage model parameters."""
|
| 487 |
+
return [grad_dict[n] for n, _ in model.named_parameters()]
|
| 488 |
+
|
| 489 |
+
# -- Parallel path: split → parallelize_llama → Muon step -------------
|
| 490 |
+
par_model = _split_llama4(copy.deepcopy(model_orig).cuda())
|
| 491 |
+
parallelize_llama(par_model, tt_dims, job_config)
|
| 492 |
+
|
| 493 |
+
par_model, _ = apply_muon_step_moe(
|
| 494 |
+
model=par_model,
|
| 495 |
+
parallel_dims=None,
|
| 496 |
+
grads=_stage_grads(par_model),
|
| 497 |
+
warmup_step=5,
|
| 498 |
+
chunk_size=2,
|
| 499 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
|
| 501 |
+
# -- Sequential baseline: split → no parallelization → base Muon ------
|
| 502 |
+
seq_model = _split_llama4(copy.deepcopy(model_orig).cuda())
|
|
|
|
|
|
|
|
|
|
| 503 |
|
| 504 |
+
seq_model, _ = apply_muon_step_moe(
|
| 505 |
+
model=seq_model,
|
| 506 |
+
parallel_dims=None,
|
| 507 |
+
grads=_stage_grads(seq_model),
|
| 508 |
+
warmup_step=-1,
|
| 509 |
+
chunk_size=-1,
|
| 510 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 511 |
|
| 512 |
# Correctness: parallel must match sequential exactly
|
| 513 |
+
assert_params_equal(par_model, seq_model, atol=0, rtol=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 514 |
|
| 515 |
set_ns_compile(True)
|
| 516 |
logger.info(
|