ARBS / testing /attention /test_kv_cache.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Unit tests for ternary KV cache storage and budget verification."""
import torch
import math
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from arbitor.converters.convert_to_ternary8 import pack_ternary, unpack_ternary
from arbitor.config import (
KV_LEDGER_SIZE, SLIDING_WINDOW_SIZE, KQ_CACHE_SIZE,
MLA_SLIDE_DIM, MLA_FULL_DIM, MLA_N_LAYERS, MLA_QK_ROPE_HEAD_DIM,
MLA_CSA_DIM, MLA_HCA_DIM,
)
def test_ternary_pack_roundtrip():
x = torch.randn(16, 64)
T = x.sign() * (x.abs() > 0.05).to(x.dtype)
packed, shape, pad = pack_ternary(T)
unpacked = unpack_ternary(packed, shape, pad)
assert unpacked.shape == T.shape, f"shape {unpacked.shape} vs {T.shape}"
assert (unpacked == T).all(), "ternary signs differ after roundtrip"
print(" PASS test_ternary_pack_roundtrip")
def test_kv_cache_sliding_window_budget():
csa_mb = SLIDING_WINDOW_SIZE * MLA_N_LAYERS * MLA_CSA_DIM / (1024 * 1024)
print(f" CSA cache (d={MLA_CSA_DIM}): {csa_mb:.1f} MB")
print(" PASS test_kv_cache_sliding_window_budget")
def test_kv_cache_full_context_budget():
full_bytes = KV_LEDGER_SIZE * MLA_FULL_DIM
full_mb = full_bytes / (1024 * 1024)
sliding_mb = (SLIDING_WINDOW_SIZE * MLA_SLIDE_DIM) / (1024 * 1024)
total_mb = sliding_mb + full_mb
hca_pos = KV_LEDGER_SIZE // 32 # HCA stride
hca_mb = hca_pos * 4 * MLA_HCA_DIM / (1024 * 1024)
csa_mb = SLIDING_WINDOW_SIZE * 4 * MLA_CSA_DIM / (1024 * 1024)
total_mb = csa_mb + hca_mb
assert total_mb < 100, f"Total {total_mb:.1f} MB exceeds 100 MB budget"
print(f" CSA: {csa_mb:.1f} MB + HCA: {hca_mb:.1f} MB = {total_mb:.1f} MB")
print(" PASS test_kv_cache_full_context_budget")
if __name__ == "__main__":
test_ternary_pack_roundtrip()
test_kv_cache_sliding_window_budget()
test_kv_cache_full_context_budget()
print("\nAll KV cache tests PASS")