""" Chimera 5.2 — inference-time helpers (CPU-first). This module collects all the lightweight components that run *after* the trunk produces hidden states: * :class:`SpanBank` — vectorised semantic memory. * :class:`STreeVerifier` — tiny scoring head. * :class:`CertificateVerifier`— per-token risk projection. * :class:`SpanInferenceEngine`— glue + risk gating. * :class:`GrammarFST` — additive constraint penalty. * :class:`EntropyValve` — adaptive loop-count router. * :class:`DebtLedger` — bias logits to honour outstanding obligations. * :class:`BraidState` — runtime scratch state. Optimisations vs the previous draft: * Grammar / Debt are *true* identity ops when their constraints are empty (no tensors allocated, no projections run) — this matters because they sit on the per-token logits path. * Entropy is computed on the slice the model actually scores (not the full 200K-vocab logits): the model passes us the last-token logits. * Everything that does not depend on the input shape is allocated once. """ from __future__ import annotations import math from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F # --------------------------------------------------------------------------- # SpanBank # --------------------------------------------------------------------------- class SpanBank(nn.Module): """Cosine-similarity span memory used for retrieval-augmented inference.""" def __init__(self, max_entries: int = 524288, max_tokens: int = 64, hidden_size: int = 2560, memory_mb: int = 384): super().__init__() self.max_entries = int(max_entries) self.max_tokens = int(max_tokens) self.hidden_size = int(hidden_size) proj_dim = max(8, hidden_size // 4) # Estimate entries the user can actually afford in RAM. budget = int(memory_mb) * 1024 * 1024 per_entry = (proj_dim + hidden_size) * 4 + 8 actual = max(1, min(self.max_entries, budget // per_entry)) self.proj_dim = proj_dim self.register_buffer("bank_keys", torch.zeros(actual, proj_dim)) self.register_buffer("bank_values", torch.zeros(actual, hidden_size)) self.register_buffer("bank_lengths", torch.zeros(actual, dtype=torch.long)) self.register_buffer("bank_count", torch.zeros((), dtype=torch.long)) self.semantic_proj = nn.Linear(hidden_size, proj_dim, bias=False) @property def capacity(self) -> int: return int(self.bank_keys.size(0)) def query_scores(self, hidden_state: torch.Tensor, top_k: int = 64 ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: c = int(self.bank_count.item()) if c == 0: return None, None q = F.normalize(self.semantic_proj(hidden_state), dim=-1) keys = F.normalize(self.bank_keys[:c], dim=-1) sims = torch.matmul(q, keys.t()) k = min(top_k, c) return torch.topk(sims, k, dim=-1) def query(self, hidden_state: torch.Tensor, top_k: int = 64) -> torch.Tensor: scores, indices = self.query_scores(hidden_state, top_k=top_k) if scores is None: return torch.zeros_like(hidden_state) c = int(self.bank_count.item()) values = self.bank_values[:c][indices] weights = torch.softmax(scores, dim=-1).unsqueeze(-1) return (values * weights).sum(dim=-2) @torch.no_grad() def add(self, keys: torch.Tensor, values: torch.Tensor) -> None: """Bulk insert; vectorised, falls back to overwriting once full.""" keys = keys.detach().reshape(-1, self.hidden_size) values = values.detach().reshape(-1, self.hidden_size) n = keys.size(0) if n == 0: return cap = self.capacity start = int(self.bank_count.item()) end = min(start + n, cap) write = end - start if write > 0: self.bank_keys[start:end] = self.semantic_proj(keys[:write]) self.bank_values[start:end] = values[:write] self.bank_lengths[start:end] = 1 self.bank_count.add_(write) @torch.no_grad() def add_span(self, hidden_state: torch.Tensor, length: int, value: Optional[torch.Tensor] = None) -> None: h = hidden_state.detach().reshape(-1, self.hidden_size).mean(dim=0, keepdim=True) v = (value.detach().reshape(-1, self.hidden_size).mean(dim=0, keepdim=True) if value is not None else h) self.add(h, v) # --------------------------------------------------------------------------- # Verifiers # --------------------------------------------------------------------------- class STreeVerifier(nn.Module): """Tiny scoring head used by speculative-tree decoding.""" def __init__(self, tree_width: int = 4, tree_depth: int = 5, hidden_size: int = 256): super().__init__() self.tree_width = int(tree_width) self.tree_depth = int(tree_depth) h_mid = max(8, hidden_size // 4) self.score_net = nn.Sequential( nn.Linear(hidden_size, h_mid), nn.ReLU(inplace=True), nn.Linear(h_mid, 1), ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return torch.sigmoid(self.score_net(hidden_states)).squeeze(-1) class CertificateVerifier(nn.Module): """Per-token certificate fields (semantic / grammar / entity / risk).""" def __init__(self, hidden_size: int): super().__init__() self.semantic_proj = nn.Linear(hidden_size, 64, bias=False) self.grammar_proj = nn.Linear(hidden_size, 16, bias=False) self.entity_proj = nn.Linear(hidden_size, 32, bias=False) self.boundary_proj = nn.Linear(hidden_size, 1, bias=False) self.risk_proj = nn.Linear(hidden_size, 1, bias=False) def forward(self, hidden_states: torch.Tensor) -> dict: return { "semantic": self.semantic_proj(hidden_states), "grammar": self.grammar_proj(hidden_states), "entity": self.entity_proj(hidden_states), "boundary": self.boundary_proj(hidden_states), "risk": torch.sigmoid(self.risk_proj(hidden_states)), } class SpanInferenceEngine(nn.Module): """Risk-gated post-trunk hidden-state modulation.""" def __init__(self, hidden_size: int, config: dict): super().__init__() self.enabled = bool(config.get("enabled", True)) self.hidden_size = int(hidden_size) self.span_bank = SpanBank( max_entries=config.get("bank_entries", 524288), max_tokens=config.get("bank_max_tokens", 64), hidden_size=self.hidden_size, memory_mb=config.get("bank_memory_mb", 384), ) self.tree_verifier = STreeVerifier( tree_width=config.get("tree_verify", {}).get("tree_width", 4), tree_depth=config.get("tree_verify", {}).get("tree_depth", 5), hidden_size=self.hidden_size, ) self.certificate = CertificateVerifier(self.hidden_size) self.scoring_weights = nn.Parameter( torch.tensor(config.get("scoring_weights_fast", [1.0, 0.8, 0.5, 0.7, 0.35]))) self.fallback_threshold = float(config.get("fallback_below_acceptance", 0.5)) # Single fused gate from concatenated hidden + risk. self.risk_gate = nn.Linear(self.hidden_size + 1, self.hidden_size, bias=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if not self.enabled: return hidden_states risk = torch.sigmoid(self.certificate.risk_proj(hidden_states)) gate_input = torch.cat([hidden_states, risk], dim=-1) modulation = torch.sigmoid(self.risk_gate(gate_input)) return hidden_states * modulation # --------------------------------------------------------------------------- # Grammar FST — additive penalty (no-op when no constraints) # --------------------------------------------------------------------------- class GrammarFST(nn.Module): """Soft-constraint penalty on next-token logits. *Identity* when ``enabled`` is false **or** there are no constraints – no entropy computation, no projection allocations. """ def __init__(self, config: dict): super().__init__() self.enabled = bool(config.get("enabled", True)) self.hard_constraints = list(config.get("hard_constraints", [])) self.soft_constraints = list(config.get("soft_constraints", [])) n_features = len(self.hard_constraints) + len(self.soft_constraints) + 1 self._n_hard = len(self.hard_constraints) self._n_soft = len(self.soft_constraints) self._n_features = n_features self._is_noop = (not self.enabled) or n_features <= 1 self.constraint_proj = nn.Linear(n_features, 1, bias=True) nn.init.normal_(self.constraint_proj.weight, std=0.01) nn.init.zeros_(self.constraint_proj.bias) def forward(self, logits: torch.Tensor, state=None) -> torch.Tensor: if self._is_noop: return logits B, T, V = logits.shape # Single log_softmax pass for entropy. log_probs = F.log_softmax(logits, dim=-1) entropy = -(log_probs.exp() * log_probs).sum(-1) # [B, T] features = logits.new_zeros(B, T, self._n_features) features[..., 0] = entropy if self._n_soft > 0 and T > 1: cos = F.cosine_similarity(logits[:, 1:], logits[:, :-1], dim=-1) features[:, 1:, self._n_hard] = cos.clamp_min(0.0) penalty = self.constraint_proj(features) # [B, T, 1] return logits + penalty # --------------------------------------------------------------------------- # Entropy valve # --------------------------------------------------------------------------- class EntropyValve(nn.Module): """Maps logits entropy → adaptive loop count for the looped trunk.""" def __init__(self, config: dict): super().__init__() self.enabled = bool(config.get("enabled", True)) self.threshold_bits = float(config.get("threshold_bits", 2.0)) self.levels = dict(config.get("levels", { "low": {"loops": 1, "min_span": 8, "audit": 0.125}, "medium": {"loops": 2, "min_span": 4, "audit": 0.5}, "high": {"loops": 4, "min_span": 1, "audit": 1.0}, })) self.router = nn.Sequential(nn.Linear(6, 32), nn.ReLU(inplace=True), nn.Linear(32, 3)) self._inv_log2 = 1.0 / math.log(2.0) def compute_entropy(self, logits: torch.Tensor) -> torch.Tensor: log_probs = F.log_softmax(logits.to(torch.float32), dim=-1) return -(log_probs.exp() * log_probs).sum(dim=-1) * self._inv_log2 def get_level(self, entropy: torch.Tensor) -> str: if not self.enabled: return "medium" mean_h = float(entropy.mean().item()) if mean_h < self.threshold_bits * 0.5: return "low" if mean_h < self.threshold_bits: return "medium" return "high" def get_loop_count(self, logits: torch.Tensor) -> int: if not self.enabled: return self.levels.get("medium", {}).get("loops", 2) level = self.get_level(self.compute_entropy(logits)) return self.levels.get(level, self.levels["medium"])["loops"] def forward(self, logits: torch.Tensor): entropy = self.compute_entropy(logits) level = self.get_level(entropy) return level, self.levels.get(level, self.levels["medium"]) # --------------------------------------------------------------------------- # Debt ledger — additive bias (no-op when no obligations) # --------------------------------------------------------------------------- class DebtLedger(nn.Module): def __init__(self, config: dict): super().__init__() self.enabled = bool(config.get("enabled", True)) self.obligations = list(config.get("obligations", [])) self.max_outstanding = int(config.get("max_outstanding", 64)) self.pressure_weight = float(config.get("pressure_weight", 0.3)) self.active_debts: list = [] self.debt_bias_scale = nn.Parameter(torch.tensor(0.5)) self.debt_proj = nn.Linear(1, 1, bias=True) nn.init.ones_(self.debt_proj.weight) nn.init.zeros_(self.debt_proj.bias) def add_debt(self, debt_type: str) -> None: if len(self.active_debts) < self.max_outstanding: self.active_debts.append(debt_type) def resolve_debt(self, debt_type: str) -> None: try: self.active_debts.remove(debt_type) except ValueError: pass def get_pressure(self) -> float: return self.pressure_weight * len(self.active_debts) / max(self.max_outstanding, 1) def forward(self, logits: torch.Tensor) -> torch.Tensor: if not self.enabled or not self.active_debts: return logits pressure = self.get_pressure() if pressure <= 0.0: return logits boost = self.debt_bias_scale * pressure boosted = self.debt_proj(boost.view(1, 1, 1)) return logits + boosted * 0.01 # --------------------------------------------------------------------------- # BraidState — runtime scratch container # --------------------------------------------------------------------------- class BraidState: """Plain-Python structure holding the runtime working memory.""" __slots__ = ["continuous", "fast", "semantic_sketch", "entity_slots", "grammar_stack", "debt_ledger_slots"] def __init__(self, config: dict, device: str = "cpu"): D = int(config.get("continuous_hidden", [2560, "float32"])[0]) self.continuous = torch.zeros(1, D, dtype=torch.float32, device=device) self.fast = torch.zeros(1, D, dtype=torch.int8, device=device) bits = int(config.get("semantic_sketch", [8192, "uint64_x128"])[0]) self.semantic_sketch = torch.zeros(1, bits // 8, dtype=torch.uint8, device=device) et = config.get("entity_table", {}) self.entity_slots = torch.zeros( int(et.get("slots", 256)), int(et.get("slot_bits", 512)) // 8, dtype=torch.uint8, device=device) gs = config.get("grammar_stack", {}) self.grammar_stack = torch.zeros( int(gs.get("slots", 64)), int(gs.get("width_bits", 128)) // 8, dtype=torch.uint8, device=device) self.debt_ledger_slots = torch.zeros( int(config.get("debt_ledger_slots", 64)), dtype=torch.int32, device=device) def reset(self) -> None: self.continuous.zero_() self.fast.zero_() self.semantic_sketch.zero_() __all__ = [ "SpanBank", "STreeVerifier", "CertificateVerifier", "SpanInferenceEngine", "GrammarFST", "EntropyValve", "DebtLedger", "BraidState", ]