File size: 4,958 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""Context Attention Scheduler β€” sliding window + full context orchestration.

Schedules 4 sliding window (d=64, CSA-compressed to d=16) and 4 full context
(d=32, HCA-compressed to d=8) MLA attention passes. Combines both via gating.

Pipeline: GNN output β†’ ContextAttentionScheduler β†’ MoE input
"""
import torch
import torch.nn as nn
from ..config import HIDDEN_DIM, MLA_HCA_STRIDE
from ..kernel.ternary_scale import TernaryScaleTensor, TernaryRMSNorm, TScaleType
from .mla import (MultiHeadLatentAttention, precompute_freqs_cis,
                  MLA_N_LAYERS, MLA_N_HEADS, MLA_SLIDE_DIM, MLA_FULL_DIM,
                  MLA_QK_NOPE_HEAD_DIM, MLA_QK_ROPE_HEAD_DIM,
                  MLA_V_HEAD_DIM, MLA_ROPE_THETA,
                  MLA_CSA_DIM, MLA_HCA_DIM, MLA_HCA_STRIDE)

SLIDING_WINDOW_SIZE = 32768
KV_LEDGER_SIZE = 262144


class ContextAttentionScheduler(nn.Module):
    def __init__(self, dim=HIDDEN_DIM):
        super().__init__()
        self.dim = dim

        # Slide layers with CSA compression (d=64 β†’ d=16) β€” half of total layers
        n_layers_per_pass = max(1, MLA_N_LAYERS // 2)
        self.slide_layers = nn.ModuleList([
            MultiHeadLatentAttention(
                dim=dim, n_heads=MLA_N_HEADS, kv_lora_rank=MLA_SLIDE_DIM,
                qk_nope_head_dim=MLA_QK_NOPE_HEAD_DIM,
                qk_rope_head_dim=MLA_QK_ROPE_HEAD_DIM,
                v_head_dim=MLA_V_HEAD_DIM,
                csa_dim=MLA_CSA_DIM, hca_dim=None,
            ) for _ in range(n_layers_per_pass)
        ])
        # CSA: embed motif IDs β†’ kv_lora_rank, then compress β†’ csa_dim
        self.slide_embed = TernaryScaleTensor(1, MLA_SLIDE_DIM, tscale_type=TScaleType.T32)
        self.slide_compress = TernaryScaleTensor(MLA_SLIDE_DIM, MLA_CSA_DIM, tscale_type=TScaleType.T32)

        # Full context layers with HCA compression (d=32 β†’ d=8) β€” half of total layers
        self.full_layers = nn.ModuleList([
            MultiHeadLatentAttention(
                dim=dim, n_heads=MLA_N_HEADS, kv_lora_rank=MLA_FULL_DIM,
                qk_nope_head_dim=MLA_QK_NOPE_HEAD_DIM,
                qk_rope_head_dim=MLA_QK_ROPE_HEAD_DIM,
                v_head_dim=MLA_V_HEAD_DIM,
                csa_dim=None, hca_dim=MLA_HCA_DIM,
            ) for _ in range(n_layers_per_pass)
        ])
        # HCA: embed motif IDs β†’ kv_lora_rank, then compress β†’ hca_dim
        self.full_embed = TernaryScaleTensor(1, MLA_FULL_DIM, tscale_type=TScaleType.T32)
        self.full_compress = TernaryScaleTensor(MLA_FULL_DIM, MLA_HCA_DIM, tscale_type=TScaleType.T32)

        self.gate = TernaryScaleTensor(dim, 1, tscale_type=TScaleType.T32)

        self._freqs_cis = None
        self._max_freq_len = 0

    def _ensure_freqs(self, seq_len, device):
        needed = max(seq_len, SLIDING_WINDOW_SIZE, KV_LEDGER_SIZE)
        if self._freqs_cis is None or needed > self._max_freq_len:
            self._max_freq_len = needed
            self._freqs_cis = precompute_freqs_cis(
                MLA_QK_ROPE_HEAD_DIM, needed, theta=MLA_ROPE_THETA
            ).to(device)
        return self._freqs_cis

    def forward(self, x, kv_ledger, full_ledger=None, kq_cache=None):
        bsz, seqlen, _ = x.shape
        device = x.device
        freqs_cis = self._ensure_freqs(seqlen, device)

        full_ledger = full_ledger or kv_ledger

        window_size = min(SLIDING_WINDOW_SIZE, kv_ledger.size) if kv_ledger.size > 0 else 0

        out_slide = x
        if window_size > 0:
            start = max(0, kv_ledger.size - SLIDING_WINDOW_SIZE)
            end = kv_ledger.size
            slide_ids = kv_ledger.get_range(start, end).float().unsqueeze(-1)
            # Embed to kv_lora_rank, then CSA compress to csa_dim
            slide_latent = self.slide_embed(slide_ids)
            csa_cache = self.slide_compress(slide_latent)
            pe_cache = torch.zeros(csa_cache.shape[0], MLA_QK_ROPE_HEAD_DIM, device=device)

            for layer in self.slide_layers:
                out_slide = layer(out_slide, slide_latent, pe_cache,
                                start_pos=0, freqs_cis=freqs_cis, mask=None,
                                csa_cache=csa_cache)

        out_full = x
        if full_ledger.size > 0:
            full = full_ledger.get_sparse(stride=MLA_HCA_STRIDE)
            full_ids = full.float().unsqueeze(-1)
            full_latent = self.full_embed(full_ids)
            hca_cache = self.full_compress(full_latent)
            pe_cache = torch.zeros(hca_cache.shape[0], MLA_QK_ROPE_HEAD_DIM, device=device)

            for layer in self.full_layers:
                out_full = layer(out_full, full_latent, pe_cache,
                               start_pos=0, freqs_cis=freqs_cis, mask=None,
                               hca_cache=hca_cache, hca_pe_cache=pe_cache)

        gate = torch.sigmoid(self.gate(x.mean(dim=1, keepdim=True)))
        out = gate * out_slide + (1 - gate) * out_full
        return out