"""Unit tests for KQCache.""" import torch import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) from arbitor.attention.kq_cache import KQCache def test_kqc_append_peek(): kq = KQCache(8) kq.append(42) assert kq.peek().tolist() == [42] kq.append(99) assert kq.peek(2).tolist() == [42, 99] print(" PASS test_kqc_append_peek") def test_kqc_wrap(): kq = KQCache(4) for i in range(6): kq.append(i) assert kq.peek(3).tolist() == [3, 4, 5], f"got {kq.peek(3).tolist()}" print(" PASS test_kqc_wrap") def test_kqc_peek_order(): kq = KQCache(4) for i in range(6): kq.append(i) all_vals = kq.peek(4) assert all_vals.tolist() == [2, 3, 4, 5], f"got {all_vals.tolist()}" print(" PASS test_kqc_peek_order") def test_kqc_empty(): kq = KQCache(8) assert kq.peek(3).numel() == 0 assert kq.size == 0 print(" PASS test_kqc_empty") def test_kqc_reset(): kq = KQCache(8) kq.append(1) kq.append(2) kq.reset() assert kq.size == 0 assert kq.peek().numel() == 0 print(" PASS test_kqc_reset") if __name__ == "__main__": test_kqc_append_peek() test_kqc_wrap() test_kqc_peek_order() test_kqc_empty() test_kqc_reset() print("\nAll KQ cache tests PASS")