File size: 7,407 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 | """Multi-Head Latent Attention with CSA + HCA compression (DeepSeek V4 style).
Ternary-weighted. KV cache stores compressed latent at multiple levels:
- Base: MLA latent (d=kv_lora_rank, typically 64/32)
- CSA: Secondary compression (d_csa, e.g. 16) — 4x compression on cache
- HCA: Heavily compressed (d_hca, e.g. 8) — 8x compression, wider stride
Scores = q_nope_absorbed @ decompress(kv_cache) + q_pe @ pe_cache
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..config import HIDDEN_DIM, MLA_CSA_DIM, MLA_HCA_DIM, MLA_HCA_STRIDE, MLA_N_LAYERS
from ..kernel.ternary_scale import TernaryScaleTensor, TernaryRMSNorm, TScaleType
MLA_N_HEADS = 32
MLA_QK_NOPE_HEAD_DIM = 96
MLA_QK_ROPE_HEAD_DIM = 32
MLA_V_HEAD_DIM = 96
MLA_ROPE_THETA = 10000.0
MLA_SLIDE_DIM = 64
MLA_FULL_DIM = 32
def apply_rotary_emb(x, freqs_cis):
x_complex = torch.view_as_complex(
x.float().reshape(*x.shape[:-1], -1, 2)
)
freqs = freqs_cis.unsqueeze(1).unsqueeze(0)
return torch.view_as_real(x_complex * freqs).flatten(-2).to(x.dtype)
def precompute_freqs_cis(dim, end, theta=MLA_ROPE_THETA):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs)
return torch.polar(torch.ones_like(freqs), freqs)
class MultiHeadLatentAttention(nn.Module):
def __init__(self, dim=HIDDEN_DIM, n_heads=MLA_N_HEADS, kv_lora_rank=MLA_SLIDE_DIM,
qk_nope_head_dim=MLA_QK_NOPE_HEAD_DIM, qk_rope_head_dim=MLA_QK_ROPE_HEAD_DIM,
v_head_dim=MLA_V_HEAD_DIM, max_seq_len=65536,
csa_dim=MLA_CSA_DIM, hca_dim=MLA_HCA_DIM,
tscale_type=TScaleType.T32):
super().__init__()
self.dim = dim
self.n_heads = n_heads
self.kv_lora_rank = kv_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.softmax_scale = self.qk_head_dim ** -0.5
self.max_seq_len = max_seq_len
self.csa_dim = csa_dim
self.hca_dim = hca_dim
self.wq_norm = TernaryRMSNorm(dim, tscale_type=tscale_type)
self.wq = TernaryScaleTensor(dim, n_heads * self.qk_head_dim, tscale_type=tscale_type)
combined_out = n_heads * (qk_nope_head_dim + v_head_dim)
self.wkv_b = TernaryScaleTensor(kv_lora_rank, combined_out, tscale_type=tscale_type)
self.wo = TernaryScaleTensor(n_heads * v_head_dim, dim, tscale_type=tscale_type)
# CSA: secondary compression (kv_lora_rank -> csa_dim)
if csa_dim and csa_dim < kv_lora_rank:
self.csa_compress = TernaryScaleTensor(kv_lora_rank, csa_dim, tscale_type=tscale_type)
self.csa_decompress = TernaryScaleTensor(csa_dim, kv_lora_rank, tscale_type=tscale_type)
else:
self.csa_compress = None
self.csa_decompress = None
# HCA: heavily compressed (kv_lora_rank -> hca_dim)
if hca_dim and hca_dim < (csa_dim or kv_lora_rank):
self.hca_compress = TernaryScaleTensor(kv_lora_rank, hca_dim, tscale_type=tscale_type)
self.hca_decompress = TernaryScaleTensor(hca_dim, kv_lora_rank, tscale_type=tscale_type)
else:
self.hca_compress = None
self.hca_decompress = None
def _compress(self, kv_cache, compress_proj):
"""Compress kv_cache from kv_lora_rank to smaller dim."""
return compress_proj(kv_cache)
def _decompress(self, cache, decompress_proj):
"""Decompress cache back to kv_lora_rank."""
return decompress_proj(cache)
def _compute_scores(self, q_nope_absorbed, q_pe, kv_flat, pe_flat,
start_pos, seqlen, mask):
"""Shared score computation for base, CSA, and HCA attention."""
n_keys = min(kv_flat.shape[0], pe_flat.shape[0])
kv_flat = kv_flat[:n_keys]
pe_flat = pe_flat[:n_keys]
if n_keys == 0:
return q_pe.new_zeros(q_pe.shape[0], seqlen, q_pe.shape[2], 0)
scores = (
torch.einsum("bshc,btc->bsht",
q_nope_absorbed, kv_flat.unsqueeze(0))
+ torch.einsum("bshr,btr->bsht",
q_pe, pe_flat.unsqueeze(0))
) * self.softmax_scale
if mask is not None:
scores = scores + mask.unsqueeze(0).unsqueeze(0)
if mask is None and seqlen > 1:
causal = torch.triu(
torch.full((seqlen, n_keys), float('-inf'), device=q_pe.device),
diagonal=1 + start_pos
)
scores = scores + causal.unsqueeze(0).unsqueeze(2)
return scores
def forward(self, x, kv_cache, pe_cache, start_pos=0, freqs_cis=None, mask=None,
csa_cache=None, hca_cache=None, hca_pe_cache=None):
bsz, seqlen, _ = x.size()
q = self.wq(self.wq_norm(x))
q = q.view(bsz, seqlen, self.n_heads, self.qk_head_dim)
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
if freqs_cis is not None:
q_pe = apply_rotary_emb(q_pe, freqs_cis[start_pos:start_pos + seqlen])
wkv_b = self.wkv_b._get_T() * self.wkv_b._get_S()
wkv_b = wkv_b.view(self.n_heads, -1, self.kv_lora_rank)
q_nope_absorbed = torch.einsum(
"bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
n_cache = min(kv_cache.shape[0], pe_cache.shape[0])
kv_flat = kv_cache[:n_cache]
pe_flat = pe_cache[:n_cache]
# Decompress CSA cache if provided (replaces base kv_cache)
if csa_cache is not None and self.csa_decompress is not None:
n_csa = min(csa_cache.shape[0], pe_flat.shape[0])
kv_flat = self._decompress(csa_cache[:n_csa], self.csa_decompress)
pe_flat = pe_flat[:n_csa]
# Base attention (exact, CSA-compressed if applicable)
scores = self._compute_scores(
q_nope_absorbed, q_pe, kv_flat, pe_flat,
start_pos, seqlen, mask,
)
scores = scores.softmax(dim=-1, dtype=torch.float32)
attn_out = torch.einsum(
"bsht,btc->bshc", scores, kv_flat.unsqueeze(0))
# HCA long-range attention (heavily compressed, strided)
hca_out = None
if hca_cache is not None and self.hca_decompress is not None:
hca_kv = self._decompress(hca_cache, self.hca_decompress)
if hca_pe_cache is None:
hca_pe = pe_cache[::MLA_HCA_STRIDE]
else:
hca_pe = hca_pe_cache
n_hca = min(hca_kv.shape[0], hca_pe.shape[0])
hca_kv = hca_kv[:n_hca]
hca_pe = hca_pe[:n_hca]
hca_scores = self._compute_scores(
q_nope_absorbed, q_pe, hca_kv, hca_pe,
start_pos, seqlen, mask=None,
)
hca_scores = hca_scores.softmax(dim=-1, dtype=torch.float32)
hca_out = torch.einsum(
"bsht,btc->bshc", hca_scores, hca_kv.unsqueeze(0))
attn_out = attn_out + hca_out
attn_unproj = torch.einsum(
"bshc,hdc->bshd", attn_out, wkv_b[:, -self.v_head_dim:])
return self.wo(attn_unproj.flatten(2))
|