File size: 1,322 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 | """Unit tests for KQCache."""
import torch
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from arbitor.attention.kq_cache import KQCache
def test_kqc_append_peek():
kq = KQCache(8)
kq.append(42)
assert kq.peek().tolist() == [42]
kq.append(99)
assert kq.peek(2).tolist() == [42, 99]
print(" PASS test_kqc_append_peek")
def test_kqc_wrap():
kq = KQCache(4)
for i in range(6):
kq.append(i)
assert kq.peek(3).tolist() == [3, 4, 5], f"got {kq.peek(3).tolist()}"
print(" PASS test_kqc_wrap")
def test_kqc_peek_order():
kq = KQCache(4)
for i in range(6):
kq.append(i)
all_vals = kq.peek(4)
assert all_vals.tolist() == [2, 3, 4, 5], f"got {all_vals.tolist()}"
print(" PASS test_kqc_peek_order")
def test_kqc_empty():
kq = KQCache(8)
assert kq.peek(3).numel() == 0
assert kq.size == 0
print(" PASS test_kqc_empty")
def test_kqc_reset():
kq = KQCache(8)
kq.append(1)
kq.append(2)
kq.reset()
assert kq.size == 0
assert kq.peek().numel() == 0
print(" PASS test_kqc_reset")
if __name__ == "__main__":
test_kqc_append_peek()
test_kqc_wrap()
test_kqc_peek_order()
test_kqc_empty()
test_kqc_reset()
print("\nAll KQ cache tests PASS")
|