| """Unit tests for KQCache.""" |
| import torch |
| import sys |
| import os |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
|
|
| from arbitor.attention.kq_cache import KQCache |
|
|
|
|
| def test_kqc_append_peek(): |
| kq = KQCache(8) |
| kq.append(42) |
| assert kq.peek().tolist() == [42] |
| kq.append(99) |
| assert kq.peek(2).tolist() == [42, 99] |
| print(" PASS test_kqc_append_peek") |
|
|
|
|
| def test_kqc_wrap(): |
| kq = KQCache(4) |
| for i in range(6): |
| kq.append(i) |
| assert kq.peek(3).tolist() == [3, 4, 5], f"got {kq.peek(3).tolist()}" |
| print(" PASS test_kqc_wrap") |
|
|
|
|
| def test_kqc_peek_order(): |
| kq = KQCache(4) |
| for i in range(6): |
| kq.append(i) |
| all_vals = kq.peek(4) |
| assert all_vals.tolist() == [2, 3, 4, 5], f"got {all_vals.tolist()}" |
| print(" PASS test_kqc_peek_order") |
|
|
|
|
| def test_kqc_empty(): |
| kq = KQCache(8) |
| assert kq.peek(3).numel() == 0 |
| assert kq.size == 0 |
| print(" PASS test_kqc_empty") |
|
|
|
|
| def test_kqc_reset(): |
| kq = KQCache(8) |
| kq.append(1) |
| kq.append(2) |
| kq.reset() |
| assert kq.size == 0 |
| assert kq.peek().numel() == 0 |
| print(" PASS test_kqc_reset") |
|
|
|
|
| if __name__ == "__main__": |
| test_kqc_append_peek() |
| test_kqc_wrap() |
| test_kqc_peek_order() |
| test_kqc_empty() |
| test_kqc_reset() |
| print("\nAll KQ cache tests PASS") |
|
|