"""Unit tests for GPURingBuffer and KVLedger.""" import torch import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) from arbitor.attention.ring_buffer import GPURingBuffer from arbitor.attention.kv_ledger import KVLedger def test_rb_append_wrap(): rb = GPURingBuffer(4) for i in range(6): rb.append(i) assert rb.get_last_n(3).tolist() == [3, 4, 5], f"got {rb.get_last_n(3).tolist()}" print(" PASS test_rb_append_wrap") def test_rb_contiguous_no_wrap(): rb = GPURingBuffer(4) for i in range(3): rb.append(i) assert rb.get_last_n(3).tolist() == [0, 1, 2] print(" PASS test_rb_contiguous_no_wrap") def test_rb_empty(): rb = GPURingBuffer(4) assert rb.get_last_n(3).numel() == 0 print(" PASS test_rb_empty") def test_rb_reset(): rb = GPURingBuffer(4) for i in range(3): rb.append(i) rb.reset() assert rb.ptr == 0 and rb.size == 0 assert rb.get_last_n(3).numel() == 0 print(" PASS test_rb_reset") def test_rb_multi_dim(): rb = GPURingBuffer(3, dtype=torch.float32, dim=4) assert rb.buffer.shape == (3, 4) for i in range(3): rb.append(torch.ones(4) * i) last = rb.get_last_n(2) assert last.shape == (2, 4), f"shape {last.shape}" print(" PASS test_rb_multi_dim") def test_rb_get_all(): rb = GPURingBuffer(4) for i in range(6): rb.append(i) all_vals = rb.get_all() assert all_vals.tolist() == [2, 3, 4, 5], f"got {all_vals.tolist()}" print(" PASS test_rb_get_all") def test_rb_partial(): rb = GPURingBuffer(8) for i in range(5): rb.append(i) assert rb.get_all().tolist() == [0, 1, 2, 3, 4] assert rb.get_last_n(3).tolist() == [2, 3, 4] print(" PASS test_rb_partial") def test_kv_ledger_basic(): kv = KVLedger(256) for i in range(100): kv.append(i) assert len(kv) == 100 assert kv.get_sliding_window(5).tolist() == [95, 96, 97, 98, 99] print(" PASS test_kv_ledger_basic") def test_kv_ledger_sliding_window(): kv = KVLedger(32) for i in range(32): kv.append(i) last_5 = kv.get_sliding_window(5) assert last_5.tolist() == [27, 28, 29, 30, 31], f"got {last_5.tolist()}" print(" PASS test_kv_ledger_sliding_window") def test_kv_ledger_sparse(): kv = KVLedger(16) for i in range(10): kv.append(i) sparse = kv.get_sparse(stride=3) assert len(sparse) == 4, f"len={len(sparse)}" print(" PASS test_kv_ledger_sparse") def test_kv_ledger_reset(): kv = KVLedger(8) for i in range(5): kv.append(i) kv.reset() assert len(kv) == 0 print(" PASS test_kv_ledger_reset") def test_cuda_device_move(): if not torch.cuda.is_available(): print(" SKIP test_cuda_device_move (no cuda)") return rb = GPURingBuffer(4) rb = rb.to("cuda") assert rb.buffer.device.type == "cuda" rb.append(42) assert rb.get_last_n(1).tolist() == [42] print(" PASS test_cuda_device_move") if __name__ == "__main__": test_rb_append_wrap() test_rb_contiguous_no_wrap() test_rb_empty() test_rb_reset() test_rb_multi_dim() test_rb_get_all() test_rb_partial() test_kv_ledger_basic() test_kv_ledger_sliding_window() test_kv_ledger_sparse() test_kv_ledger_reset() test_cuda_device_move() print("\nAll ring buffer tests PASS")