File size: 1,675 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""Tests for MoEGraph top-k parallel expert routing."""
import os, torch
from arbitor.components import MoEGraph
from arbitor.config import HIDDEN_DIM


def test_moegraph_topk_params():
    mg = MoEGraph(top_k=8)
    assert mg.top_k == 8
    assert len(mg.W_gate) == 256
    assert mg.codebook_up is not None


def test_moegraph_topk_forward_shape():
    mg = MoEGraph(top_k=8)
    B, T = 1, 4
    x = torch.randn(B, T, HIDDEN_DIM)
    vq = torch.randint(0, 1000, (B, T))
    out, ponder = mg(x, vq)
    assert out.shape == (B, T, HIDDEN_DIM)
    assert torch.isfinite(out).all()
    assert torch.isfinite(ponder)
    assert ponder >= 0


def test_moegraph_top1_vs_topk_shape():
    mg1 = MoEGraph(top_k=1)
    mg8 = MoEGraph(top_k=8)
    B, T = 1, 4
    x = torch.randn(B, T, HIDDEN_DIM)
    vq = torch.randint(0, 1000, (B, T))
    out1, _ = mg1(x, vq)
    out8, _ = mg8(x, vq)
    assert out1.shape == out8.shape


def test_moegraph_topk_no_dead_code():
    from arbitor.main import ARBModel
    model = ARBModel(enable_image=False, enable_audio=False, enable_vq=True, enable_graph=True)
    mg = model.moegraph
    assert mg.top_k == 4
    assert mg.num_experts == 256
    assert mg.cb_dim == 768


def test_moegraph_topk_routing_logic():
    torch.manual_seed(42)
    B, T, E, k = 2, 4, 256, 8
    scores = torch.randn(B, T, E)
    scores_topk, idx = scores.topk(k=k, dim=-1)
    weights = torch.softmax(scores_topk / 0.1, dim=-1)
    assert weights.shape == (B, T, k)
    assert idx.shape == (B, T, k)
    assert torch.allclose(weights.sum(dim=-1), torch.ones(B, T))
    for b in range(B):
        for t in range(T):
            assert len(torch.unique(idx[b, t])) == k