"""Context Attention Scheduler — sliding window + full context orchestration. Schedules 4 sliding window (d=64, CSA-compressed to d=16) and 4 full context (d=32, HCA-compressed to d=8) MLA attention passes. Combines both via gating. Pipeline: GNN output → ContextAttentionScheduler → MoE input """ import torch import torch.nn as nn from ..config import HIDDEN_DIM, MLA_HCA_STRIDE from ..kernel.ternary_scale import TernaryScaleTensor, TernaryRMSNorm, TScaleType from .mla import (MultiHeadLatentAttention, precompute_freqs_cis, MLA_N_LAYERS, MLA_N_HEADS, MLA_SLIDE_DIM, MLA_FULL_DIM, MLA_QK_NOPE_HEAD_DIM, MLA_QK_ROPE_HEAD_DIM, MLA_V_HEAD_DIM, MLA_ROPE_THETA, MLA_CSA_DIM, MLA_HCA_DIM, MLA_HCA_STRIDE) SLIDING_WINDOW_SIZE = 32768 KV_LEDGER_SIZE = 262144 class ContextAttentionScheduler(nn.Module): def __init__(self, dim=HIDDEN_DIM): super().__init__() self.dim = dim # Slide layers with CSA compression (d=64 → d=16) — half of total layers n_layers_per_pass = max(1, MLA_N_LAYERS // 2) self.slide_layers = nn.ModuleList([ MultiHeadLatentAttention( dim=dim, n_heads=MLA_N_HEADS, kv_lora_rank=MLA_SLIDE_DIM, qk_nope_head_dim=MLA_QK_NOPE_HEAD_DIM, qk_rope_head_dim=MLA_QK_ROPE_HEAD_DIM, v_head_dim=MLA_V_HEAD_DIM, csa_dim=MLA_CSA_DIM, hca_dim=None, ) for _ in range(n_layers_per_pass) ]) # CSA: embed motif IDs → kv_lora_rank, then compress → csa_dim self.slide_embed = TernaryScaleTensor(1, MLA_SLIDE_DIM, tscale_type=TScaleType.T32) self.slide_compress = TernaryScaleTensor(MLA_SLIDE_DIM, MLA_CSA_DIM, tscale_type=TScaleType.T32) # Full context layers with HCA compression (d=32 → d=8) — half of total layers self.full_layers = nn.ModuleList([ MultiHeadLatentAttention( dim=dim, n_heads=MLA_N_HEADS, kv_lora_rank=MLA_FULL_DIM, qk_nope_head_dim=MLA_QK_NOPE_HEAD_DIM, qk_rope_head_dim=MLA_QK_ROPE_HEAD_DIM, v_head_dim=MLA_V_HEAD_DIM, csa_dim=None, hca_dim=MLA_HCA_DIM, ) for _ in range(n_layers_per_pass) ]) # HCA: embed motif IDs → kv_lora_rank, then compress → hca_dim self.full_embed = TernaryScaleTensor(1, MLA_FULL_DIM, tscale_type=TScaleType.T32) self.full_compress = TernaryScaleTensor(MLA_FULL_DIM, MLA_HCA_DIM, tscale_type=TScaleType.T32) self.gate = TernaryScaleTensor(dim, 1, tscale_type=TScaleType.T32) self._freqs_cis = None self._max_freq_len = 0 def _ensure_freqs(self, seq_len, device): needed = max(seq_len, SLIDING_WINDOW_SIZE, KV_LEDGER_SIZE) if self._freqs_cis is None or needed > self._max_freq_len: self._max_freq_len = needed self._freqs_cis = precompute_freqs_cis( MLA_QK_ROPE_HEAD_DIM, needed, theta=MLA_ROPE_THETA ).to(device) return self._freqs_cis def forward(self, x, kv_ledger, full_ledger=None, kq_cache=None): bsz, seqlen, _ = x.shape device = x.device freqs_cis = self._ensure_freqs(seqlen, device) full_ledger = full_ledger or kv_ledger window_size = min(SLIDING_WINDOW_SIZE, kv_ledger.size) if kv_ledger.size > 0 else 0 out_slide = x if window_size > 0: start = max(0, kv_ledger.size - SLIDING_WINDOW_SIZE) end = kv_ledger.size slide_ids = kv_ledger.get_range(start, end).float().unsqueeze(-1) # Embed to kv_lora_rank, then CSA compress to csa_dim slide_latent = self.slide_embed(slide_ids) csa_cache = self.slide_compress(slide_latent) pe_cache = torch.zeros(csa_cache.shape[0], MLA_QK_ROPE_HEAD_DIM, device=device) for layer in self.slide_layers: out_slide = layer(out_slide, slide_latent, pe_cache, start_pos=0, freqs_cis=freqs_cis, mask=None, csa_cache=csa_cache) out_full = x if full_ledger.size > 0: full = full_ledger.get_sparse(stride=MLA_HCA_STRIDE) full_ids = full.float().unsqueeze(-1) full_latent = self.full_embed(full_ids) hca_cache = self.full_compress(full_latent) pe_cache = torch.zeros(hca_cache.shape[0], MLA_QK_ROPE_HEAD_DIM, device=device) for layer in self.full_layers: out_full = layer(out_full, full_latent, pe_cache, start_pos=0, freqs_cis=freqs_cis, mask=None, hca_cache=hca_cache, hca_pe_cache=pe_cache) gate = torch.sigmoid(self.gate(x.mean(dim=1, keepdim=True))) out = gate * out_slide + (1 - gate) * out_full return out