| """Unit tests for ternary KV cache storage and budget verification.""" |
| import torch |
| import math |
| import sys |
| import os |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
|
|
| from arbitor.converters.convert_to_ternary8 import pack_ternary, unpack_ternary |
| from arbitor.config import ( |
| KV_LEDGER_SIZE, SLIDING_WINDOW_SIZE, KQ_CACHE_SIZE, |
| MLA_SLIDE_DIM, MLA_FULL_DIM, MLA_N_LAYERS, MLA_QK_ROPE_HEAD_DIM, |
| MLA_CSA_DIM, MLA_HCA_DIM, |
| ) |
|
|
|
|
| def test_ternary_pack_roundtrip(): |
| x = torch.randn(16, 64) |
| T = x.sign() * (x.abs() > 0.05).to(x.dtype) |
| packed, shape, pad = pack_ternary(T) |
| unpacked = unpack_ternary(packed, shape, pad) |
| assert unpacked.shape == T.shape, f"shape {unpacked.shape} vs {T.shape}" |
| assert (unpacked == T).all(), "ternary signs differ after roundtrip" |
| print(" PASS test_ternary_pack_roundtrip") |
|
|
|
|
| def test_kv_cache_sliding_window_budget(): |
| csa_mb = SLIDING_WINDOW_SIZE * MLA_N_LAYERS * MLA_CSA_DIM / (1024 * 1024) |
| print(f" CSA cache (d={MLA_CSA_DIM}): {csa_mb:.1f} MB") |
| print(" PASS test_kv_cache_sliding_window_budget") |
|
|
|
|
| def test_kv_cache_full_context_budget(): |
| full_bytes = KV_LEDGER_SIZE * MLA_FULL_DIM |
| full_mb = full_bytes / (1024 * 1024) |
| sliding_mb = (SLIDING_WINDOW_SIZE * MLA_SLIDE_DIM) / (1024 * 1024) |
| total_mb = sliding_mb + full_mb |
|
|
| hca_pos = KV_LEDGER_SIZE // 32 |
| hca_mb = hca_pos * 4 * MLA_HCA_DIM / (1024 * 1024) |
| csa_mb = SLIDING_WINDOW_SIZE * 4 * MLA_CSA_DIM / (1024 * 1024) |
| total_mb = csa_mb + hca_mb |
|
|
| assert total_mb < 100, f"Total {total_mb:.1f} MB exceeds 100 MB budget" |
| print(f" CSA: {csa_mb:.1f} MB + HCA: {hca_mb:.1f} MB = {total_mb:.1f} MB") |
| print(" PASS test_kv_cache_full_context_budget") |
|
|
|
|
| if __name__ == "__main__": |
| test_ternary_pack_roundtrip() |
| test_kv_cache_sliding_window_budget() |
| test_kv_cache_full_context_budget() |
| print("\nAll KV cache tests PASS") |
|
|