"""Unit tests for MultiHeadLatentAttention, ContextAttentionScheduler, RoPE utilities.""" import math import torch import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) from arbitor.attention.mla import ( MultiHeadLatentAttention, apply_rotary_emb, precompute_freqs_cis, ) from arbitor.attention.context_attention import ContextAttentionScheduler from arbitor.attention.kv_ledger import KVLedger def _default_mla(): return MultiHeadLatentAttention( dim=256, n_heads=4, kv_lora_rank=16, qk_nope_head_dim=24, qk_rope_head_dim=8, v_head_dim=24, ) def test_mla_construct(): mla = _default_mla() assert mla.dim == 256 assert mla.n_heads == 4 assert mla.kv_lora_rank == 16 assert mla.qk_head_dim == 32 print(" PASS test_mla_construct") def test_mla_shape(): mla = _default_mla() x = torch.randn(1, 4, 256) kv_cache = torch.randn(8, 16) pe_cache = torch.randn(8, 8) out = mla(x, kv_cache, pe_cache) assert out.shape == (1, 4, 256), f"shape {out.shape}" assert torch.isfinite(out).all() print(" PASS test_mla_shape") def _get_wkv_b_eff(mla, n_heads, kv_lora_rank): """Get effective wkv_b weight from TernaryScaleTensor.""" T = mla.wkv_b._get_T() S = mla.wkv_b._get_S() W = (T * S).view(n_heads, -1, kv_lora_rank) return W def test_mla_absorb_vs_naive(): for seed in [42, 123, 256]: torch.manual_seed(seed) dim = 128 n_heads = 2 kv_lora_rank = 8 qk_nope = 16 qk_rope = 8 v_dim = 16 mla = MultiHeadLatentAttention( dim=dim, n_heads=n_heads, kv_lora_rank=kv_lora_rank, qk_nope_head_dim=qk_nope, qk_rope_head_dim=qk_rope, v_head_dim=v_dim, ) x = torch.randn(1, 4, dim) kv_cache = torch.randn(8, kv_lora_rank) pe_cache = torch.randn(8, qk_rope) absorb_out = mla(x, kv_cache, pe_cache) wkv_b = _get_wkv_b_eff(mla, n_heads, kv_lora_rank) kv_nope = torch.einsum("hdc,tc->thd", wkv_b[:, :qk_nope], kv_cache) kv_full_k = torch.cat([kv_nope, pe_cache.unsqueeze(1).expand(-1, n_heads, -1)], dim=-1) kv_full_v = torch.einsum("hdc,tc->thd", wkv_b[:, -v_dim:], kv_cache) naive = _naive_attention(mla, x, kv_full_k, kv_full_v, pe_cache) diff = (absorb_out - naive).abs().max().item() assert diff < 1e-4, f"seed={seed} diff={diff}" print(" PASS test_mla_absorb_vs_naive") def _naive_attention(mla, x, kv_full_k, kv_full_v, pe_cache): bsz, seqlen, _ = x.shape q = mla.wq(mla.wq_norm(x)) q = q.view(bsz, seqlen, mla.n_heads, mla.qk_head_dim) scores = torch.einsum("bshd,thd->bsht", q, kv_full_k) * mla.softmax_scale if seqlen > 1: causal = torch.triu( torch.full((seqlen, kv_full_k.shape[0]), float('-inf'), device=x.device), diagonal=1 ) scores = scores + causal.unsqueeze(0).unsqueeze(2) scores = scores.softmax(dim=-1, dtype=torch.float32) attn = torch.einsum("bsht,thd->bshd", scores, kv_full_v) attn = attn.flatten(2) return mla.wo(attn) def test_mla_gradient_flow(): mla = _default_mla() x = torch.randn(1, 4, 256, requires_grad=True) kv_cache = torch.randn(8, 16) pe_cache = torch.randn(8, 8) out = mla(x, kv_cache, pe_cache) loss = out.sum() loss.backward() assert x.grad is not None, "input grad is None" assert x.grad.abs().sum().item() > 0, "input grad is zero" print(" PASS test_mla_gradient_flow") def test_mla_causal_mask(): mla = _default_mla() x = torch.randn(1, 8, 256) kv_cache = torch.randn(12, 16) pe_cache = torch.randn(12, 8) out = mla(x, kv_cache, pe_cache, mask=None) assert out.shape == (1, 8, 256) mla2 = _default_mla() out2 = mla2(x, kv_cache, pe_cache, mask=None) assert torch.isfinite(out2).all() print(" PASS test_mla_causal_mask") def test_apply_rotary_emb(): x = torch.randn(1, 4, 2, 8) freqs_cis = torch.polar( torch.ones(4, 4), torch.linspace(0, math.pi, 4 * 4).reshape(4, 4), ) out = apply_rotary_emb(x, freqs_cis) assert out.shape == (1, 4, 2, 8), f"shape {out.shape}" assert not torch.allclose(out, x), "rotation did nothing" print(" PASS test_apply_rotary_emb") def test_precompute_freqs_cis(): freqs = precompute_freqs_cis(dim=32, end=100) assert freqs.shape == (100, 16), f"shape {freqs.shape}" assert torch.is_complex(freqs) assert freqs.imag.abs().sum().item() > 0, "imag part is zero" print(" PASS test_precompute_freqs_cis") def test_context_scheduler(): scheduler = ContextAttentionScheduler(dim=256) x = torch.randn(1, 4, 256) ledger = KVLedger(32) for i in range(20): ledger.append(i) out = scheduler(x, ledger) assert out.shape == (1, 4, 256), f"shape {out.shape}" assert torch.isfinite(out).all() print(" PASS test_context_scheduler") def test_context_scheduler_empty_ledger(): scheduler = ContextAttentionScheduler(dim=256) x = torch.randn(1, 4, 256) ledger = KVLedger(32) out = scheduler(x, ledger) assert out.shape == (1, 4, 256) assert torch.isfinite(out).all() print(" PASS test_context_scheduler_empty_ledger") def test_context_scheduler_gate(): scheduler = ContextAttentionScheduler(dim=256) x = torch.randn(1, 4, 256) ledger = KVLedger(32) for i in range(20): ledger.append(i) out = scheduler(x, ledger) gate_val = torch.sigmoid(scheduler.gate(x.mean(dim=1, keepdim=True))) assert gate_val.shape == (1, 1, 1) assert 0 < gate_val.item() < 1 print(" PASS test_context_scheduler_gate") def test_context_scheduler_hca_shape_mismatch_regression(): scheduler = ContextAttentionScheduler(dim=256) x = torch.randn(1, 4, 256) ledger = KVLedger(256) for i in range(160): ledger.append(i) out = scheduler(x, ledger) assert out.shape == (1, 4, 256), f"shape {out.shape}" assert torch.isfinite(out).all() print(" PASS test_context_scheduler_hca_shape_mismatch_regression") if __name__ == "__main__": test_mla_construct() test_mla_shape() test_mla_absorb_vs_naive() test_mla_gradient_flow() test_mla_causal_mask() test_apply_rotary_emb() test_precompute_freqs_cis() test_context_scheduler() test_context_scheduler_empty_ledger() test_context_scheduler_gate() test_context_scheduler_hca_shape_mismatch_regression() print("\nAll MLA + scheduler tests PASS")