File size: 3,417 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""Unit tests for GPURingBuffer and KVLedger."""
import torch
import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))

from arbitor.attention.ring_buffer import GPURingBuffer
from arbitor.attention.kv_ledger import KVLedger


def test_rb_append_wrap():
    rb = GPURingBuffer(4)
    for i in range(6):
        rb.append(i)
    assert rb.get_last_n(3).tolist() == [3, 4, 5], f"got {rb.get_last_n(3).tolist()}"
    print(" PASS test_rb_append_wrap")


def test_rb_contiguous_no_wrap():
    rb = GPURingBuffer(4)
    for i in range(3):
        rb.append(i)
    assert rb.get_last_n(3).tolist() == [0, 1, 2]
    print(" PASS test_rb_contiguous_no_wrap")


def test_rb_empty():
    rb = GPURingBuffer(4)
    assert rb.get_last_n(3).numel() == 0
    print(" PASS test_rb_empty")


def test_rb_reset():
    rb = GPURingBuffer(4)
    for i in range(3):
        rb.append(i)
    rb.reset()
    assert rb.ptr == 0 and rb.size == 0
    assert rb.get_last_n(3).numel() == 0
    print(" PASS test_rb_reset")


def test_rb_multi_dim():
    rb = GPURingBuffer(3, dtype=torch.float32, dim=4)
    assert rb.buffer.shape == (3, 4)
    for i in range(3):
        rb.append(torch.ones(4) * i)
    last = rb.get_last_n(2)
    assert last.shape == (2, 4), f"shape {last.shape}"
    print(" PASS test_rb_multi_dim")


def test_rb_get_all():
    rb = GPURingBuffer(4)
    for i in range(6):
        rb.append(i)
    all_vals = rb.get_all()
    assert all_vals.tolist() == [2, 3, 4, 5], f"got {all_vals.tolist()}"
    print(" PASS test_rb_get_all")


def test_rb_partial():
    rb = GPURingBuffer(8)
    for i in range(5):
        rb.append(i)
    assert rb.get_all().tolist() == [0, 1, 2, 3, 4]
    assert rb.get_last_n(3).tolist() == [2, 3, 4]
    print(" PASS test_rb_partial")


def test_kv_ledger_basic():
    kv = KVLedger(256)
    for i in range(100):
        kv.append(i)
    assert len(kv) == 100
    assert kv.get_sliding_window(5).tolist() == [95, 96, 97, 98, 99]
    print(" PASS test_kv_ledger_basic")


def test_kv_ledger_sliding_window():
    kv = KVLedger(32)
    for i in range(32):
        kv.append(i)
    last_5 = kv.get_sliding_window(5)
    assert last_5.tolist() == [27, 28, 29, 30, 31], f"got {last_5.tolist()}"
    print(" PASS test_kv_ledger_sliding_window")


def test_kv_ledger_sparse():
    kv = KVLedger(16)
    for i in range(10):
        kv.append(i)
    sparse = kv.get_sparse(stride=3)
    assert len(sparse) == 4, f"len={len(sparse)}"
    print(" PASS test_kv_ledger_sparse")


def test_kv_ledger_reset():
    kv = KVLedger(8)
    for i in range(5):
        kv.append(i)
    kv.reset()
    assert len(kv) == 0
    print(" PASS test_kv_ledger_reset")


def test_cuda_device_move():
    if not torch.cuda.is_available():
        print(" SKIP test_cuda_device_move (no cuda)")
        return
    rb = GPURingBuffer(4)
    rb = rb.to("cuda")
    assert rb.buffer.device.type == "cuda"
    rb.append(42)
    assert rb.get_last_n(1).tolist() == [42]
    print(" PASS test_cuda_device_move")


if __name__ == "__main__":
    test_rb_append_wrap()
    test_rb_contiguous_no_wrap()
    test_rb_empty()
    test_rb_reset()
    test_rb_multi_dim()
    test_rb_get_all()
    test_rb_partial()
    test_kv_ledger_basic()
    test_kv_ledger_sliding_window()
    test_kv_ledger_sparse()
    test_kv_ledger_reset()
    test_cuda_device_move()
    print("\nAll ring buffer tests PASS")