File size: 2,354 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""Unit tests for KG edge co-occurrence learning (Phase 17 Plan 01)."""
import torch
import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))

from arbitor.components import MoEGraph


def _make_mg():
    return MoEGraph(cb_dim=64, trigram_dim=512, num_experts=4, core_rank=32, shared_inter=128,
                    max_iters=2, codebook_size=16, active_graph_max_nodes=32)

def test_int_score_cooccurrence():
    tg = _make_mg()
    vq1 = torch.tensor([[2, 4, 8, 15]])
    tg.update_kg_edges(vq1)
    score_after_1 = tg.edge_score.clone()

    vq2 = torch.tensor([[2, 4, 9, 15]])
    tg.update_kg_edges(vq2)
    score_after_2 = tg.edge_score.clone()

    assert not torch.equal(score_after_1, score_after_2), "second update should change int edge scores"
    print(" PASS test_int_score_cooccurrence")


def test_ternary_quantize():
    tg = _make_mg()
    tg.edge_score.fill_(0)
    tg.edge_score[:10] = 4
    tg.edge_score[10:20] = -4
    tg._steps_since_requant.fill_(50)
    tg.update_kg_edges(torch.tensor([[0, 1, 2, 3]]))
    assert (tg.edge_attr[:10] == 1).all(), "positive should be +1"
    assert (tg.edge_attr[10:20] == -1).all(), "negative should be -1"
    print(" PASS test_ternary_quantize")


def test_batch_detection():
    tg = _make_mg()
    vq_ids = torch.randint(0, 16, (2, 4))
    old_score = tg.edge_score.clone()
    tg.update_kg_edges(vq_ids)
    assert not torch.equal(old_score, tg.edge_score), "batch update should change edge scores"
    print(" PASS test_batch_detection")


def test_int_scores_stay_finite():
    tg = _make_mg()
    for _ in range(10):
        vq = torch.randint(0, 16, (2, 4))
        tg.update_kg_edges(vq)
    assert tg.edge_score.dtype == torch.int8
    print(" PASS test_int_scores_stay_finite")


def test_checkpoint_persistence():
    tg1 = _make_mg()
    vq = torch.randint(0, 16, (2, 4))
    tg1.update_kg_edges(vq)
    sd = tg1.state_dict()

    tg2 = _make_mg()
    tg2.load_state_dict(sd)
    assert torch.equal(tg1.edge_score, tg2.edge_score), "state_dict should preserve edge scores"
    print(" PASS test_checkpoint_persistence")


if __name__ == "__main__":
    test_int_score_cooccurrence()
    test_ternary_quantize()
    test_batch_detection()
    test_int_scores_stay_finite()
    test_checkpoint_persistence()
    print("\nAll KG edge tests PASS")