"""Unit tests for KV ledger integration with composite motif IDs (Phase 17 Plan 02).""" import torch import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) from arbitor.attention.kv_ledger import KVLedger from arbitor.config import SHARED_VQ_SIZE def test_composite_ids_no_collision(): ledger = KVLedger(max_size=256) for i in range(10): ledger.append(i) for cid in [100, 101, 102]: ledger.append(cid) assert len(ledger) == 13, f"expected 13, got {len(ledger)}" all_vals = ledger.get_sliding_window(13).tolist() assert all_vals == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 100, 101, 102], f"got {all_vals}" print(" PASS test_composite_ids_no_collision") def test_composite_offset_non_overlapping(): composite_offset = SHARED_VQ_SIZE assert composite_offset > 0 print(f" PASS test_composite_offset_non_overlapping (offset={composite_offset})") def test_composite_ids_track_in_ledger(): ledger = KVLedger(max_size=256) composite_ids = torch.tensor([[0, -1, 2], [3, 4, -1]]) offset = 4096 for b in range(composite_ids.shape[0]): for k in range(composite_ids.shape[1]): cid = int(composite_ids[b, k]) if cid >= 0: ledger.append(offset + cid) assert len(ledger) == 4, f"expected 4 valid composite IDs, got {len(ledger)}" vals = ledger.get_sliding_window(4) expected = [4096, 4098, 4099, 4100] assert vals.tolist() == expected, f"got {vals.tolist()}" print(" PASS test_composite_ids_track_in_ledger") if __name__ == "__main__": test_composite_ids_no_collision() test_composite_offset_non_overlapping() test_composite_ids_track_in_ledger() print("\nAll KV integration tests PASS")