File size: 7,407 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
"""Multi-Head Latent Attention with CSA + HCA compression (DeepSeek V4 style).

Ternary-weighted. KV cache stores compressed latent at multiple levels:
- Base: MLA latent (d=kv_lora_rank, typically 64/32)
- CSA: Secondary compression (d_csa, e.g. 16) — 4x compression on cache
- HCA: Heavily compressed (d_hca, e.g. 8) — 8x compression, wider stride

Scores = q_nope_absorbed @ decompress(kv_cache) + q_pe @ pe_cache
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..config import HIDDEN_DIM, MLA_CSA_DIM, MLA_HCA_DIM, MLA_HCA_STRIDE, MLA_N_LAYERS
from ..kernel.ternary_scale import TernaryScaleTensor, TernaryRMSNorm, TScaleType

MLA_N_HEADS = 32
MLA_QK_NOPE_HEAD_DIM = 96
MLA_QK_ROPE_HEAD_DIM = 32
MLA_V_HEAD_DIM = 96
MLA_ROPE_THETA = 10000.0
MLA_SLIDE_DIM = 64
MLA_FULL_DIM = 32


def apply_rotary_emb(x, freqs_cis):
    x_complex = torch.view_as_complex(
        x.float().reshape(*x.shape[:-1], -1, 2)
    )
    freqs = freqs_cis.unsqueeze(1).unsqueeze(0)
    return torch.view_as_real(x_complex * freqs).flatten(-2).to(x.dtype)


def precompute_freqs_cis(dim, end, theta=MLA_ROPE_THETA):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs)
    return torch.polar(torch.ones_like(freqs), freqs)


class MultiHeadLatentAttention(nn.Module):
    def __init__(self, dim=HIDDEN_DIM, n_heads=MLA_N_HEADS, kv_lora_rank=MLA_SLIDE_DIM,
                 qk_nope_head_dim=MLA_QK_NOPE_HEAD_DIM, qk_rope_head_dim=MLA_QK_ROPE_HEAD_DIM,
                 v_head_dim=MLA_V_HEAD_DIM, max_seq_len=65536,
                 csa_dim=MLA_CSA_DIM, hca_dim=MLA_HCA_DIM,
                 tscale_type=TScaleType.T32):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.kv_lora_rank = kv_lora_rank
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.softmax_scale = self.qk_head_dim ** -0.5
        self.max_seq_len = max_seq_len
        self.csa_dim = csa_dim
        self.hca_dim = hca_dim

        self.wq_norm = TernaryRMSNorm(dim, tscale_type=tscale_type)
        self.wq = TernaryScaleTensor(dim, n_heads * self.qk_head_dim, tscale_type=tscale_type)

        combined_out = n_heads * (qk_nope_head_dim + v_head_dim)
        self.wkv_b = TernaryScaleTensor(kv_lora_rank, combined_out, tscale_type=tscale_type)
        self.wo = TernaryScaleTensor(n_heads * v_head_dim, dim, tscale_type=tscale_type)

        # CSA: secondary compression (kv_lora_rank -> csa_dim)
        if csa_dim and csa_dim < kv_lora_rank:
            self.csa_compress = TernaryScaleTensor(kv_lora_rank, csa_dim, tscale_type=tscale_type)
            self.csa_decompress = TernaryScaleTensor(csa_dim, kv_lora_rank, tscale_type=tscale_type)
        else:
            self.csa_compress = None
            self.csa_decompress = None

        # HCA: heavily compressed (kv_lora_rank -> hca_dim)
        if hca_dim and hca_dim < (csa_dim or kv_lora_rank):
            self.hca_compress = TernaryScaleTensor(kv_lora_rank, hca_dim, tscale_type=tscale_type)
            self.hca_decompress = TernaryScaleTensor(hca_dim, kv_lora_rank, tscale_type=tscale_type)
        else:
            self.hca_compress = None
            self.hca_decompress = None

    def _compress(self, kv_cache, compress_proj):
        """Compress kv_cache from kv_lora_rank to smaller dim."""
        return compress_proj(kv_cache)

    def _decompress(self, cache, decompress_proj):
        """Decompress cache back to kv_lora_rank."""
        return decompress_proj(cache)

    def _compute_scores(self, q_nope_absorbed, q_pe, kv_flat, pe_flat,
                        start_pos, seqlen, mask):
        """Shared score computation for base, CSA, and HCA attention."""
        n_keys = min(kv_flat.shape[0], pe_flat.shape[0])
        kv_flat = kv_flat[:n_keys]
        pe_flat = pe_flat[:n_keys]
        if n_keys == 0:
            return q_pe.new_zeros(q_pe.shape[0], seqlen, q_pe.shape[2], 0)
        scores = (
            torch.einsum("bshc,btc->bsht",
                         q_nope_absorbed, kv_flat.unsqueeze(0))
            + torch.einsum("bshr,btr->bsht",
                           q_pe, pe_flat.unsqueeze(0))
        ) * self.softmax_scale

        if mask is not None:
            scores = scores + mask.unsqueeze(0).unsqueeze(0)
        if mask is None and seqlen > 1:
            causal = torch.triu(
                torch.full((seqlen, n_keys), float('-inf'), device=q_pe.device),
                diagonal=1 + start_pos
            )
            scores = scores + causal.unsqueeze(0).unsqueeze(2)
        return scores

    def forward(self, x, kv_cache, pe_cache, start_pos=0, freqs_cis=None, mask=None,
                csa_cache=None, hca_cache=None, hca_pe_cache=None):
        bsz, seqlen, _ = x.size()

        q = self.wq(self.wq_norm(x))
        q = q.view(bsz, seqlen, self.n_heads, self.qk_head_dim)
        q_nope, q_pe = torch.split(
            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

        if freqs_cis is not None:
            q_pe = apply_rotary_emb(q_pe, freqs_cis[start_pos:start_pos + seqlen])

        wkv_b = self.wkv_b._get_T() * self.wkv_b._get_S()
        wkv_b = wkv_b.view(self.n_heads, -1, self.kv_lora_rank)

        q_nope_absorbed = torch.einsum(
            "bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])

        n_cache = min(kv_cache.shape[0], pe_cache.shape[0])
        kv_flat = kv_cache[:n_cache]
        pe_flat = pe_cache[:n_cache]
        
        # Decompress CSA cache if provided (replaces base kv_cache)
        if csa_cache is not None and self.csa_decompress is not None:
            n_csa = min(csa_cache.shape[0], pe_flat.shape[0])
            kv_flat = self._decompress(csa_cache[:n_csa], self.csa_decompress)
            pe_flat = pe_flat[:n_csa]

        # Base attention (exact, CSA-compressed if applicable)
        scores = self._compute_scores(
            q_nope_absorbed, q_pe, kv_flat, pe_flat,
            start_pos, seqlen, mask,
        )
        scores = scores.softmax(dim=-1, dtype=torch.float32)

        attn_out = torch.einsum(
            "bsht,btc->bshc", scores, kv_flat.unsqueeze(0))

        # HCA long-range attention (heavily compressed, strided)
        hca_out = None
        if hca_cache is not None and self.hca_decompress is not None:
            hca_kv = self._decompress(hca_cache, self.hca_decompress)
            if hca_pe_cache is None:
                hca_pe = pe_cache[::MLA_HCA_STRIDE]
            else:
                hca_pe = hca_pe_cache
            n_hca = min(hca_kv.shape[0], hca_pe.shape[0])
            hca_kv = hca_kv[:n_hca]
            hca_pe = hca_pe[:n_hca]
            hca_scores = self._compute_scores(
                q_nope_absorbed, q_pe, hca_kv, hca_pe,
                start_pos, seqlen, mask=None,
            )
            hca_scores = hca_scores.softmax(dim=-1, dtype=torch.float32)
            hca_out = torch.einsum(
                "bsht,btc->bshc", hca_scores, hca_kv.unsqueeze(0))
            attn_out = attn_out + hca_out

        attn_unproj = torch.einsum(
            "bshc,hdc->bshd", attn_out, wkv_b[:, -self.v_head_dim:])

        return self.wo(attn_unproj.flatten(2))