| """Unit tests for MultiHeadLatentAttention, ContextAttentionScheduler, RoPE utilities.""" |
| import math |
| import torch |
| import sys |
| import os |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
|
|
| from arbitor.attention.mla import ( |
| MultiHeadLatentAttention, apply_rotary_emb, precompute_freqs_cis, |
| ) |
| from arbitor.attention.context_attention import ContextAttentionScheduler |
| from arbitor.attention.kv_ledger import KVLedger |
|
|
|
|
| def _default_mla(): |
| return MultiHeadLatentAttention( |
| dim=256, n_heads=4, kv_lora_rank=16, |
| qk_nope_head_dim=24, qk_rope_head_dim=8, v_head_dim=24, |
| ) |
|
|
|
|
| def test_mla_construct(): |
| mla = _default_mla() |
| assert mla.dim == 256 |
| assert mla.n_heads == 4 |
| assert mla.kv_lora_rank == 16 |
| assert mla.qk_head_dim == 32 |
| print(" PASS test_mla_construct") |
|
|
|
|
| def test_mla_shape(): |
| mla = _default_mla() |
| x = torch.randn(1, 4, 256) |
| kv_cache = torch.randn(8, 16) |
| pe_cache = torch.randn(8, 8) |
| out = mla(x, kv_cache, pe_cache) |
| assert out.shape == (1, 4, 256), f"shape {out.shape}" |
| assert torch.isfinite(out).all() |
| print(" PASS test_mla_shape") |
|
|
|
|
| def _get_wkv_b_eff(mla, n_heads, kv_lora_rank): |
| """Get effective wkv_b weight from TernaryScaleTensor.""" |
| T = mla.wkv_b._get_T() |
| S = mla.wkv_b._get_S() |
| W = (T * S).view(n_heads, -1, kv_lora_rank) |
| return W |
|
|
|
|
| def test_mla_absorb_vs_naive(): |
| for seed in [42, 123, 256]: |
| torch.manual_seed(seed) |
| dim = 128 |
| n_heads = 2 |
| kv_lora_rank = 8 |
| qk_nope = 16 |
| qk_rope = 8 |
| v_dim = 16 |
|
|
| mla = MultiHeadLatentAttention( |
| dim=dim, n_heads=n_heads, kv_lora_rank=kv_lora_rank, |
| qk_nope_head_dim=qk_nope, qk_rope_head_dim=qk_rope, v_head_dim=v_dim, |
| ) |
|
|
| x = torch.randn(1, 4, dim) |
| kv_cache = torch.randn(8, kv_lora_rank) |
| pe_cache = torch.randn(8, qk_rope) |
| absorb_out = mla(x, kv_cache, pe_cache) |
|
|
| wkv_b = _get_wkv_b_eff(mla, n_heads, kv_lora_rank) |
| kv_nope = torch.einsum("hdc,tc->thd", wkv_b[:, :qk_nope], kv_cache) |
| kv_full_k = torch.cat([kv_nope, pe_cache.unsqueeze(1).expand(-1, n_heads, -1)], dim=-1) |
| kv_full_v = torch.einsum("hdc,tc->thd", wkv_b[:, -v_dim:], kv_cache) |
| naive = _naive_attention(mla, x, kv_full_k, kv_full_v, pe_cache) |
|
|
| diff = (absorb_out - naive).abs().max().item() |
| assert diff < 1e-4, f"seed={seed} diff={diff}" |
| print(" PASS test_mla_absorb_vs_naive") |
|
|
|
|
| def _naive_attention(mla, x, kv_full_k, kv_full_v, pe_cache): |
| bsz, seqlen, _ = x.shape |
| q = mla.wq(mla.wq_norm(x)) |
| q = q.view(bsz, seqlen, mla.n_heads, mla.qk_head_dim) |
|
|
| scores = torch.einsum("bshd,thd->bsht", q, kv_full_k) * mla.softmax_scale |
|
|
| if seqlen > 1: |
| causal = torch.triu( |
| torch.full((seqlen, kv_full_k.shape[0]), float('-inf'), device=x.device), |
| diagonal=1 |
| ) |
| scores = scores + causal.unsqueeze(0).unsqueeze(2) |
|
|
| scores = scores.softmax(dim=-1, dtype=torch.float32) |
|
|
| attn = torch.einsum("bsht,thd->bshd", scores, kv_full_v) |
| attn = attn.flatten(2) |
| return mla.wo(attn) |
|
|
|
|
| def test_mla_gradient_flow(): |
| mla = _default_mla() |
| x = torch.randn(1, 4, 256, requires_grad=True) |
| kv_cache = torch.randn(8, 16) |
| pe_cache = torch.randn(8, 8) |
|
|
| out = mla(x, kv_cache, pe_cache) |
| loss = out.sum() |
| loss.backward() |
|
|
| assert x.grad is not None, "input grad is None" |
| assert x.grad.abs().sum().item() > 0, "input grad is zero" |
| print(" PASS test_mla_gradient_flow") |
|
|
|
|
| def test_mla_causal_mask(): |
| mla = _default_mla() |
| x = torch.randn(1, 8, 256) |
| kv_cache = torch.randn(12, 16) |
| pe_cache = torch.randn(12, 8) |
|
|
| out = mla(x, kv_cache, pe_cache, mask=None) |
| assert out.shape == (1, 8, 256) |
|
|
| mla2 = _default_mla() |
| out2 = mla2(x, kv_cache, pe_cache, mask=None) |
| assert torch.isfinite(out2).all() |
| print(" PASS test_mla_causal_mask") |
|
|
|
|
| def test_apply_rotary_emb(): |
| x = torch.randn(1, 4, 2, 8) |
| freqs_cis = torch.polar( |
| torch.ones(4, 4), |
| torch.linspace(0, math.pi, 4 * 4).reshape(4, 4), |
| ) |
| out = apply_rotary_emb(x, freqs_cis) |
| assert out.shape == (1, 4, 2, 8), f"shape {out.shape}" |
| assert not torch.allclose(out, x), "rotation did nothing" |
| print(" PASS test_apply_rotary_emb") |
|
|
|
|
| def test_precompute_freqs_cis(): |
| freqs = precompute_freqs_cis(dim=32, end=100) |
| assert freqs.shape == (100, 16), f"shape {freqs.shape}" |
| assert torch.is_complex(freqs) |
| assert freqs.imag.abs().sum().item() > 0, "imag part is zero" |
| print(" PASS test_precompute_freqs_cis") |
|
|
|
|
| def test_context_scheduler(): |
| scheduler = ContextAttentionScheduler(dim=256) |
| x = torch.randn(1, 4, 256) |
|
|
| ledger = KVLedger(32) |
| for i in range(20): |
| ledger.append(i) |
|
|
| out = scheduler(x, ledger) |
| assert out.shape == (1, 4, 256), f"shape {out.shape}" |
| assert torch.isfinite(out).all() |
| print(" PASS test_context_scheduler") |
|
|
|
|
| def test_context_scheduler_empty_ledger(): |
| scheduler = ContextAttentionScheduler(dim=256) |
| x = torch.randn(1, 4, 256) |
| ledger = KVLedger(32) |
|
|
| out = scheduler(x, ledger) |
| assert out.shape == (1, 4, 256) |
| assert torch.isfinite(out).all() |
| print(" PASS test_context_scheduler_empty_ledger") |
|
|
|
|
| def test_context_scheduler_gate(): |
| scheduler = ContextAttentionScheduler(dim=256) |
| x = torch.randn(1, 4, 256) |
| ledger = KVLedger(32) |
| for i in range(20): |
| ledger.append(i) |
|
|
| out = scheduler(x, ledger) |
| gate_val = torch.sigmoid(scheduler.gate(x.mean(dim=1, keepdim=True))) |
| assert gate_val.shape == (1, 1, 1) |
| assert 0 < gate_val.item() < 1 |
| print(" PASS test_context_scheduler_gate") |
|
|
|
|
| def test_context_scheduler_hca_shape_mismatch_regression(): |
| scheduler = ContextAttentionScheduler(dim=256) |
| x = torch.randn(1, 4, 256) |
| ledger = KVLedger(256) |
| for i in range(160): |
| ledger.append(i) |
|
|
| out = scheduler(x, ledger) |
| assert out.shape == (1, 4, 256), f"shape {out.shape}" |
| assert torch.isfinite(out).all() |
| print(" PASS test_context_scheduler_hca_shape_mismatch_regression") |
|
|
|
|
| if __name__ == "__main__": |
| test_mla_construct() |
| test_mla_shape() |
| test_mla_absorb_vs_naive() |
| test_mla_gradient_flow() |
| test_mla_causal_mask() |
| test_apply_rotary_emb() |
| test_precompute_freqs_cis() |
| test_context_scheduler() |
| test_context_scheduler_empty_ledger() |
| test_context_scheduler_gate() |
| test_context_scheduler_hca_shape_mismatch_regression() |
| print("\nAll MLA + scheduler tests PASS") |
|
|