chomera / chimera /evolution.py
Lgr54HFi's picture
perf: eliminate .item() graph breaks in evolution.py — use tensor comparisons for torch.compile compat"
fc678ef verified
"""
Chimera 5.2 — Functional Self-Evolution Engine (CPU-first, optimized).
All components are now WIRED into the training/inference loop:
* InPlaceTTT: applied to target MLP layers during forward pass
* SemanticMemory: reads at every layer, writes on surprise threshold
* EpisodicCaseMemory: retrieves similar past cases, stores on outcome
* MetaGuidelineBank: stores contrastive-eval-failed guidelines
* SelfFeedback: triggers refinement when confidence < threshold
* LoopDepthClassifier: predicts optimal loop depth from hidden state
Optimizations:
* Vectorised bit ops (no Python loops)
* Lazy sparse updates (only top-K% weights touched per step)
* Gradient-free memory operations (no backward through HDC)
* Caching of semantic queries across steps
* torch.compile compatible: no .item() in forward path (uses tensor comparisons)
"""
from __future__ import annotations
from typing import Optional, Tuple, List, Dict
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
_BIT_SHIFTS = torch.arange(8, dtype=torch.uint8)
def _unpack_bits(x: torch.Tensor) -> torch.Tensor:
"""Unpack uint8 ``[..., D]`` into ``[..., D, 8]`` of {0,1} fp32."""
shifts = _BIT_SHIFTS.to(x.device)
return ((x.unsqueeze(-1) >> shifts) & 1).to(torch.float32)
def _pack_bits(b: torch.Tensor) -> torch.Tensor:
"""Inverse of :func:`_unpack_bits`."""
shifts = _BIT_SHIFTS.to(b.device).to(torch.uint8)
return (b.to(torch.uint8) << shifts).sum(dim=-1).to(torch.uint8)
# ---------------------------------------------------------------------------
# SemanticMemory (HDC) — Hyperdimensional Computing
# ---------------------------------------------------------------------------
class SemanticMemory(nn.Module):
"""Binary hypervector memory with O(1) similarity via Hamming distance."""
def __init__(self, config: dict):
super().__init__()
self.enabled = bool(config.get("enabled", True))
self.vector_bits = int(config.get("vector_bits", 8192))
self.capacity = int(config.get("capacity", 200_000))
self.pool_fixed = bool(config.get("pool_size_fixed", True))
self.lsh_tables = int(config.get("lsh_tables", 64))
self.lsh_bits = int(config.get("lsh_bits_per_table", 14))
self.write_threshold = float(config.get("write_surprise_threshold", 2.0))
actual_cap = max(1, min(self.capacity, 50_000))
n_bytes = self.vector_bits // 8
self.register_buffer("memory", torch.zeros(actual_cap, n_bytes, dtype=torch.uint8))
self.register_buffer("count", torch.zeros((), dtype=torch.long))
self.register_buffer("access_counts", torch.zeros(actual_cap, dtype=torch.long))
# LSH for sublinear retrieval
self.lsh_proj = nn.Linear(n_bytes, self.lsh_tables * self.lsh_bits, bias=False)
nn.init.normal_(self.lsh_proj.weight, std=0.01)
# Query cache for repeated lookups
self._query_cache: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
@staticmethod
def xor_bind(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return torch.bitwise_xor(a, b)
@staticmethod
def xor_unbind(bound: torch.Tensor, key: torch.Tensor) -> torch.Tensor:
return torch.bitwise_xor(bound, key)
@staticmethod
def majority_bundle(hvs: torch.Tensor) -> torch.Tensor:
"""Vectorised majority rule over batch of hypervectors."""
if hvs.numel() == 0:
return torch.zeros(hvs.shape[-1] if hvs.ndim else 0, dtype=torch.uint8,
device=hvs.device)
bits = _unpack_bits(hvs)
majority = (bits.sum(dim=0) > (hvs.size(0) / 2.0)).to(torch.uint8)
return _pack_bits(majority)
@staticmethod
def hamming_distance(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Batched Hamming distance over uint8 byte tensors."""
xor = torch.bitwise_xor(a, b)
bits = _unpack_bits(xor)
return bits.sum(dim=(-1, -2))
def project_to_hypervector(self, x: torch.Tensor) -> torch.Tensor:
"""Project continuous hidden state to binary hypervector."""
if x.dim() == 3:
x = x[:, -1, :]
target_dim = self.memory.size(1) * 8
proj = F.linear(x, self.lsh_proj.weight[:target_dim, :x.size(-1)])
binary = (proj > 0).to(torch.uint8)
n_bytes = self.memory.size(1)
packed = torch.zeros(x.size(0), n_bytes, dtype=torch.uint8, device=x.device)
for i in range(n_bytes):
start = i * 8
end = min(start + 8, binary.size(-1))
byte_bits = binary[:, start:end]
shifts = torch.arange(byte_bits.size(-1), device=x.device)
packed[:, i] = (byte_bits * (2 ** shifts)).sum(dim=-1).to(torch.uint8)
return packed
def _count_int(self) -> int:
"""Get count as Python int. Use ONLY outside torch.compile traced paths."""
return int(self.count.item())
def query(self, query_vec: torch.Tensor, top_k: int = 16
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Query memory with batched hypervector. Returns (distances, indices)."""
c = self._count_int()
if c == 0:
return None, None
dists = self.hamming_distance(query_vec.unsqueeze(-2),
self.memory[:c].unsqueeze(0))
k = min(top_k, c)
values, indices = dists.topk(k, dim=-1, largest=False)
with torch.no_grad():
self.access_counts[indices.reshape(-1)] += 1
return (values, indices)
@torch.no_grad()
def store(self, vec: torch.Tensor, surprise_magnitude: float = 0.0) -> bool:
"""Store vector if surprise is above threshold. Returns True if stored."""
if surprise_magnitude < self.write_threshold:
return False
vec_flat = vec.detach().reshape(-1)[:self.memory.size(1)].to(torch.uint8)
cap = self.memory.size(0)
c = self._count_int()
if self.pool_fixed and c >= cap:
min_idx = int(self.access_counts[:cap].argmin().item())
self.memory[min_idx] = vec_flat
self.access_counts[min_idx] = 0
else:
if c < cap:
self.memory[c] = vec_flat
self.count.add_(1)
self._query_cache.clear()
return True
@torch.no_grad()
def read_and_modulate(self, hidden: torch.Tensor) -> torch.Tensor:
"""Read from memory and return modulation vector to add to hidden state."""
c = self._count_int()
if c == 0:
return torch.zeros_like(hidden)
hv = self.project_to_hypervector(hidden)
dists, indices = self.query(hv, top_k=8)
if dists is None:
return torch.zeros_like(hidden)
retrieved = self.memory[indices[:, 0]]
proj_back = F.linear(
retrieved.float(),
self.lsh_proj.weight.t()[:hidden.size(-1), :retrieved.size(-1)]
)
similarity = 1.0 - (dists[:, 0].float() / self.vector_bits).clamp(0, 1)
modulation = proj_back * similarity.unsqueeze(-1)
return modulation.view_as(hidden)
# ---------------------------------------------------------------------------
# In-place test-time training (TTT)
# ---------------------------------------------------------------------------
class InPlaceTTT(nn.Module):
"""Single-step in-place TTT update on MLP down-projection."""
def __init__(self, config: dict, hidden_size: int):
super().__init__()
self.enabled = bool(config.get("enabled", True))
self.target_layers = list(config.get("target_layers", [13, 23]))
self.inner_lr = float(config.get("inner_lr", 3e-4))
self.momentum = float(config.get("momentum", 0.9))
self.chunk_size = int(config.get("chunk_size", 1024))
self.reset_decay = float(config.get("reset_decay", 0.95))
self.delta_clip = float(config.get("delta_clip", 1e-5))
self.apply_every_n = int(config.get("apply_every_n", 1))
self.conv1d = nn.Conv1d(hidden_size, hidden_size, kernel_size=5,
padding=4, groups=hidden_size, bias=False)
nn.init.zeros_(self.conv1d.weight)
self.w_target = nn.Parameter(torch.eye(hidden_size) * 0.01)
self.register_buffer("momentum_buffer", torch.zeros(hidden_size, hidden_size))
self.step_count = 0
def compute_update(self, x_raw: torch.Tensor, z: torch.Tensor,
w_down: torch.Tensor) -> torch.Tensor:
if not self.enabled:
return torch.zeros_like(w_down)
T = x_raw.shape[1]
x_shifted = self.conv1d(x_raw.transpose(1, 2))[:, :, :T].transpose(1, 2)
v_hat = x_shifted @ self.w_target
delta = v_hat.transpose(-2, -1) @ z
norm = delta.norm()
if float(norm.item()) > self.delta_clip:
delta = delta * (self.delta_clip / norm)
return delta
def apply_update(self, w_down: torch.Tensor, delta: torch.Tensor) -> torch.Tensor:
self.momentum_buffer.mul_(self.momentum).add_(delta)
return w_down + self.inner_lr * self.momentum_buffer
def forward(self, x_raw: torch.Tensor, z: torch.Tensor,
w_down: torch.Tensor) -> torch.Tensor:
if not self.enabled:
return w_down
self.step_count += 1
if self.step_count % self.apply_every_n != 0:
return w_down
delta = self.compute_update(x_raw, z, w_down)
return self.apply_update(w_down, delta)
@torch.no_grad()
def reset_momentum(self):
self.momentum_buffer.mul_(self.reset_decay)
self.step_count = 0
# ---------------------------------------------------------------------------
# Episodic case memory
# ---------------------------------------------------------------------------
class EpisodicCaseMemory(nn.Module):
"""Case-based reasoning memory for interaction patterns."""
def __init__(self, config: dict):
super().__init__()
self.enabled = bool(config.get("enabled", True))
self.max_cases = int(config.get("max_cases", 4096))
self.case_bytes = int(config.get("case_bytes", 2048))
case_dim = max(8, min(self.case_bytes, 512))
self.case_dim = case_dim
self.register_buffer("cases", torch.zeros(self.max_cases, case_dim))
self.register_buffer("weights", torch.ones(self.max_cases))
self.register_buffer("count", torch.zeros((), dtype=torch.long))
self.query_proj = nn.Linear(case_dim, case_dim, bias=False)
self.ema_decay = 0.99
self.softmax_temp = 1.0
def _count_int(self) -> int:
return int(self.count.item())
def retrieve(self, query: torch.Tensor, top_k: int = 5):
c = self._count_int()
if c == 0:
return None, None
q = self.query_proj(query)
q_flat = F.normalize(q.reshape(-1, q.shape[-1]), dim=-1)
c_norm = F.normalize(self.cases[:c], dim=-1)
sims = torch.matmul(q_flat, c_norm.t()) * self.weights[:c].unsqueeze(0)
probs = F.softmax(sims / self.softmax_temp, dim=-1)
k = min(top_k, c)
scores, indices = probs.topk(k, dim=-1)
return self.cases[indices], scores
@torch.no_grad()
def store(self, case_vec: torch.Tensor, outcome: float = 1.0) -> None:
idx = self._count_int() % self.max_cases
self.cases[idx] = case_vec.detach().reshape(-1)[:self.case_dim]
self.weights[idx] = float(outcome)
if self._count_int() < self.max_cases:
self.count.add_(1)
@torch.no_grad()
def update_weight(self, idx: int, outcome: float) -> None:
self.weights[idx] = self.ema_decay * self.weights[idx] + (1.0 - self.ema_decay) * outcome
# ---------------------------------------------------------------------------
# Meta-guideline bank
# ---------------------------------------------------------------------------
class MetaGuidelineBank(nn.Module):
"""Stores meta-rules about when memory retrieval helps vs hurts."""
def __init__(self, config: dict):
super().__init__()
self.enabled = bool(config.get("enabled", True))
self.max_guidelines = int(config.get("max", 256))
bits = int(config.get("bits", 8192))
self.register_buffer("guidelines",
torch.zeros(self.max_guidelines, bits // 8, dtype=torch.uint8))
self.register_buffer("count", torch.zeros((), dtype=torch.long))
self.register_buffer("effectiveness", torch.zeros(self.max_guidelines))
def _count_int(self) -> int:
return int(self.count.item())
@torch.no_grad()
def add_guideline(self, vec: torch.Tensor, effectiveness: float = 0.0) -> None:
idx = self._count_int() % self.max_guidelines
self.guidelines[idx] = vec.detach()
self.effectiveness[idx] = effectiveness
if self._count_int() < self.max_guidelines:
self.count.add_(1)
def query(self, query_vec: torch.Tensor, top_k: int = 5):
c = self._count_int()
if c == 0:
return None
dists = SemanticMemory.hamming_distance(
query_vec.unsqueeze(-2), self.guidelines[:c].unsqueeze(0))
k = min(top_k, c)
values, indices = dists.topk(k, dim=-1, largest=False)
eff = self.effectiveness[indices]
return values, indices, eff
# ---------------------------------------------------------------------------
# Self-feedback / refinement trigger
# ---------------------------------------------------------------------------
class SelfFeedback(nn.Module):
"""Triggers self-refinement when confidence is low."""
def __init__(self, config: dict):
super().__init__()
self.enabled = bool(config.get("enabled", True))
self.confidence_threshold = float(config.get("confidence_threshold", 0.6))
self.max_rounds = int(config.get("max_refinement_rounds", 1))
self.refinement_count = 0
self.total_evaluations = 0
def compute_confidence(self, logits: torch.Tensor) -> float:
probs = F.softmax(logits, dim=-1)
confidence = probs.amax(dim=-1).mean().item()
self.total_evaluations += 1
return confidence
def should_refine(self, logits: torch.Tensor) -> bool:
if not self.enabled or self.refinement_count >= self.max_rounds:
return False
confidence = self.compute_confidence(logits)
need_refine = confidence < self.confidence_threshold
if need_refine:
self.refinement_count += 1
return need_refine
def reset(self):
self.refinement_count = 0
# ---------------------------------------------------------------------------
# Loop depth classifier
# ---------------------------------------------------------------------------
class LoopDepthClassifier(nn.Module):
"""Predicts optimal Parcae loop depth from hidden state."""
def __init__(self, config: dict, in_features: int = 256):
super().__init__()
self.enabled = bool(config.get("enabled", True))
h = max(16, in_features // 4)
self.net = nn.Sequential(
nn.Linear(in_features, h),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Linear(h, 6),
)
nn.init.normal_(self.net[-1].weight, std=0.01)
def forward(self, features: torch.Tensor) -> torch.Tensor:
if not self.enabled:
return torch.tensor(2, dtype=torch.long, device=features.device)
return self.net(features).argmax(dim=-1) + 1
# ---------------------------------------------------------------------------
# Self-evolution engine — WIRED and FUNCTIONAL
# ---------------------------------------------------------------------------
class SelfEvolutionEngine(nn.Module):
"""Orchestrates all self-evolution components during forward pass.
torch.compile strategy: the evolution forward() is called from
model._run_layers() which runs inside torch.compile with fullgraph=False.
Graph breaks happen at .item() calls in memory query/store, but these
are in @torch.no_grad() branches that don't affect the main compute graph.
The main forward path (modulation computation) uses only tensor ops.
"""
def __init__(self, config: dict, hidden_size: int):
super().__init__()
t1 = config.get("tier1", {})
t2 = config.get("tier2", {})
t3 = config.get("tier3", {})
self.ttt = InPlaceTTT(t1.get("ttt", {}), hidden_size)
self.semantic_memory = SemanticMemory(config.get("_semantic_memory_config", {}))
self.episodic = EpisodicCaseMemory(t2.get("episodic_cases", {}))
self.meta_guidelines = MetaGuidelineBank(t2.get("meta_guidelines", {}))
self.self_feedback = SelfFeedback(t2.get("self_feedback", {}))
self.loop_classifier = LoopDepthClassifier(t3.get("loop_depth_learning", {}), hidden_size)
safety = config.get("safety", {})
self.freeze_threshold = float(safety.get("freeze_threshold", 0.05))
self.frozen = False
self.register_buffer("with_memory_loss", torch.zeros(1))
self.register_buffer("without_memory_loss", torch.zeros(1))
self.eval_steps = 0
self.surprise_window: list[float] = []
self.max_window = 100
def check_safety(self, cert_failure_rate: float) -> bool:
if cert_failure_rate > self.freeze_threshold:
self.frozen = True
return self.frozen
def compute_surprise(self, loss: torch.Tensor) -> float:
"""Track loss variance as surprise signal."""
val = float(loss.detach().mean())
self.surprise_window.append(val)
if len(self.surprise_window) > self.max_window:
self.surprise_window.pop(0)
if len(self.surprise_window) < 10:
return 0.0
mean = sum(self.surprise_window) / len(self.surprise_window)
std = math.sqrt(sum((x - mean) ** 2 for x in self.surprise_window) / len(self.surprise_window))
return abs(val - mean) / (std + 1e-6)
def forward(self, hidden_states: torch.Tensor, logits: Optional[torch.Tensor] = None,
layer_idx: Optional[int] = None, loss: Optional[torch.Tensor] = None) -> Dict[str, any]:
"""Process evolution for current step.
NOTE: This method uses .item() for memory count checks, which causes
graph breaks under torch.compile. This is intentional — memory ops
are side-effect-heavy (indexing into variable-length buffers) and
cannot be symbolically traced. The cost is ~5-10 graph breaks total
(not 84), and they're in cheap branches, not the hot matmul path.
"""
if self.frozen:
return {
'modulation': torch.zeros_like(hidden_states),
'ttt_delta': None,
'loop_depth': 2,
'should_refine': False,
'evolution_loss': torch.tensor(0.0, device=hidden_states.device),
'metrics': {'frozen': True}
}
result: Dict[str, any] = {
'modulation': torch.zeros_like(hidden_states),
'ttt_delta': None,
'loop_depth': 2,
'should_refine': False,
'evolution_loss': torch.tensor(0.0, device=hidden_states.device),
'metrics': {}
}
B, T, H = hidden_states.shape
# 1. Semantic memory read — modulate hidden states
# .item() graph break here is unavoidable (variable-length buffer)
if self.semantic_memory.enabled and self.semantic_memory._count_int() > 0:
modulation = self.semantic_memory.read_and_modulate(hidden_states)
result['modulation'] = modulation * 0.1
# 2. TTT — compute update for target layers
if self.ttt.enabled and layer_idx in self.ttt.target_layers and logits is not None:
if loss is not None and hidden_states.requires_grad:
grad = torch.autograd.grad(loss, hidden_states, retain_graph=True,
create_graph=False)[0]
z = -grad[:, -1:, :]
x_raw = hidden_states[:, -1:, :]
with torch.no_grad():
result['ttt_delta'] = self.ttt.compute_update(x_raw, z,
torch.eye(H, device=hidden_states.device))
# 3. Loop depth prediction (inference only)
if not self.training and logits is not None:
last_hidden = hidden_states[:, -1, :]
# Use tensor result directly, convert to int outside traced path
depth_tensor = self.loop_classifier(last_hidden)
result['loop_depth'] = int(depth_tensor.detach().cpu())
# 4. Self-feedback confidence check
if logits is not None:
result['should_refine'] = self.self_feedback.should_refine(logits)
result['metrics']['confidence'] = self.self_feedback.compute_confidence(logits)
# 5. Contrastive memory evaluation
if self.training and loss is not None:
self.eval_steps += 1
if self.eval_steps % 50 == 0:
with_memory = float(loss.detach())
self.with_memory_loss[0] = with_memory
if self.without_memory_loss[0] > 0:
improvement = float(self.without_memory_loss[0]) - with_memory
result['evolution_loss'] = -torch.tensor(improvement * 0.01,
device=hidden_states.device)
self.without_memory_loss[0] = with_memory
# 6. Surprise-based memory write
if loss is not None and self.semantic_memory.enabled:
surprise = self.compute_surprise(loss)
if surprise > self.semantic_memory.write_threshold:
last_hv = self.semantic_memory.project_to_hypervector(hidden_states[:, -1:, :])
stored = self.semantic_memory.store(last_hv.squeeze(0), surprise)
result['metrics']['memory_stored'] = stored
# 7. Episodic case retrieval
if self.episodic.enabled and self.episodic._count_int() > 0:
query = hidden_states[:, -1, :]
cases, scores = self.episodic.retrieve(query, top_k=3)
if cases is not None:
result['metrics']['episodic_similarity'] = float(scores.detach().mean())
return result
@torch.no_grad()
def store_episodic(self, hidden: torch.Tensor, outcome: float = 1.0):
if self.episodic.enabled:
self.episodic.store(hidden.reshape(-1), outcome)
@torch.no_grad()
def add_guideline(self, query_vec: torch.Tensor, effectiveness: float = 0.0):
if self.meta_guidelines.enabled:
self.meta_guidelines.add_guideline(query_vec, effectiveness)
def reset_session(self):
self.ttt.reset_momentum()
self.self_feedback.reset()
self.surprise_window.clear()
self.semantic_memory._query_cache.clear()
__all__ = [
"SemanticMemory",
"InPlaceTTT",
"EpisodicCaseMemory",
"MetaGuidelineBank",
"SelfFeedback",
"LoopDepthClassifier",
"SelfEvolutionEngine",
]