"""Multi-Head Latent Attention with CSA + HCA compression (DeepSeek V4 style). Ternary-weighted. KV cache stores compressed latent at multiple levels: - Base: MLA latent (d=kv_lora_rank, typically 64/32) - CSA: Secondary compression (d_csa, e.g. 16) — 4x compression on cache - HCA: Heavily compressed (d_hca, e.g. 8) — 8x compression, wider stride Scores = q_nope_absorbed @ decompress(kv_cache) + q_pe @ pe_cache """ import torch import torch.nn as nn import torch.nn.functional as F from ..config import HIDDEN_DIM, MLA_CSA_DIM, MLA_HCA_DIM, MLA_HCA_STRIDE, MLA_N_LAYERS from ..kernel.ternary_scale import TernaryScaleTensor, TernaryRMSNorm, TScaleType MLA_N_HEADS = 32 MLA_QK_NOPE_HEAD_DIM = 96 MLA_QK_ROPE_HEAD_DIM = 32 MLA_V_HEAD_DIM = 96 MLA_ROPE_THETA = 10000.0 MLA_SLIDE_DIM = 64 MLA_FULL_DIM = 32 def apply_rotary_emb(x, freqs_cis): x_complex = torch.view_as_complex( x.float().reshape(*x.shape[:-1], -1, 2) ) freqs = freqs_cis.unsqueeze(1).unsqueeze(0) return torch.view_as_real(x_complex * freqs).flatten(-2).to(x.dtype) def precompute_freqs_cis(dim, end, theta=MLA_ROPE_THETA): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) freqs = torch.outer(t, freqs) return torch.polar(torch.ones_like(freqs), freqs) class MultiHeadLatentAttention(nn.Module): def __init__(self, dim=HIDDEN_DIM, n_heads=MLA_N_HEADS, kv_lora_rank=MLA_SLIDE_DIM, qk_nope_head_dim=MLA_QK_NOPE_HEAD_DIM, qk_rope_head_dim=MLA_QK_ROPE_HEAD_DIM, v_head_dim=MLA_V_HEAD_DIM, max_seq_len=65536, csa_dim=MLA_CSA_DIM, hca_dim=MLA_HCA_DIM, tscale_type=TScaleType.T32): super().__init__() self.dim = dim self.n_heads = n_heads self.kv_lora_rank = kv_lora_rank self.qk_nope_head_dim = qk_nope_head_dim self.qk_rope_head_dim = qk_rope_head_dim self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim self.v_head_dim = v_head_dim self.softmax_scale = self.qk_head_dim ** -0.5 self.max_seq_len = max_seq_len self.csa_dim = csa_dim self.hca_dim = hca_dim self.wq_norm = TernaryRMSNorm(dim, tscale_type=tscale_type) self.wq = TernaryScaleTensor(dim, n_heads * self.qk_head_dim, tscale_type=tscale_type) combined_out = n_heads * (qk_nope_head_dim + v_head_dim) self.wkv_b = TernaryScaleTensor(kv_lora_rank, combined_out, tscale_type=tscale_type) self.wo = TernaryScaleTensor(n_heads * v_head_dim, dim, tscale_type=tscale_type) # CSA: secondary compression (kv_lora_rank -> csa_dim) if csa_dim and csa_dim < kv_lora_rank: self.csa_compress = TernaryScaleTensor(kv_lora_rank, csa_dim, tscale_type=tscale_type) self.csa_decompress = TernaryScaleTensor(csa_dim, kv_lora_rank, tscale_type=tscale_type) else: self.csa_compress = None self.csa_decompress = None # HCA: heavily compressed (kv_lora_rank -> hca_dim) if hca_dim and hca_dim < (csa_dim or kv_lora_rank): self.hca_compress = TernaryScaleTensor(kv_lora_rank, hca_dim, tscale_type=tscale_type) self.hca_decompress = TernaryScaleTensor(hca_dim, kv_lora_rank, tscale_type=tscale_type) else: self.hca_compress = None self.hca_decompress = None def _compress(self, kv_cache, compress_proj): """Compress kv_cache from kv_lora_rank to smaller dim.""" return compress_proj(kv_cache) def _decompress(self, cache, decompress_proj): """Decompress cache back to kv_lora_rank.""" return decompress_proj(cache) def _compute_scores(self, q_nope_absorbed, q_pe, kv_flat, pe_flat, start_pos, seqlen, mask): """Shared score computation for base, CSA, and HCA attention.""" n_keys = min(kv_flat.shape[0], pe_flat.shape[0]) kv_flat = kv_flat[:n_keys] pe_flat = pe_flat[:n_keys] if n_keys == 0: return q_pe.new_zeros(q_pe.shape[0], seqlen, q_pe.shape[2], 0) scores = ( torch.einsum("bshc,btc->bsht", q_nope_absorbed, kv_flat.unsqueeze(0)) + torch.einsum("bshr,btr->bsht", q_pe, pe_flat.unsqueeze(0)) ) * self.softmax_scale if mask is not None: scores = scores + mask.unsqueeze(0).unsqueeze(0) if mask is None and seqlen > 1: causal = torch.triu( torch.full((seqlen, n_keys), float('-inf'), device=q_pe.device), diagonal=1 + start_pos ) scores = scores + causal.unsqueeze(0).unsqueeze(2) return scores def forward(self, x, kv_cache, pe_cache, start_pos=0, freqs_cis=None, mask=None, csa_cache=None, hca_cache=None, hca_pe_cache=None): bsz, seqlen, _ = x.size() q = self.wq(self.wq_norm(x)) q = q.view(bsz, seqlen, self.n_heads, self.qk_head_dim) q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) if freqs_cis is not None: q_pe = apply_rotary_emb(q_pe, freqs_cis[start_pos:start_pos + seqlen]) wkv_b = self.wkv_b._get_T() * self.wkv_b._get_S() wkv_b = wkv_b.view(self.n_heads, -1, self.kv_lora_rank) q_nope_absorbed = torch.einsum( "bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim]) n_cache = min(kv_cache.shape[0], pe_cache.shape[0]) kv_flat = kv_cache[:n_cache] pe_flat = pe_cache[:n_cache] # Decompress CSA cache if provided (replaces base kv_cache) if csa_cache is not None and self.csa_decompress is not None: n_csa = min(csa_cache.shape[0], pe_flat.shape[0]) kv_flat = self._decompress(csa_cache[:n_csa], self.csa_decompress) pe_flat = pe_flat[:n_csa] # Base attention (exact, CSA-compressed if applicable) scores = self._compute_scores( q_nope_absorbed, q_pe, kv_flat, pe_flat, start_pos, seqlen, mask, ) scores = scores.softmax(dim=-1, dtype=torch.float32) attn_out = torch.einsum( "bsht,btc->bshc", scores, kv_flat.unsqueeze(0)) # HCA long-range attention (heavily compressed, strided) hca_out = None if hca_cache is not None and self.hca_decompress is not None: hca_kv = self._decompress(hca_cache, self.hca_decompress) if hca_pe_cache is None: hca_pe = pe_cache[::MLA_HCA_STRIDE] else: hca_pe = hca_pe_cache n_hca = min(hca_kv.shape[0], hca_pe.shape[0]) hca_kv = hca_kv[:n_hca] hca_pe = hca_pe[:n_hca] hca_scores = self._compute_scores( q_nope_absorbed, q_pe, hca_kv, hca_pe, start_pos, seqlen, mask=None, ) hca_scores = hca_scores.softmax(dim=-1, dtype=torch.float32) hca_out = torch.einsum( "bsht,btc->bshc", hca_scores, hca_kv.unsqueeze(0)) attn_out = attn_out + hca_out attn_unproj = torch.einsum( "bshc,hdc->bshd", attn_out, wkv_b[:, -self.v_head_dim:]) return self.wo(attn_unproj.flatten(2))