| import torch |
| import torch.nn.functional as F |
| import torch.nn as nn |
| import time |
|
|
| from triton_moe.layers import MoE |
|
|
|
|
| |
| class OpenaiExperts(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.num_experts = config.num_local_experts |
| self.intermediate_size = config.intermediate_size |
| self.hidden_size = config.hidden_size |
| self.expert_dim = self.intermediate_size |
| self.gate_up_proj = nn.Parameter( |
| torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim) |
| ) |
| self.gate_up_proj_bias = nn.Parameter( |
| torch.empty(self.num_experts, 2 * self.expert_dim) |
| ) |
| self.down_proj = nn.Parameter( |
| torch.empty((self.num_experts, self.expert_dim, self.hidden_size)) |
| ) |
| self.down_proj_bias = nn.Parameter( |
| torch.empty(self.num_experts, self.hidden_size) |
| ) |
| self.alpha = 1.702 |
|
|
| def forward( |
| self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None |
| ) -> torch.Tensor: |
| """ |
| When training is is more efficient to just loop over the experts and compute the output for each expert |
| as otherwise the memory would explode. |
| |
| For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs. |
| |
| Args: |
| hidden_states (torch.Tensor): (batch_size * token_num, hidden_size) |
| selected_experts (torch.Tensor): (batch_size * token_num, top_k) |
| routing_weights (torch.Tensor): (batch_size * token_num, top_k) |
| Returns: |
| torch.Tensor |
| """ |
| if self.training: |
| next_states = torch.zeros_like( |
| hidden_states, dtype=hidden_states.dtype, device=hidden_states.device |
| ) |
| with torch.no_grad(): |
| expert_mask = torch.nn.functional.one_hot( |
| router_indices, num_classes=self.num_experts |
| ).permute(2, 1, 0) |
| expert_hitted = torch.greater( |
| expert_mask.sum(dim=(-1, -2)), 0 |
| ).nonzero() |
| for expert_idx in expert_hitted: |
| with torch.no_grad(): |
| idx, top_x = torch.where( |
| expert_mask[expert_idx][0] |
| ) |
| current_state = hidden_states[top_x] |
| gate_up = ( |
| current_state @ self.gate_up_proj[expert_idx] |
| + self.gate_up_proj_bias[expert_idx] |
| ) |
| gate, up = gate_up.chunk(2, dim=-1) |
| glu = gate * torch.sigmoid( |
| gate * self.alpha |
| ) |
| gated_output = (up + 1) * glu |
| out = ( |
| gated_output @ self.down_proj[expert_idx] |
| + self.down_proj_bias[expert_idx] |
| ) |
| weighted_output = ( |
| out * routing_weights[top_x, idx, None] |
| ) |
| next_states.index_add_( |
| 0, top_x, weighted_output.to(hidden_states.dtype)[0] |
| ) |
| else: |
| hidden_states = hidden_states.repeat(self.num_experts, 1) |
| hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) |
| gate_up = ( |
| torch.bmm(hidden_states, self.gate_up_proj) |
| + self.gate_up_proj_bias[..., None, :] |
| ) |
| gate, up = gate_up.chunk(2, dim=-1) |
| glu = gate * torch.sigmoid(gate * self.alpha) |
| next_states = ( |
| torch.bmm(((up + 1) * glu), self.down_proj) |
| + self.down_proj_bias[..., None, :] |
| ) |
| next_states = next_states.view(-1, self.hidden_size) |
| return next_states |
|
|
|
|
| def test_moe_forward(): |
|
|
| |
| num_experts = 128 |
| hidden_size = 1024 |
| expert_dim = 512 |
| batch_tokens = 4096 |
| topk = 2 |
|
|
| torch.manual_seed(1337) |
| torch.cuda.manual_seed(1337) |
|
|
| |
| hidden_states = torch.randn( |
| batch_tokens, |
| hidden_size, |
| device="cuda", |
| dtype=torch.float32, |
| requires_grad=True, |
| ) |
| router_idx = torch.randint(0, num_experts, (batch_tokens, topk), device="cuda") |
| router_wt = torch.rand(batch_tokens, topk, device="cuda") |
| router_wt = router_wt / router_wt.sum(dim=-1, keepdim=True) |
|
|
| |
| gate_up_proj = nn.Parameter( |
| torch.randn(num_experts, hidden_size, 2 * expert_dim, device="cuda") |
| ) |
| gate_up_proj_bias = nn.Parameter( |
| torch.randn(num_experts, 2 * expert_dim, device="cuda") |
| ) |
| down_proj = nn.Parameter( |
| torch.randn(num_experts, expert_dim, hidden_size, device="cuda") |
| ) |
| down_proj_bias = nn.Parameter(torch.randn(num_experts, hidden_size, device="cuda")) |
| alpha = 1.702 |
|
|
| |
| layer = MoE() |
| ref_layer = OpenaiExperts( |
| config=type( |
| "Config", |
| (object,), |
| { |
| "num_local_experts": num_experts, |
| "intermediate_size": expert_dim, |
| "hidden_size": hidden_size, |
| }, |
| ) |
| ) |
|
|
| ref_layer.gate_up_proj = gate_up_proj |
| ref_layer.gate_up_proj_bias = gate_up_proj_bias |
| ref_layer.down_proj = down_proj |
| ref_layer.down_proj_bias = down_proj_bias |
|
|
| with torch.no_grad(): |
| old_output = ref_layer(hidden_states, router_idx, router_wt) |
| output = layer( |
| hidden_states, |
| router_idx, |
| router_wt, |
| alpha, |
| gate_up_proj, |
| gate_up_proj_bias, |
| down_proj, |
| down_proj_bias, |
| ) |
|
|
| assert old_output.shape == output.shape, "Output shapes do not match" |
|
|
| diff = (old_output - output).abs() |
| avg_diff = diff.mean() |
|
|
| print(f"Average difference: {avg_diff.item()}") |
| |
|
|
| print(f"Max difference: {diff.max().item()}") |
| |
|
|
| |
| assert torch.allclose( |
| |
| old_output, |
| output, |
| rtol=1e-1, |
| atol=1e-1, |
| ), "Outputs do not match between the two implementations" |
|
|
|
|
| def test_moe_backward_grad(): |
| """Simple backward test comparing gradients between MoE and OpenaiExperts.""" |
|
|
| |
| num_experts = 128 |
| hidden_size = 1024 |
| expert_dim = 512 |
| batch_tokens = 4096 |
| topk = 2 |
|
|
| torch.manual_seed(1337) |
| torch.cuda.manual_seed(1337) |
|
|
| |
| hidden_states = torch.randn( |
| batch_tokens, |
| hidden_size, |
| device="cuda", |
| dtype=torch.float32, |
| requires_grad=True, |
| ) |
| router_idx = torch.randint(0, num_experts, (batch_tokens, topk), device="cuda") |
| router_wt = torch.rand(batch_tokens, topk, device="cuda") |
| router_wt = router_wt / router_wt.sum(dim=-1, keepdim=True) |
|
|
| |
| gate_up_proj = nn.Parameter( |
| torch.randn(num_experts, hidden_size, 2 * expert_dim, device="cuda") |
| ) |
| gate_up_proj_bias = nn.Parameter( |
| torch.randn(num_experts, 2 * expert_dim, device="cuda") |
| ) |
| down_proj = nn.Parameter( |
| torch.randn(num_experts, expert_dim, hidden_size, device="cuda") |
| ) |
| down_proj_bias = nn.Parameter(torch.randn(num_experts, hidden_size, device="cuda")) |
| alpha = 1.702 |
|
|
| |
| layer = MoE() |
|
|
| output = layer( |
| hidden_states, |
| router_idx, |
| router_wt, |
| alpha, |
| gate_up_proj, |
| gate_up_proj_bias, |
| down_proj, |
| down_proj_bias, |
| ) |
|
|
| |
| loss = output.sum() |
|
|
| loss.backward() |
|
|
| |
| print(f"gate_up_proj.grad exists: {gate_up_proj.grad is not None}") |
| print(f"gate_up_proj_bias.grad exists: {gate_up_proj_bias.grad is not None}") |
| print(f"down_proj.grad exists: {down_proj.grad is not None}") |
| print(f"down_proj_bias.grad exists: {down_proj_bias.grad is not None}") |
|
|
| |
| assert ( |
| gate_up_proj.grad is not None |
| ), "gate_up_proj gradient is None - custom MoE not using this parameter" |
| assert ( |
| gate_up_proj_bias.grad is not None |
| ), "gate_up_proj_bias gradient is None - custom MoE not using this parameter" |
| assert down_proj.grad is not None, "down_proj gradient is None" |
| assert down_proj_bias.grad is not None, "down_proj_bias gradient is None" |
|
|
|
|
| def test_moe_backward(): |
| """Simple backward test comparing gradients between MoE and OpenaiExperts.""" |
|
|
| |
| num_experts = 128 |
| hidden_size = 1024 |
| expert_dim = 512 |
| batch_tokens = 4096 |
| topk = 2 |
|
|
| torch.manual_seed(1337) |
| torch.cuda.manual_seed(1337) |
|
|
| |
| hidden_states = torch.randn( |
| batch_tokens, |
| hidden_size, |
| device="cuda", |
| dtype=torch.float32, |
| requires_grad=True, |
| ) |
| router_idx = torch.randint(0, num_experts, (batch_tokens, topk), device="cuda") |
| router_wt = torch.rand(batch_tokens, topk, device="cuda") |
| router_wt = router_wt / router_wt.sum(dim=-1, keepdim=True) |
|
|
| |
| gate_up_proj = nn.Parameter( |
| torch.randn(num_experts, hidden_size, 2 * expert_dim, device="cuda") |
| ) |
| gate_up_proj_bias = nn.Parameter( |
| torch.randn(num_experts, 2 * expert_dim, device="cuda") |
| ) |
| down_proj = nn.Parameter( |
| torch.randn(num_experts, expert_dim, hidden_size, device="cuda") |
| ) |
| down_proj_bias = nn.Parameter(torch.randn(num_experts, hidden_size, device="cuda")) |
| alpha = 1.702 |
|
|
| |
| layer = MoE() |
| ref_layer = OpenaiExperts( |
| config=type( |
| "Config", |
| (object,), |
| { |
| "num_local_experts": num_experts, |
| "intermediate_size": expert_dim, |
| "hidden_size": hidden_size, |
| }, |
| ) |
| ) |
|
|
| |
| ref_layer.gate_up_proj = nn.Parameter( |
| gate_up_proj.clone().detach().requires_grad_(True) |
| ) |
| ref_layer.gate_up_proj_bias = nn.Parameter( |
| gate_up_proj_bias.clone().detach().requires_grad_(True) |
| ) |
| ref_layer.down_proj = nn.Parameter(down_proj.clone().detach().requires_grad_(True)) |
| ref_layer.down_proj_bias = nn.Parameter( |
| down_proj_bias.clone().detach().requires_grad_(True) |
| ) |
|
|
| |
| ref_output = ref_layer(hidden_states, router_idx, router_wt) |
| output = layer( |
| hidden_states, |
| router_idx, |
| router_wt, |
| alpha, |
| gate_up_proj, |
| gate_up_proj_bias, |
| down_proj, |
| down_proj_bias, |
| ) |
|
|
| |
| ref_loss = ref_output.sum() |
| loss = output.sum() |
|
|
| ref_loss.backward(retain_graph=True) |
| loss.backward() |
|
|
| |
| print(f"gate_up_proj.grad exists: {gate_up_proj.grad is not None}") |
| print(f"gate_up_proj_bias.grad exists: {gate_up_proj_bias.grad is not None}") |
| print(f"down_proj.grad exists: {down_proj.grad is not None}") |
| print(f"down_proj_bias.grad exists: {down_proj_bias.grad is not None}") |
| print(f"hidden_states.grad exists: {hidden_states.grad is not None}") |
|
|
| |
| assert ( |
| gate_up_proj.grad is not None |
| ), "gate_up_proj gradient is None - custom MoE not using this parameter" |
| assert ( |
| gate_up_proj_bias.grad is not None |
| ), "gate_up_proj_bias gradient is None - custom MoE not using this parameter" |
| assert down_proj.grad is not None, "down_proj gradient is None" |
| assert down_proj_bias.grad is not None, "down_proj_bias gradient is None" |
| assert hidden_states.grad is not None, "hidden_states gradient is None" |
|
|
| print("✓ Backward test passed - all parameters have gradients") |
|
|
| |
| print("10 elements from gate_up_proj gradients:") |
| print(gate_up_proj.grad.flatten()[:10]) |
| print("10 elements from ref_layer.gate_up_proj gradients:") |
| print(ref_layer.gate_up_proj.grad.flatten()[:10]) |
|
|
| |
| assert torch.allclose( |
| ref_layer.gate_up_proj.grad, |
| gate_up_proj.grad, |
| rtol=1e-1, |
| atol=1e-1, |
| ), "gate_up_proj gradients do not match between implementations" |
| assert torch.allclose( |
| ref_layer.gate_up_proj_bias.grad, |
| gate_up_proj_bias.grad, |
| rtol=1e-1, |
| atol=1e-1, |
| ), "gate_up_proj_bias gradients do not match between implementations" |
| assert torch.allclose( |
| ref_layer.down_proj.grad, |
| down_proj.grad, |
| rtol=1e-1, |
| atol=1e-1, |
| ), "down_proj gradients do not match between implementations" |
| assert torch.allclose( |
| ref_layer.down_proj_bias.grad, |
| down_proj_bias.grad, |
| rtol=1e-1, |
| atol=1e-1, |
| ), "down_proj_bias gradients do not match between implementations" |
|
|
|
|
| def test_moe_backward_benchmark(): |
| """Benchmark backward pass performance between MoE and OpenaiExperts.""" |
|
|
| |
| num_experts = 128 |
| hidden_size = 1024 |
| expert_dim = 512 |
| batch_tokens = 4096 |
| topk = 2 |
| num_warmup = 5 |
| num_runs = 20 |
|
|
| torch.manual_seed(1337) |
| torch.cuda.manual_seed(1337) |
|
|
| |
| hidden_states = torch.randn( |
| batch_tokens, |
| hidden_size, |
| device="cuda", |
| dtype=torch.float32, |
| requires_grad=True, |
| ) |
| router_idx = torch.randint(0, num_experts, (batch_tokens, topk), device="cuda") |
| router_wt = torch.rand(batch_tokens, topk, device="cuda") |
| router_wt = router_wt / router_wt.sum(dim=-1, keepdim=True) |
|
|
| |
| gate_up_proj = nn.Parameter( |
| torch.randn(num_experts, hidden_size, 2 * expert_dim, device="cuda") |
| ) |
| gate_up_proj_bias = nn.Parameter( |
| torch.randn(num_experts, 2 * expert_dim, device="cuda") |
| ) |
| down_proj = nn.Parameter( |
| torch.randn(num_experts, expert_dim, hidden_size, device="cuda") |
| ) |
| down_proj_bias = nn.Parameter(torch.randn(num_experts, hidden_size, device="cuda")) |
| alpha = 1.702 |
|
|
| |
| layer = MoE() |
| ref_layer = OpenaiExperts( |
| config=type( |
| "Config", |
| (object,), |
| { |
| "num_local_experts": num_experts, |
| "intermediate_size": expert_dim, |
| "hidden_size": hidden_size, |
| }, |
| ) |
| ) |
|
|
| |
| ref_layer.gate_up_proj = gate_up_proj |
| ref_layer.gate_up_proj_bias = gate_up_proj_bias |
| ref_layer.down_proj = down_proj |
| ref_layer.down_proj_bias = down_proj_bias |
|
|
| def benchmark_ref_backward(): |
| """Benchmark reference implementation backward pass.""" |
| hidden_states_copy = hidden_states.clone().detach().requires_grad_(True) |
|
|
| |
| for param in [gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias]: |
| if param.grad is not None: |
| param.grad.zero_() |
|
|
| |
| output = ref_layer(hidden_states_copy, router_idx, router_wt) |
| loss = output.sum() |
| loss.backward() |
|
|
| return loss.item() |
|
|
| def benchmark_custom_backward(): |
| """Benchmark custom implementation backward pass.""" |
| hidden_states_copy = hidden_states.clone().detach().requires_grad_(True) |
|
|
| |
| for param in [gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias]: |
| if param.grad is not None: |
| param.grad.zero_() |
|
|
| |
| output = layer( |
| hidden_states_copy, |
| router_idx, |
| router_wt, |
| alpha, |
| gate_up_proj, |
| gate_up_proj_bias, |
| down_proj, |
| down_proj_bias, |
| ) |
| loss = output.sum() |
| loss.backward() |
|
|
| return loss.item() |
|
|
| |
| print("Warming up...") |
| for _ in range(num_warmup): |
| benchmark_ref_backward() |
| benchmark_custom_backward() |
|
|
| |
| torch.cuda.empty_cache() |
| torch.cuda.synchronize() |
|
|
| |
| print(f"Benchmarking reference implementation ({num_runs} runs)...") |
| ref_times = [] |
|
|
| for i in range(num_runs): |
| torch.cuda.synchronize() |
| start_time = time.perf_counter() |
|
|
| benchmark_ref_backward() |
|
|
| torch.cuda.synchronize() |
| end_time = time.perf_counter() |
|
|
| ref_times.append((end_time - start_time) * 1000) |
|
|
| if (i + 1) % 5 == 0: |
| print(f" Completed {i + 1}/{num_runs} runs") |
|
|
| |
| torch.cuda.empty_cache() |
| torch.cuda.synchronize() |
|
|
| |
| print(f"Benchmarking custom implementation ({num_runs} runs)...") |
| custom_times = [] |
|
|
| for i in range(num_runs): |
| torch.cuda.synchronize() |
| start_time = time.perf_counter() |
|
|
| benchmark_custom_backward() |
|
|
| torch.cuda.synchronize() |
| end_time = time.perf_counter() |
|
|
| custom_times.append((end_time - start_time) * 1000) |
|
|
| if (i + 1) % 5 == 0: |
| print(f" Completed {i + 1}/{num_runs} runs") |
|
|
| |
| ref_mean = sum(ref_times) / len(ref_times) |
| ref_std = (sum((t - ref_mean) ** 2 for t in ref_times) / len(ref_times)) ** 0.5 |
| ref_min = min(ref_times) |
| ref_max = max(ref_times) |
|
|
| custom_mean = sum(custom_times) / len(custom_times) |
| custom_std = ( |
| sum((t - custom_mean) ** 2 for t in custom_times) / len(custom_times) |
| ) ** 0.5 |
| custom_min = min(custom_times) |
| custom_max = max(custom_times) |
|
|
| speedup = ref_mean / custom_mean |
|
|
| |
| print("\n" + "=" * 80) |
| print("BACKWARD PASS BENCHMARK RESULTS") |
| print("=" * 80) |
| print(f"Configuration:") |
| print(f" - Experts: {num_experts}") |
| print(f" - Hidden size: {hidden_size}") |
| print(f" - Expert dim: {expert_dim}") |
| print(f" - Batch tokens: {batch_tokens}") |
| print(f" - Top-k: {topk}") |
| print(f" - Runs: {num_runs}") |
| print() |
|
|
| print(f"Reference Implementation (OpenaiExperts):") |
| print(f" - Mean: {ref_mean:.3f} ms") |
| print(f" - Std: {ref_std:.3f} ms") |
| print(f" - Min: {ref_min:.3f} ms") |
| print(f" - Max: {ref_max:.3f} ms") |
| print() |
|
|
| print(f"Custom Implementation (MoE):") |
| print(f" - Mean: {custom_mean:.3f} ms") |
| print(f" - Std: {custom_std:.3f} ms") |
| print(f" - Min: {custom_min:.3f} ms") |
| print(f" - Max: {custom_max:.3f} ms") |
| print() |
|
|
| print(f"Speedup: {speedup:.2f}x") |
|
|
| if speedup > 1.0: |
| print(f"✓ Custom implementation is {speedup:.2f}x faster") |
| else: |
| print(f"✗ Custom implementation is {1/speedup:.2f}x slower") |
|
|
| print("=" * 80) |
|
|
| |
| print(f"\nDetailed timings (ms):") |
| print(f"Reference: {ref_times}") |
| print(f"Custom: {custom_times}") |
|
|
|
|
| def test_moe_backward_benchmark_memory(): |
| """Benchmark memory usage during backward pass.""" |
|
|
| |
| num_experts = 128 |
| hidden_size = 1024 |
| expert_dim = 512 |
| batch_tokens = 4096 |
| topk = 2 |
|
|
| torch.manual_seed(1337) |
| torch.cuda.manual_seed(1337) |
|
|
| |
| hidden_states = torch.randn( |
| batch_tokens, |
| hidden_size, |
| device="cuda", |
| dtype=torch.float32, |
| requires_grad=True, |
| ) |
| router_idx = torch.randint(0, num_experts, (batch_tokens, topk), device="cuda") |
| router_wt = torch.rand(batch_tokens, topk, device="cuda") |
| router_wt = router_wt / router_wt.sum(dim=-1, keepdim=True) |
|
|
| |
| gate_up_proj = nn.Parameter( |
| torch.randn(num_experts, hidden_size, 2 * expert_dim, device="cuda") |
| ) |
| gate_up_proj_bias = nn.Parameter( |
| torch.randn(num_experts, 2 * expert_dim, device="cuda") |
| ) |
| down_proj = nn.Parameter( |
| torch.randn(num_experts, expert_dim, hidden_size, device="cuda") |
| ) |
| down_proj_bias = nn.Parameter(torch.randn(num_experts, hidden_size, device="cuda")) |
| alpha = 1.702 |
|
|
| |
| layer = MoE() |
| ref_layer = OpenaiExperts( |
| config=type( |
| "Config", |
| (object,), |
| { |
| "num_local_experts": num_experts, |
| "intermediate_size": expert_dim, |
| "hidden_size": hidden_size, |
| }, |
| ) |
| ) |
|
|
| |
| ref_layer.gate_up_proj = gate_up_proj |
| ref_layer.gate_up_proj_bias = gate_up_proj_bias |
| ref_layer.down_proj = down_proj |
| ref_layer.down_proj_bias = down_proj_bias |
|
|
| |
| torch.cuda.empty_cache() |
| torch.cuda.reset_peak_memory_stats() |
|
|
| hidden_states_copy = hidden_states.clone().detach().requires_grad_(True) |
| output = ref_layer(hidden_states_copy, router_idx, router_wt) |
| loss = output.sum() |
| loss.backward() |
|
|
| ref_memory = torch.cuda.max_memory_allocated() / 1024**3 |
|
|
| |
| torch.cuda.empty_cache() |
| torch.cuda.reset_peak_memory_stats() |
|
|
| hidden_states_copy = hidden_states.clone().detach().requires_grad_(True) |
| output = layer( |
| hidden_states_copy, |
| router_idx, |
| router_wt, |
| alpha, |
| gate_up_proj, |
| gate_up_proj_bias, |
| down_proj, |
| down_proj_bias, |
| ) |
| loss = output.sum() |
| loss.backward() |
|
|
| custom_memory = torch.cuda.max_memory_allocated() / 1024**3 |
|
|
| print("\n" + "=" * 60) |
| print("MEMORY USAGE BENCHMARK") |
| print("=" * 60) |
| print(f"Reference implementation: {ref_memory:.3f} GB") |
| print(f"Custom implementation: {custom_memory:.3f} GB") |
| print(f"Memory ratio: {custom_memory/ref_memory:.3f}x") |
|
|
| if custom_memory < ref_memory: |
| print(f"✓ Custom uses {(1 - custom_memory/ref_memory)*100:.1f}% less memory") |
| else: |
| print(f"✗ Custom uses {(custom_memory/ref_memory - 1)*100:.1f}% more memory") |
| print("=" * 60) |
|
|
|
|
| |
|
|
|
|
| def test_moe_forward_benchmark(): |
| """Benchmark forward pass performance between MoE and OpenaiExperts.""" |
|
|
| |
| num_experts = 128 |
| hidden_size = 1024 |
| expert_dim = 512 |
| batch_tokens = 4096 |
| topk = 2 |
| num_warmup = 5 |
| num_runs = 50 |
|
|
| torch.manual_seed(1337) |
| torch.cuda.manual_seed(1337) |
|
|
| |
| hidden_states = torch.randn( |
| batch_tokens, |
| hidden_size, |
| device="cuda", |
| dtype=torch.float32, |
| ) |
| router_idx = torch.randint(0, num_experts, (batch_tokens, topk), device="cuda") |
| router_wt = torch.rand(batch_tokens, topk, device="cuda") |
| router_wt = router_wt / router_wt.sum(dim=-1, keepdim=True) |
|
|
| |
| gate_up_proj = nn.Parameter( |
| torch.randn(num_experts, hidden_size, 2 * expert_dim, device="cuda") |
| ) |
| gate_up_proj_bias = nn.Parameter( |
| torch.randn(num_experts, 2 * expert_dim, device="cuda") |
| ) |
| down_proj = nn.Parameter( |
| torch.randn(num_experts, expert_dim, hidden_size, device="cuda") |
| ) |
| down_proj_bias = nn.Parameter(torch.randn(num_experts, hidden_size, device="cuda")) |
| alpha = 1.702 |
|
|
| |
| layer = MoE() |
| ref_layer = OpenaiExperts( |
| config=type( |
| "Config", |
| (object,), |
| { |
| "num_local_experts": num_experts, |
| "intermediate_size": expert_dim, |
| "hidden_size": hidden_size, |
| }, |
| ) |
| ) |
|
|
| |
| ref_layer.gate_up_proj = gate_up_proj |
| ref_layer.gate_up_proj_bias = gate_up_proj_bias |
| ref_layer.down_proj = down_proj |
| ref_layer.down_proj_bias = down_proj_bias |
|
|
| def benchmark_ref_forward(): |
| """Benchmark reference implementation forward pass.""" |
| with torch.no_grad(): |
| output = ref_layer(hidden_states, router_idx, router_wt) |
| return output |
|
|
| def benchmark_custom_forward(): |
| """Benchmark custom implementation forward pass.""" |
| with torch.no_grad(): |
| output = layer( |
| hidden_states, |
| router_idx, |
| router_wt, |
| alpha, |
| gate_up_proj, |
| gate_up_proj_bias, |
| down_proj, |
| down_proj_bias, |
| ) |
| return output |
|
|
| |
| print("Warming up...") |
| for _ in range(num_warmup): |
| benchmark_ref_forward() |
| benchmark_custom_forward() |
|
|
| |
| torch.cuda.empty_cache() |
| torch.cuda.synchronize() |
|
|
| |
| print(f"Benchmarking reference implementation ({num_runs} runs)...") |
| ref_times = [] |
|
|
| for i in range(num_runs): |
| torch.cuda.synchronize() |
| start_time = time.perf_counter() |
|
|
| benchmark_ref_forward() |
|
|
| torch.cuda.synchronize() |
| end_time = time.perf_counter() |
|
|
| ref_times.append((end_time - start_time) * 1000) |
|
|
| if (i + 1) % 10 == 0: |
| print(f" Completed {i + 1}/{num_runs} runs") |
|
|
| |
| torch.cuda.empty_cache() |
| torch.cuda.synchronize() |
|
|
| |
| print(f"Benchmarking custom implementation ({num_runs} runs)...") |
| custom_times = [] |
|
|
| for i in range(num_runs): |
| torch.cuda.synchronize() |
| start_time = time.perf_counter() |
|
|
| benchmark_custom_forward() |
|
|
| torch.cuda.synchronize() |
| end_time = time.perf_counter() |
|
|
| custom_times.append((end_time - start_time) * 1000) |
|
|
| if (i + 1) % 10 == 0: |
| print(f" Completed {i + 1}/{num_runs} runs") |
|
|
| |
| ref_mean = sum(ref_times) / len(ref_times) |
| ref_std = (sum((t - ref_mean) ** 2 for t in ref_times) / len(ref_times)) ** 0.5 |
| ref_min = min(ref_times) |
| ref_max = max(ref_times) |
|
|
| custom_mean = sum(custom_times) / len(custom_times) |
| custom_std = ( |
| sum((t - custom_mean) ** 2 for t in custom_times) / len(custom_times) |
| ) ** 0.5 |
| custom_min = min(custom_times) |
| custom_max = max(custom_times) |
|
|
| speedup = ref_mean / custom_mean |
|
|
| |
| print("\n" + "=" * 80) |
| print("FORWARD PASS BENCHMARK RESULTS") |
| print("=" * 80) |
| print(f"Configuration:") |
| print(f" - Experts: {num_experts}") |
| print(f" - Hidden size: {hidden_size}") |
| print(f" - Expert dim: {expert_dim}") |
| print(f" - Batch tokens: {batch_tokens}") |
| print(f" - Top-k: {topk}") |
| print(f" - Runs: {num_runs}") |
| print() |
|
|
| print(f"Reference Implementation (OpenaiExperts):") |
| print(f" - Mean: {ref_mean:.3f} ms") |
| print(f" - Std: {ref_std:.3f} ms") |
| print(f" - Min: {ref_min:.3f} ms") |
| print(f" - Max: {ref_max:.3f} ms") |
| print() |
|
|
| print(f"Custom Implementation (MoE):") |
| print(f" - Mean: {custom_mean:.3f} ms") |
| print(f" - Std: {custom_std:.3f} ms") |
| print(f" - Min: {custom_min:.3f} ms") |
| print(f" - Max: {custom_max:.3f} ms") |
| print() |
|
|
| print(f"Speedup: {speedup:.2f}x") |
|
|
| if speedup > 1.0: |
| print(f"✓ Custom implementation is {speedup:.2f}x faster") |
| else: |
| print(f"✗ Custom implementation is {1/speedup:.2f}x slower") |
|
|
| print("=" * 80) |
|
|
| |
| print(f"\nDetailed timings (ms):") |
| print(f"Reference: {ref_times}") |
| print(f"Custom: {custom_times}") |
|
|
|
|
| def test_moe_forward_benchmark_memory(): |
| """Benchmark memory usage during forward pass.""" |
|
|
| |
| num_experts = 128 |
| hidden_size = 1024 |
| expert_dim = 512 |
| batch_tokens = 4096 |
| topk = 2 |
|
|
| torch.manual_seed(1337) |
| torch.cuda.manual_seed(1337) |
|
|
| |
| hidden_states = torch.randn( |
| batch_tokens, |
| hidden_size, |
| device="cuda", |
| dtype=torch.float32, |
| ) |
| router_idx = torch.randint(0, num_experts, (batch_tokens, topk), device="cuda") |
| router_wt = torch.rand(batch_tokens, topk, device="cuda") |
| router_wt = router_wt / router_wt.sum(dim=-1, keepdim=True) |
|
|
| |
| gate_up_proj = nn.Parameter( |
| torch.randn(num_experts, hidden_size, 2 * expert_dim, device="cuda") |
| ) |
| gate_up_proj_bias = nn.Parameter( |
| torch.randn(num_experts, 2 * expert_dim, device="cuda") |
| ) |
| down_proj = nn.Parameter( |
| torch.randn(num_experts, expert_dim, hidden_size, device="cuda") |
| ) |
| down_proj_bias = nn.Parameter(torch.randn(num_experts, hidden_size, device="cuda")) |
| alpha = 1.702 |
|
|
| |
| layer = MoE() |
| ref_layer = OpenaiExperts( |
| config=type( |
| "Config", |
| (object,), |
| { |
| "num_local_experts": num_experts, |
| "intermediate_size": expert_dim, |
| "hidden_size": hidden_size, |
| }, |
| ) |
| ) |
|
|
| |
| ref_layer.gate_up_proj = gate_up_proj |
| ref_layer.gate_up_proj_bias = gate_up_proj_bias |
| ref_layer.down_proj = down_proj |
| ref_layer.down_proj_bias = down_proj_bias |
|
|
| |
| torch.cuda.empty_cache() |
| torch.cuda.reset_peak_memory_stats() |
|
|
| with torch.no_grad(): |
| output = ref_layer(hidden_states, router_idx, router_wt) |
|
|
| ref_memory = torch.cuda.max_memory_allocated() / 1024**3 |
|
|
| |
| torch.cuda.empty_cache() |
| torch.cuda.reset_peak_memory_stats() |
|
|
| with torch.no_grad(): |
| output = layer( |
| hidden_states, |
| router_idx, |
| router_wt, |
| alpha, |
| gate_up_proj, |
| gate_up_proj_bias, |
| down_proj, |
| down_proj_bias, |
| ) |
|
|
| custom_memory = torch.cuda.max_memory_allocated() / 1024**3 |
|
|
| print("\n" + "=" * 60) |
| print("FORWARD MEMORY USAGE BENCHMARK") |
| print("=" * 60) |
| print(f"Reference implementation: {ref_memory:.3f} GB") |
| print(f"Custom implementation: {custom_memory:.3f} GB") |
| print(f"Memory ratio: {custom_memory/ref_memory:.3f}x") |
|
|
| if custom_memory < ref_memory: |
| print(f"✓ Custom uses {(1 - custom_memory/ref_memory)*100:.1f}% less memory") |
| else: |
| print(f"✗ Custom uses {(custom_memory/ref_memory - 1)*100:.1f}% more memory") |
| print("=" * 60) |
|
|
|
|
| def test_moe_forward_benchmark_throughput(): |
| """Benchmark throughput (tokens/second) for forward pass.""" |
|
|
| |
| num_experts = 128 |
| hidden_size = 1024 |
| expert_dim = 512 |
| batch_tokens = 4096 |
| topk = 2 |
| num_runs = 100 |
|
|
| torch.manual_seed(1337) |
| torch.cuda.manual_seed(1337) |
|
|
| |
| hidden_states = torch.randn( |
| batch_tokens, |
| hidden_size, |
| device="cuda", |
| dtype=torch.float32, |
| ) |
| router_idx = torch.randint(0, num_experts, (batch_tokens, topk), device="cuda") |
| router_wt = torch.rand(batch_tokens, topk, device="cuda") |
| router_wt = router_wt / router_wt.sum(dim=-1, keepdim=True) |
|
|
| |
| gate_up_proj = nn.Parameter( |
| torch.randn(num_experts, hidden_size, 2 * expert_dim, device="cuda") |
| ) |
| gate_up_proj_bias = nn.Parameter( |
| torch.randn(num_experts, 2 * expert_dim, device="cuda") |
| ) |
| down_proj = nn.Parameter( |
| torch.randn(num_experts, expert_dim, hidden_size, device="cuda") |
| ) |
| down_proj_bias = nn.Parameter(torch.randn(num_experts, hidden_size, device="cuda")) |
| alpha = 1.702 |
|
|
| |
| layer = MoE() |
| ref_layer = OpenaiExperts( |
| config=type( |
| "Config", |
| (object,), |
| { |
| "num_local_experts": num_experts, |
| "intermediate_size": expert_dim, |
| "hidden_size": hidden_size, |
| }, |
| ) |
| ) |
|
|
| |
| ref_layer.gate_up_proj = gate_up_proj |
| ref_layer.gate_up_proj_bias = gate_up_proj_bias |
| ref_layer.down_proj = down_proj |
| ref_layer.down_proj_bias = down_proj_bias |
|
|
| |
| print("Warming up for throughput test...") |
| for _ in range(10): |
| with torch.no_grad(): |
| ref_layer(hidden_states, router_idx, router_wt) |
| layer( |
| hidden_states, |
| router_idx, |
| router_wt, |
| alpha, |
| gate_up_proj, |
| gate_up_proj_bias, |
| down_proj, |
| down_proj_bias, |
| ) |
|
|
| |
| torch.cuda.synchronize() |
| start_time = time.perf_counter() |
|
|
| for _ in range(num_runs): |
| with torch.no_grad(): |
| ref_layer(hidden_states, router_idx, router_wt) |
|
|
| torch.cuda.synchronize() |
| ref_total_time = time.perf_counter() - start_time |
|
|
| |
| torch.cuda.synchronize() |
| start_time = time.perf_counter() |
|
|
| for _ in range(num_runs): |
| with torch.no_grad(): |
| layer( |
| hidden_states, |
| router_idx, |
| router_wt, |
| alpha, |
| gate_up_proj, |
| gate_up_proj_bias, |
| down_proj, |
| down_proj_bias, |
| ) |
|
|
| torch.cuda.synchronize() |
| custom_total_time = time.perf_counter() - start_time |
|
|
| |
| total_tokens_processed = batch_tokens * num_runs |
| ref_throughput = total_tokens_processed / ref_total_time |
| custom_throughput = total_tokens_processed / custom_total_time |
|
|
| print("\n" + "=" * 70) |
| print("FORWARD THROUGHPUT BENCHMARK") |
| print("=" * 70) |
| print( |
| f"Configuration: {batch_tokens} tokens/batch × {num_runs} runs = {total_tokens_processed:,} tokens" |
| ) |
| print() |
| print(f"Reference Implementation:") |
| print(f" - Total time: {ref_total_time:.3f} seconds") |
| print(f" - Throughput: {ref_throughput:,.0f} tokens/second") |
| print() |
| print(f"Custom Implementation:") |
| print(f" - Total time: {custom_total_time:.3f} seconds") |
| print(f" - Throughput: {custom_throughput:,.0f} tokens/second") |
| print() |
| print(f"Throughput improvement: {custom_throughput/ref_throughput:.2f}x") |
|
|
| if custom_throughput > ref_throughput: |
| print( |
| f"✓ Custom processes {((custom_throughput/ref_throughput - 1)*100):.1f}% more tokens/second" |
| ) |
| else: |
| print( |
| f"✗ Custom processes {((1 - custom_throughput/ref_throughput)*100):.1f}% fewer tokens/second" |
| ) |
| print("=" * 70) |
|
|