| import torch |
| import torch.distributed as dist |
| import torch.multiprocessing as mp |
| import os |
| import pytest |
| from megablocks.layers import MegaBlocksMoeMLPWithSharedExpert, create_shared_expert_weights |
|
|
|
|
| def run_distributed_shared_expert_test(rank, world_size): |
| os.environ["MASTER_ADDR"] = "localhost" |
| os.environ["MASTER_PORT"] = "12356" |
| os.environ["RANK"] = str(rank) |
| os.environ["WORLD_SIZE"] = str(world_size) |
|
|
| dist.init_process_group( |
| backend="gloo", |
| rank=rank, |
| world_size=world_size, |
| ) |
|
|
| model = MegaBlocksMoeMLPWithSharedExpert() |
|
|
| hidden_size = 128 |
| shared_expert_hidden_size = 192 |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| def simple_init(tensor): |
| torch.nn.init.xavier_uniform_(tensor) |
|
|
| shared_up_proj_weight, shared_down_proj_weight, shared_up_proj_bias, shared_down_proj_bias = create_shared_expert_weights( |
| hidden_size=hidden_size, |
| shared_expert_hidden_size=shared_expert_hidden_size, |
| device=torch.device(device), |
| dtype=torch.float32, |
| init_method=simple_init |
| ) |
|
|
| model.set_shared_expert_weights( |
| up_proj_weight=shared_up_proj_weight, |
| down_proj_weight=shared_down_proj_weight, |
| up_proj_bias=shared_up_proj_bias, |
| down_proj_bias=shared_down_proj_bias, |
| weighted_sum=True, |
| activation_fn=torch.nn.functional.gelu |
| ) |
|
|
| assert model.shared_up_proj_weight is not None, f"Shared up proj weight not set on rank {rank}" |
| assert model.shared_down_proj_weight is not None, f"Shared down proj weight not set on rank {rank}" |
| assert model.shared_expert_weighted_sum == True, f"Weighted sum not set correctly on rank {rank}" |
| |
| print(f"Rank {rank}: Shared expert setup test passed!") |
|
|
| dist.destroy_process_group() |
|
|
|
|
| def run_distributed_shared_expert_weighted_sum_test(rank, world_size): |
| os.environ["MASTER_ADDR"] = "localhost" |
| os.environ["MASTER_PORT"] = "12357" |
| os.environ["RANK"] = str(rank) |
| os.environ["WORLD_SIZE"] = str(world_size) |
|
|
| dist.init_process_group( |
| backend="gloo", |
| rank=rank, |
| world_size=world_size, |
| ) |
|
|
| model = MegaBlocksMoeMLPWithSharedExpert() |
|
|
| hidden_size = 64 |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| def simple_init(tensor): |
| torch.nn.init.xavier_uniform_(tensor) |
|
|
| shared_up_proj_weight, shared_down_proj_weight, _, _ = create_shared_expert_weights( |
| hidden_size=hidden_size, |
| shared_expert_hidden_size=96, |
| device=torch.device(device), |
| dtype=torch.float32, |
| init_method=simple_init |
| ) |
|
|
| model.set_shared_expert_weights( |
| up_proj_weight=shared_up_proj_weight, |
| down_proj_weight=shared_down_proj_weight, |
| weighted_sum=False, |
| activation_fn=torch.nn.functional.relu |
| ) |
|
|
| assert model.shared_up_proj_weight is not None, f"Shared up proj weight not set on rank {rank}" |
| assert model.shared_down_proj_weight is not None, f"Shared down proj weight not set on rank {rank}" |
| assert model.shared_expert_weighted_sum == False, f"Weighted sum not set correctly on rank {rank}" |
| assert model.shared_activation_fn == torch.nn.functional.relu, f"Activation function not set correctly on rank {rank}" |
| |
| print(f"Rank {rank}: Weighted sum setup test passed!") |
|
|
| dist.destroy_process_group() |
|
|
|
|
| @pytest.mark.parametrize("world_size", [1, 2, 4, 8]) |
| def test_shared_expert_distributed_functionality(world_size): |
| if world_size == 1: |
| |
| model = MegaBlocksMoeMLPWithSharedExpert() |
| |
| hidden_size = 128 |
| shared_expert_hidden_size = 192 |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| def simple_init(tensor): |
| torch.nn.init.xavier_uniform_(tensor) |
|
|
| shared_up_proj_weight, shared_down_proj_weight, shared_up_proj_bias, shared_down_proj_bias = create_shared_expert_weights( |
| hidden_size=hidden_size, |
| shared_expert_hidden_size=shared_expert_hidden_size, |
| device=torch.device(device), |
| dtype=torch.float32, |
| init_method=simple_init |
| ) |
|
|
| model.set_shared_expert_weights( |
| up_proj_weight=shared_up_proj_weight, |
| down_proj_weight=shared_down_proj_weight, |
| up_proj_bias=shared_up_proj_bias, |
| down_proj_bias=shared_down_proj_bias, |
| weighted_sum=True, |
| activation_fn=torch.nn.functional.gelu |
| ) |
|
|
| assert model.shared_up_proj_weight is not None, "Shared up proj weight not set" |
| assert model.shared_down_proj_weight is not None, "Shared down proj weight not set" |
| assert model.shared_expert_weighted_sum == True, "Weighted sum not set correctly" |
| |
| print("Single process shared expert setup test passed!") |
| else: |
| |
| mp.spawn(run_distributed_shared_expert_test, args=(world_size,), nprocs=world_size, join=True) |
| print("Multi-process shared expert test completed successfully!") |
|
|
|
|
| @pytest.mark.parametrize("world_size", [1, 2, 4, 8]) |
| def test_shared_expert_distributed_weighted_sum(world_size): |
| if world_size == 1: |
| |
| model = MegaBlocksMoeMLPWithSharedExpert() |
|
|
| hidden_size = 64 |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| def simple_init(tensor): |
| torch.nn.init.xavier_uniform_(tensor) |
|
|
| shared_up_proj_weight, shared_down_proj_weight, _, _ = create_shared_expert_weights( |
| hidden_size=hidden_size, |
| shared_expert_hidden_size=96, |
| device=torch.device(device), |
| dtype=torch.float32, |
| init_method=simple_init |
| ) |
|
|
| model.set_shared_expert_weights( |
| up_proj_weight=shared_up_proj_weight, |
| down_proj_weight=shared_down_proj_weight, |
| weighted_sum=False, |
| activation_fn=torch.nn.functional.relu |
| ) |
|
|
| assert model.shared_up_proj_weight is not None, "Shared up proj weight not set" |
| assert model.shared_down_proj_weight is not None, "Shared down proj weight not set" |
| assert model.shared_expert_weighted_sum == False, "Weighted sum not set correctly" |
| assert model.shared_activation_fn == torch.nn.functional.relu, "Activation function not set correctly" |
| |
| print("Single process weighted sum setup test passed!") |
| else: |
| |
| mp.spawn(run_distributed_shared_expert_weighted_sum_test, args=(world_size,), nprocs=world_size, join=True) |
| print("Multi-process shared expert weighted sum test completed successfully!") |
|
|
|
|
| def test_shared_expert_single_process(): |
| model = MegaBlocksMoeMLPWithSharedExpert() |
| |
| assert model.shared_up_proj_weight is None |
| assert model.shared_down_proj_weight is None |
| assert hasattr(model, 'set_shared_expert_weights') |
| |
| print("Single process shared expert basic test passed!") |
|
|
|
|
| if __name__ == "__main__": |
| test_shared_expert_single_process() |
| print("Single process test passed!") |
| |
| os.environ['WORLD_SIZE'] = '2' |
| test_shared_expert_distributed_functionality() |
| print("Distributed functionality test passed!") |
| |
| test_shared_expert_distributed_weighted_sum() |
| print("Distributed weighted sum test passed!") |