ARBS / testing /attention /test_mla.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Unit tests for MultiHeadLatentAttention, ContextAttentionScheduler, RoPE utilities."""
import math
import torch
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from arbitor.attention.mla import (
MultiHeadLatentAttention, apply_rotary_emb, precompute_freqs_cis,
)
from arbitor.attention.context_attention import ContextAttentionScheduler
from arbitor.attention.kv_ledger import KVLedger
def _default_mla():
return MultiHeadLatentAttention(
dim=256, n_heads=4, kv_lora_rank=16,
qk_nope_head_dim=24, qk_rope_head_dim=8, v_head_dim=24,
)
def test_mla_construct():
mla = _default_mla()
assert mla.dim == 256
assert mla.n_heads == 4
assert mla.kv_lora_rank == 16
assert mla.qk_head_dim == 32
print(" PASS test_mla_construct")
def test_mla_shape():
mla = _default_mla()
x = torch.randn(1, 4, 256)
kv_cache = torch.randn(8, 16)
pe_cache = torch.randn(8, 8)
out = mla(x, kv_cache, pe_cache)
assert out.shape == (1, 4, 256), f"shape {out.shape}"
assert torch.isfinite(out).all()
print(" PASS test_mla_shape")
def _get_wkv_b_eff(mla, n_heads, kv_lora_rank):
"""Get effective wkv_b weight from TernaryScaleTensor."""
T = mla.wkv_b._get_T()
S = mla.wkv_b._get_S()
W = (T * S).view(n_heads, -1, kv_lora_rank)
return W
def test_mla_absorb_vs_naive():
for seed in [42, 123, 256]:
torch.manual_seed(seed)
dim = 128
n_heads = 2
kv_lora_rank = 8
qk_nope = 16
qk_rope = 8
v_dim = 16
mla = MultiHeadLatentAttention(
dim=dim, n_heads=n_heads, kv_lora_rank=kv_lora_rank,
qk_nope_head_dim=qk_nope, qk_rope_head_dim=qk_rope, v_head_dim=v_dim,
)
x = torch.randn(1, 4, dim)
kv_cache = torch.randn(8, kv_lora_rank)
pe_cache = torch.randn(8, qk_rope)
absorb_out = mla(x, kv_cache, pe_cache)
wkv_b = _get_wkv_b_eff(mla, n_heads, kv_lora_rank)
kv_nope = torch.einsum("hdc,tc->thd", wkv_b[:, :qk_nope], kv_cache)
kv_full_k = torch.cat([kv_nope, pe_cache.unsqueeze(1).expand(-1, n_heads, -1)], dim=-1)
kv_full_v = torch.einsum("hdc,tc->thd", wkv_b[:, -v_dim:], kv_cache)
naive = _naive_attention(mla, x, kv_full_k, kv_full_v, pe_cache)
diff = (absorb_out - naive).abs().max().item()
assert diff < 1e-4, f"seed={seed} diff={diff}"
print(" PASS test_mla_absorb_vs_naive")
def _naive_attention(mla, x, kv_full_k, kv_full_v, pe_cache):
bsz, seqlen, _ = x.shape
q = mla.wq(mla.wq_norm(x))
q = q.view(bsz, seqlen, mla.n_heads, mla.qk_head_dim)
scores = torch.einsum("bshd,thd->bsht", q, kv_full_k) * mla.softmax_scale
if seqlen > 1:
causal = torch.triu(
torch.full((seqlen, kv_full_k.shape[0]), float('-inf'), device=x.device),
diagonal=1
)
scores = scores + causal.unsqueeze(0).unsqueeze(2)
scores = scores.softmax(dim=-1, dtype=torch.float32)
attn = torch.einsum("bsht,thd->bshd", scores, kv_full_v)
attn = attn.flatten(2)
return mla.wo(attn)
def test_mla_gradient_flow():
mla = _default_mla()
x = torch.randn(1, 4, 256, requires_grad=True)
kv_cache = torch.randn(8, 16)
pe_cache = torch.randn(8, 8)
out = mla(x, kv_cache, pe_cache)
loss = out.sum()
loss.backward()
assert x.grad is not None, "input grad is None"
assert x.grad.abs().sum().item() > 0, "input grad is zero"
print(" PASS test_mla_gradient_flow")
def test_mla_causal_mask():
mla = _default_mla()
x = torch.randn(1, 8, 256)
kv_cache = torch.randn(12, 16)
pe_cache = torch.randn(12, 8)
out = mla(x, kv_cache, pe_cache, mask=None)
assert out.shape == (1, 8, 256)
mla2 = _default_mla()
out2 = mla2(x, kv_cache, pe_cache, mask=None)
assert torch.isfinite(out2).all()
print(" PASS test_mla_causal_mask")
def test_apply_rotary_emb():
x = torch.randn(1, 4, 2, 8)
freqs_cis = torch.polar(
torch.ones(4, 4),
torch.linspace(0, math.pi, 4 * 4).reshape(4, 4),
)
out = apply_rotary_emb(x, freqs_cis)
assert out.shape == (1, 4, 2, 8), f"shape {out.shape}"
assert not torch.allclose(out, x), "rotation did nothing"
print(" PASS test_apply_rotary_emb")
def test_precompute_freqs_cis():
freqs = precompute_freqs_cis(dim=32, end=100)
assert freqs.shape == (100, 16), f"shape {freqs.shape}"
assert torch.is_complex(freqs)
assert freqs.imag.abs().sum().item() > 0, "imag part is zero"
print(" PASS test_precompute_freqs_cis")
def test_context_scheduler():
scheduler = ContextAttentionScheduler(dim=256)
x = torch.randn(1, 4, 256)
ledger = KVLedger(32)
for i in range(20):
ledger.append(i)
out = scheduler(x, ledger)
assert out.shape == (1, 4, 256), f"shape {out.shape}"
assert torch.isfinite(out).all()
print(" PASS test_context_scheduler")
def test_context_scheduler_empty_ledger():
scheduler = ContextAttentionScheduler(dim=256)
x = torch.randn(1, 4, 256)
ledger = KVLedger(32)
out = scheduler(x, ledger)
assert out.shape == (1, 4, 256)
assert torch.isfinite(out).all()
print(" PASS test_context_scheduler_empty_ledger")
def test_context_scheduler_gate():
scheduler = ContextAttentionScheduler(dim=256)
x = torch.randn(1, 4, 256)
ledger = KVLedger(32)
for i in range(20):
ledger.append(i)
out = scheduler(x, ledger)
gate_val = torch.sigmoid(scheduler.gate(x.mean(dim=1, keepdim=True)))
assert gate_val.shape == (1, 1, 1)
assert 0 < gate_val.item() < 1
print(" PASS test_context_scheduler_gate")
def test_context_scheduler_hca_shape_mismatch_regression():
scheduler = ContextAttentionScheduler(dim=256)
x = torch.randn(1, 4, 256)
ledger = KVLedger(256)
for i in range(160):
ledger.append(i)
out = scheduler(x, ledger)
assert out.shape == (1, 4, 256), f"shape {out.shape}"
assert torch.isfinite(out).all()
print(" PASS test_context_scheduler_hca_shape_mismatch_regression")
if __name__ == "__main__":
test_mla_construct()
test_mla_shape()
test_mla_absorb_vs_naive()
test_mla_gradient_flow()
test_mla_causal_mask()
test_apply_rotary_emb()
test_precompute_freqs_cis()
test_context_scheduler()
test_context_scheduler_empty_ledger()
test_context_scheduler_gate()
test_context_scheduler_hca_shape_mismatch_regression()
print("\nAll MLA + scheduler tests PASS")