| """Multi-Head Latent Attention with CSA + HCA compression (DeepSeek V4 style). |
| |
| Ternary-weighted. KV cache stores compressed latent at multiple levels: |
| - Base: MLA latent (d=kv_lora_rank, typically 64/32) |
| - CSA: Secondary compression (d_csa, e.g. 16) — 4x compression on cache |
| - HCA: Heavily compressed (d_hca, e.g. 8) — 8x compression, wider stride |
| |
| Scores = q_nope_absorbed @ decompress(kv_cache) + q_pe @ pe_cache |
| """ |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from ..config import HIDDEN_DIM, MLA_CSA_DIM, MLA_HCA_DIM, MLA_HCA_STRIDE, MLA_N_LAYERS |
| from ..kernel.ternary_scale import TernaryScaleTensor, TernaryRMSNorm, TScaleType |
|
|
| MLA_N_HEADS = 32 |
| MLA_QK_NOPE_HEAD_DIM = 96 |
| MLA_QK_ROPE_HEAD_DIM = 32 |
| MLA_V_HEAD_DIM = 96 |
| MLA_ROPE_THETA = 10000.0 |
| MLA_SLIDE_DIM = 64 |
| MLA_FULL_DIM = 32 |
|
|
|
|
| def apply_rotary_emb(x, freqs_cis): |
| x_complex = torch.view_as_complex( |
| x.float().reshape(*x.shape[:-1], -1, 2) |
| ) |
| freqs = freqs_cis.unsqueeze(1).unsqueeze(0) |
| return torch.view_as_real(x_complex * freqs).flatten(-2).to(x.dtype) |
|
|
|
|
| def precompute_freqs_cis(dim, end, theta=MLA_ROPE_THETA): |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) |
| t = torch.arange(end, device=freqs.device) |
| freqs = torch.outer(t, freqs) |
| return torch.polar(torch.ones_like(freqs), freqs) |
|
|
|
|
| class MultiHeadLatentAttention(nn.Module): |
| def __init__(self, dim=HIDDEN_DIM, n_heads=MLA_N_HEADS, kv_lora_rank=MLA_SLIDE_DIM, |
| qk_nope_head_dim=MLA_QK_NOPE_HEAD_DIM, qk_rope_head_dim=MLA_QK_ROPE_HEAD_DIM, |
| v_head_dim=MLA_V_HEAD_DIM, max_seq_len=65536, |
| csa_dim=MLA_CSA_DIM, hca_dim=MLA_HCA_DIM, |
| tscale_type=TScaleType.T32): |
| super().__init__() |
| self.dim = dim |
| self.n_heads = n_heads |
| self.kv_lora_rank = kv_lora_rank |
| self.qk_nope_head_dim = qk_nope_head_dim |
| self.qk_rope_head_dim = qk_rope_head_dim |
| self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim |
| self.v_head_dim = v_head_dim |
| self.softmax_scale = self.qk_head_dim ** -0.5 |
| self.max_seq_len = max_seq_len |
| self.csa_dim = csa_dim |
| self.hca_dim = hca_dim |
|
|
| self.wq_norm = TernaryRMSNorm(dim, tscale_type=tscale_type) |
| self.wq = TernaryScaleTensor(dim, n_heads * self.qk_head_dim, tscale_type=tscale_type) |
|
|
| combined_out = n_heads * (qk_nope_head_dim + v_head_dim) |
| self.wkv_b = TernaryScaleTensor(kv_lora_rank, combined_out, tscale_type=tscale_type) |
| self.wo = TernaryScaleTensor(n_heads * v_head_dim, dim, tscale_type=tscale_type) |
|
|
| |
| if csa_dim and csa_dim < kv_lora_rank: |
| self.csa_compress = TernaryScaleTensor(kv_lora_rank, csa_dim, tscale_type=tscale_type) |
| self.csa_decompress = TernaryScaleTensor(csa_dim, kv_lora_rank, tscale_type=tscale_type) |
| else: |
| self.csa_compress = None |
| self.csa_decompress = None |
|
|
| |
| if hca_dim and hca_dim < (csa_dim or kv_lora_rank): |
| self.hca_compress = TernaryScaleTensor(kv_lora_rank, hca_dim, tscale_type=tscale_type) |
| self.hca_decompress = TernaryScaleTensor(hca_dim, kv_lora_rank, tscale_type=tscale_type) |
| else: |
| self.hca_compress = None |
| self.hca_decompress = None |
|
|
| def _compress(self, kv_cache, compress_proj): |
| """Compress kv_cache from kv_lora_rank to smaller dim.""" |
| return compress_proj(kv_cache) |
|
|
| def _decompress(self, cache, decompress_proj): |
| """Decompress cache back to kv_lora_rank.""" |
| return decompress_proj(cache) |
|
|
| def _compute_scores(self, q_nope_absorbed, q_pe, kv_flat, pe_flat, |
| start_pos, seqlen, mask): |
| """Shared score computation for base, CSA, and HCA attention.""" |
| n_keys = min(kv_flat.shape[0], pe_flat.shape[0]) |
| kv_flat = kv_flat[:n_keys] |
| pe_flat = pe_flat[:n_keys] |
| if n_keys == 0: |
| return q_pe.new_zeros(q_pe.shape[0], seqlen, q_pe.shape[2], 0) |
| scores = ( |
| torch.einsum("bshc,btc->bsht", |
| q_nope_absorbed, kv_flat.unsqueeze(0)) |
| + torch.einsum("bshr,btr->bsht", |
| q_pe, pe_flat.unsqueeze(0)) |
| ) * self.softmax_scale |
|
|
| if mask is not None: |
| scores = scores + mask.unsqueeze(0).unsqueeze(0) |
| if mask is None and seqlen > 1: |
| causal = torch.triu( |
| torch.full((seqlen, n_keys), float('-inf'), device=q_pe.device), |
| diagonal=1 + start_pos |
| ) |
| scores = scores + causal.unsqueeze(0).unsqueeze(2) |
| return scores |
|
|
| def forward(self, x, kv_cache, pe_cache, start_pos=0, freqs_cis=None, mask=None, |
| csa_cache=None, hca_cache=None, hca_pe_cache=None): |
| bsz, seqlen, _ = x.size() |
|
|
| q = self.wq(self.wq_norm(x)) |
| q = q.view(bsz, seqlen, self.n_heads, self.qk_head_dim) |
| q_nope, q_pe = torch.split( |
| q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) |
|
|
| if freqs_cis is not None: |
| q_pe = apply_rotary_emb(q_pe, freqs_cis[start_pos:start_pos + seqlen]) |
|
|
| wkv_b = self.wkv_b._get_T() * self.wkv_b._get_S() |
| wkv_b = wkv_b.view(self.n_heads, -1, self.kv_lora_rank) |
|
|
| q_nope_absorbed = torch.einsum( |
| "bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim]) |
|
|
| n_cache = min(kv_cache.shape[0], pe_cache.shape[0]) |
| kv_flat = kv_cache[:n_cache] |
| pe_flat = pe_cache[:n_cache] |
| |
| |
| if csa_cache is not None and self.csa_decompress is not None: |
| n_csa = min(csa_cache.shape[0], pe_flat.shape[0]) |
| kv_flat = self._decompress(csa_cache[:n_csa], self.csa_decompress) |
| pe_flat = pe_flat[:n_csa] |
|
|
| |
| scores = self._compute_scores( |
| q_nope_absorbed, q_pe, kv_flat, pe_flat, |
| start_pos, seqlen, mask, |
| ) |
| scores = scores.softmax(dim=-1, dtype=torch.float32) |
|
|
| attn_out = torch.einsum( |
| "bsht,btc->bshc", scores, kv_flat.unsqueeze(0)) |
|
|
| |
| hca_out = None |
| if hca_cache is not None and self.hca_decompress is not None: |
| hca_kv = self._decompress(hca_cache, self.hca_decompress) |
| if hca_pe_cache is None: |
| hca_pe = pe_cache[::MLA_HCA_STRIDE] |
| else: |
| hca_pe = hca_pe_cache |
| n_hca = min(hca_kv.shape[0], hca_pe.shape[0]) |
| hca_kv = hca_kv[:n_hca] |
| hca_pe = hca_pe[:n_hca] |
| hca_scores = self._compute_scores( |
| q_nope_absorbed, q_pe, hca_kv, hca_pe, |
| start_pos, seqlen, mask=None, |
| ) |
| hca_scores = hca_scores.softmax(dim=-1, dtype=torch.float32) |
| hca_out = torch.einsum( |
| "bsht,btc->bshc", hca_scores, hca_kv.unsqueeze(0)) |
| attn_out = attn_out + hca_out |
|
|
| attn_unproj = torch.einsum( |
| "bshc,hdc->bshd", attn_out, wkv_b[:, -self.v_head_dim:]) |
|
|
| return self.wo(attn_unproj.flatten(2)) |
|
|