| import torch |
| import megablocks |
| from megablocks.layers import MegaBlocksMoeMLPWithSharedExpert, create_shared_expert_weights |
|
|
|
|
| def test_megablocks_moe_mlp_with_shared_expert_import(): |
| mlp = MegaBlocksMoeMLPWithSharedExpert() |
| assert hasattr(mlp, 'shared_up_proj_weight') |
| assert hasattr(mlp, 'shared_down_proj_weight') |
| assert hasattr(mlp, 'set_shared_expert_weights') |
|
|
|
|
| def test_set_shared_expert_weights(): |
| mlp = MegaBlocksMoeMLPWithSharedExpert() |
| |
| hidden_size = 128 |
| shared_expert_hidden_size = 256 |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| dtype = torch.float32 |
| |
| up_proj_weight = torch.randn(shared_expert_hidden_size, hidden_size, device=device, dtype=dtype) |
| down_proj_weight = torch.randn(hidden_size, shared_expert_hidden_size, device=device, dtype=dtype) |
| up_proj_bias = torch.randn(shared_expert_hidden_size, device=device, dtype=dtype) |
| down_proj_bias = torch.randn(hidden_size, device=device, dtype=dtype) |
| |
| mlp.set_shared_expert_weights( |
| up_proj_weight=up_proj_weight, |
| down_proj_weight=down_proj_weight, |
| up_proj_bias=up_proj_bias, |
| down_proj_bias=down_proj_bias, |
| weighted_sum=True, |
| activation_fn=torch.nn.functional.gelu |
| ) |
| |
| assert torch.equal(mlp.shared_up_proj_weight, up_proj_weight) |
| assert torch.equal(mlp.shared_down_proj_weight, down_proj_weight) |
| assert torch.equal(mlp.shared_up_proj_bias, up_proj_bias) |
| assert torch.equal(mlp.shared_down_proj_bias, down_proj_bias) |
| assert mlp.shared_expert_weighted_sum == True |
| assert mlp.shared_activation_fn == torch.nn.functional.gelu |
|
|
|
|
| def test_create_shared_expert_weights(): |
| hidden_size = 128 |
| shared_expert_hidden_size = 256 |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| dtype = torch.float32 |
| |
| def init_method(tensor): |
| torch.nn.init.xavier_uniform_(tensor) |
| |
| up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = create_shared_expert_weights( |
| hidden_size=hidden_size, |
| shared_expert_hidden_size=shared_expert_hidden_size, |
| device=device, |
| dtype=dtype, |
| init_method=init_method |
| ) |
| |
| assert up_proj_weight.shape == (shared_expert_hidden_size, hidden_size) |
| assert down_proj_weight.shape == (hidden_size, shared_expert_hidden_size) |
| assert up_proj_weight.device.type == device.type |
| assert down_proj_weight.device.type == device.type |
| assert up_proj_weight.dtype == dtype |
| assert down_proj_weight.dtype == dtype |
| assert up_proj_bias is None |
| assert down_proj_bias is None |
|
|
|
|
| def test_shared_expert_weights_none_by_default(): |
| mlp = MegaBlocksMoeMLPWithSharedExpert() |
| |
| assert mlp.shared_up_proj_weight is None |
| assert mlp.shared_down_proj_weight is None |
| assert mlp.shared_up_proj_bias is None |
| assert mlp.shared_down_proj_bias is None |
| assert mlp.shared_expert_weighted_sum == False |
| assert mlp.shared_activation_fn is None |
|
|
|
|
| def test_inheritance_from_megablocks_moe_mlp(): |
| mlp = MegaBlocksMoeMLPWithSharedExpert() |
| |
| from megablocks.layers import MegaBlocksMoeMLP |
| assert isinstance(mlp, MegaBlocksMoeMLP) |
| assert hasattr(mlp, 'forward') |
|
|
|
|
| def test_shared_expert_weights_custom_init(): |
| hidden_size = 64 |
| shared_expert_hidden_size = 128 |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| dtype = torch.float16 |
| |
| def custom_init(tensor): |
| torch.nn.init.constant_(tensor, 0.5) |
| |
| def custom_output_init(tensor): |
| torch.nn.init.constant_(tensor, 0.1) |
| |
| up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = create_shared_expert_weights( |
| hidden_size=hidden_size, |
| shared_expert_hidden_size=shared_expert_hidden_size, |
| device=device, |
| dtype=dtype, |
| init_method=custom_init, |
| output_layer_init_method=custom_output_init |
| ) |
| |
| assert torch.all(up_proj_weight == 0.5) |
| assert torch.all(down_proj_weight == 0.1) |
| assert up_proj_weight.dtype == dtype |
| assert down_proj_weight.dtype == dtype |
|
|
|
|
| def test_shared_expert_weights_dimensions(): |
| mlp = MegaBlocksMoeMLPWithSharedExpert() |
| |
| batch_size = 4 |
| seq_len = 16 |
| hidden_size = 128 |
| shared_expert_hidden_size = 256 |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
| up_proj_weight = torch.randn(shared_expert_hidden_size, hidden_size, device=device) |
| down_proj_weight = torch.randn(hidden_size, shared_expert_hidden_size, device=device) |
| |
| mlp.set_shared_expert_weights( |
| up_proj_weight=up_proj_weight, |
| down_proj_weight=down_proj_weight |
| ) |
| |
| x = torch.randn(seq_len, batch_size, hidden_size, device=device) |
| |
| expected_up_output_shape = (seq_len, batch_size, shared_expert_hidden_size) |
| expected_down_output_shape = (seq_len, batch_size, hidden_size) |
| |
| assert up_proj_weight.shape[1] == x.shape[-1] |
| assert down_proj_weight.shape[0] == x.shape[-1] |