ARBS / arbitor /attention /mla.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Multi-Head Latent Attention with CSA + HCA compression (DeepSeek V4 style).
Ternary-weighted. KV cache stores compressed latent at multiple levels:
- Base: MLA latent (d=kv_lora_rank, typically 64/32)
- CSA: Secondary compression (d_csa, e.g. 16) — 4x compression on cache
- HCA: Heavily compressed (d_hca, e.g. 8) — 8x compression, wider stride
Scores = q_nope_absorbed @ decompress(kv_cache) + q_pe @ pe_cache
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..config import HIDDEN_DIM, MLA_CSA_DIM, MLA_HCA_DIM, MLA_HCA_STRIDE, MLA_N_LAYERS
from ..kernel.ternary_scale import TernaryScaleTensor, TernaryRMSNorm, TScaleType
MLA_N_HEADS = 32
MLA_QK_NOPE_HEAD_DIM = 96
MLA_QK_ROPE_HEAD_DIM = 32
MLA_V_HEAD_DIM = 96
MLA_ROPE_THETA = 10000.0
MLA_SLIDE_DIM = 64
MLA_FULL_DIM = 32
def apply_rotary_emb(x, freqs_cis):
x_complex = torch.view_as_complex(
x.float().reshape(*x.shape[:-1], -1, 2)
)
freqs = freqs_cis.unsqueeze(1).unsqueeze(0)
return torch.view_as_real(x_complex * freqs).flatten(-2).to(x.dtype)
def precompute_freqs_cis(dim, end, theta=MLA_ROPE_THETA):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs)
return torch.polar(torch.ones_like(freqs), freqs)
class MultiHeadLatentAttention(nn.Module):
def __init__(self, dim=HIDDEN_DIM, n_heads=MLA_N_HEADS, kv_lora_rank=MLA_SLIDE_DIM,
qk_nope_head_dim=MLA_QK_NOPE_HEAD_DIM, qk_rope_head_dim=MLA_QK_ROPE_HEAD_DIM,
v_head_dim=MLA_V_HEAD_DIM, max_seq_len=65536,
csa_dim=MLA_CSA_DIM, hca_dim=MLA_HCA_DIM,
tscale_type=TScaleType.T32):
super().__init__()
self.dim = dim
self.n_heads = n_heads
self.kv_lora_rank = kv_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.softmax_scale = self.qk_head_dim ** -0.5
self.max_seq_len = max_seq_len
self.csa_dim = csa_dim
self.hca_dim = hca_dim
self.wq_norm = TernaryRMSNorm(dim, tscale_type=tscale_type)
self.wq = TernaryScaleTensor(dim, n_heads * self.qk_head_dim, tscale_type=tscale_type)
combined_out = n_heads * (qk_nope_head_dim + v_head_dim)
self.wkv_b = TernaryScaleTensor(kv_lora_rank, combined_out, tscale_type=tscale_type)
self.wo = TernaryScaleTensor(n_heads * v_head_dim, dim, tscale_type=tscale_type)
# CSA: secondary compression (kv_lora_rank -> csa_dim)
if csa_dim and csa_dim < kv_lora_rank:
self.csa_compress = TernaryScaleTensor(kv_lora_rank, csa_dim, tscale_type=tscale_type)
self.csa_decompress = TernaryScaleTensor(csa_dim, kv_lora_rank, tscale_type=tscale_type)
else:
self.csa_compress = None
self.csa_decompress = None
# HCA: heavily compressed (kv_lora_rank -> hca_dim)
if hca_dim and hca_dim < (csa_dim or kv_lora_rank):
self.hca_compress = TernaryScaleTensor(kv_lora_rank, hca_dim, tscale_type=tscale_type)
self.hca_decompress = TernaryScaleTensor(hca_dim, kv_lora_rank, tscale_type=tscale_type)
else:
self.hca_compress = None
self.hca_decompress = None
def _compress(self, kv_cache, compress_proj):
"""Compress kv_cache from kv_lora_rank to smaller dim."""
return compress_proj(kv_cache)
def _decompress(self, cache, decompress_proj):
"""Decompress cache back to kv_lora_rank."""
return decompress_proj(cache)
def _compute_scores(self, q_nope_absorbed, q_pe, kv_flat, pe_flat,
start_pos, seqlen, mask):
"""Shared score computation for base, CSA, and HCA attention."""
n_keys = min(kv_flat.shape[0], pe_flat.shape[0])
kv_flat = kv_flat[:n_keys]
pe_flat = pe_flat[:n_keys]
if n_keys == 0:
return q_pe.new_zeros(q_pe.shape[0], seqlen, q_pe.shape[2], 0)
scores = (
torch.einsum("bshc,btc->bsht",
q_nope_absorbed, kv_flat.unsqueeze(0))
+ torch.einsum("bshr,btr->bsht",
q_pe, pe_flat.unsqueeze(0))
) * self.softmax_scale
if mask is not None:
scores = scores + mask.unsqueeze(0).unsqueeze(0)
if mask is None and seqlen > 1:
causal = torch.triu(
torch.full((seqlen, n_keys), float('-inf'), device=q_pe.device),
diagonal=1 + start_pos
)
scores = scores + causal.unsqueeze(0).unsqueeze(2)
return scores
def forward(self, x, kv_cache, pe_cache, start_pos=0, freqs_cis=None, mask=None,
csa_cache=None, hca_cache=None, hca_pe_cache=None):
bsz, seqlen, _ = x.size()
q = self.wq(self.wq_norm(x))
q = q.view(bsz, seqlen, self.n_heads, self.qk_head_dim)
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
if freqs_cis is not None:
q_pe = apply_rotary_emb(q_pe, freqs_cis[start_pos:start_pos + seqlen])
wkv_b = self.wkv_b._get_T() * self.wkv_b._get_S()
wkv_b = wkv_b.view(self.n_heads, -1, self.kv_lora_rank)
q_nope_absorbed = torch.einsum(
"bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
n_cache = min(kv_cache.shape[0], pe_cache.shape[0])
kv_flat = kv_cache[:n_cache]
pe_flat = pe_cache[:n_cache]
# Decompress CSA cache if provided (replaces base kv_cache)
if csa_cache is not None and self.csa_decompress is not None:
n_csa = min(csa_cache.shape[0], pe_flat.shape[0])
kv_flat = self._decompress(csa_cache[:n_csa], self.csa_decompress)
pe_flat = pe_flat[:n_csa]
# Base attention (exact, CSA-compressed if applicable)
scores = self._compute_scores(
q_nope_absorbed, q_pe, kv_flat, pe_flat,
start_pos, seqlen, mask,
)
scores = scores.softmax(dim=-1, dtype=torch.float32)
attn_out = torch.einsum(
"bsht,btc->bshc", scores, kv_flat.unsqueeze(0))
# HCA long-range attention (heavily compressed, strided)
hca_out = None
if hca_cache is not None and self.hca_decompress is not None:
hca_kv = self._decompress(hca_cache, self.hca_decompress)
if hca_pe_cache is None:
hca_pe = pe_cache[::MLA_HCA_STRIDE]
else:
hca_pe = hca_pe_cache
n_hca = min(hca_kv.shape[0], hca_pe.shape[0])
hca_kv = hca_kv[:n_hca]
hca_pe = hca_pe[:n_hca]
hca_scores = self._compute_scores(
q_nope_absorbed, q_pe, hca_kv, hca_pe,
start_pos, seqlen, mask=None,
)
hca_scores = hca_scores.softmax(dim=-1, dtype=torch.float32)
hca_out = torch.einsum(
"bsht,btc->bshc", hca_scores, hca_kv.unsqueeze(0))
attn_out = attn_out + hca_out
attn_unproj = torch.einsum(
"bshc,hdc->bshd", attn_out, wkv_b[:, -self.v_head_dim:])
return self.wo(attn_unproj.flatten(2))