triton moe
This repository contains the kernels to run the Mixture of Experts (MoE) model using Triton.
# /// script
# dependencies = [
# "kernels",
# "numpy",
# "torch",
# ]
# ///
import torch
from kernels import get_kernel
# Make reproducible
torch.manual_seed(42)
torch.cuda.manual_seed(42)
# Download optimized kernels from the Hugging Face hub
triton_moe = get_kernel("kernels-community/triton-moe")
# Random tensor
x = torch.randn((10, 10), dtype=torch.float16, device="cuda")
# Run the kernel
gate_up_out = x.unsqueeze(-1).repeat(1, 1, 2)
out = triton_moe.fused_glu.fused_glu_triton(gate_up_out=gate_up_out, alpha=1.0)
# Check the output
print("Output shape:", out.shape)
print("Output sum:", out.sum().item())
# Output shape: torch.Size([10, 10, 1])
# Output sum: 62.875
Testing
nix develop -i -L .#test --command python -m pytest -s tests
expected output of the test in tests/test_triton_moe.py:
warning: Git tree '/home/ubuntu/Projects/triton-moe' is dirty
evaluation warning: CUDA versions older than 12.0 will be removed in Nixpkgs 25.05; see the 24.11 release notes for more information
triton_moe-torch-ext> Running phase: unpackPhase
triton_moe-torch-ext> unpacking source archive /nix/store/5zm9aqzym4h6xx414sy17dynr1hjbwh8-source
triton_moe-torch-ext> source root is source
triton_moe-torch-ext> Running phase: patchPhase
triton_moe-torch-ext> Running phase: updateAutotoolsGnuConfigScriptsPhase
triton_moe-torch-ext> Running phase: configurePhase
triton_moe-torch-ext> no configure script, doing nothing
triton_moe-torch-ext> Running phase: installPhase
triton_moe-torch-ext> Running phase: fixupPhase
triton_moe-torch-ext> shrinking RPATHs of ELF executables and libraries in /nix/store/yrzl0pngg8xxpf4jpkya8rmvmakgn4cd-triton_moe-torch-ext
triton_moe-torch-ext> checking for references to /build/ in /nix/store/yrzl0pngg8xxpf4jpkya8rmvmakgn4cd-triton_moe-torch-ext...
triton_moe-torch-ext> patching script interpreter paths in /nix/store/yrzl0pngg8xxpf4jpkya8rmvmakgn4cd-triton_moe-torch-ext
===================================== test session starts ======================================
platform linux -- Python 3.12.10, pytest-8.3.5, pluggy-1.5.0
rootdir: /home/ubuntu/Projects/triton-moe
plugins: hypothesis-6.130.12
collected 8 items
tests/test_triton_moe.py Average difference: 0.009301766753196716
Max difference: 0.095703125
.gate_up_proj.grad exists: True
gate_up_proj_bias.grad exists: True
down_proj.grad exists: True
down_proj_bias.grad exists: True
.gate_up_proj.grad exists: True
gate_up_proj_bias.grad exists: True
down_proj.grad exists: True
down_proj_bias.grad exists: True
hidden_states.grad exists: True
✓ Backward test passed - all parameters have gradients
10 elements from gate_up_proj gradients:
tensor([ 179.9667, -852.7672, -3274.1992, -4076.2095, -2571.6282, -296.6539,
1800.2004, 503.5397, 48.4640, 191.8257], device='cuda:0')
10 elements from ref_layer.gate_up_proj gradients:
tensor([ 179.9663, -852.7676, -3274.1995, -4076.2188, -2571.6238, -296.6619,
1800.1997, 503.5378, 48.4632, 191.8266], device='cuda:0')
.Warming up...
Benchmarking reference implementation (20 runs)...
Completed 5/20 runs
Completed 10/20 runs
Completed 15/20 runs
Completed 20/20 runs
Benchmarking custom implementation (20 runs)...
Completed 5/20 runs
Completed 10/20 runs
Completed 15/20 runs
Completed 20/20 runs
================================================================================
BACKWARD PASS BENCHMARK RESULTS
================================================================================
Configuration:
- Experts: 128
- Hidden size: 1024
- Expert dim: 512
- Batch tokens: 4096
- Top-k: 2
- Runs: 20
Reference Implementation (OpenaiExperts):
- Mean: 1855.949 ms
- Std: 9.959 ms
- Min: 1851.829 ms
- Max: 1896.181 ms
Custom Implementation (MoE):
- Mean: 250.311 ms
- Std: 0.591 ms
- Min: 249.697 ms
- Max: 252.103 ms
Speedup: 7.41x
✓ Custom implementation is 7.41x faster
================================================================================
Detailed timings (ms):
Reference: [1869.8261399986222, 1852.2405139519833, 1853.5575779969804, 1852.6741919922642, 1853.342688002158, 1853.4869640134275, 1854.8076269798912, 1852.6032069930807, 1853.7065120181069, 1852.3418360273354, 1853.6653390037827, 1853.267052967567, 1853.421829000581, 1851.8838259624317, 1852.5341430213302, 1851.8291829968803, 1852.3553150007501, 1896.1806659935974, 1852.8080619871616, 1852.4555769981816]
Custom: [252.10309802787378, 251.01929501397535, 250.63456897623837, 250.02275395672768, 249.69729100121185, 249.89533895859495, 249.75963402539492, 249.9880829709582, 251.09507801244035, 250.78181497519836, 250.63572899671271, 250.75654400279745, 250.1296689733863, 249.93031000485644, 249.83260600129142, 250.06496504647657, 250.1023070071824, 250.03619497874752, 249.96405199635774, 249.77726401994005]
.
============================================================
MEMORY USAGE BENCHMARK
============================================================
Reference implementation: 2.977 GB
Custom implementation: 2.737 GB
Memory ratio: 0.919x
✓ Custom uses 8.1% less memory
============================================================
.Warming up...
Benchmarking reference implementation (50 runs)...
Completed 10/50 runs
Completed 20/50 runs
Completed 30/50 runs
Completed 40/50 runs
Completed 50/50 runs
Benchmarking custom implementation (50 runs)...
Completed 10/50 runs
Completed 20/50 runs
Completed 30/50 runs
Completed 40/50 runs
Completed 50/50 runs
================================================================================
FORWARD PASS BENCHMARK RESULTS
================================================================================
Configuration:
- Experts: 128
- Hidden size: 1024
- Expert dim: 512
- Batch tokens: 4096
- Top-k: 2
- Runs: 50
Reference Implementation (OpenaiExperts):
- Mean: 45.218 ms
- Std: 0.643 ms
- Min: 44.657 ms
- Max: 49.252 ms
Custom Implementation (MoE):
- Mean: 45.092 ms
- Std: 0.382 ms
- Min: 44.630 ms
- Max: 45.988 ms
Speedup: 1.00x
✓ Custom implementation is 1.00x faster
================================================================================
Detailed timings (ms):
Reference: [49.2524920264259, 44.996229000389576, 44.75043900310993, 45.368854014668614, 45.28193100122735, 44.9596070102416, 45.46641802880913, 45.40711600566283, 44.76102895569056, 45.04138103220612, 44.8942250222899, 45.068661973346025, 45.07604299578816, 45.02651101211086, 44.988538953475654, 44.86601299140602, 45.13182397931814, 44.97082799207419, 44.656905985902995, 45.53678201045841, 45.52290996070951, 45.288041001185775, 45.18025700235739, 45.118414040189236, 44.841952971182764, 45.04251101752743, 45.11384398210794, 45.021480007562786, 44.88742502871901, 44.94614701252431, 45.80574203282595, 46.0561320069246, 45.724289026111364, 45.17030599527061, 45.18806800479069, 44.91667600814253, 45.08163296850398, 44.952887983527035, 45.330752967856824, 44.88741495879367, 44.882684014737606, 45.47502798959613, 45.74003902962431, 45.015780022367835, 45.045601029414684, 44.906274997629225, 44.9566770112142, 45.05321098258719, 45.05068197613582, 45.216178987175226]
Custom: [45.50682002445683, 44.99372898135334, 45.103324053343385, 45.18590698717162, 44.91502500604838, 45.08770297979936, 45.72224896401167, 45.902586018200964, 45.94191804062575, 45.69388699019328, 45.324411999899894, 45.01690098550171, 44.72995799733326, 44.99283799668774, 45.312902017030865, 45.248389011248946, 44.97709800489247, 44.99160900013521, 44.63031404884532, 44.66017603408545, 45.02782097551972, 44.99858903000131, 44.89475500304252, 44.744468992576, 44.88639399642125, 44.79811096098274, 44.76995003642514, 44.648215000052005, 44.673235970549285, 44.698296987917274, 44.77059002965689, 44.66089600464329, 44.90563599392772, 45.984469004906714, 45.526801026426256, 45.913067006040365, 45.988010009750724, 45.26515997713432, 45.067521976307034, 45.230168965645134, 44.975398981478065, 44.86092395382002, 45.25493999244645, 44.94089604122564, 44.82307197758928, 44.838363013695925, 45.01837998395786, 44.708857021760195, 44.941117987036705, 44.85234396997839]
.
============================================================
FORWARD MEMORY USAGE BENCHMARK
============================================================
Reference implementation: 0.822 GB
Custom implementation: 1.758 GB
Memory ratio: 2.140x
✗ Custom uses 114.0% more memory
============================================================
.Warming up for throughput test...
======================================================================
FORWARD THROUGHPUT BENCHMARK
======================================================================
Configuration: 4096 tokens/batch × 100 runs = 409,600 tokens
Reference Implementation:
- Total time: 4.510 seconds
- Throughput: 90,816 tokens/second
Custom Implementation:
- Total time: 4.510 seconds
- Throughput: 90,827 tokens/second
Throughput improvement: 1.00x
✓ Custom processes 0.0% more tokens/second
======================================================================
.
================================= 8 passed in 75.71s (0:01:15) =================================
- Downloads last month
- -
kernels
apache-2.0