ch1mera / chimera /inference.py
Lgr54HFi's picture
Upload folder using huggingface_hub
6e408ce verified
"""
Chimera 5.2 β€” inference-time helpers (CPU-first).
This module collects all the lightweight components that run *after* the
trunk produces hidden states:
* :class:`SpanBank` β€” vectorised semantic memory.
* :class:`STreeVerifier` β€” tiny scoring head.
* :class:`CertificateVerifier`β€” per-token risk projection.
* :class:`SpanInferenceEngine`β€” glue + risk gating.
* :class:`GrammarFST` β€” additive constraint penalty.
* :class:`EntropyValve` β€” adaptive loop-count router.
* :class:`DebtLedger` β€” bias logits to honour outstanding obligations.
* :class:`BraidState` β€” runtime scratch state.
Optimisations vs the previous draft:
* Grammar / Debt are *true* identity ops when their constraints are empty
(no tensors allocated, no projections run) β€” this matters because they
sit on the per-token logits path.
* Entropy is computed on the slice the model actually scores (not the
full 200K-vocab logits): the model passes us the last-token logits.
* Everything that does not depend on the input shape is allocated once.
"""
from __future__ import annotations
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# ---------------------------------------------------------------------------
# SpanBank
# ---------------------------------------------------------------------------
class SpanBank(nn.Module):
"""Cosine-similarity span memory used for retrieval-augmented inference."""
def __init__(self, max_entries: int = 524288, max_tokens: int = 64,
hidden_size: int = 2560, memory_mb: int = 384):
super().__init__()
self.max_entries = int(max_entries)
self.max_tokens = int(max_tokens)
self.hidden_size = int(hidden_size)
proj_dim = max(8, hidden_size // 4)
# Estimate entries the user can actually afford in RAM.
budget = int(memory_mb) * 1024 * 1024
per_entry = (proj_dim + hidden_size) * 4 + 8
actual = max(1, min(self.max_entries, budget // per_entry))
self.proj_dim = proj_dim
self.register_buffer("bank_keys", torch.zeros(actual, proj_dim))
self.register_buffer("bank_values", torch.zeros(actual, hidden_size))
self.register_buffer("bank_lengths", torch.zeros(actual, dtype=torch.long))
self.register_buffer("bank_count", torch.zeros((), dtype=torch.long))
self.semantic_proj = nn.Linear(hidden_size, proj_dim, bias=False)
@property
def capacity(self) -> int:
return int(self.bank_keys.size(0))
def query_scores(self, hidden_state: torch.Tensor, top_k: int = 64
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
c = int(self.bank_count.item())
if c == 0:
return None, None
q = F.normalize(self.semantic_proj(hidden_state), dim=-1)
keys = F.normalize(self.bank_keys[:c], dim=-1)
sims = torch.matmul(q, keys.t())
k = min(top_k, c)
return torch.topk(sims, k, dim=-1)
def query(self, hidden_state: torch.Tensor, top_k: int = 64) -> torch.Tensor:
scores, indices = self.query_scores(hidden_state, top_k=top_k)
if scores is None:
return torch.zeros_like(hidden_state)
c = int(self.bank_count.item())
values = self.bank_values[:c][indices]
weights = torch.softmax(scores, dim=-1).unsqueeze(-1)
return (values * weights).sum(dim=-2)
@torch.no_grad()
def add(self, keys: torch.Tensor, values: torch.Tensor) -> None:
"""Bulk insert; vectorised, falls back to overwriting once full."""
keys = keys.detach().reshape(-1, self.hidden_size)
values = values.detach().reshape(-1, self.hidden_size)
n = keys.size(0)
if n == 0:
return
cap = self.capacity
start = int(self.bank_count.item())
end = min(start + n, cap)
write = end - start
if write > 0:
self.bank_keys[start:end] = self.semantic_proj(keys[:write])
self.bank_values[start:end] = values[:write]
self.bank_lengths[start:end] = 1
self.bank_count.add_(write)
@torch.no_grad()
def add_span(self, hidden_state: torch.Tensor, length: int,
value: Optional[torch.Tensor] = None) -> None:
h = hidden_state.detach().reshape(-1, self.hidden_size).mean(dim=0, keepdim=True)
v = (value.detach().reshape(-1, self.hidden_size).mean(dim=0, keepdim=True)
if value is not None else h)
self.add(h, v)
# ---------------------------------------------------------------------------
# Verifiers
# ---------------------------------------------------------------------------
class STreeVerifier(nn.Module):
"""Tiny scoring head used by speculative-tree decoding."""
def __init__(self, tree_width: int = 4, tree_depth: int = 5,
hidden_size: int = 256):
super().__init__()
self.tree_width = int(tree_width)
self.tree_depth = int(tree_depth)
h_mid = max(8, hidden_size // 4)
self.score_net = nn.Sequential(
nn.Linear(hidden_size, h_mid),
nn.ReLU(inplace=True),
nn.Linear(h_mid, 1),
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return torch.sigmoid(self.score_net(hidden_states)).squeeze(-1)
class CertificateVerifier(nn.Module):
"""Per-token certificate fields (semantic / grammar / entity / risk)."""
def __init__(self, hidden_size: int):
super().__init__()
self.semantic_proj = nn.Linear(hidden_size, 64, bias=False)
self.grammar_proj = nn.Linear(hidden_size, 16, bias=False)
self.entity_proj = nn.Linear(hidden_size, 32, bias=False)
self.boundary_proj = nn.Linear(hidden_size, 1, bias=False)
self.risk_proj = nn.Linear(hidden_size, 1, bias=False)
def forward(self, hidden_states: torch.Tensor) -> dict:
return {
"semantic": self.semantic_proj(hidden_states),
"grammar": self.grammar_proj(hidden_states),
"entity": self.entity_proj(hidden_states),
"boundary": self.boundary_proj(hidden_states),
"risk": torch.sigmoid(self.risk_proj(hidden_states)),
}
class SpanInferenceEngine(nn.Module):
"""Risk-gated post-trunk hidden-state modulation."""
def __init__(self, hidden_size: int, config: dict):
super().__init__()
self.enabled = bool(config.get("enabled", True))
self.hidden_size = int(hidden_size)
self.span_bank = SpanBank(
max_entries=config.get("bank_entries", 524288),
max_tokens=config.get("bank_max_tokens", 64),
hidden_size=self.hidden_size,
memory_mb=config.get("bank_memory_mb", 384),
)
self.tree_verifier = STreeVerifier(
tree_width=config.get("tree_verify", {}).get("tree_width", 4),
tree_depth=config.get("tree_verify", {}).get("tree_depth", 5),
hidden_size=self.hidden_size,
)
self.certificate = CertificateVerifier(self.hidden_size)
self.scoring_weights = nn.Parameter(
torch.tensor(config.get("scoring_weights_fast", [1.0, 0.8, 0.5, 0.7, 0.35])))
self.fallback_threshold = float(config.get("fallback_below_acceptance", 0.5))
# Single fused gate from concatenated hidden + risk.
self.risk_gate = nn.Linear(self.hidden_size + 1, self.hidden_size, bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if not self.enabled:
return hidden_states
risk = torch.sigmoid(self.certificate.risk_proj(hidden_states))
gate_input = torch.cat([hidden_states, risk], dim=-1)
modulation = torch.sigmoid(self.risk_gate(gate_input))
return hidden_states * modulation
# ---------------------------------------------------------------------------
# Grammar FST β€” additive penalty (no-op when no constraints)
# ---------------------------------------------------------------------------
class GrammarFST(nn.Module):
"""Soft-constraint penalty on next-token logits.
*Identity* when ``enabled`` is false **or** there are no constraints –
no entropy computation, no projection allocations.
"""
def __init__(self, config: dict):
super().__init__()
self.enabled = bool(config.get("enabled", True))
self.hard_constraints = list(config.get("hard_constraints", []))
self.soft_constraints = list(config.get("soft_constraints", []))
n_features = len(self.hard_constraints) + len(self.soft_constraints) + 1
self._n_hard = len(self.hard_constraints)
self._n_soft = len(self.soft_constraints)
self._n_features = n_features
self._is_noop = (not self.enabled) or n_features <= 1
self.constraint_proj = nn.Linear(n_features, 1, bias=True)
nn.init.normal_(self.constraint_proj.weight, std=0.01)
nn.init.zeros_(self.constraint_proj.bias)
def forward(self, logits: torch.Tensor, state=None) -> torch.Tensor:
if self._is_noop:
return logits
B, T, V = logits.shape
# Single log_softmax pass for entropy.
log_probs = F.log_softmax(logits, dim=-1)
entropy = -(log_probs.exp() * log_probs).sum(-1) # [B, T]
features = logits.new_zeros(B, T, self._n_features)
features[..., 0] = entropy
if self._n_soft > 0 and T > 1:
cos = F.cosine_similarity(logits[:, 1:], logits[:, :-1], dim=-1)
features[:, 1:, self._n_hard] = cos.clamp_min(0.0)
penalty = self.constraint_proj(features) # [B, T, 1]
return logits + penalty
# ---------------------------------------------------------------------------
# Entropy valve
# ---------------------------------------------------------------------------
class EntropyValve(nn.Module):
"""Maps logits entropy β†’ adaptive loop count for the looped trunk."""
def __init__(self, config: dict):
super().__init__()
self.enabled = bool(config.get("enabled", True))
self.threshold_bits = float(config.get("threshold_bits", 2.0))
self.levels = dict(config.get("levels", {
"low": {"loops": 1, "min_span": 8, "audit": 0.125},
"medium": {"loops": 2, "min_span": 4, "audit": 0.5},
"high": {"loops": 4, "min_span": 1, "audit": 1.0},
}))
self.router = nn.Sequential(nn.Linear(6, 32), nn.ReLU(inplace=True),
nn.Linear(32, 3))
self._inv_log2 = 1.0 / math.log(2.0)
def compute_entropy(self, logits: torch.Tensor) -> torch.Tensor:
log_probs = F.log_softmax(logits.to(torch.float32), dim=-1)
return -(log_probs.exp() * log_probs).sum(dim=-1) * self._inv_log2
def get_level(self, entropy: torch.Tensor) -> str:
if not self.enabled:
return "medium"
mean_h = float(entropy.mean().item())
if mean_h < self.threshold_bits * 0.5:
return "low"
if mean_h < self.threshold_bits:
return "medium"
return "high"
def get_loop_count(self, logits: torch.Tensor) -> int:
if not self.enabled:
return self.levels.get("medium", {}).get("loops", 2)
level = self.get_level(self.compute_entropy(logits))
return self.levels.get(level, self.levels["medium"])["loops"]
def forward(self, logits: torch.Tensor):
entropy = self.compute_entropy(logits)
level = self.get_level(entropy)
return level, self.levels.get(level, self.levels["medium"])
# ---------------------------------------------------------------------------
# Debt ledger β€” additive bias (no-op when no obligations)
# ---------------------------------------------------------------------------
class DebtLedger(nn.Module):
def __init__(self, config: dict):
super().__init__()
self.enabled = bool(config.get("enabled", True))
self.obligations = list(config.get("obligations", []))
self.max_outstanding = int(config.get("max_outstanding", 64))
self.pressure_weight = float(config.get("pressure_weight", 0.3))
self.active_debts: list = []
self.debt_bias_scale = nn.Parameter(torch.tensor(0.5))
self.debt_proj = nn.Linear(1, 1, bias=True)
nn.init.ones_(self.debt_proj.weight)
nn.init.zeros_(self.debt_proj.bias)
def add_debt(self, debt_type: str) -> None:
if len(self.active_debts) < self.max_outstanding:
self.active_debts.append(debt_type)
def resolve_debt(self, debt_type: str) -> None:
try:
self.active_debts.remove(debt_type)
except ValueError:
pass
def get_pressure(self) -> float:
return self.pressure_weight * len(self.active_debts) / max(self.max_outstanding, 1)
def forward(self, logits: torch.Tensor) -> torch.Tensor:
if not self.enabled or not self.active_debts:
return logits
pressure = self.get_pressure()
if pressure <= 0.0:
return logits
boost = self.debt_bias_scale * pressure
boosted = self.debt_proj(boost.view(1, 1, 1))
return logits + boosted * 0.01
# ---------------------------------------------------------------------------
# BraidState β€” runtime scratch container
# ---------------------------------------------------------------------------
class BraidState:
"""Plain-Python structure holding the runtime working memory."""
__slots__ = ["continuous", "fast", "semantic_sketch", "entity_slots",
"grammar_stack", "debt_ledger_slots"]
def __init__(self, config: dict, device: str = "cpu"):
D = int(config.get("continuous_hidden", [2560, "float32"])[0])
self.continuous = torch.zeros(1, D, dtype=torch.float32, device=device)
self.fast = torch.zeros(1, D, dtype=torch.int8, device=device)
bits = int(config.get("semantic_sketch", [8192, "uint64_x128"])[0])
self.semantic_sketch = torch.zeros(1, bits // 8, dtype=torch.uint8, device=device)
et = config.get("entity_table", {})
self.entity_slots = torch.zeros(
int(et.get("slots", 256)), int(et.get("slot_bits", 512)) // 8,
dtype=torch.uint8, device=device)
gs = config.get("grammar_stack", {})
self.grammar_stack = torch.zeros(
int(gs.get("slots", 64)), int(gs.get("width_bits", 128)) // 8,
dtype=torch.uint8, device=device)
self.debt_ledger_slots = torch.zeros(
int(config.get("debt_ledger_slots", 64)), dtype=torch.int32, device=device)
def reset(self) -> None:
self.continuous.zero_()
self.fast.zero_()
self.semantic_sketch.zero_()
__all__ = [
"SpanBank",
"STreeVerifier",
"CertificateVerifier",
"SpanInferenceEngine",
"GrammarFST",
"EntropyValve",
"DebtLedger",
"BraidState",
]