File size: 2,354 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | """Unit tests for KG edge co-occurrence learning (Phase 17 Plan 01)."""
import torch
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from arbitor.components import MoEGraph
def _make_mg():
return MoEGraph(cb_dim=64, trigram_dim=512, num_experts=4, core_rank=32, shared_inter=128,
max_iters=2, codebook_size=16, active_graph_max_nodes=32)
def test_int_score_cooccurrence():
tg = _make_mg()
vq1 = torch.tensor([[2, 4, 8, 15]])
tg.update_kg_edges(vq1)
score_after_1 = tg.edge_score.clone()
vq2 = torch.tensor([[2, 4, 9, 15]])
tg.update_kg_edges(vq2)
score_after_2 = tg.edge_score.clone()
assert not torch.equal(score_after_1, score_after_2), "second update should change int edge scores"
print(" PASS test_int_score_cooccurrence")
def test_ternary_quantize():
tg = _make_mg()
tg.edge_score.fill_(0)
tg.edge_score[:10] = 4
tg.edge_score[10:20] = -4
tg._steps_since_requant.fill_(50)
tg.update_kg_edges(torch.tensor([[0, 1, 2, 3]]))
assert (tg.edge_attr[:10] == 1).all(), "positive should be +1"
assert (tg.edge_attr[10:20] == -1).all(), "negative should be -1"
print(" PASS test_ternary_quantize")
def test_batch_detection():
tg = _make_mg()
vq_ids = torch.randint(0, 16, (2, 4))
old_score = tg.edge_score.clone()
tg.update_kg_edges(vq_ids)
assert not torch.equal(old_score, tg.edge_score), "batch update should change edge scores"
print(" PASS test_batch_detection")
def test_int_scores_stay_finite():
tg = _make_mg()
for _ in range(10):
vq = torch.randint(0, 16, (2, 4))
tg.update_kg_edges(vq)
assert tg.edge_score.dtype == torch.int8
print(" PASS test_int_scores_stay_finite")
def test_checkpoint_persistence():
tg1 = _make_mg()
vq = torch.randint(0, 16, (2, 4))
tg1.update_kg_edges(vq)
sd = tg1.state_dict()
tg2 = _make_mg()
tg2.load_state_dict(sd)
assert torch.equal(tg1.edge_score, tg2.edge_score), "state_dict should preserve edge scores"
print(" PASS test_checkpoint_persistence")
if __name__ == "__main__":
test_int_score_cooccurrence()
test_ternary_quantize()
test_batch_detection()
test_int_scores_stay_finite()
test_checkpoint_persistence()
print("\nAll KG edge tests PASS")
|