File size: 1,932 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""Unit tests for ternary KV cache storage and budget verification."""
import torch
import math
import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))

from arbitor.converters.convert_to_ternary8 import pack_ternary, unpack_ternary
from arbitor.config import (
    KV_LEDGER_SIZE, SLIDING_WINDOW_SIZE, KQ_CACHE_SIZE,
    MLA_SLIDE_DIM, MLA_FULL_DIM, MLA_N_LAYERS, MLA_QK_ROPE_HEAD_DIM,
    MLA_CSA_DIM, MLA_HCA_DIM,
)


def test_ternary_pack_roundtrip():
    x = torch.randn(16, 64)
    T = x.sign() * (x.abs() > 0.05).to(x.dtype)
    packed, shape, pad = pack_ternary(T)
    unpacked = unpack_ternary(packed, shape, pad)
    assert unpacked.shape == T.shape, f"shape {unpacked.shape} vs {T.shape}"
    assert (unpacked == T).all(), "ternary signs differ after roundtrip"
    print(" PASS test_ternary_pack_roundtrip")


def test_kv_cache_sliding_window_budget():
    csa_mb = SLIDING_WINDOW_SIZE * MLA_N_LAYERS * MLA_CSA_DIM / (1024 * 1024)
    print(f"  CSA cache (d={MLA_CSA_DIM}): {csa_mb:.1f} MB")
    print(" PASS test_kv_cache_sliding_window_budget")


def test_kv_cache_full_context_budget():
    full_bytes = KV_LEDGER_SIZE * MLA_FULL_DIM
    full_mb = full_bytes / (1024 * 1024)
    sliding_mb = (SLIDING_WINDOW_SIZE * MLA_SLIDE_DIM) / (1024 * 1024)
    total_mb = sliding_mb + full_mb

    hca_pos = KV_LEDGER_SIZE // 32  # HCA stride
    hca_mb = hca_pos * 4 * MLA_HCA_DIM / (1024 * 1024)
    csa_mb = SLIDING_WINDOW_SIZE * 4 * MLA_CSA_DIM / (1024 * 1024)
    total_mb = csa_mb + hca_mb

    assert total_mb < 100, f"Total {total_mb:.1f} MB exceeds 100 MB budget"
    print(f"  CSA: {csa_mb:.1f} MB + HCA: {hca_mb:.1f} MB = {total_mb:.1f} MB")
    print(" PASS test_kv_cache_full_context_budget")


if __name__ == "__main__":
    test_ternary_pack_roundtrip()
    test_kv_cache_sliding_window_budget()
    test_kv_cache_full_context_budget()
    print("\nAll KV cache tests PASS")