ARBS / arbitor /attention /context_attention.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Context Attention Scheduler β€” sliding window + full context orchestration.
Schedules 4 sliding window (d=64, CSA-compressed to d=16) and 4 full context
(d=32, HCA-compressed to d=8) MLA attention passes. Combines both via gating.
Pipeline: GNN output β†’ ContextAttentionScheduler β†’ MoE input
"""
import torch
import torch.nn as nn
from ..config import HIDDEN_DIM, MLA_HCA_STRIDE
from ..kernel.ternary_scale import TernaryScaleTensor, TernaryRMSNorm, TScaleType
from .mla import (MultiHeadLatentAttention, precompute_freqs_cis,
MLA_N_LAYERS, MLA_N_HEADS, MLA_SLIDE_DIM, MLA_FULL_DIM,
MLA_QK_NOPE_HEAD_DIM, MLA_QK_ROPE_HEAD_DIM,
MLA_V_HEAD_DIM, MLA_ROPE_THETA,
MLA_CSA_DIM, MLA_HCA_DIM, MLA_HCA_STRIDE)
SLIDING_WINDOW_SIZE = 32768
KV_LEDGER_SIZE = 262144
class ContextAttentionScheduler(nn.Module):
def __init__(self, dim=HIDDEN_DIM):
super().__init__()
self.dim = dim
# Slide layers with CSA compression (d=64 β†’ d=16) β€” half of total layers
n_layers_per_pass = max(1, MLA_N_LAYERS // 2)
self.slide_layers = nn.ModuleList([
MultiHeadLatentAttention(
dim=dim, n_heads=MLA_N_HEADS, kv_lora_rank=MLA_SLIDE_DIM,
qk_nope_head_dim=MLA_QK_NOPE_HEAD_DIM,
qk_rope_head_dim=MLA_QK_ROPE_HEAD_DIM,
v_head_dim=MLA_V_HEAD_DIM,
csa_dim=MLA_CSA_DIM, hca_dim=None,
) for _ in range(n_layers_per_pass)
])
# CSA: embed motif IDs β†’ kv_lora_rank, then compress β†’ csa_dim
self.slide_embed = TernaryScaleTensor(1, MLA_SLIDE_DIM, tscale_type=TScaleType.T32)
self.slide_compress = TernaryScaleTensor(MLA_SLIDE_DIM, MLA_CSA_DIM, tscale_type=TScaleType.T32)
# Full context layers with HCA compression (d=32 β†’ d=8) β€” half of total layers
self.full_layers = nn.ModuleList([
MultiHeadLatentAttention(
dim=dim, n_heads=MLA_N_HEADS, kv_lora_rank=MLA_FULL_DIM,
qk_nope_head_dim=MLA_QK_NOPE_HEAD_DIM,
qk_rope_head_dim=MLA_QK_ROPE_HEAD_DIM,
v_head_dim=MLA_V_HEAD_DIM,
csa_dim=None, hca_dim=MLA_HCA_DIM,
) for _ in range(n_layers_per_pass)
])
# HCA: embed motif IDs β†’ kv_lora_rank, then compress β†’ hca_dim
self.full_embed = TernaryScaleTensor(1, MLA_FULL_DIM, tscale_type=TScaleType.T32)
self.full_compress = TernaryScaleTensor(MLA_FULL_DIM, MLA_HCA_DIM, tscale_type=TScaleType.T32)
self.gate = TernaryScaleTensor(dim, 1, tscale_type=TScaleType.T32)
self._freqs_cis = None
self._max_freq_len = 0
def _ensure_freqs(self, seq_len, device):
needed = max(seq_len, SLIDING_WINDOW_SIZE, KV_LEDGER_SIZE)
if self._freqs_cis is None or needed > self._max_freq_len:
self._max_freq_len = needed
self._freqs_cis = precompute_freqs_cis(
MLA_QK_ROPE_HEAD_DIM, needed, theta=MLA_ROPE_THETA
).to(device)
return self._freqs_cis
def forward(self, x, kv_ledger, full_ledger=None, kq_cache=None):
bsz, seqlen, _ = x.shape
device = x.device
freqs_cis = self._ensure_freqs(seqlen, device)
full_ledger = full_ledger or kv_ledger
window_size = min(SLIDING_WINDOW_SIZE, kv_ledger.size) if kv_ledger.size > 0 else 0
out_slide = x
if window_size > 0:
start = max(0, kv_ledger.size - SLIDING_WINDOW_SIZE)
end = kv_ledger.size
slide_ids = kv_ledger.get_range(start, end).float().unsqueeze(-1)
# Embed to kv_lora_rank, then CSA compress to csa_dim
slide_latent = self.slide_embed(slide_ids)
csa_cache = self.slide_compress(slide_latent)
pe_cache = torch.zeros(csa_cache.shape[0], MLA_QK_ROPE_HEAD_DIM, device=device)
for layer in self.slide_layers:
out_slide = layer(out_slide, slide_latent, pe_cache,
start_pos=0, freqs_cis=freqs_cis, mask=None,
csa_cache=csa_cache)
out_full = x
if full_ledger.size > 0:
full = full_ledger.get_sparse(stride=MLA_HCA_STRIDE)
full_ids = full.float().unsqueeze(-1)
full_latent = self.full_embed(full_ids)
hca_cache = self.full_compress(full_latent)
pe_cache = torch.zeros(hca_cache.shape[0], MLA_QK_ROPE_HEAD_DIM, device=device)
for layer in self.full_layers:
out_full = layer(out_full, full_latent, pe_cache,
start_pos=0, freqs_cis=freqs_cis, mask=None,
hca_cache=hca_cache, hca_pe_cache=pe_cache)
gate = torch.sigmoid(self.gate(x.mean(dim=1, keepdim=True)))
out = gate * out_slide + (1 - gate) * out_full
return out