| """Context Attention Scheduler β sliding window + full context orchestration. |
| |
| Schedules 4 sliding window (d=64, CSA-compressed to d=16) and 4 full context |
| (d=32, HCA-compressed to d=8) MLA attention passes. Combines both via gating. |
| |
| Pipeline: GNN output β ContextAttentionScheduler β MoE input |
| """ |
| import torch |
| import torch.nn as nn |
| from ..config import HIDDEN_DIM, MLA_HCA_STRIDE |
| from ..kernel.ternary_scale import TernaryScaleTensor, TernaryRMSNorm, TScaleType |
| from .mla import (MultiHeadLatentAttention, precompute_freqs_cis, |
| MLA_N_LAYERS, MLA_N_HEADS, MLA_SLIDE_DIM, MLA_FULL_DIM, |
| MLA_QK_NOPE_HEAD_DIM, MLA_QK_ROPE_HEAD_DIM, |
| MLA_V_HEAD_DIM, MLA_ROPE_THETA, |
| MLA_CSA_DIM, MLA_HCA_DIM, MLA_HCA_STRIDE) |
|
|
| SLIDING_WINDOW_SIZE = 32768 |
| KV_LEDGER_SIZE = 262144 |
|
|
|
|
| class ContextAttentionScheduler(nn.Module): |
| def __init__(self, dim=HIDDEN_DIM): |
| super().__init__() |
| self.dim = dim |
|
|
| |
| n_layers_per_pass = max(1, MLA_N_LAYERS // 2) |
| self.slide_layers = nn.ModuleList([ |
| MultiHeadLatentAttention( |
| dim=dim, n_heads=MLA_N_HEADS, kv_lora_rank=MLA_SLIDE_DIM, |
| qk_nope_head_dim=MLA_QK_NOPE_HEAD_DIM, |
| qk_rope_head_dim=MLA_QK_ROPE_HEAD_DIM, |
| v_head_dim=MLA_V_HEAD_DIM, |
| csa_dim=MLA_CSA_DIM, hca_dim=None, |
| ) for _ in range(n_layers_per_pass) |
| ]) |
| |
| self.slide_embed = TernaryScaleTensor(1, MLA_SLIDE_DIM, tscale_type=TScaleType.T32) |
| self.slide_compress = TernaryScaleTensor(MLA_SLIDE_DIM, MLA_CSA_DIM, tscale_type=TScaleType.T32) |
|
|
| |
| self.full_layers = nn.ModuleList([ |
| MultiHeadLatentAttention( |
| dim=dim, n_heads=MLA_N_HEADS, kv_lora_rank=MLA_FULL_DIM, |
| qk_nope_head_dim=MLA_QK_NOPE_HEAD_DIM, |
| qk_rope_head_dim=MLA_QK_ROPE_HEAD_DIM, |
| v_head_dim=MLA_V_HEAD_DIM, |
| csa_dim=None, hca_dim=MLA_HCA_DIM, |
| ) for _ in range(n_layers_per_pass) |
| ]) |
| |
| self.full_embed = TernaryScaleTensor(1, MLA_FULL_DIM, tscale_type=TScaleType.T32) |
| self.full_compress = TernaryScaleTensor(MLA_FULL_DIM, MLA_HCA_DIM, tscale_type=TScaleType.T32) |
|
|
| self.gate = TernaryScaleTensor(dim, 1, tscale_type=TScaleType.T32) |
|
|
| self._freqs_cis = None |
| self._max_freq_len = 0 |
|
|
| def _ensure_freqs(self, seq_len, device): |
| needed = max(seq_len, SLIDING_WINDOW_SIZE, KV_LEDGER_SIZE) |
| if self._freqs_cis is None or needed > self._max_freq_len: |
| self._max_freq_len = needed |
| self._freqs_cis = precompute_freqs_cis( |
| MLA_QK_ROPE_HEAD_DIM, needed, theta=MLA_ROPE_THETA |
| ).to(device) |
| return self._freqs_cis |
|
|
| def forward(self, x, kv_ledger, full_ledger=None, kq_cache=None): |
| bsz, seqlen, _ = x.shape |
| device = x.device |
| freqs_cis = self._ensure_freqs(seqlen, device) |
|
|
| full_ledger = full_ledger or kv_ledger |
|
|
| window_size = min(SLIDING_WINDOW_SIZE, kv_ledger.size) if kv_ledger.size > 0 else 0 |
|
|
| out_slide = x |
| if window_size > 0: |
| start = max(0, kv_ledger.size - SLIDING_WINDOW_SIZE) |
| end = kv_ledger.size |
| slide_ids = kv_ledger.get_range(start, end).float().unsqueeze(-1) |
| |
| slide_latent = self.slide_embed(slide_ids) |
| csa_cache = self.slide_compress(slide_latent) |
| pe_cache = torch.zeros(csa_cache.shape[0], MLA_QK_ROPE_HEAD_DIM, device=device) |
|
|
| for layer in self.slide_layers: |
| out_slide = layer(out_slide, slide_latent, pe_cache, |
| start_pos=0, freqs_cis=freqs_cis, mask=None, |
| csa_cache=csa_cache) |
|
|
| out_full = x |
| if full_ledger.size > 0: |
| full = full_ledger.get_sparse(stride=MLA_HCA_STRIDE) |
| full_ids = full.float().unsqueeze(-1) |
| full_latent = self.full_embed(full_ids) |
| hca_cache = self.full_compress(full_latent) |
| pe_cache = torch.zeros(hca_cache.shape[0], MLA_QK_ROPE_HEAD_DIM, device=device) |
|
|
| for layer in self.full_layers: |
| out_full = layer(out_full, full_latent, pe_cache, |
| start_pos=0, freqs_cis=freqs_cis, mask=None, |
| hca_cache=hca_cache, hca_pe_cache=pe_cache) |
|
|
| gate = torch.sigmoid(self.gate(x.mean(dim=1, keepdim=True))) |
| out = gate * out_slide + (1 - gate) * out_full |
| return out |
|
|