| import logging |
|
|
| import pytest |
| import torch |
| import torch.distributed as dist |
| from packaging import version |
| from transformers import AutoModelForCausalLM |
|
|
| logger = logging.getLogger(__name__) |
| logging.basicConfig(level=logging.INFO) |
|
|
| |
| |
| |
| torch._dynamo.config.recompile_limit = 64 |
|
|
| SEED = 0xdeadbeef |
|
|
|
|
| def pytest_addoption(parser): |
| parser.addoption( |
| "--measure-perf", |
| action="store_true", |
| default=False, |
| help= |
| "Measure execution time and peak memory usage during optimizer step.", |
| ) |
|
|
| parser.addoption( |
| "--do-profile", |
| action="store_true", |
| default=False, |
| help="Enable profiling during tests.", |
| ) |
|
|
| parser.addoption( |
| "--skip-verify", |
| action="store_true", |
| default=False, |
| help= |
| "Skip verification of optimizer step correctness with sequential implementation.\n" |
| "This can be useful when GPU memory is limited.", |
| ) |
|
|
|
|
| def pytest_configure(config): |
| if config.getoption( |
| "--do-profile") and not config.getoption("--measure-perf"): |
| raise pytest.UsageError( |
| "--do-profile requires --measure-perf. Please enable both flags.") |
|
|
|
|
| @pytest.fixture(scope="session") |
| def measure_perf(request): |
| return request.config.getoption("--measure-perf") |
|
|
|
|
| @pytest.fixture(scope="session") |
| def do_profile(request): |
| return request.config.getoption("--do-profile") |
|
|
|
|
| @pytest.fixture(scope="session") |
| def skip_verify(request): |
| return request.config.getoption("--skip-verify") |
|
|
|
|
| @pytest.fixture(scope="session", autouse=True) |
| def init_dist(request): |
| if version.parse(torch.__version__) < version.parse("2.8"): |
| pytest.skip("torch>=2.8.0 is required for parallel muon") |
| return |
|
|
| try: |
| dist.init_process_group(backend="nccl") |
| torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) |
| except Exception as e: |
| print(f"Failed to initialize torch.distributed: {e}") |
| pytest.skip("Failed to initialize torch.distributed") |
|
|
| if dist.get_world_size() != 8: |
| pytest.skip("Need 8 processes in dist group. " |
| "You can run with `torchrun --nproc-per-node=8 " |
| "--local-ranks-filter 0 -m pytest " |
| "test_rms_norm_sequence_parallel.py`." |
| "To run with less than 8 gpus, modify " |
| "the test cases accordingly.") |
|
|
| yield |
| dist.destroy_process_group() |
|
|
|
|
| @pytest.fixture(scope="session") |
| def inputs(): |
| """Load Motif-2.6B model and generate random gradients for testing. |
| Returns: |
| tuple[torch.nn.Module, list[torch.Tensor], dict[int, torch.Tensor]]: |
| - torch.nn.Module: The Motif-2.6B model. |
| - list[torch.Tensor]: A list of random gradients for each model parameter. |
| - dict[int, torch.Tensor]: A dictionary mapping layer indices to random QK logits. |
| """ |
| model_name = "Motif-Technologies/Motif-2.6B-4layer-random" |
|
|
| torch.manual_seed(SEED) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(SEED) |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| trust_remote_code=True, |
| ) |
| logger.info( |
| f"Loaded model {model_name}. ({len(list(model.parameters()))} parameters)" |
| ) |
|
|
| grads: list[torch.Tensor] = [] |
| for param in model.parameters(): |
| grad = torch.randn_like(param, device=param.device, dtype=param.dtype) |
| grads.append(grad) |
|
|
| qk_logits: dict[int, torch.Tensor] = { |
| i: |
| torch.randn(model.config.num_attention_heads, |
| device=model.device, |
| dtype=torch.bfloat16) |
| for i in range(model.config.num_hidden_layers) |
| } |
|
|
| return [model, grads, qk_logits] |
|
|
|
|
| def _create_moe_model(num_experts=8, top_k=2, n_layers=4): |
| """Create a torchtitan Llama4 MoE model with random gradients.""" |
| from torchtitan.models.llama4.model.args import TransformerModelArgs |
| from torchtitan.models.llama4.model.model import Transformer |
| from torchtitan.models.moe import MoEArgs |
|
|
| torch.manual_seed(SEED) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(SEED) |
|
|
| moe_args = MoEArgs( |
| num_experts=num_experts, |
| num_shared_experts=1, |
| top_k=top_k, |
| score_func="sigmoid", |
| ) |
| model_args = TransformerModelArgs( |
| dim=2048, |
| n_layers=n_layers, |
| n_heads=16, |
| n_kv_heads=8, |
| vocab_size=32000, |
| norm_eps=1e-5, |
| rope_theta=10000, |
| max_seq_len=4096, |
| moe_args=moe_args, |
| interleave_moe_layer_step=1, |
| ) |
| model = Transformer(model_args) |
| model.init_weights() |
| logger.info(f"Created torchtitan Llama4 MoE model " |
| f"(num_experts={num_experts}, n_layers={n_layers}, " |
| f"{len(list(model.parameters()))} parameters)") |
|
|
| grads = [ |
| torch.randn_like(param, device=param.device, dtype=param.dtype) |
| for param in model.parameters() |
| ] |
|
|
| return [model, grads] |
|
|
|
|
| @pytest.fixture(scope="session") |
| def moe_inputs(): |
| """MoE model with 8 experts (standard config).""" |
| return _create_moe_model(num_experts=8, top_k=2) |
|
|
|
|
| @pytest.fixture(scope="session") |
| def moe_inputs_few_experts(): |
| """MoE model with 2 experts (triggers EFSDP Shard(1) mode).""" |
| return _create_moe_model(num_experts=2, top_k=1) |
|
|