"""Unit tests for KG edge co-occurrence learning (Phase 17 Plan 01).""" import torch import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) from arbitor.components import MoEGraph def _make_mg(): return MoEGraph(cb_dim=64, trigram_dim=512, num_experts=4, core_rank=32, shared_inter=128, max_iters=2, codebook_size=16, active_graph_max_nodes=32) def test_int_score_cooccurrence(): tg = _make_mg() vq1 = torch.tensor([[2, 4, 8, 15]]) tg.update_kg_edges(vq1) score_after_1 = tg.edge_score.clone() vq2 = torch.tensor([[2, 4, 9, 15]]) tg.update_kg_edges(vq2) score_after_2 = tg.edge_score.clone() assert not torch.equal(score_after_1, score_after_2), "second update should change int edge scores" print(" PASS test_int_score_cooccurrence") def test_ternary_quantize(): tg = _make_mg() tg.edge_score.fill_(0) tg.edge_score[:10] = 4 tg.edge_score[10:20] = -4 tg._steps_since_requant.fill_(50) tg.update_kg_edges(torch.tensor([[0, 1, 2, 3]])) assert (tg.edge_attr[:10] == 1).all(), "positive should be +1" assert (tg.edge_attr[10:20] == -1).all(), "negative should be -1" print(" PASS test_ternary_quantize") def test_batch_detection(): tg = _make_mg() vq_ids = torch.randint(0, 16, (2, 4)) old_score = tg.edge_score.clone() tg.update_kg_edges(vq_ids) assert not torch.equal(old_score, tg.edge_score), "batch update should change edge scores" print(" PASS test_batch_detection") def test_int_scores_stay_finite(): tg = _make_mg() for _ in range(10): vq = torch.randint(0, 16, (2, 4)) tg.update_kg_edges(vq) assert tg.edge_score.dtype == torch.int8 print(" PASS test_int_scores_stay_finite") def test_checkpoint_persistence(): tg1 = _make_mg() vq = torch.randint(0, 16, (2, 4)) tg1.update_kg_edges(vq) sd = tg1.state_dict() tg2 = _make_mg() tg2.load_state_dict(sd) assert torch.equal(tg1.edge_score, tg2.edge_score), "state_dict should preserve edge scores" print(" PASS test_checkpoint_persistence") if __name__ == "__main__": test_int_score_cooccurrence() test_ternary_quantize() test_batch_detection() test_int_scores_stay_finite() test_checkpoint_persistence() print("\nAll KG edge tests PASS")