triton-moe / tests /test_triton_moe.py
drbh
feat: fix backward op
399fb8c
import torch
import torch.nn.functional as F
import torch.nn as nn
import time
from triton_moe.layers import MoE
# Copied from transformers.models.openai.modeling_openai
class OpenaiExperts(nn.Module):
def __init__(self, config):
super().__init__()
self.num_experts = config.num_local_experts
self.intermediate_size = config.intermediate_size
self.hidden_size = config.hidden_size
self.expert_dim = self.intermediate_size
self.gate_up_proj = nn.Parameter(
torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)
)
self.gate_up_proj_bias = nn.Parameter(
torch.empty(self.num_experts, 2 * self.expert_dim)
)
self.down_proj = nn.Parameter(
torch.empty((self.num_experts, self.expert_dim, self.hidden_size))
)
self.down_proj_bias = nn.Parameter(
torch.empty(self.num_experts, self.hidden_size)
)
self.alpha = 1.702
def forward(
self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None
) -> torch.Tensor:
"""
When training is is more efficient to just loop over the experts and compute the output for each expert
as otherwise the memory would explode.
For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
Args:
hidden_states (torch.Tensor): (batch_size * token_num, hidden_size)
selected_experts (torch.Tensor): (batch_size * token_num, top_k)
routing_weights (torch.Tensor): (batch_size * token_num, top_k)
Returns:
torch.Tensor
"""
if self.training:
next_states = torch.zeros_like(
hidden_states, dtype=hidden_states.dtype, device=hidden_states.device
)
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(
router_indices, num_classes=self.num_experts
).permute(2, 1, 0)
expert_hitted = torch.greater(
expert_mask.sum(dim=(-1, -2)), 0
).nonzero()
for expert_idx in expert_hitted:
with torch.no_grad():
idx, top_x = torch.where(
expert_mask[expert_idx][0]
) # idx: top-1/top-2 indicator, top_x: token indices
current_state = hidden_states[top_x] # (num_tokens, hidden_dim)
gate_up = (
current_state @ self.gate_up_proj[expert_idx]
+ self.gate_up_proj_bias[expert_idx]
) # (num_tokens, 2 * interm_dim)
gate, up = gate_up.chunk(2, dim=-1) # (num_tokens, interm_dim)
glu = gate * torch.sigmoid(
gate * self.alpha
) # (num_tokens, interm_dim)
gated_output = (up + 1) * glu # (num_tokens, interm_dim)
out = (
gated_output @ self.down_proj[expert_idx]
+ self.down_proj_bias[expert_idx]
) # (num_tokens, hidden_dim)
weighted_output = (
out * routing_weights[top_x, idx, None]
) # (num_tokens, hidden_dim)
next_states.index_add_(
0, top_x, weighted_output.to(hidden_states.dtype)[0]
)
else:
hidden_states = hidden_states.repeat(self.num_experts, 1)
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
gate_up = (
torch.bmm(hidden_states, self.gate_up_proj)
+ self.gate_up_proj_bias[..., None, :]
)
gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors
glu = gate * torch.sigmoid(gate * self.alpha)
next_states = (
torch.bmm(((up + 1) * glu), self.down_proj)
+ self.down_proj_bias[..., None, :]
)
next_states = next_states.view(-1, self.hidden_size)
return next_states
def test_moe_forward():
# Test configuration
num_experts = 128
hidden_size = 1024
expert_dim = 512
batch_tokens = 4096
topk = 2
torch.manual_seed(1337)
torch.cuda.manual_seed(1337)
# Generate test data
hidden_states = torch.randn(
batch_tokens,
hidden_size,
device="cuda",
dtype=torch.float32,
requires_grad=True,
)
router_idx = torch.randint(0, num_experts, (batch_tokens, topk), device="cuda")
router_wt = torch.rand(batch_tokens, topk, device="cuda")
router_wt = router_wt / router_wt.sum(dim=-1, keepdim=True) # Normalize
# Initialize parameters
gate_up_proj = nn.Parameter(
torch.randn(num_experts, hidden_size, 2 * expert_dim, device="cuda")
)
gate_up_proj_bias = nn.Parameter(
torch.randn(num_experts, 2 * expert_dim, device="cuda")
)
down_proj = nn.Parameter(
torch.randn(num_experts, expert_dim, hidden_size, device="cuda")
)
down_proj_bias = nn.Parameter(torch.randn(num_experts, hidden_size, device="cuda"))
alpha = 1.702
# Create Layers
layer = MoE()
ref_layer = OpenaiExperts(
config=type(
"Config",
(object,),
{
"num_local_experts": num_experts,
"intermediate_size": expert_dim,
"hidden_size": hidden_size,
},
)
)
ref_layer.gate_up_proj = gate_up_proj
ref_layer.gate_up_proj_bias = gate_up_proj_bias
ref_layer.down_proj = down_proj
ref_layer.down_proj_bias = down_proj_bias
with torch.no_grad():
old_output = ref_layer(hidden_states, router_idx, router_wt)
output = layer(
hidden_states,
router_idx,
router_wt,
alpha,
gate_up_proj,
gate_up_proj_bias,
down_proj,
down_proj_bias,
)
assert old_output.shape == output.shape, "Output shapes do not match"
diff = (old_output - output).abs()
avg_diff = diff.mean()
print(f"Average difference: {avg_diff.item()}")
# Average difference: 0.009219333529472351
print(f"Max difference: {diff.max().item()}")
# Max difference: 0.09375
# TODO: Improve the precision
assert torch.allclose(
# old_output, output, rtol=1e-3, atol=1e-3
old_output,
output,
rtol=1e-1,
atol=1e-1,
), "Outputs do not match between the two implementations"
def test_moe_backward_grad():
"""Simple backward test comparing gradients between MoE and OpenaiExperts."""
# Test configuration
num_experts = 128
hidden_size = 1024
expert_dim = 512
batch_tokens = 4096
topk = 2
torch.manual_seed(1337)
torch.cuda.manual_seed(1337)
# Generate test data
hidden_states = torch.randn(
batch_tokens,
hidden_size,
device="cuda",
dtype=torch.float32,
requires_grad=True,
)
router_idx = torch.randint(0, num_experts, (batch_tokens, topk), device="cuda")
router_wt = torch.rand(batch_tokens, topk, device="cuda")
router_wt = router_wt / router_wt.sum(dim=-1, keepdim=True) # Normalize
# Initialize parameters (shared between both implementations)
gate_up_proj = nn.Parameter(
torch.randn(num_experts, hidden_size, 2 * expert_dim, device="cuda")
)
gate_up_proj_bias = nn.Parameter(
torch.randn(num_experts, 2 * expert_dim, device="cuda")
)
down_proj = nn.Parameter(
torch.randn(num_experts, expert_dim, hidden_size, device="cuda")
)
down_proj_bias = nn.Parameter(torch.randn(num_experts, hidden_size, device="cuda"))
alpha = 1.702
# Create Layers
layer = MoE()
output = layer(
hidden_states,
router_idx,
router_wt,
alpha,
gate_up_proj,
gate_up_proj_bias,
down_proj,
down_proj_bias,
)
# Backward pass
loss = output.sum()
loss.backward()
# Check that gradients exist
print(f"gate_up_proj.grad exists: {gate_up_proj.grad is not None}")
print(f"gate_up_proj_bias.grad exists: {gate_up_proj_bias.grad is not None}")
print(f"down_proj.grad exists: {down_proj.grad is not None}")
print(f"down_proj_bias.grad exists: {down_proj_bias.grad is not None}")
# Simple check: if gradients exist, the backward pass worked
assert (
gate_up_proj.grad is not None
), "gate_up_proj gradient is None - custom MoE not using this parameter"
assert (
gate_up_proj_bias.grad is not None
), "gate_up_proj_bias gradient is None - custom MoE not using this parameter"
assert down_proj.grad is not None, "down_proj gradient is None"
assert down_proj_bias.grad is not None, "down_proj_bias gradient is None"
def test_moe_backward():
"""Simple backward test comparing gradients between MoE and OpenaiExperts."""
# Test configuration
num_experts = 128
hidden_size = 1024
expert_dim = 512
batch_tokens = 4096
topk = 2
torch.manual_seed(1337)
torch.cuda.manual_seed(1337)
# Generate test data
hidden_states = torch.randn(
batch_tokens,
hidden_size,
device="cuda",
dtype=torch.float32,
requires_grad=True,
)
router_idx = torch.randint(0, num_experts, (batch_tokens, topk), device="cuda")
router_wt = torch.rand(batch_tokens, topk, device="cuda")
router_wt = router_wt / router_wt.sum(dim=-1, keepdim=True) # Normalize
# Initialize parameters (shared between both implementations)
gate_up_proj = nn.Parameter(
torch.randn(num_experts, hidden_size, 2 * expert_dim, device="cuda")
)
gate_up_proj_bias = nn.Parameter(
torch.randn(num_experts, 2 * expert_dim, device="cuda")
)
down_proj = nn.Parameter(
torch.randn(num_experts, expert_dim, hidden_size, device="cuda")
)
down_proj_bias = nn.Parameter(torch.randn(num_experts, hidden_size, device="cuda"))
alpha = 1.702
# Create Layers
layer = MoE()
ref_layer = OpenaiExperts(
config=type(
"Config",
(object,),
{
"num_local_experts": num_experts,
"intermediate_size": expert_dim,
"hidden_size": hidden_size,
},
)
)
# Set reference layer parameters and clone to avoid mutating the original parameters
ref_layer.gate_up_proj = nn.Parameter(
gate_up_proj.clone().detach().requires_grad_(True)
)
ref_layer.gate_up_proj_bias = nn.Parameter(
gate_up_proj_bias.clone().detach().requires_grad_(True)
)
ref_layer.down_proj = nn.Parameter(down_proj.clone().detach().requires_grad_(True))
ref_layer.down_proj_bias = nn.Parameter(
down_proj_bias.clone().detach().requires_grad_(True)
)
# Forward pass
ref_output = ref_layer(hidden_states, router_idx, router_wt)
output = layer(
hidden_states,
router_idx,
router_wt,
alpha,
gate_up_proj,
gate_up_proj_bias,
down_proj,
down_proj_bias,
)
# Backward pass
ref_loss = ref_output.sum()
loss = output.sum()
ref_loss.backward(retain_graph=True)
loss.backward()
# Check that gradients exist
print(f"gate_up_proj.grad exists: {gate_up_proj.grad is not None}")
print(f"gate_up_proj_bias.grad exists: {gate_up_proj_bias.grad is not None}")
print(f"down_proj.grad exists: {down_proj.grad is not None}")
print(f"down_proj_bias.grad exists: {down_proj_bias.grad is not None}")
print(f"hidden_states.grad exists: {hidden_states.grad is not None}")
# Simple check: if gradients exist, the backward pass worked
assert (
gate_up_proj.grad is not None
), "gate_up_proj gradient is None - custom MoE not using this parameter"
assert (
gate_up_proj_bias.grad is not None
), "gate_up_proj_bias gradient is None - custom MoE not using this parameter"
assert down_proj.grad is not None, "down_proj gradient is None"
assert down_proj_bias.grad is not None, "down_proj_bias gradient is None"
assert hidden_states.grad is not None, "hidden_states gradient is None"
print("✓ Backward test passed - all parameters have gradients")
# Sanity check on gate up gradients
print("10 elements from gate_up_proj gradients:")
print(gate_up_proj.grad.flatten()[:10])
print("10 elements from ref_layer.gate_up_proj gradients:")
print(ref_layer.gate_up_proj.grad.flatten()[:10])
# Compare the values and ensure they are close enough
assert torch.allclose(
ref_layer.gate_up_proj.grad,
gate_up_proj.grad,
rtol=1e-1,
atol=1e-1,
), "gate_up_proj gradients do not match between implementations"
assert torch.allclose(
ref_layer.gate_up_proj_bias.grad,
gate_up_proj_bias.grad,
rtol=1e-1,
atol=1e-1,
), "gate_up_proj_bias gradients do not match between implementations"
assert torch.allclose(
ref_layer.down_proj.grad,
down_proj.grad,
rtol=1e-1,
atol=1e-1,
), "down_proj gradients do not match between implementations"
assert torch.allclose(
ref_layer.down_proj_bias.grad,
down_proj_bias.grad,
rtol=1e-1,
atol=1e-1,
), "down_proj_bias gradients do not match between implementations"
def test_moe_backward_benchmark():
"""Benchmark backward pass performance between MoE and OpenaiExperts."""
# Test configuration
num_experts = 128
hidden_size = 1024
expert_dim = 512
batch_tokens = 4096
topk = 2
num_warmup = 5
num_runs = 20
torch.manual_seed(1337)
torch.cuda.manual_seed(1337)
# Generate test data
hidden_states = torch.randn(
batch_tokens,
hidden_size,
device="cuda",
dtype=torch.float32,
requires_grad=True,
)
router_idx = torch.randint(0, num_experts, (batch_tokens, topk), device="cuda")
router_wt = torch.rand(batch_tokens, topk, device="cuda")
router_wt = router_wt / router_wt.sum(dim=-1, keepdim=True) # Normalize
# Initialize parameters
gate_up_proj = nn.Parameter(
torch.randn(num_experts, hidden_size, 2 * expert_dim, device="cuda")
)
gate_up_proj_bias = nn.Parameter(
torch.randn(num_experts, 2 * expert_dim, device="cuda")
)
down_proj = nn.Parameter(
torch.randn(num_experts, expert_dim, hidden_size, device="cuda")
)
down_proj_bias = nn.Parameter(torch.randn(num_experts, hidden_size, device="cuda"))
alpha = 1.702
# Create Layers
layer = MoE()
ref_layer = OpenaiExperts(
config=type(
"Config",
(object,),
{
"num_local_experts": num_experts,
"intermediate_size": expert_dim,
"hidden_size": hidden_size,
},
)
)
# Set reference layer parameters
ref_layer.gate_up_proj = gate_up_proj
ref_layer.gate_up_proj_bias = gate_up_proj_bias
ref_layer.down_proj = down_proj
ref_layer.down_proj_bias = down_proj_bias
def benchmark_ref_backward():
"""Benchmark reference implementation backward pass."""
hidden_states_copy = hidden_states.clone().detach().requires_grad_(True)
# Clear gradients
for param in [gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias]:
if param.grad is not None:
param.grad.zero_()
# Forward + Backward
output = ref_layer(hidden_states_copy, router_idx, router_wt)
loss = output.sum()
loss.backward()
return loss.item()
def benchmark_custom_backward():
"""Benchmark custom implementation backward pass."""
hidden_states_copy = hidden_states.clone().detach().requires_grad_(True)
# Clear gradients
for param in [gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias]:
if param.grad is not None:
param.grad.zero_()
# Forward + Backward
output = layer(
hidden_states_copy,
router_idx,
router_wt,
alpha,
gate_up_proj,
gate_up_proj_bias,
down_proj,
down_proj_bias,
)
loss = output.sum()
loss.backward()
return loss.item()
# Warmup
print("Warming up...")
for _ in range(num_warmup):
benchmark_ref_backward()
benchmark_custom_backward()
# Clear cache and synchronize
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Benchmark reference implementation
print(f"Benchmarking reference implementation ({num_runs} runs)...")
ref_times = []
for i in range(num_runs):
torch.cuda.synchronize()
start_time = time.perf_counter()
benchmark_ref_backward()
torch.cuda.synchronize()
end_time = time.perf_counter()
ref_times.append((end_time - start_time) * 1000) # Convert to ms
if (i + 1) % 5 == 0:
print(f" Completed {i + 1}/{num_runs} runs")
# Clear cache and synchronize
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Benchmark custom implementation
print(f"Benchmarking custom implementation ({num_runs} runs)...")
custom_times = []
for i in range(num_runs):
torch.cuda.synchronize()
start_time = time.perf_counter()
benchmark_custom_backward()
torch.cuda.synchronize()
end_time = time.perf_counter()
custom_times.append((end_time - start_time) * 1000) # Convert to ms
if (i + 1) % 5 == 0:
print(f" Completed {i + 1}/{num_runs} runs")
# Calculate statistics
ref_mean = sum(ref_times) / len(ref_times)
ref_std = (sum((t - ref_mean) ** 2 for t in ref_times) / len(ref_times)) ** 0.5
ref_min = min(ref_times)
ref_max = max(ref_times)
custom_mean = sum(custom_times) / len(custom_times)
custom_std = (
sum((t - custom_mean) ** 2 for t in custom_times) / len(custom_times)
) ** 0.5
custom_min = min(custom_times)
custom_max = max(custom_times)
speedup = ref_mean / custom_mean
# Print results
print("\n" + "=" * 80)
print("BACKWARD PASS BENCHMARK RESULTS")
print("=" * 80)
print(f"Configuration:")
print(f" - Experts: {num_experts}")
print(f" - Hidden size: {hidden_size}")
print(f" - Expert dim: {expert_dim}")
print(f" - Batch tokens: {batch_tokens}")
print(f" - Top-k: {topk}")
print(f" - Runs: {num_runs}")
print()
print(f"Reference Implementation (OpenaiExperts):")
print(f" - Mean: {ref_mean:.3f} ms")
print(f" - Std: {ref_std:.3f} ms")
print(f" - Min: {ref_min:.3f} ms")
print(f" - Max: {ref_max:.3f} ms")
print()
print(f"Custom Implementation (MoE):")
print(f" - Mean: {custom_mean:.3f} ms")
print(f" - Std: {custom_std:.3f} ms")
print(f" - Min: {custom_min:.3f} ms")
print(f" - Max: {custom_max:.3f} ms")
print()
print(f"Speedup: {speedup:.2f}x")
if speedup > 1.0:
print(f"✓ Custom implementation is {speedup:.2f}x faster")
else:
print(f"✗ Custom implementation is {1/speedup:.2f}x slower")
print("=" * 80)
# Optional: Save detailed timing data
print(f"\nDetailed timings (ms):")
print(f"Reference: {ref_times}")
print(f"Custom: {custom_times}")
def test_moe_backward_benchmark_memory():
"""Benchmark memory usage during backward pass."""
# Test configuration
num_experts = 128
hidden_size = 1024
expert_dim = 512
batch_tokens = 4096
topk = 2
torch.manual_seed(1337)
torch.cuda.manual_seed(1337)
# Generate test data
hidden_states = torch.randn(
batch_tokens,
hidden_size,
device="cuda",
dtype=torch.float32,
requires_grad=True,
)
router_idx = torch.randint(0, num_experts, (batch_tokens, topk), device="cuda")
router_wt = torch.rand(batch_tokens, topk, device="cuda")
router_wt = router_wt / router_wt.sum(dim=-1, keepdim=True)
# Initialize parameters
gate_up_proj = nn.Parameter(
torch.randn(num_experts, hidden_size, 2 * expert_dim, device="cuda")
)
gate_up_proj_bias = nn.Parameter(
torch.randn(num_experts, 2 * expert_dim, device="cuda")
)
down_proj = nn.Parameter(
torch.randn(num_experts, expert_dim, hidden_size, device="cuda")
)
down_proj_bias = nn.Parameter(torch.randn(num_experts, hidden_size, device="cuda"))
alpha = 1.702
# Create Layers
layer = MoE()
ref_layer = OpenaiExperts(
config=type(
"Config",
(object,),
{
"num_local_experts": num_experts,
"intermediate_size": expert_dim,
"hidden_size": hidden_size,
},
)
)
# Set reference layer parameters
ref_layer.gate_up_proj = gate_up_proj
ref_layer.gate_up_proj_bias = gate_up_proj_bias
ref_layer.down_proj = down_proj
ref_layer.down_proj_bias = down_proj_bias
# Measure memory for reference implementation
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
hidden_states_copy = hidden_states.clone().detach().requires_grad_(True)
output = ref_layer(hidden_states_copy, router_idx, router_wt)
loss = output.sum()
loss.backward()
ref_memory = torch.cuda.max_memory_allocated() / 1024**3 # GB
# Measure memory for custom implementation
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
hidden_states_copy = hidden_states.clone().detach().requires_grad_(True)
output = layer(
hidden_states_copy,
router_idx,
router_wt,
alpha,
gate_up_proj,
gate_up_proj_bias,
down_proj,
down_proj_bias,
)
loss = output.sum()
loss.backward()
custom_memory = torch.cuda.max_memory_allocated() / 1024**3 # GB
print("\n" + "=" * 60)
print("MEMORY USAGE BENCHMARK")
print("=" * 60)
print(f"Reference implementation: {ref_memory:.3f} GB")
print(f"Custom implementation: {custom_memory:.3f} GB")
print(f"Memory ratio: {custom_memory/ref_memory:.3f}x")
if custom_memory < ref_memory:
print(f"✓ Custom uses {(1 - custom_memory/ref_memory)*100:.1f}% less memory")
else:
print(f"✗ Custom uses {(custom_memory/ref_memory - 1)*100:.1f}% more memory")
print("=" * 60)
#################
def test_moe_forward_benchmark():
"""Benchmark forward pass performance between MoE and OpenaiExperts."""
# Test configuration
num_experts = 128
hidden_size = 1024
expert_dim = 512
batch_tokens = 4096
topk = 2
num_warmup = 5
num_runs = 50
torch.manual_seed(1337)
torch.cuda.manual_seed(1337)
# Generate test data
hidden_states = torch.randn(
batch_tokens,
hidden_size,
device="cuda",
dtype=torch.float32,
)
router_idx = torch.randint(0, num_experts, (batch_tokens, topk), device="cuda")
router_wt = torch.rand(batch_tokens, topk, device="cuda")
router_wt = router_wt / router_wt.sum(dim=-1, keepdim=True) # Normalize
# Initialize parameters
gate_up_proj = nn.Parameter(
torch.randn(num_experts, hidden_size, 2 * expert_dim, device="cuda")
)
gate_up_proj_bias = nn.Parameter(
torch.randn(num_experts, 2 * expert_dim, device="cuda")
)
down_proj = nn.Parameter(
torch.randn(num_experts, expert_dim, hidden_size, device="cuda")
)
down_proj_bias = nn.Parameter(torch.randn(num_experts, hidden_size, device="cuda"))
alpha = 1.702
# Create Layers
layer = MoE()
ref_layer = OpenaiExperts(
config=type(
"Config",
(object,),
{
"num_local_experts": num_experts,
"intermediate_size": expert_dim,
"hidden_size": hidden_size,
},
)
)
# Set reference layer parameters
ref_layer.gate_up_proj = gate_up_proj
ref_layer.gate_up_proj_bias = gate_up_proj_bias
ref_layer.down_proj = down_proj
ref_layer.down_proj_bias = down_proj_bias
def benchmark_ref_forward():
"""Benchmark reference implementation forward pass."""
with torch.no_grad():
output = ref_layer(hidden_states, router_idx, router_wt)
return output
def benchmark_custom_forward():
"""Benchmark custom implementation forward pass."""
with torch.no_grad():
output = layer(
hidden_states,
router_idx,
router_wt,
alpha,
gate_up_proj,
gate_up_proj_bias,
down_proj,
down_proj_bias,
)
return output
# Warmup
print("Warming up...")
for _ in range(num_warmup):
benchmark_ref_forward()
benchmark_custom_forward()
# Clear cache and synchronize
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Benchmark reference implementation
print(f"Benchmarking reference implementation ({num_runs} runs)...")
ref_times = []
for i in range(num_runs):
torch.cuda.synchronize()
start_time = time.perf_counter()
benchmark_ref_forward()
torch.cuda.synchronize()
end_time = time.perf_counter()
ref_times.append((end_time - start_time) * 1000) # Convert to ms
if (i + 1) % 10 == 0:
print(f" Completed {i + 1}/{num_runs} runs")
# Clear cache and synchronize
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Benchmark custom implementation
print(f"Benchmarking custom implementation ({num_runs} runs)...")
custom_times = []
for i in range(num_runs):
torch.cuda.synchronize()
start_time = time.perf_counter()
benchmark_custom_forward()
torch.cuda.synchronize()
end_time = time.perf_counter()
custom_times.append((end_time - start_time) * 1000) # Convert to ms
if (i + 1) % 10 == 0:
print(f" Completed {i + 1}/{num_runs} runs")
# Calculate statistics
ref_mean = sum(ref_times) / len(ref_times)
ref_std = (sum((t - ref_mean) ** 2 for t in ref_times) / len(ref_times)) ** 0.5
ref_min = min(ref_times)
ref_max = max(ref_times)
custom_mean = sum(custom_times) / len(custom_times)
custom_std = (
sum((t - custom_mean) ** 2 for t in custom_times) / len(custom_times)
) ** 0.5
custom_min = min(custom_times)
custom_max = max(custom_times)
speedup = ref_mean / custom_mean
# Print results
print("\n" + "=" * 80)
print("FORWARD PASS BENCHMARK RESULTS")
print("=" * 80)
print(f"Configuration:")
print(f" - Experts: {num_experts}")
print(f" - Hidden size: {hidden_size}")
print(f" - Expert dim: {expert_dim}")
print(f" - Batch tokens: {batch_tokens}")
print(f" - Top-k: {topk}")
print(f" - Runs: {num_runs}")
print()
print(f"Reference Implementation (OpenaiExperts):")
print(f" - Mean: {ref_mean:.3f} ms")
print(f" - Std: {ref_std:.3f} ms")
print(f" - Min: {ref_min:.3f} ms")
print(f" - Max: {ref_max:.3f} ms")
print()
print(f"Custom Implementation (MoE):")
print(f" - Mean: {custom_mean:.3f} ms")
print(f" - Std: {custom_std:.3f} ms")
print(f" - Min: {custom_min:.3f} ms")
print(f" - Max: {custom_max:.3f} ms")
print()
print(f"Speedup: {speedup:.2f}x")
if speedup > 1.0:
print(f"✓ Custom implementation is {speedup:.2f}x faster")
else:
print(f"✗ Custom implementation is {1/speedup:.2f}x slower")
print("=" * 80)
# Optional: Save detailed timing data
print(f"\nDetailed timings (ms):")
print(f"Reference: {ref_times}")
print(f"Custom: {custom_times}")
def test_moe_forward_benchmark_memory():
"""Benchmark memory usage during forward pass."""
# Test configuration
num_experts = 128
hidden_size = 1024
expert_dim = 512
batch_tokens = 4096
topk = 2
torch.manual_seed(1337)
torch.cuda.manual_seed(1337)
# Generate test data
hidden_states = torch.randn(
batch_tokens,
hidden_size,
device="cuda",
dtype=torch.float32,
)
router_idx = torch.randint(0, num_experts, (batch_tokens, topk), device="cuda")
router_wt = torch.rand(batch_tokens, topk, device="cuda")
router_wt = router_wt / router_wt.sum(dim=-1, keepdim=True)
# Initialize parameters
gate_up_proj = nn.Parameter(
torch.randn(num_experts, hidden_size, 2 * expert_dim, device="cuda")
)
gate_up_proj_bias = nn.Parameter(
torch.randn(num_experts, 2 * expert_dim, device="cuda")
)
down_proj = nn.Parameter(
torch.randn(num_experts, expert_dim, hidden_size, device="cuda")
)
down_proj_bias = nn.Parameter(torch.randn(num_experts, hidden_size, device="cuda"))
alpha = 1.702
# Create Layers
layer = MoE()
ref_layer = OpenaiExperts(
config=type(
"Config",
(object,),
{
"num_local_experts": num_experts,
"intermediate_size": expert_dim,
"hidden_size": hidden_size,
},
)
)
# Set reference layer parameters
ref_layer.gate_up_proj = gate_up_proj
ref_layer.gate_up_proj_bias = gate_up_proj_bias
ref_layer.down_proj = down_proj
ref_layer.down_proj_bias = down_proj_bias
# Measure memory for reference implementation
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
with torch.no_grad():
output = ref_layer(hidden_states, router_idx, router_wt)
ref_memory = torch.cuda.max_memory_allocated() / 1024**3 # GB
# Measure memory for custom implementation
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
with torch.no_grad():
output = layer(
hidden_states,
router_idx,
router_wt,
alpha,
gate_up_proj,
gate_up_proj_bias,
down_proj,
down_proj_bias,
)
custom_memory = torch.cuda.max_memory_allocated() / 1024**3 # GB
print("\n" + "=" * 60)
print("FORWARD MEMORY USAGE BENCHMARK")
print("=" * 60)
print(f"Reference implementation: {ref_memory:.3f} GB")
print(f"Custom implementation: {custom_memory:.3f} GB")
print(f"Memory ratio: {custom_memory/ref_memory:.3f}x")
if custom_memory < ref_memory:
print(f"✓ Custom uses {(1 - custom_memory/ref_memory)*100:.1f}% less memory")
else:
print(f"✗ Custom uses {(custom_memory/ref_memory - 1)*100:.1f}% more memory")
print("=" * 60)
def test_moe_forward_benchmark_throughput():
"""Benchmark throughput (tokens/second) for forward pass."""
# Test configuration
num_experts = 128
hidden_size = 1024
expert_dim = 512
batch_tokens = 4096
topk = 2
num_runs = 100
torch.manual_seed(1337)
torch.cuda.manual_seed(1337)
# Generate test data
hidden_states = torch.randn(
batch_tokens,
hidden_size,
device="cuda",
dtype=torch.float32,
)
router_idx = torch.randint(0, num_experts, (batch_tokens, topk), device="cuda")
router_wt = torch.rand(batch_tokens, topk, device="cuda")
router_wt = router_wt / router_wt.sum(dim=-1, keepdim=True)
# Initialize parameters
gate_up_proj = nn.Parameter(
torch.randn(num_experts, hidden_size, 2 * expert_dim, device="cuda")
)
gate_up_proj_bias = nn.Parameter(
torch.randn(num_experts, 2 * expert_dim, device="cuda")
)
down_proj = nn.Parameter(
torch.randn(num_experts, expert_dim, hidden_size, device="cuda")
)
down_proj_bias = nn.Parameter(torch.randn(num_experts, hidden_size, device="cuda"))
alpha = 1.702
# Create Layers
layer = MoE()
ref_layer = OpenaiExperts(
config=type(
"Config",
(object,),
{
"num_local_experts": num_experts,
"intermediate_size": expert_dim,
"hidden_size": hidden_size,
},
)
)
# Set reference layer parameters
ref_layer.gate_up_proj = gate_up_proj
ref_layer.gate_up_proj_bias = gate_up_proj_bias
ref_layer.down_proj = down_proj
ref_layer.down_proj_bias = down_proj_bias
# Warmup
print("Warming up for throughput test...")
for _ in range(10):
with torch.no_grad():
ref_layer(hidden_states, router_idx, router_wt)
layer(
hidden_states,
router_idx,
router_wt,
alpha,
gate_up_proj,
gate_up_proj_bias,
down_proj,
down_proj_bias,
)
# Benchmark reference throughput
torch.cuda.synchronize()
start_time = time.perf_counter()
for _ in range(num_runs):
with torch.no_grad():
ref_layer(hidden_states, router_idx, router_wt)
torch.cuda.synchronize()
ref_total_time = time.perf_counter() - start_time
# Benchmark custom throughput
torch.cuda.synchronize()
start_time = time.perf_counter()
for _ in range(num_runs):
with torch.no_grad():
layer(
hidden_states,
router_idx,
router_wt,
alpha,
gate_up_proj,
gate_up_proj_bias,
down_proj,
down_proj_bias,
)
torch.cuda.synchronize()
custom_total_time = time.perf_counter() - start_time
# Calculate throughput
total_tokens_processed = batch_tokens * num_runs
ref_throughput = total_tokens_processed / ref_total_time
custom_throughput = total_tokens_processed / custom_total_time
print("\n" + "=" * 70)
print("FORWARD THROUGHPUT BENCHMARK")
print("=" * 70)
print(
f"Configuration: {batch_tokens} tokens/batch × {num_runs} runs = {total_tokens_processed:,} tokens"
)
print()
print(f"Reference Implementation:")
print(f" - Total time: {ref_total_time:.3f} seconds")
print(f" - Throughput: {ref_throughput:,.0f} tokens/second")
print()
print(f"Custom Implementation:")
print(f" - Total time: {custom_total_time:.3f} seconds")
print(f" - Throughput: {custom_throughput:,.0f} tokens/second")
print()
print(f"Throughput improvement: {custom_throughput/ref_throughput:.2f}x")
if custom_throughput > ref_throughput:
print(
f"✓ Custom processes {((custom_throughput/ref_throughput - 1)*100):.1f}% more tokens/second"
)
else:
print(
f"✗ Custom processes {((1 - custom_throughput/ref_throughput)*100):.1f}% fewer tokens/second"
)
print("=" * 70)