| """ |
| Chimera 5.2 β inference-time helpers (CPU-first). |
| |
| This module collects all the lightweight components that run *after* the |
| trunk produces hidden states: |
| |
| * :class:`SpanBank` β vectorised semantic memory. |
| * :class:`STreeVerifier` β tiny scoring head. |
| * :class:`CertificateVerifier`β per-token risk projection. |
| * :class:`SpanInferenceEngine`β glue + risk gating. |
| * :class:`GrammarFST` β additive constraint penalty. |
| * :class:`EntropyValve` β adaptive loop-count router. |
| * :class:`DebtLedger` β bias logits to honour outstanding obligations. |
| * :class:`BraidState` β runtime scratch state. |
| |
| Optimisations vs the previous draft: |
| * Grammar / Debt are *true* identity ops when their constraints are empty |
| (no tensors allocated, no projections run) β this matters because they |
| sit on the per-token logits path. |
| * Entropy is computed on the slice the model actually scores (not the |
| full 200K-vocab logits): the model passes us the last-token logits. |
| * Everything that does not depend on the input shape is allocated once. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
|
|
| class SpanBank(nn.Module): |
| """Cosine-similarity span memory used for retrieval-augmented inference.""" |
|
|
| def __init__(self, max_entries: int = 524288, max_tokens: int = 64, |
| hidden_size: int = 2560, memory_mb: int = 384): |
| super().__init__() |
| self.max_entries = int(max_entries) |
| self.max_tokens = int(max_tokens) |
| self.hidden_size = int(hidden_size) |
| proj_dim = max(8, hidden_size // 4) |
| |
| budget = int(memory_mb) * 1024 * 1024 |
| per_entry = (proj_dim + hidden_size) * 4 + 8 |
| actual = max(1, min(self.max_entries, budget // per_entry)) |
| self.proj_dim = proj_dim |
| self.register_buffer("bank_keys", torch.zeros(actual, proj_dim)) |
| self.register_buffer("bank_values", torch.zeros(actual, hidden_size)) |
| self.register_buffer("bank_lengths", torch.zeros(actual, dtype=torch.long)) |
| self.register_buffer("bank_count", torch.zeros((), dtype=torch.long)) |
| self.semantic_proj = nn.Linear(hidden_size, proj_dim, bias=False) |
|
|
| @property |
| def capacity(self) -> int: |
| return int(self.bank_keys.size(0)) |
|
|
| def query_scores(self, hidden_state: torch.Tensor, top_k: int = 64 |
| ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: |
| c = int(self.bank_count.item()) |
| if c == 0: |
| return None, None |
| q = F.normalize(self.semantic_proj(hidden_state), dim=-1) |
| keys = F.normalize(self.bank_keys[:c], dim=-1) |
| sims = torch.matmul(q, keys.t()) |
| k = min(top_k, c) |
| return torch.topk(sims, k, dim=-1) |
|
|
| def query(self, hidden_state: torch.Tensor, top_k: int = 64) -> torch.Tensor: |
| scores, indices = self.query_scores(hidden_state, top_k=top_k) |
| if scores is None: |
| return torch.zeros_like(hidden_state) |
| c = int(self.bank_count.item()) |
| values = self.bank_values[:c][indices] |
| weights = torch.softmax(scores, dim=-1).unsqueeze(-1) |
| return (values * weights).sum(dim=-2) |
|
|
| @torch.no_grad() |
| def add(self, keys: torch.Tensor, values: torch.Tensor) -> None: |
| """Bulk insert; vectorised, falls back to overwriting once full.""" |
| keys = keys.detach().reshape(-1, self.hidden_size) |
| values = values.detach().reshape(-1, self.hidden_size) |
| n = keys.size(0) |
| if n == 0: |
| return |
| cap = self.capacity |
| start = int(self.bank_count.item()) |
| end = min(start + n, cap) |
| write = end - start |
| if write > 0: |
| self.bank_keys[start:end] = self.semantic_proj(keys[:write]) |
| self.bank_values[start:end] = values[:write] |
| self.bank_lengths[start:end] = 1 |
| self.bank_count.add_(write) |
|
|
| @torch.no_grad() |
| def add_span(self, hidden_state: torch.Tensor, length: int, |
| value: Optional[torch.Tensor] = None) -> None: |
| h = hidden_state.detach().reshape(-1, self.hidden_size).mean(dim=0, keepdim=True) |
| v = (value.detach().reshape(-1, self.hidden_size).mean(dim=0, keepdim=True) |
| if value is not None else h) |
| self.add(h, v) |
|
|
|
|
| |
| |
| |
|
|
| class STreeVerifier(nn.Module): |
| """Tiny scoring head used by speculative-tree decoding.""" |
|
|
| def __init__(self, tree_width: int = 4, tree_depth: int = 5, |
| hidden_size: int = 256): |
| super().__init__() |
| self.tree_width = int(tree_width) |
| self.tree_depth = int(tree_depth) |
| h_mid = max(8, hidden_size // 4) |
| self.score_net = nn.Sequential( |
| nn.Linear(hidden_size, h_mid), |
| nn.ReLU(inplace=True), |
| nn.Linear(h_mid, 1), |
| ) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| return torch.sigmoid(self.score_net(hidden_states)).squeeze(-1) |
|
|
|
|
| class CertificateVerifier(nn.Module): |
| """Per-token certificate fields (semantic / grammar / entity / risk).""" |
|
|
| def __init__(self, hidden_size: int): |
| super().__init__() |
| self.semantic_proj = nn.Linear(hidden_size, 64, bias=False) |
| self.grammar_proj = nn.Linear(hidden_size, 16, bias=False) |
| self.entity_proj = nn.Linear(hidden_size, 32, bias=False) |
| self.boundary_proj = nn.Linear(hidden_size, 1, bias=False) |
| self.risk_proj = nn.Linear(hidden_size, 1, bias=False) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> dict: |
| return { |
| "semantic": self.semantic_proj(hidden_states), |
| "grammar": self.grammar_proj(hidden_states), |
| "entity": self.entity_proj(hidden_states), |
| "boundary": self.boundary_proj(hidden_states), |
| "risk": torch.sigmoid(self.risk_proj(hidden_states)), |
| } |
|
|
|
|
| class SpanInferenceEngine(nn.Module): |
| """Risk-gated post-trunk hidden-state modulation.""" |
|
|
| def __init__(self, hidden_size: int, config: dict): |
| super().__init__() |
| self.enabled = bool(config.get("enabled", True)) |
| self.hidden_size = int(hidden_size) |
| self.span_bank = SpanBank( |
| max_entries=config.get("bank_entries", 524288), |
| max_tokens=config.get("bank_max_tokens", 64), |
| hidden_size=self.hidden_size, |
| memory_mb=config.get("bank_memory_mb", 384), |
| ) |
| self.tree_verifier = STreeVerifier( |
| tree_width=config.get("tree_verify", {}).get("tree_width", 4), |
| tree_depth=config.get("tree_verify", {}).get("tree_depth", 5), |
| hidden_size=self.hidden_size, |
| ) |
| self.certificate = CertificateVerifier(self.hidden_size) |
| self.scoring_weights = nn.Parameter( |
| torch.tensor(config.get("scoring_weights_fast", [1.0, 0.8, 0.5, 0.7, 0.35]))) |
| self.fallback_threshold = float(config.get("fallback_below_acceptance", 0.5)) |
| |
| self.risk_gate = nn.Linear(self.hidden_size + 1, self.hidden_size, bias=False) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| if not self.enabled: |
| return hidden_states |
| risk = torch.sigmoid(self.certificate.risk_proj(hidden_states)) |
| gate_input = torch.cat([hidden_states, risk], dim=-1) |
| modulation = torch.sigmoid(self.risk_gate(gate_input)) |
| return hidden_states * modulation |
|
|
|
|
| |
| |
| |
|
|
| class GrammarFST(nn.Module): |
| """Soft-constraint penalty on next-token logits. |
| |
| *Identity* when ``enabled`` is false **or** there are no constraints β |
| no entropy computation, no projection allocations. |
| """ |
|
|
| def __init__(self, config: dict): |
| super().__init__() |
| self.enabled = bool(config.get("enabled", True)) |
| self.hard_constraints = list(config.get("hard_constraints", [])) |
| self.soft_constraints = list(config.get("soft_constraints", [])) |
| n_features = len(self.hard_constraints) + len(self.soft_constraints) + 1 |
| self._n_hard = len(self.hard_constraints) |
| self._n_soft = len(self.soft_constraints) |
| self._n_features = n_features |
| self._is_noop = (not self.enabled) or n_features <= 1 |
| self.constraint_proj = nn.Linear(n_features, 1, bias=True) |
| nn.init.normal_(self.constraint_proj.weight, std=0.01) |
| nn.init.zeros_(self.constraint_proj.bias) |
|
|
| def forward(self, logits: torch.Tensor, state=None) -> torch.Tensor: |
| if self._is_noop: |
| return logits |
| B, T, V = logits.shape |
| |
| log_probs = F.log_softmax(logits, dim=-1) |
| entropy = -(log_probs.exp() * log_probs).sum(-1) |
| features = logits.new_zeros(B, T, self._n_features) |
| features[..., 0] = entropy |
| if self._n_soft > 0 and T > 1: |
| cos = F.cosine_similarity(logits[:, 1:], logits[:, :-1], dim=-1) |
| features[:, 1:, self._n_hard] = cos.clamp_min(0.0) |
| penalty = self.constraint_proj(features) |
| return logits + penalty |
|
|
|
|
| |
| |
| |
|
|
| class EntropyValve(nn.Module): |
| """Maps logits entropy β adaptive loop count for the looped trunk.""" |
|
|
| def __init__(self, config: dict): |
| super().__init__() |
| self.enabled = bool(config.get("enabled", True)) |
| self.threshold_bits = float(config.get("threshold_bits", 2.0)) |
| self.levels = dict(config.get("levels", { |
| "low": {"loops": 1, "min_span": 8, "audit": 0.125}, |
| "medium": {"loops": 2, "min_span": 4, "audit": 0.5}, |
| "high": {"loops": 4, "min_span": 1, "audit": 1.0}, |
| })) |
| self.router = nn.Sequential(nn.Linear(6, 32), nn.ReLU(inplace=True), |
| nn.Linear(32, 3)) |
| self._inv_log2 = 1.0 / math.log(2.0) |
|
|
| def compute_entropy(self, logits: torch.Tensor) -> torch.Tensor: |
| log_probs = F.log_softmax(logits.to(torch.float32), dim=-1) |
| return -(log_probs.exp() * log_probs).sum(dim=-1) * self._inv_log2 |
|
|
| def get_level(self, entropy: torch.Tensor) -> str: |
| if not self.enabled: |
| return "medium" |
| mean_h = float(entropy.mean().item()) |
| if mean_h < self.threshold_bits * 0.5: |
| return "low" |
| if mean_h < self.threshold_bits: |
| return "medium" |
| return "high" |
|
|
| def get_loop_count(self, logits: torch.Tensor) -> int: |
| if not self.enabled: |
| return self.levels.get("medium", {}).get("loops", 2) |
| level = self.get_level(self.compute_entropy(logits)) |
| return self.levels.get(level, self.levels["medium"])["loops"] |
|
|
| def forward(self, logits: torch.Tensor): |
| entropy = self.compute_entropy(logits) |
| level = self.get_level(entropy) |
| return level, self.levels.get(level, self.levels["medium"]) |
|
|
|
|
| |
| |
| |
|
|
| class DebtLedger(nn.Module): |
| def __init__(self, config: dict): |
| super().__init__() |
| self.enabled = bool(config.get("enabled", True)) |
| self.obligations = list(config.get("obligations", [])) |
| self.max_outstanding = int(config.get("max_outstanding", 64)) |
| self.pressure_weight = float(config.get("pressure_weight", 0.3)) |
| self.active_debts: list = [] |
| self.debt_bias_scale = nn.Parameter(torch.tensor(0.5)) |
| self.debt_proj = nn.Linear(1, 1, bias=True) |
| nn.init.ones_(self.debt_proj.weight) |
| nn.init.zeros_(self.debt_proj.bias) |
|
|
| def add_debt(self, debt_type: str) -> None: |
| if len(self.active_debts) < self.max_outstanding: |
| self.active_debts.append(debt_type) |
|
|
| def resolve_debt(self, debt_type: str) -> None: |
| try: |
| self.active_debts.remove(debt_type) |
| except ValueError: |
| pass |
|
|
| def get_pressure(self) -> float: |
| return self.pressure_weight * len(self.active_debts) / max(self.max_outstanding, 1) |
|
|
| def forward(self, logits: torch.Tensor) -> torch.Tensor: |
| if not self.enabled or not self.active_debts: |
| return logits |
| pressure = self.get_pressure() |
| if pressure <= 0.0: |
| return logits |
| boost = self.debt_bias_scale * pressure |
| boosted = self.debt_proj(boost.view(1, 1, 1)) |
| return logits + boosted * 0.01 |
|
|
|
|
| |
| |
| |
|
|
| class BraidState: |
| """Plain-Python structure holding the runtime working memory.""" |
|
|
| __slots__ = ["continuous", "fast", "semantic_sketch", "entity_slots", |
| "grammar_stack", "debt_ledger_slots"] |
|
|
| def __init__(self, config: dict, device: str = "cpu"): |
| D = int(config.get("continuous_hidden", [2560, "float32"])[0]) |
| self.continuous = torch.zeros(1, D, dtype=torch.float32, device=device) |
| self.fast = torch.zeros(1, D, dtype=torch.int8, device=device) |
| bits = int(config.get("semantic_sketch", [8192, "uint64_x128"])[0]) |
| self.semantic_sketch = torch.zeros(1, bits // 8, dtype=torch.uint8, device=device) |
| et = config.get("entity_table", {}) |
| self.entity_slots = torch.zeros( |
| int(et.get("slots", 256)), int(et.get("slot_bits", 512)) // 8, |
| dtype=torch.uint8, device=device) |
| gs = config.get("grammar_stack", {}) |
| self.grammar_stack = torch.zeros( |
| int(gs.get("slots", 64)), int(gs.get("width_bits", 128)) // 8, |
| dtype=torch.uint8, device=device) |
| self.debt_ledger_slots = torch.zeros( |
| int(config.get("debt_ledger_slots", 64)), dtype=torch.int32, device=device) |
|
|
| def reset(self) -> None: |
| self.continuous.zero_() |
| self.fast.zero_() |
| self.semantic_sketch.zero_() |
|
|
|
|
| __all__ = [ |
| "SpanBank", |
| "STreeVerifier", |
| "CertificateVerifier", |
| "SpanInferenceEngine", |
| "GrammarFST", |
| "EntropyValve", |
| "DebtLedger", |
| "BraidState", |
| ] |
|
|