"""Unit tests for ternary KV cache storage and budget verification.""" import torch import math import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) from arbitor.converters.convert_to_ternary8 import pack_ternary, unpack_ternary from arbitor.config import ( KV_LEDGER_SIZE, SLIDING_WINDOW_SIZE, KQ_CACHE_SIZE, MLA_SLIDE_DIM, MLA_FULL_DIM, MLA_N_LAYERS, MLA_QK_ROPE_HEAD_DIM, MLA_CSA_DIM, MLA_HCA_DIM, ) def test_ternary_pack_roundtrip(): x = torch.randn(16, 64) T = x.sign() * (x.abs() > 0.05).to(x.dtype) packed, shape, pad = pack_ternary(T) unpacked = unpack_ternary(packed, shape, pad) assert unpacked.shape == T.shape, f"shape {unpacked.shape} vs {T.shape}" assert (unpacked == T).all(), "ternary signs differ after roundtrip" print(" PASS test_ternary_pack_roundtrip") def test_kv_cache_sliding_window_budget(): csa_mb = SLIDING_WINDOW_SIZE * MLA_N_LAYERS * MLA_CSA_DIM / (1024 * 1024) print(f" CSA cache (d={MLA_CSA_DIM}): {csa_mb:.1f} MB") print(" PASS test_kv_cache_sliding_window_budget") def test_kv_cache_full_context_budget(): full_bytes = KV_LEDGER_SIZE * MLA_FULL_DIM full_mb = full_bytes / (1024 * 1024) sliding_mb = (SLIDING_WINDOW_SIZE * MLA_SLIDE_DIM) / (1024 * 1024) total_mb = sliding_mb + full_mb hca_pos = KV_LEDGER_SIZE // 32 # HCA stride hca_mb = hca_pos * 4 * MLA_HCA_DIM / (1024 * 1024) csa_mb = SLIDING_WINDOW_SIZE * 4 * MLA_CSA_DIM / (1024 * 1024) total_mb = csa_mb + hca_mb assert total_mb < 100, f"Total {total_mb:.1f} MB exceeds 100 MB budget" print(f" CSA: {csa_mb:.1f} MB + HCA: {hca_mb:.1f} MB = {total_mb:.1f} MB") print(" PASS test_kv_cache_full_context_budget") if __name__ == "__main__": test_ternary_pack_roundtrip() test_kv_cache_sliding_window_budget() test_kv_cache_full_context_budget() print("\nAll KV cache tests PASS")