File size: 3,417 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | """Unit tests for GPURingBuffer and KVLedger."""
import torch
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from arbitor.attention.ring_buffer import GPURingBuffer
from arbitor.attention.kv_ledger import KVLedger
def test_rb_append_wrap():
rb = GPURingBuffer(4)
for i in range(6):
rb.append(i)
assert rb.get_last_n(3).tolist() == [3, 4, 5], f"got {rb.get_last_n(3).tolist()}"
print(" PASS test_rb_append_wrap")
def test_rb_contiguous_no_wrap():
rb = GPURingBuffer(4)
for i in range(3):
rb.append(i)
assert rb.get_last_n(3).tolist() == [0, 1, 2]
print(" PASS test_rb_contiguous_no_wrap")
def test_rb_empty():
rb = GPURingBuffer(4)
assert rb.get_last_n(3).numel() == 0
print(" PASS test_rb_empty")
def test_rb_reset():
rb = GPURingBuffer(4)
for i in range(3):
rb.append(i)
rb.reset()
assert rb.ptr == 0 and rb.size == 0
assert rb.get_last_n(3).numel() == 0
print(" PASS test_rb_reset")
def test_rb_multi_dim():
rb = GPURingBuffer(3, dtype=torch.float32, dim=4)
assert rb.buffer.shape == (3, 4)
for i in range(3):
rb.append(torch.ones(4) * i)
last = rb.get_last_n(2)
assert last.shape == (2, 4), f"shape {last.shape}"
print(" PASS test_rb_multi_dim")
def test_rb_get_all():
rb = GPURingBuffer(4)
for i in range(6):
rb.append(i)
all_vals = rb.get_all()
assert all_vals.tolist() == [2, 3, 4, 5], f"got {all_vals.tolist()}"
print(" PASS test_rb_get_all")
def test_rb_partial():
rb = GPURingBuffer(8)
for i in range(5):
rb.append(i)
assert rb.get_all().tolist() == [0, 1, 2, 3, 4]
assert rb.get_last_n(3).tolist() == [2, 3, 4]
print(" PASS test_rb_partial")
def test_kv_ledger_basic():
kv = KVLedger(256)
for i in range(100):
kv.append(i)
assert len(kv) == 100
assert kv.get_sliding_window(5).tolist() == [95, 96, 97, 98, 99]
print(" PASS test_kv_ledger_basic")
def test_kv_ledger_sliding_window():
kv = KVLedger(32)
for i in range(32):
kv.append(i)
last_5 = kv.get_sliding_window(5)
assert last_5.tolist() == [27, 28, 29, 30, 31], f"got {last_5.tolist()}"
print(" PASS test_kv_ledger_sliding_window")
def test_kv_ledger_sparse():
kv = KVLedger(16)
for i in range(10):
kv.append(i)
sparse = kv.get_sparse(stride=3)
assert len(sparse) == 4, f"len={len(sparse)}"
print(" PASS test_kv_ledger_sparse")
def test_kv_ledger_reset():
kv = KVLedger(8)
for i in range(5):
kv.append(i)
kv.reset()
assert len(kv) == 0
print(" PASS test_kv_ledger_reset")
def test_cuda_device_move():
if not torch.cuda.is_available():
print(" SKIP test_cuda_device_move (no cuda)")
return
rb = GPURingBuffer(4)
rb = rb.to("cuda")
assert rb.buffer.device.type == "cuda"
rb.append(42)
assert rb.get_last_n(1).tolist() == [42]
print(" PASS test_cuda_device_move")
if __name__ == "__main__":
test_rb_append_wrap()
test_rb_contiguous_no_wrap()
test_rb_empty()
test_rb_reset()
test_rb_multi_dim()
test_rb_get_all()
test_rb_partial()
test_kv_ledger_basic()
test_kv_ledger_sliding_window()
test_kv_ledger_sparse()
test_kv_ledger_reset()
test_cuda_device_move()
print("\nAll ring buffer tests PASS")
|