ARBS / testing /kg /test_kv_integration.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Unit tests for KV ledger integration with composite motif IDs (Phase 17 Plan 02)."""
import torch
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from arbitor.attention.kv_ledger import KVLedger
from arbitor.config import SHARED_VQ_SIZE
def test_composite_ids_no_collision():
ledger = KVLedger(max_size=256)
for i in range(10):
ledger.append(i)
for cid in [100, 101, 102]:
ledger.append(cid)
assert len(ledger) == 13, f"expected 13, got {len(ledger)}"
all_vals = ledger.get_sliding_window(13).tolist()
assert all_vals == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 100, 101, 102], f"got {all_vals}"
print(" PASS test_composite_ids_no_collision")
def test_composite_offset_non_overlapping():
composite_offset = SHARED_VQ_SIZE
assert composite_offset > 0
print(f" PASS test_composite_offset_non_overlapping (offset={composite_offset})")
def test_composite_ids_track_in_ledger():
ledger = KVLedger(max_size=256)
composite_ids = torch.tensor([[0, -1, 2], [3, 4, -1]])
offset = 4096
for b in range(composite_ids.shape[0]):
for k in range(composite_ids.shape[1]):
cid = int(composite_ids[b, k])
if cid >= 0:
ledger.append(offset + cid)
assert len(ledger) == 4, f"expected 4 valid composite IDs, got {len(ledger)}"
vals = ledger.get_sliding_window(4)
expected = [4096, 4098, 4099, 4100]
assert vals.tolist() == expected, f"got {vals.tolist()}"
print(" PASS test_composite_ids_track_in_ledger")
if __name__ == "__main__":
test_composite_ids_no_collision()
test_composite_offset_non_overlapping()
test_composite_ids_track_in_ledger()
print("\nAll KV integration tests PASS")