| """Unit tests for KV ledger integration with composite motif IDs (Phase 17 Plan 02).""" |
| import torch |
| import sys |
| import os |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
|
|
| from arbitor.attention.kv_ledger import KVLedger |
| from arbitor.config import SHARED_VQ_SIZE |
|
|
|
|
| def test_composite_ids_no_collision(): |
| ledger = KVLedger(max_size=256) |
| for i in range(10): |
| ledger.append(i) |
| for cid in [100, 101, 102]: |
| ledger.append(cid) |
| assert len(ledger) == 13, f"expected 13, got {len(ledger)}" |
| all_vals = ledger.get_sliding_window(13).tolist() |
| assert all_vals == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 100, 101, 102], f"got {all_vals}" |
| print(" PASS test_composite_ids_no_collision") |
|
|
|
|
| def test_composite_offset_non_overlapping(): |
| composite_offset = SHARED_VQ_SIZE |
| assert composite_offset > 0 |
| print(f" PASS test_composite_offset_non_overlapping (offset={composite_offset})") |
|
|
|
|
| def test_composite_ids_track_in_ledger(): |
| ledger = KVLedger(max_size=256) |
| composite_ids = torch.tensor([[0, -1, 2], [3, 4, -1]]) |
| offset = 4096 |
|
|
| for b in range(composite_ids.shape[0]): |
| for k in range(composite_ids.shape[1]): |
| cid = int(composite_ids[b, k]) |
| if cid >= 0: |
| ledger.append(offset + cid) |
|
|
| assert len(ledger) == 4, f"expected 4 valid composite IDs, got {len(ledger)}" |
| vals = ledger.get_sliding_window(4) |
| expected = [4096, 4098, 4099, 4100] |
| assert vals.tolist() == expected, f"got {vals.tolist()}" |
| print(" PASS test_composite_ids_track_in_ledger") |
|
|
|
|
| if __name__ == "__main__": |
| test_composite_ids_no_collision() |
| test_composite_offset_non_overlapping() |
| test_composite_ids_track_in_ledger() |
| print("\nAll KV integration tests PASS") |
|
|