| """Unit tests for KG edge co-occurrence learning (Phase 17 Plan 01).""" |
| import torch |
| import sys |
| import os |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
|
|
| from arbitor.components import MoEGraph |
|
|
|
|
| def _make_mg(): |
| return MoEGraph(cb_dim=64, trigram_dim=512, num_experts=4, core_rank=32, shared_inter=128, |
| max_iters=2, codebook_size=16, active_graph_max_nodes=32) |
|
|
| def test_int_score_cooccurrence(): |
| tg = _make_mg() |
| vq1 = torch.tensor([[2, 4, 8, 15]]) |
| tg.update_kg_edges(vq1) |
| score_after_1 = tg.edge_score.clone() |
|
|
| vq2 = torch.tensor([[2, 4, 9, 15]]) |
| tg.update_kg_edges(vq2) |
| score_after_2 = tg.edge_score.clone() |
|
|
| assert not torch.equal(score_after_1, score_after_2), "second update should change int edge scores" |
| print(" PASS test_int_score_cooccurrence") |
|
|
|
|
| def test_ternary_quantize(): |
| tg = _make_mg() |
| tg.edge_score.fill_(0) |
| tg.edge_score[:10] = 4 |
| tg.edge_score[10:20] = -4 |
| tg._steps_since_requant.fill_(50) |
| tg.update_kg_edges(torch.tensor([[0, 1, 2, 3]])) |
| assert (tg.edge_attr[:10] == 1).all(), "positive should be +1" |
| assert (tg.edge_attr[10:20] == -1).all(), "negative should be -1" |
| print(" PASS test_ternary_quantize") |
|
|
|
|
| def test_batch_detection(): |
| tg = _make_mg() |
| vq_ids = torch.randint(0, 16, (2, 4)) |
| old_score = tg.edge_score.clone() |
| tg.update_kg_edges(vq_ids) |
| assert not torch.equal(old_score, tg.edge_score), "batch update should change edge scores" |
| print(" PASS test_batch_detection") |
|
|
|
|
| def test_int_scores_stay_finite(): |
| tg = _make_mg() |
| for _ in range(10): |
| vq = torch.randint(0, 16, (2, 4)) |
| tg.update_kg_edges(vq) |
| assert tg.edge_score.dtype == torch.int8 |
| print(" PASS test_int_scores_stay_finite") |
|
|
|
|
| def test_checkpoint_persistence(): |
| tg1 = _make_mg() |
| vq = torch.randint(0, 16, (2, 4)) |
| tg1.update_kg_edges(vq) |
| sd = tg1.state_dict() |
|
|
| tg2 = _make_mg() |
| tg2.load_state_dict(sd) |
| assert torch.equal(tg1.edge_score, tg2.edge_score), "state_dict should preserve edge scores" |
| print(" PASS test_checkpoint_persistence") |
|
|
|
|
| if __name__ == "__main__": |
| test_int_score_cooccurrence() |
| test_ternary_quantize() |
| test_batch_detection() |
| test_int_scores_stay_finite() |
| test_checkpoint_persistence() |
| print("\nAll KG edge tests PASS") |
|
|