File size: 1,675 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | """Tests for MoEGraph top-k parallel expert routing."""
import os, torch
from arbitor.components import MoEGraph
from arbitor.config import HIDDEN_DIM
def test_moegraph_topk_params():
mg = MoEGraph(top_k=8)
assert mg.top_k == 8
assert len(mg.W_gate) == 256
assert mg.codebook_up is not None
def test_moegraph_topk_forward_shape():
mg = MoEGraph(top_k=8)
B, T = 1, 4
x = torch.randn(B, T, HIDDEN_DIM)
vq = torch.randint(0, 1000, (B, T))
out, ponder = mg(x, vq)
assert out.shape == (B, T, HIDDEN_DIM)
assert torch.isfinite(out).all()
assert torch.isfinite(ponder)
assert ponder >= 0
def test_moegraph_top1_vs_topk_shape():
mg1 = MoEGraph(top_k=1)
mg8 = MoEGraph(top_k=8)
B, T = 1, 4
x = torch.randn(B, T, HIDDEN_DIM)
vq = torch.randint(0, 1000, (B, T))
out1, _ = mg1(x, vq)
out8, _ = mg8(x, vq)
assert out1.shape == out8.shape
def test_moegraph_topk_no_dead_code():
from arbitor.main import ARBModel
model = ARBModel(enable_image=False, enable_audio=False, enable_vq=True, enable_graph=True)
mg = model.moegraph
assert mg.top_k == 4
assert mg.num_experts == 256
assert mg.cb_dim == 768
def test_moegraph_topk_routing_logic():
torch.manual_seed(42)
B, T, E, k = 2, 4, 256, 8
scores = torch.randn(B, T, E)
scores_topk, idx = scores.topk(k=k, dim=-1)
weights = torch.softmax(scores_topk / 0.1, dim=-1)
assert weights.shape == (B, T, k)
assert idx.shape == (B, T, k)
assert torch.allclose(weights.sum(dim=-1), torch.ones(B, T))
for b in range(B):
for t in range(T):
assert len(torch.unique(idx[b, t])) == k
|