File size: 2,382 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""Verify LSTM removal — structural checks only, no ARBModel construction."""
import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))

from arbitor.components import LossWeights, LossComponents


def test_loss_cleanup():
    lw = LossWeights()
    assert not hasattr(lw, "lstm_hidden_reg")
    assert not hasattr(lw, "conv_vq_commitment")
    lc = LossComponents(weights=lw)
    assert not hasattr(lc, "lstm_hidden_reg")
    assert not hasattr(lc, "conv_vq_commitment")
    print(" PASS test_loss_cleanup")


def test_moe_clean():
    from arbitor.components import MoEGraph
    moe = MoEGraph(cb_dim=64, trigram_dim=128, num_experts=4,
                   core_rank=32, shared_inter=64, max_iters=1,
                   codebook_size=128, active_graph_max_nodes=0)
    assert not hasattr(moe, "router_h"), "router_h should be removed"
    assert not hasattr(moe, "lstm_enabled"), "lstm_enabled should be removed"
    print(" PASS test_moe_clean")


def test_classes_removed():
    import arbitor.components as C
    assert not hasattr(C, "ConvVQCodebook"), "ConvVQCodebook should be removed"
    assert not hasattr(C, "FocusGate"), "FocusGate should be removed"
    assert not hasattr(C, "ConversationStack"), "ConversationStack should be removed"
    assert not hasattr(C, "ConversationLSTM"), "ConversationLSTM should be removed"
    print(" PASS test_classes_removed")


def test_attention_imports():
    from arbitor.attention import GPURingBuffer, KVLedger, KQCache
    from arbitor.attention import MultiHeadLatentAttention, ContextAttentionScheduler
    from arbitor.attention.mla import apply_rotary_emb, precompute_freqs_cis
    print(" PASS test_attention_imports")


def test_video_talker_imports():
    from arbitor.decoders import VideoHead, TalkerHead
    print(" PASS test_video_talker_imports")


def test_memory_budget():
    from arbitor.config import KV_LEDGER_SIZE, KQ_CACHE_SIZE
    ledger_bytes = KV_LEDGER_SIZE * 4
    kq_bytes = KQ_CACHE_SIZE * 4
    total_mb = (ledger_bytes + kq_bytes) / (1024 * 1024)
    assert total_mb < 100
    print(f" PASS test_memory_budget ({total_mb:.2f} MB)")


if __name__ == "__main__":
    test_loss_cleanup()
    test_moe_clean()
    test_classes_removed()
    test_attention_imports()
    test_video_talker_imports()
    test_memory_budget()
    print("\nAll LSTM removal tests PASS")