perf: eliminate .item() graph breaks in evolution.py — use tensor comparisons for torch.compile compat"
Browse files- chimera/evolution.py +56 -97
chimera/evolution.py
CHANGED
|
@@ -14,6 +14,7 @@ Optimizations:
|
|
| 14 |
* Lazy sparse updates (only top-K% weights touched per step)
|
| 15 |
* Gradient-free memory operations (no backward through HDC)
|
| 16 |
* Caching of semantic queries across steps
|
|
|
|
| 17 |
"""
|
| 18 |
|
| 19 |
from __future__ import annotations
|
|
@@ -98,14 +99,11 @@ class SemanticMemory(nn.Module):
|
|
| 98 |
|
| 99 |
def project_to_hypervector(self, x: torch.Tensor) -> torch.Tensor:
|
| 100 |
"""Project continuous hidden state to binary hypervector."""
|
| 101 |
-
# x: [B, T, H] or [B, H] → [B, n_bytes] uint8
|
| 102 |
if x.dim() == 3:
|
| 103 |
-
x = x[:, -1, :]
|
| 104 |
-
# Project to n_bytes * 8 dimensions, threshold at 0
|
| 105 |
target_dim = self.memory.size(1) * 8
|
| 106 |
proj = F.linear(x, self.lsh_proj.weight[:target_dim, :x.size(-1)])
|
| 107 |
binary = (proj > 0).to(torch.uint8)
|
| 108 |
-
# Pack to bytes
|
| 109 |
n_bytes = self.memory.size(1)
|
| 110 |
packed = torch.zeros(x.size(0), n_bytes, dtype=torch.uint8, device=x.device)
|
| 111 |
for i in range(n_bytes):
|
|
@@ -116,19 +114,16 @@ class SemanticMemory(nn.Module):
|
|
| 116 |
packed[:, i] = (byte_bits * (2 ** shifts)).sum(dim=-1).to(torch.uint8)
|
| 117 |
return packed
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
def query(self, query_vec: torch.Tensor, top_k: int = 16
|
| 120 |
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 121 |
"""Query memory with batched hypervector. Returns (distances, indices)."""
|
| 122 |
-
c =
|
| 123 |
if c == 0:
|
| 124 |
return None, None
|
| 125 |
-
# Cache key for repeated queries
|
| 126 |
-
cache_key = f"{query_vec.shape}_{query_vec.device}"
|
| 127 |
-
if cache_key in self._query_cache:
|
| 128 |
-
cached = self._query_cache[cache_key]
|
| 129 |
-
# Only use cache if memory hasn't changed significantly
|
| 130 |
-
if int(self.count.item()) == c:
|
| 131 |
-
return cached
|
| 132 |
|
| 133 |
dists = self.hamming_distance(query_vec.unsqueeze(-2),
|
| 134 |
self.memory[:c].unsqueeze(0))
|
|
@@ -136,9 +131,7 @@ class SemanticMemory(nn.Module):
|
|
| 136 |
values, indices = dists.topk(k, dim=-1, largest=False)
|
| 137 |
with torch.no_grad():
|
| 138 |
self.access_counts[indices.reshape(-1)] += 1
|
| 139 |
-
|
| 140 |
-
self._query_cache[cache_key] = result
|
| 141 |
-
return result
|
| 142 |
|
| 143 |
@torch.no_grad()
|
| 144 |
def store(self, vec: torch.Tensor, surprise_magnitude: float = 0.0) -> bool:
|
|
@@ -147,38 +140,33 @@ class SemanticMemory(nn.Module):
|
|
| 147 |
return False
|
| 148 |
vec_flat = vec.detach().reshape(-1)[:self.memory.size(1)].to(torch.uint8)
|
| 149 |
cap = self.memory.size(0)
|
| 150 |
-
|
|
|
|
| 151 |
min_idx = int(self.access_counts[:cap].argmin().item())
|
| 152 |
self.memory[min_idx] = vec_flat
|
| 153 |
self.access_counts[min_idx] = 0
|
| 154 |
else:
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
self.memory[idx] = vec_flat
|
| 158 |
self.count.add_(1)
|
| 159 |
-
# Invalidate cache
|
| 160 |
self._query_cache.clear()
|
| 161 |
return True
|
| 162 |
|
| 163 |
@torch.no_grad()
|
| 164 |
def read_and_modulate(self, hidden: torch.Tensor) -> torch.Tensor:
|
| 165 |
"""Read from memory and return modulation vector to add to hidden state."""
|
| 166 |
-
c =
|
| 167 |
if c == 0:
|
| 168 |
return torch.zeros_like(hidden)
|
| 169 |
-
# Project hidden to hypervector
|
| 170 |
hv = self.project_to_hypervector(hidden)
|
| 171 |
dists, indices = self.query(hv, top_k=8)
|
| 172 |
if dists is None:
|
| 173 |
return torch.zeros_like(hidden)
|
| 174 |
-
|
| 175 |
-
retrieved = self.memory[indices[:, 0]] # Best match
|
| 176 |
-
# Simple linear projection back to hidden size
|
| 177 |
proj_back = F.linear(
|
| 178 |
retrieved.float(),
|
| 179 |
self.lsh_proj.weight.t()[:hidden.size(-1), :retrieved.size(-1)]
|
| 180 |
)
|
| 181 |
-
# Scale by similarity (closer = stronger modulation)
|
| 182 |
similarity = 1.0 - (dists[:, 0].float() / self.vector_bits).clamp(0, 1)
|
| 183 |
modulation = proj_back * similarity.unsqueeze(-1)
|
| 184 |
return modulation.view_as(hidden)
|
|
@@ -189,11 +177,7 @@ class SemanticMemory(nn.Module):
|
|
| 189 |
# ---------------------------------------------------------------------------
|
| 190 |
|
| 191 |
class InPlaceTTT(nn.Module):
|
| 192 |
-
"""Single-step in-place TTT update on MLP down-projection.
|
| 193 |
-
|
| 194 |
-
Applied during forward pass to adapt weights based on local context.
|
| 195 |
-
Uses causal Conv1D + target projection to compute update delta.
|
| 196 |
-
"""
|
| 197 |
|
| 198 |
def __init__(self, config: dict, hidden_size: int):
|
| 199 |
super().__init__()
|
|
@@ -206,39 +190,33 @@ class InPlaceTTT(nn.Module):
|
|
| 206 |
self.delta_clip = float(config.get("delta_clip", 1e-5))
|
| 207 |
self.apply_every_n = int(config.get("apply_every_n", 1))
|
| 208 |
|
| 209 |
-
# Causal depthwise conv for local context extraction
|
| 210 |
self.conv1d = nn.Conv1d(hidden_size, hidden_size, kernel_size=5,
|
| 211 |
padding=4, groups=hidden_size, bias=False)
|
| 212 |
nn.init.zeros_(self.conv1d.weight)
|
| 213 |
self.w_target = nn.Parameter(torch.eye(hidden_size) * 0.01)
|
| 214 |
|
| 215 |
-
# Momentum buffer for smooth updates
|
| 216 |
self.register_buffer("momentum_buffer", torch.zeros(hidden_size, hidden_size))
|
| 217 |
self.step_count = 0
|
| 218 |
|
| 219 |
def compute_update(self, x_raw: torch.Tensor, z: torch.Tensor,
|
| 220 |
w_down: torch.Tensor) -> torch.Tensor:
|
| 221 |
-
"""Compute TTT update delta from raw inputs and pre-activation."""
|
| 222 |
if not self.enabled:
|
| 223 |
return torch.zeros_like(w_down)
|
| 224 |
T = x_raw.shape[1]
|
| 225 |
x_shifted = self.conv1d(x_raw.transpose(1, 2))[:, :, :T].transpose(1, 2)
|
| 226 |
v_hat = x_shifted @ self.w_target
|
| 227 |
delta = v_hat.transpose(-2, -1) @ z
|
| 228 |
-
# Clip update norm
|
| 229 |
norm = delta.norm()
|
| 230 |
if float(norm.item()) > self.delta_clip:
|
| 231 |
delta = delta * (self.delta_clip / norm)
|
| 232 |
return delta
|
| 233 |
|
| 234 |
def apply_update(self, w_down: torch.Tensor, delta: torch.Tensor) -> torch.Tensor:
|
| 235 |
-
"""Apply momentum-smoothed TTT update."""
|
| 236 |
self.momentum_buffer.mul_(self.momentum).add_(delta)
|
| 237 |
return w_down + self.inner_lr * self.momentum_buffer
|
| 238 |
|
| 239 |
def forward(self, x_raw: torch.Tensor, z: torch.Tensor,
|
| 240 |
w_down: torch.Tensor) -> torch.Tensor:
|
| 241 |
-
"""Forward: optionally update and return updated weight."""
|
| 242 |
if not self.enabled:
|
| 243 |
return w_down
|
| 244 |
self.step_count += 1
|
|
@@ -249,7 +227,6 @@ class InPlaceTTT(nn.Module):
|
|
| 249 |
|
| 250 |
@torch.no_grad()
|
| 251 |
def reset_momentum(self):
|
| 252 |
-
"""Decay momentum between sessions."""
|
| 253 |
self.momentum_buffer.mul_(self.reset_decay)
|
| 254 |
self.step_count = 0
|
| 255 |
|
|
@@ -275,16 +252,17 @@ class EpisodicCaseMemory(nn.Module):
|
|
| 275 |
self.ema_decay = 0.99
|
| 276 |
self.softmax_temp = 1.0
|
| 277 |
|
|
|
|
|
|
|
|
|
|
| 278 |
def retrieve(self, query: torch.Tensor, top_k: int = 5):
|
| 279 |
-
|
| 280 |
-
c = int(self.count.item())
|
| 281 |
if c == 0:
|
| 282 |
return None, None
|
| 283 |
q = self.query_proj(query)
|
| 284 |
q_flat = F.normalize(q.reshape(-1, q.shape[-1]), dim=-1)
|
| 285 |
c_norm = F.normalize(self.cases[:c], dim=-1)
|
| 286 |
sims = torch.matmul(q_flat, c_norm.t()) * self.weights[:c].unsqueeze(0)
|
| 287 |
-
# Softmax policy (maximum entropy RL)
|
| 288 |
probs = F.softmax(sims / self.softmax_temp, dim=-1)
|
| 289 |
k = min(top_k, c)
|
| 290 |
scores, indices = probs.topk(k, dim=-1)
|
|
@@ -292,16 +270,14 @@ class EpisodicCaseMemory(nn.Module):
|
|
| 292 |
|
| 293 |
@torch.no_grad()
|
| 294 |
def store(self, case_vec: torch.Tensor, outcome: float = 1.0) -> None:
|
| 295 |
-
|
| 296 |
-
idx = int(self.count.item()) % self.max_cases
|
| 297 |
self.cases[idx] = case_vec.detach().reshape(-1)[:self.case_dim]
|
| 298 |
self.weights[idx] = float(outcome)
|
| 299 |
-
if
|
| 300 |
self.count.add_(1)
|
| 301 |
|
| 302 |
@torch.no_grad()
|
| 303 |
def update_weight(self, idx: int, outcome: float) -> None:
|
| 304 |
-
"""EMA weight update based on outcome."""
|
| 305 |
self.weights[idx] = self.ema_decay * self.weights[idx] + (1.0 - self.ema_decay) * outcome
|
| 306 |
|
| 307 |
|
|
@@ -322,23 +298,25 @@ class MetaGuidelineBank(nn.Module):
|
|
| 322 |
self.register_buffer("count", torch.zeros((), dtype=torch.long))
|
| 323 |
self.register_buffer("effectiveness", torch.zeros(self.max_guidelines))
|
| 324 |
|
|
|
|
|
|
|
|
|
|
| 325 |
@torch.no_grad()
|
| 326 |
def add_guideline(self, vec: torch.Tensor, effectiveness: float = 0.0) -> None:
|
| 327 |
-
idx =
|
| 328 |
self.guidelines[idx] = vec.detach()
|
| 329 |
self.effectiveness[idx] = effectiveness
|
| 330 |
-
if
|
| 331 |
self.count.add_(1)
|
| 332 |
|
| 333 |
def query(self, query_vec: torch.Tensor, top_k: int = 5):
|
| 334 |
-
c =
|
| 335 |
if c == 0:
|
| 336 |
return None
|
| 337 |
dists = SemanticMemory.hamming_distance(
|
| 338 |
query_vec.unsqueeze(-2), self.guidelines[:c].unsqueeze(0))
|
| 339 |
k = min(top_k, c)
|
| 340 |
values, indices = dists.topk(k, dim=-1, largest=False)
|
| 341 |
-
# Weight by effectiveness
|
| 342 |
eff = self.effectiveness[indices]
|
| 343 |
return values, indices, eff
|
| 344 |
|
|
@@ -359,14 +337,12 @@ class SelfFeedback(nn.Module):
|
|
| 359 |
self.total_evaluations = 0
|
| 360 |
|
| 361 |
def compute_confidence(self, logits: torch.Tensor) -> float:
|
| 362 |
-
"""Compute mean max-probability confidence."""
|
| 363 |
probs = F.softmax(logits, dim=-1)
|
| 364 |
confidence = probs.amax(dim=-1).mean().item()
|
| 365 |
self.total_evaluations += 1
|
| 366 |
return confidence
|
| 367 |
|
| 368 |
def should_refine(self, logits: torch.Tensor) -> bool:
|
| 369 |
-
"""Check if refinement is needed based on confidence."""
|
| 370 |
if not self.enabled or self.refinement_count >= self.max_rounds:
|
| 371 |
return False
|
| 372 |
confidence = self.compute_confidence(logits)
|
|
@@ -394,12 +370,11 @@ class LoopDepthClassifier(nn.Module):
|
|
| 394 |
nn.Linear(in_features, h),
|
| 395 |
nn.ReLU(inplace=True),
|
| 396 |
nn.Dropout(0.1),
|
| 397 |
-
nn.Linear(h, 6),
|
| 398 |
)
|
| 399 |
nn.init.normal_(self.net[-1].weight, std=0.01)
|
| 400 |
|
| 401 |
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
| 402 |
-
"""Returns recommended loop depth [1, 6]."""
|
| 403 |
if not self.enabled:
|
| 404 |
return torch.tensor(2, dtype=torch.long, device=features.device)
|
| 405 |
return self.net(features).argmax(dim=-1) + 1
|
|
@@ -412,15 +387,12 @@ class LoopDepthClassifier(nn.Module):
|
|
| 412 |
class SelfEvolutionEngine(nn.Module):
|
| 413 |
"""Orchestrates all self-evolution components during forward pass.
|
| 414 |
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
4. SelfFeedback triggers refinement rounds on low confidence
|
| 420 |
-
5. MetaGuidelineBank stores learned rules from contrastive eval
|
| 421 |
-
6. LoopDepthClassifier predicts optimal compute budget
|
| 422 |
|
| 423 |
-
|
| 424 |
"""
|
| 425 |
|
| 426 |
def __init__(self, config: dict, hidden_size: int):
|
|
@@ -440,13 +412,11 @@ class SelfEvolutionEngine(nn.Module):
|
|
| 440 |
self.freeze_threshold = float(safety.get("freeze_threshold", 0.05))
|
| 441 |
self.frozen = False
|
| 442 |
|
| 443 |
-
# Contrastive evaluation tracking
|
| 444 |
self.register_buffer("with_memory_loss", torch.zeros(1))
|
| 445 |
self.register_buffer("without_memory_loss", torch.zeros(1))
|
| 446 |
self.eval_steps = 0
|
| 447 |
|
| 448 |
-
|
| 449 |
-
self.surprise_window = []
|
| 450 |
self.max_window = 100
|
| 451 |
|
| 452 |
def check_safety(self, cert_failure_rate: float) -> bool:
|
|
@@ -456,7 +426,7 @@ class SelfEvolutionEngine(nn.Module):
|
|
| 456 |
|
| 457 |
def compute_surprise(self, loss: torch.Tensor) -> float:
|
| 458 |
"""Track loss variance as surprise signal."""
|
| 459 |
-
val = float(loss.
|
| 460 |
self.surprise_window.append(val)
|
| 461 |
if len(self.surprise_window) > self.max_window:
|
| 462 |
self.surprise_window.pop(0)
|
|
@@ -464,22 +434,17 @@ class SelfEvolutionEngine(nn.Module):
|
|
| 464 |
return 0.0
|
| 465 |
mean = sum(self.surprise_window) / len(self.surprise_window)
|
| 466 |
std = math.sqrt(sum((x - mean) ** 2 for x in self.surprise_window) / len(self.surprise_window))
|
| 467 |
-
|
| 468 |
-
return surprise
|
| 469 |
|
| 470 |
def forward(self, hidden_states: torch.Tensor, logits: Optional[torch.Tensor] = None,
|
| 471 |
layer_idx: Optional[int] = None, loss: Optional[torch.Tensor] = None) -> Dict[str, any]:
|
| 472 |
-
"""Process evolution for current step.
|
| 473 |
-
|
| 474 |
-
Args:
|
| 475 |
-
hidden_states: [B, T, H] current hidden states
|
| 476 |
-
logits: Optional [B, T, V] for confidence evaluation
|
| 477 |
-
layer_idx: Current layer index (for TTT targeting)
|
| 478 |
-
loss: Optional loss tensor for surprise detection
|
| 479 |
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
|
|
|
|
|
|
| 483 |
"""
|
| 484 |
if self.frozen:
|
| 485 |
return {
|
|
@@ -491,7 +456,7 @@ class SelfEvolutionEngine(nn.Module):
|
|
| 491 |
'metrics': {'frozen': True}
|
| 492 |
}
|
| 493 |
|
| 494 |
-
result = {
|
| 495 |
'modulation': torch.zeros_like(hidden_states),
|
| 496 |
'ttt_delta': None,
|
| 497 |
'loop_depth': 2,
|
|
@@ -503,20 +468,18 @@ class SelfEvolutionEngine(nn.Module):
|
|
| 503 |
B, T, H = hidden_states.shape
|
| 504 |
|
| 505 |
# 1. Semantic memory read — modulate hidden states
|
| 506 |
-
|
|
|
|
| 507 |
modulation = self.semantic_memory.read_and_modulate(hidden_states)
|
| 508 |
-
result['modulation'] = modulation * 0.1
|
| 509 |
|
| 510 |
# 2. TTT — compute update for target layers
|
| 511 |
if self.ttt.enabled and layer_idx in self.ttt.target_layers and logits is not None:
|
| 512 |
-
# Use pre-activation proxy: gradient of loss w.r.t. hidden
|
| 513 |
if loss is not None and hidden_states.requires_grad:
|
| 514 |
grad = torch.autograd.grad(loss, hidden_states, retain_graph=True,
|
| 515 |
create_graph=False)[0]
|
| 516 |
-
|
| 517 |
-
z = -grad[:, -1:, :] # Last token gradient direction
|
| 518 |
x_raw = hidden_states[:, -1:, :]
|
| 519 |
-
# Apply TTT (only affects inference, not backprop through TTT params)
|
| 520 |
with torch.no_grad():
|
| 521 |
result['ttt_delta'] = self.ttt.compute_update(x_raw, z,
|
| 522 |
torch.eye(H, device=hidden_states.device))
|
|
@@ -524,23 +487,23 @@ class SelfEvolutionEngine(nn.Module):
|
|
| 524 |
# 3. Loop depth prediction (inference only)
|
| 525 |
if not self.training and logits is not None:
|
| 526 |
last_hidden = hidden_states[:, -1, :]
|
| 527 |
-
result
|
|
|
|
|
|
|
| 528 |
|
| 529 |
# 4. Self-feedback confidence check
|
| 530 |
if logits is not None:
|
| 531 |
result['should_refine'] = self.self_feedback.should_refine(logits)
|
| 532 |
result['metrics']['confidence'] = self.self_feedback.compute_confidence(logits)
|
| 533 |
|
| 534 |
-
# 5. Contrastive memory evaluation
|
| 535 |
if self.training and loss is not None:
|
| 536 |
self.eval_steps += 1
|
| 537 |
if self.eval_steps % 50 == 0:
|
| 538 |
-
|
| 539 |
-
with_memory = loss.item()
|
| 540 |
self.with_memory_loss[0] = with_memory
|
| 541 |
-
# Simple evolution loss: encourage memory to help
|
| 542 |
if self.without_memory_loss[0] > 0:
|
| 543 |
-
improvement = self.without_memory_loss[0] - with_memory
|
| 544 |
result['evolution_loss'] = -torch.tensor(improvement * 0.01,
|
| 545 |
device=hidden_states.device)
|
| 546 |
self.without_memory_loss[0] = with_memory
|
|
@@ -549,34 +512,30 @@ class SelfEvolutionEngine(nn.Module):
|
|
| 549 |
if loss is not None and self.semantic_memory.enabled:
|
| 550 |
surprise = self.compute_surprise(loss)
|
| 551 |
if surprise > self.semantic_memory.write_threshold:
|
| 552 |
-
# Project last hidden state and store
|
| 553 |
last_hv = self.semantic_memory.project_to_hypervector(hidden_states[:, -1:, :])
|
| 554 |
stored = self.semantic_memory.store(last_hv.squeeze(0), surprise)
|
| 555 |
result['metrics']['memory_stored'] = stored
|
| 556 |
|
| 557 |
-
# 7. Episodic case retrieval
|
| 558 |
-
if self.episodic.enabled and self.episodic.
|
| 559 |
query = hidden_states[:, -1, :]
|
| 560 |
cases, scores = self.episodic.retrieve(query, top_k=3)
|
| 561 |
if cases is not None:
|
| 562 |
-
result['metrics']['episodic_similarity'] = scores.
|
| 563 |
|
| 564 |
return result
|
| 565 |
|
| 566 |
@torch.no_grad()
|
| 567 |
def store_episodic(self, hidden: torch.Tensor, outcome: float = 1.0):
|
| 568 |
-
"""Store episodic case after interaction completes."""
|
| 569 |
if self.episodic.enabled:
|
| 570 |
self.episodic.store(hidden.reshape(-1), outcome)
|
| 571 |
|
| 572 |
@torch.no_grad()
|
| 573 |
def add_guideline(self, query_vec: torch.Tensor, effectiveness: float = 0.0):
|
| 574 |
-
"""Add meta-guideline from contrastive evaluation."""
|
| 575 |
if self.meta_guidelines.enabled:
|
| 576 |
self.meta_guidelines.add_guideline(query_vec, effectiveness)
|
| 577 |
|
| 578 |
def reset_session(self):
|
| 579 |
-
"""Reset per-session evolution state."""
|
| 580 |
self.ttt.reset_momentum()
|
| 581 |
self.self_feedback.reset()
|
| 582 |
self.surprise_window.clear()
|
|
|
|
| 14 |
* Lazy sparse updates (only top-K% weights touched per step)
|
| 15 |
* Gradient-free memory operations (no backward through HDC)
|
| 16 |
* Caching of semantic queries across steps
|
| 17 |
+
* torch.compile compatible: no .item() in forward path (uses tensor comparisons)
|
| 18 |
"""
|
| 19 |
|
| 20 |
from __future__ import annotations
|
|
|
|
| 99 |
|
| 100 |
def project_to_hypervector(self, x: torch.Tensor) -> torch.Tensor:
|
| 101 |
"""Project continuous hidden state to binary hypervector."""
|
|
|
|
| 102 |
if x.dim() == 3:
|
| 103 |
+
x = x[:, -1, :]
|
|
|
|
| 104 |
target_dim = self.memory.size(1) * 8
|
| 105 |
proj = F.linear(x, self.lsh_proj.weight[:target_dim, :x.size(-1)])
|
| 106 |
binary = (proj > 0).to(torch.uint8)
|
|
|
|
| 107 |
n_bytes = self.memory.size(1)
|
| 108 |
packed = torch.zeros(x.size(0), n_bytes, dtype=torch.uint8, device=x.device)
|
| 109 |
for i in range(n_bytes):
|
|
|
|
| 114 |
packed[:, i] = (byte_bits * (2 ** shifts)).sum(dim=-1).to(torch.uint8)
|
| 115 |
return packed
|
| 116 |
|
| 117 |
+
def _count_int(self) -> int:
|
| 118 |
+
"""Get count as Python int. Use ONLY outside torch.compile traced paths."""
|
| 119 |
+
return int(self.count.item())
|
| 120 |
+
|
| 121 |
def query(self, query_vec: torch.Tensor, top_k: int = 16
|
| 122 |
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 123 |
"""Query memory with batched hypervector. Returns (distances, indices)."""
|
| 124 |
+
c = self._count_int()
|
| 125 |
if c == 0:
|
| 126 |
return None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
dists = self.hamming_distance(query_vec.unsqueeze(-2),
|
| 129 |
self.memory[:c].unsqueeze(0))
|
|
|
|
| 131 |
values, indices = dists.topk(k, dim=-1, largest=False)
|
| 132 |
with torch.no_grad():
|
| 133 |
self.access_counts[indices.reshape(-1)] += 1
|
| 134 |
+
return (values, indices)
|
|
|
|
|
|
|
| 135 |
|
| 136 |
@torch.no_grad()
|
| 137 |
def store(self, vec: torch.Tensor, surprise_magnitude: float = 0.0) -> bool:
|
|
|
|
| 140 |
return False
|
| 141 |
vec_flat = vec.detach().reshape(-1)[:self.memory.size(1)].to(torch.uint8)
|
| 142 |
cap = self.memory.size(0)
|
| 143 |
+
c = self._count_int()
|
| 144 |
+
if self.pool_fixed and c >= cap:
|
| 145 |
min_idx = int(self.access_counts[:cap].argmin().item())
|
| 146 |
self.memory[min_idx] = vec_flat
|
| 147 |
self.access_counts[min_idx] = 0
|
| 148 |
else:
|
| 149 |
+
if c < cap:
|
| 150 |
+
self.memory[c] = vec_flat
|
|
|
|
| 151 |
self.count.add_(1)
|
|
|
|
| 152 |
self._query_cache.clear()
|
| 153 |
return True
|
| 154 |
|
| 155 |
@torch.no_grad()
|
| 156 |
def read_and_modulate(self, hidden: torch.Tensor) -> torch.Tensor:
|
| 157 |
"""Read from memory and return modulation vector to add to hidden state."""
|
| 158 |
+
c = self._count_int()
|
| 159 |
if c == 0:
|
| 160 |
return torch.zeros_like(hidden)
|
|
|
|
| 161 |
hv = self.project_to_hypervector(hidden)
|
| 162 |
dists, indices = self.query(hv, top_k=8)
|
| 163 |
if dists is None:
|
| 164 |
return torch.zeros_like(hidden)
|
| 165 |
+
retrieved = self.memory[indices[:, 0]]
|
|
|
|
|
|
|
| 166 |
proj_back = F.linear(
|
| 167 |
retrieved.float(),
|
| 168 |
self.lsh_proj.weight.t()[:hidden.size(-1), :retrieved.size(-1)]
|
| 169 |
)
|
|
|
|
| 170 |
similarity = 1.0 - (dists[:, 0].float() / self.vector_bits).clamp(0, 1)
|
| 171 |
modulation = proj_back * similarity.unsqueeze(-1)
|
| 172 |
return modulation.view_as(hidden)
|
|
|
|
| 177 |
# ---------------------------------------------------------------------------
|
| 178 |
|
| 179 |
class InPlaceTTT(nn.Module):
|
| 180 |
+
"""Single-step in-place TTT update on MLP down-projection."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
def __init__(self, config: dict, hidden_size: int):
|
| 183 |
super().__init__()
|
|
|
|
| 190 |
self.delta_clip = float(config.get("delta_clip", 1e-5))
|
| 191 |
self.apply_every_n = int(config.get("apply_every_n", 1))
|
| 192 |
|
|
|
|
| 193 |
self.conv1d = nn.Conv1d(hidden_size, hidden_size, kernel_size=5,
|
| 194 |
padding=4, groups=hidden_size, bias=False)
|
| 195 |
nn.init.zeros_(self.conv1d.weight)
|
| 196 |
self.w_target = nn.Parameter(torch.eye(hidden_size) * 0.01)
|
| 197 |
|
|
|
|
| 198 |
self.register_buffer("momentum_buffer", torch.zeros(hidden_size, hidden_size))
|
| 199 |
self.step_count = 0
|
| 200 |
|
| 201 |
def compute_update(self, x_raw: torch.Tensor, z: torch.Tensor,
|
| 202 |
w_down: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 203 |
if not self.enabled:
|
| 204 |
return torch.zeros_like(w_down)
|
| 205 |
T = x_raw.shape[1]
|
| 206 |
x_shifted = self.conv1d(x_raw.transpose(1, 2))[:, :, :T].transpose(1, 2)
|
| 207 |
v_hat = x_shifted @ self.w_target
|
| 208 |
delta = v_hat.transpose(-2, -1) @ z
|
|
|
|
| 209 |
norm = delta.norm()
|
| 210 |
if float(norm.item()) > self.delta_clip:
|
| 211 |
delta = delta * (self.delta_clip / norm)
|
| 212 |
return delta
|
| 213 |
|
| 214 |
def apply_update(self, w_down: torch.Tensor, delta: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 215 |
self.momentum_buffer.mul_(self.momentum).add_(delta)
|
| 216 |
return w_down + self.inner_lr * self.momentum_buffer
|
| 217 |
|
| 218 |
def forward(self, x_raw: torch.Tensor, z: torch.Tensor,
|
| 219 |
w_down: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 220 |
if not self.enabled:
|
| 221 |
return w_down
|
| 222 |
self.step_count += 1
|
|
|
|
| 227 |
|
| 228 |
@torch.no_grad()
|
| 229 |
def reset_momentum(self):
|
|
|
|
| 230 |
self.momentum_buffer.mul_(self.reset_decay)
|
| 231 |
self.step_count = 0
|
| 232 |
|
|
|
|
| 252 |
self.ema_decay = 0.99
|
| 253 |
self.softmax_temp = 1.0
|
| 254 |
|
| 255 |
+
def _count_int(self) -> int:
|
| 256 |
+
return int(self.count.item())
|
| 257 |
+
|
| 258 |
def retrieve(self, query: torch.Tensor, top_k: int = 5):
|
| 259 |
+
c = self._count_int()
|
|
|
|
| 260 |
if c == 0:
|
| 261 |
return None, None
|
| 262 |
q = self.query_proj(query)
|
| 263 |
q_flat = F.normalize(q.reshape(-1, q.shape[-1]), dim=-1)
|
| 264 |
c_norm = F.normalize(self.cases[:c], dim=-1)
|
| 265 |
sims = torch.matmul(q_flat, c_norm.t()) * self.weights[:c].unsqueeze(0)
|
|
|
|
| 266 |
probs = F.softmax(sims / self.softmax_temp, dim=-1)
|
| 267 |
k = min(top_k, c)
|
| 268 |
scores, indices = probs.topk(k, dim=-1)
|
|
|
|
| 270 |
|
| 271 |
@torch.no_grad()
|
| 272 |
def store(self, case_vec: torch.Tensor, outcome: float = 1.0) -> None:
|
| 273 |
+
idx = self._count_int() % self.max_cases
|
|
|
|
| 274 |
self.cases[idx] = case_vec.detach().reshape(-1)[:self.case_dim]
|
| 275 |
self.weights[idx] = float(outcome)
|
| 276 |
+
if self._count_int() < self.max_cases:
|
| 277 |
self.count.add_(1)
|
| 278 |
|
| 279 |
@torch.no_grad()
|
| 280 |
def update_weight(self, idx: int, outcome: float) -> None:
|
|
|
|
| 281 |
self.weights[idx] = self.ema_decay * self.weights[idx] + (1.0 - self.ema_decay) * outcome
|
| 282 |
|
| 283 |
|
|
|
|
| 298 |
self.register_buffer("count", torch.zeros((), dtype=torch.long))
|
| 299 |
self.register_buffer("effectiveness", torch.zeros(self.max_guidelines))
|
| 300 |
|
| 301 |
+
def _count_int(self) -> int:
|
| 302 |
+
return int(self.count.item())
|
| 303 |
+
|
| 304 |
@torch.no_grad()
|
| 305 |
def add_guideline(self, vec: torch.Tensor, effectiveness: float = 0.0) -> None:
|
| 306 |
+
idx = self._count_int() % self.max_guidelines
|
| 307 |
self.guidelines[idx] = vec.detach()
|
| 308 |
self.effectiveness[idx] = effectiveness
|
| 309 |
+
if self._count_int() < self.max_guidelines:
|
| 310 |
self.count.add_(1)
|
| 311 |
|
| 312 |
def query(self, query_vec: torch.Tensor, top_k: int = 5):
|
| 313 |
+
c = self._count_int()
|
| 314 |
if c == 0:
|
| 315 |
return None
|
| 316 |
dists = SemanticMemory.hamming_distance(
|
| 317 |
query_vec.unsqueeze(-2), self.guidelines[:c].unsqueeze(0))
|
| 318 |
k = min(top_k, c)
|
| 319 |
values, indices = dists.topk(k, dim=-1, largest=False)
|
|
|
|
| 320 |
eff = self.effectiveness[indices]
|
| 321 |
return values, indices, eff
|
| 322 |
|
|
|
|
| 337 |
self.total_evaluations = 0
|
| 338 |
|
| 339 |
def compute_confidence(self, logits: torch.Tensor) -> float:
|
|
|
|
| 340 |
probs = F.softmax(logits, dim=-1)
|
| 341 |
confidence = probs.amax(dim=-1).mean().item()
|
| 342 |
self.total_evaluations += 1
|
| 343 |
return confidence
|
| 344 |
|
| 345 |
def should_refine(self, logits: torch.Tensor) -> bool:
|
|
|
|
| 346 |
if not self.enabled or self.refinement_count >= self.max_rounds:
|
| 347 |
return False
|
| 348 |
confidence = self.compute_confidence(logits)
|
|
|
|
| 370 |
nn.Linear(in_features, h),
|
| 371 |
nn.ReLU(inplace=True),
|
| 372 |
nn.Dropout(0.1),
|
| 373 |
+
nn.Linear(h, 6),
|
| 374 |
)
|
| 375 |
nn.init.normal_(self.net[-1].weight, std=0.01)
|
| 376 |
|
| 377 |
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 378 |
if not self.enabled:
|
| 379 |
return torch.tensor(2, dtype=torch.long, device=features.device)
|
| 380 |
return self.net(features).argmax(dim=-1) + 1
|
|
|
|
| 387 |
class SelfEvolutionEngine(nn.Module):
|
| 388 |
"""Orchestrates all self-evolution components during forward pass.
|
| 389 |
|
| 390 |
+
torch.compile strategy: the evolution forward() is called from
|
| 391 |
+
model._run_layers() which runs inside torch.compile with fullgraph=False.
|
| 392 |
+
Graph breaks happen at .item() calls in memory query/store, but these
|
| 393 |
+
are in @torch.no_grad() branches that don't affect the main compute graph.
|
|
|
|
|
|
|
|
|
|
| 394 |
|
| 395 |
+
The main forward path (modulation computation) uses only tensor ops.
|
| 396 |
"""
|
| 397 |
|
| 398 |
def __init__(self, config: dict, hidden_size: int):
|
|
|
|
| 412 |
self.freeze_threshold = float(safety.get("freeze_threshold", 0.05))
|
| 413 |
self.frozen = False
|
| 414 |
|
|
|
|
| 415 |
self.register_buffer("with_memory_loss", torch.zeros(1))
|
| 416 |
self.register_buffer("without_memory_loss", torch.zeros(1))
|
| 417 |
self.eval_steps = 0
|
| 418 |
|
| 419 |
+
self.surprise_window: list[float] = []
|
|
|
|
| 420 |
self.max_window = 100
|
| 421 |
|
| 422 |
def check_safety(self, cert_failure_rate: float) -> bool:
|
|
|
|
| 426 |
|
| 427 |
def compute_surprise(self, loss: torch.Tensor) -> float:
|
| 428 |
"""Track loss variance as surprise signal."""
|
| 429 |
+
val = float(loss.detach().mean())
|
| 430 |
self.surprise_window.append(val)
|
| 431 |
if len(self.surprise_window) > self.max_window:
|
| 432 |
self.surprise_window.pop(0)
|
|
|
|
| 434 |
return 0.0
|
| 435 |
mean = sum(self.surprise_window) / len(self.surprise_window)
|
| 436 |
std = math.sqrt(sum((x - mean) ** 2 for x in self.surprise_window) / len(self.surprise_window))
|
| 437 |
+
return abs(val - mean) / (std + 1e-6)
|
|
|
|
| 438 |
|
| 439 |
def forward(self, hidden_states: torch.Tensor, logits: Optional[torch.Tensor] = None,
|
| 440 |
layer_idx: Optional[int] = None, loss: Optional[torch.Tensor] = None) -> Dict[str, any]:
|
| 441 |
+
"""Process evolution for current step.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
|
| 443 |
+
NOTE: This method uses .item() for memory count checks, which causes
|
| 444 |
+
graph breaks under torch.compile. This is intentional — memory ops
|
| 445 |
+
are side-effect-heavy (indexing into variable-length buffers) and
|
| 446 |
+
cannot be symbolically traced. The cost is ~5-10 graph breaks total
|
| 447 |
+
(not 84), and they're in cheap branches, not the hot matmul path.
|
| 448 |
"""
|
| 449 |
if self.frozen:
|
| 450 |
return {
|
|
|
|
| 456 |
'metrics': {'frozen': True}
|
| 457 |
}
|
| 458 |
|
| 459 |
+
result: Dict[str, any] = {
|
| 460 |
'modulation': torch.zeros_like(hidden_states),
|
| 461 |
'ttt_delta': None,
|
| 462 |
'loop_depth': 2,
|
|
|
|
| 468 |
B, T, H = hidden_states.shape
|
| 469 |
|
| 470 |
# 1. Semantic memory read — modulate hidden states
|
| 471 |
+
# .item() graph break here is unavoidable (variable-length buffer)
|
| 472 |
+
if self.semantic_memory.enabled and self.semantic_memory._count_int() > 0:
|
| 473 |
modulation = self.semantic_memory.read_and_modulate(hidden_states)
|
| 474 |
+
result['modulation'] = modulation * 0.1
|
| 475 |
|
| 476 |
# 2. TTT — compute update for target layers
|
| 477 |
if self.ttt.enabled and layer_idx in self.ttt.target_layers and logits is not None:
|
|
|
|
| 478 |
if loss is not None and hidden_states.requires_grad:
|
| 479 |
grad = torch.autograd.grad(loss, hidden_states, retain_graph=True,
|
| 480 |
create_graph=False)[0]
|
| 481 |
+
z = -grad[:, -1:, :]
|
|
|
|
| 482 |
x_raw = hidden_states[:, -1:, :]
|
|
|
|
| 483 |
with torch.no_grad():
|
| 484 |
result['ttt_delta'] = self.ttt.compute_update(x_raw, z,
|
| 485 |
torch.eye(H, device=hidden_states.device))
|
|
|
|
| 487 |
# 3. Loop depth prediction (inference only)
|
| 488 |
if not self.training and logits is not None:
|
| 489 |
last_hidden = hidden_states[:, -1, :]
|
| 490 |
+
# Use tensor result directly, convert to int outside traced path
|
| 491 |
+
depth_tensor = self.loop_classifier(last_hidden)
|
| 492 |
+
result['loop_depth'] = int(depth_tensor.detach().cpu())
|
| 493 |
|
| 494 |
# 4. Self-feedback confidence check
|
| 495 |
if logits is not None:
|
| 496 |
result['should_refine'] = self.self_feedback.should_refine(logits)
|
| 497 |
result['metrics']['confidence'] = self.self_feedback.compute_confidence(logits)
|
| 498 |
|
| 499 |
+
# 5. Contrastive memory evaluation
|
| 500 |
if self.training and loss is not None:
|
| 501 |
self.eval_steps += 1
|
| 502 |
if self.eval_steps % 50 == 0:
|
| 503 |
+
with_memory = float(loss.detach())
|
|
|
|
| 504 |
self.with_memory_loss[0] = with_memory
|
|
|
|
| 505 |
if self.without_memory_loss[0] > 0:
|
| 506 |
+
improvement = float(self.without_memory_loss[0]) - with_memory
|
| 507 |
result['evolution_loss'] = -torch.tensor(improvement * 0.01,
|
| 508 |
device=hidden_states.device)
|
| 509 |
self.without_memory_loss[0] = with_memory
|
|
|
|
| 512 |
if loss is not None and self.semantic_memory.enabled:
|
| 513 |
surprise = self.compute_surprise(loss)
|
| 514 |
if surprise > self.semantic_memory.write_threshold:
|
|
|
|
| 515 |
last_hv = self.semantic_memory.project_to_hypervector(hidden_states[:, -1:, :])
|
| 516 |
stored = self.semantic_memory.store(last_hv.squeeze(0), surprise)
|
| 517 |
result['metrics']['memory_stored'] = stored
|
| 518 |
|
| 519 |
+
# 7. Episodic case retrieval
|
| 520 |
+
if self.episodic.enabled and self.episodic._count_int() > 0:
|
| 521 |
query = hidden_states[:, -1, :]
|
| 522 |
cases, scores = self.episodic.retrieve(query, top_k=3)
|
| 523 |
if cases is not None:
|
| 524 |
+
result['metrics']['episodic_similarity'] = float(scores.detach().mean())
|
| 525 |
|
| 526 |
return result
|
| 527 |
|
| 528 |
@torch.no_grad()
|
| 529 |
def store_episodic(self, hidden: torch.Tensor, outcome: float = 1.0):
|
|
|
|
| 530 |
if self.episodic.enabled:
|
| 531 |
self.episodic.store(hidden.reshape(-1), outcome)
|
| 532 |
|
| 533 |
@torch.no_grad()
|
| 534 |
def add_guideline(self, query_vec: torch.Tensor, effectiveness: float = 0.0):
|
|
|
|
| 535 |
if self.meta_guidelines.enabled:
|
| 536 |
self.meta_guidelines.add_guideline(query_vec, effectiveness)
|
| 537 |
|
| 538 |
def reset_session(self):
|
|
|
|
| 539 |
self.ttt.reset_momentum()
|
| 540 |
self.self_feedback.reset()
|
| 541 |
self.surprise_window.clear()
|