| """Unit tests for GPURingBuffer and KVLedger.""" |
| import torch |
| import sys |
| import os |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
|
|
| from arbitor.attention.ring_buffer import GPURingBuffer |
| from arbitor.attention.kv_ledger import KVLedger |
|
|
|
|
| def test_rb_append_wrap(): |
| rb = GPURingBuffer(4) |
| for i in range(6): |
| rb.append(i) |
| assert rb.get_last_n(3).tolist() == [3, 4, 5], f"got {rb.get_last_n(3).tolist()}" |
| print(" PASS test_rb_append_wrap") |
|
|
|
|
| def test_rb_contiguous_no_wrap(): |
| rb = GPURingBuffer(4) |
| for i in range(3): |
| rb.append(i) |
| assert rb.get_last_n(3).tolist() == [0, 1, 2] |
| print(" PASS test_rb_contiguous_no_wrap") |
|
|
|
|
| def test_rb_empty(): |
| rb = GPURingBuffer(4) |
| assert rb.get_last_n(3).numel() == 0 |
| print(" PASS test_rb_empty") |
|
|
|
|
| def test_rb_reset(): |
| rb = GPURingBuffer(4) |
| for i in range(3): |
| rb.append(i) |
| rb.reset() |
| assert rb.ptr == 0 and rb.size == 0 |
| assert rb.get_last_n(3).numel() == 0 |
| print(" PASS test_rb_reset") |
|
|
|
|
| def test_rb_multi_dim(): |
| rb = GPURingBuffer(3, dtype=torch.float32, dim=4) |
| assert rb.buffer.shape == (3, 4) |
| for i in range(3): |
| rb.append(torch.ones(4) * i) |
| last = rb.get_last_n(2) |
| assert last.shape == (2, 4), f"shape {last.shape}" |
| print(" PASS test_rb_multi_dim") |
|
|
|
|
| def test_rb_get_all(): |
| rb = GPURingBuffer(4) |
| for i in range(6): |
| rb.append(i) |
| all_vals = rb.get_all() |
| assert all_vals.tolist() == [2, 3, 4, 5], f"got {all_vals.tolist()}" |
| print(" PASS test_rb_get_all") |
|
|
|
|
| def test_rb_partial(): |
| rb = GPURingBuffer(8) |
| for i in range(5): |
| rb.append(i) |
| assert rb.get_all().tolist() == [0, 1, 2, 3, 4] |
| assert rb.get_last_n(3).tolist() == [2, 3, 4] |
| print(" PASS test_rb_partial") |
|
|
|
|
| def test_kv_ledger_basic(): |
| kv = KVLedger(256) |
| for i in range(100): |
| kv.append(i) |
| assert len(kv) == 100 |
| assert kv.get_sliding_window(5).tolist() == [95, 96, 97, 98, 99] |
| print(" PASS test_kv_ledger_basic") |
|
|
|
|
| def test_kv_ledger_sliding_window(): |
| kv = KVLedger(32) |
| for i in range(32): |
| kv.append(i) |
| last_5 = kv.get_sliding_window(5) |
| assert last_5.tolist() == [27, 28, 29, 30, 31], f"got {last_5.tolist()}" |
| print(" PASS test_kv_ledger_sliding_window") |
|
|
|
|
| def test_kv_ledger_sparse(): |
| kv = KVLedger(16) |
| for i in range(10): |
| kv.append(i) |
| sparse = kv.get_sparse(stride=3) |
| assert len(sparse) == 4, f"len={len(sparse)}" |
| print(" PASS test_kv_ledger_sparse") |
|
|
|
|
| def test_kv_ledger_reset(): |
| kv = KVLedger(8) |
| for i in range(5): |
| kv.append(i) |
| kv.reset() |
| assert len(kv) == 0 |
| print(" PASS test_kv_ledger_reset") |
|
|
|
|
| def test_cuda_device_move(): |
| if not torch.cuda.is_available(): |
| print(" SKIP test_cuda_device_move (no cuda)") |
| return |
| rb = GPURingBuffer(4) |
| rb = rb.to("cuda") |
| assert rb.buffer.device.type == "cuda" |
| rb.append(42) |
| assert rb.get_last_n(1).tolist() == [42] |
| print(" PASS test_cuda_device_move") |
|
|
|
|
| if __name__ == "__main__": |
| test_rb_append_wrap() |
| test_rb_contiguous_no_wrap() |
| test_rb_empty() |
| test_rb_reset() |
| test_rb_multi_dim() |
| test_rb_get_all() |
| test_rb_partial() |
| test_kv_ledger_basic() |
| test_kv_ledger_sliding_window() |
| test_kv_ledger_sparse() |
| test_kv_ledger_reset() |
| test_cuda_device_move() |
| print("\nAll ring buffer tests PASS") |
|
|