""" CREST Block v2 — Cognitively Recurrent Estimation of Step Termination ===================================================================== Authors: ENI & LO (0labs, Gujarat, India) Date: April 2026 — v2 (fixed initialization stability) Key fix from v1: Step 1 now produces IDENTICAL output to original MLP. No extra RMSNorm or h0 residual on step 1. Steps 2+ use lightweight residual mixing instead of full norm, preventing gradient instability. """ import math import copy import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple class CRESTBlock(nn.Module): """ CREST Block v2 — replaces a standard FFN/MLP sublayer. v2 fixes: - Step 1 output is IDENTICAL to original MLP (no extra norm/residual) - Steps 2+ use learned residual gate instead of RMSNorm - Halting bias initialized higher (6.0) for safer start """ def __init__( self, original_mlp: nn.Module, hidden_size: int, max_steps: int = 4, eps: float = 1e-5, ): super().__init__() self.max_steps = max_steps self.hidden_size = hidden_size self.eps = eps # ── Independent FFN per step ──────────────────────────────── self.steps = nn.ModuleList() self.steps.append(original_mlp) # Step 1 = original (untouched) for _ in range(max_steps - 1): self.steps.append(copy.deepcopy(original_mlp)) # ── Halting mechanism ─────────────────────────────────────── self.halt_linear = nn.Linear(hidden_size, 1, bias=True) nn.init.zeros_(self.halt_linear.weight) nn.init.constant_(self.halt_linear.bias, 6.0) # sigmoid(6)≈0.9975 # ── Residual gates for steps 2+ ───────────────────────────── # Learned scalar that controls how much h0 mixes into step output # Initialized to 0 → steps 2+ start identical to step 1 self.residual_gates = nn.ParameterList([ nn.Parameter(torch.zeros(1)) for _ in range(max_steps - 1) ]) # ── Runtime storage ───────────────────────────────────────── self._ponder_cost = 0.0 self._steps_used = 0.0 def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: h0 = hidden_states B, S, D = h0.shape total_prob = torch.zeros(B, S, 1, device=h0.device, dtype=h0.dtype) accumulated = torch.zeros_like(h0) h = h0 ponder_cost = torch.tensor(0.0, device=h0.device, dtype=h0.dtype) steps_taken = torch.zeros(B, S, 1, device=h0.device, dtype=h0.dtype) for i in range(self.max_steps): # ── Halting probability ───────────────────────────────── p_halt = torch.sigmoid(self.halt_linear(h)) remaining = 1.0 - total_prob p_use = torch.min(p_halt, remaining) # ── Compute step i ────────────────────────────────────── if i == 0: # Step 1: EXACT same as original MLP (no modifications) h_new = self.steps[0](h) else: # Steps 2+: MLP output + learned residual gate from h0 mlp_out = self.steps[i](h) gate = torch.sigmoid(self.residual_gates[i - 1]) h_new = mlp_out + gate * h0 # gate starts at 0.5, learned # ── Weighted accumulation ─────────────────────────────── accumulated = accumulated + p_use * h_new total_prob = total_prob + p_use ponder_cost = ponder_cost + p_use.mean() steps_taken = steps_taken + (p_use > self.eps).float() # ── Early exit ────────────────────────────────────────── if (total_prob >= (1.0 - self.eps)).all(): break h = h_new # ── Distribute remaining probability ──────────────────────── remainder = 1.0 - total_prob if remainder.max() > self.eps: accumulated = accumulated + remainder * h self._ponder_cost = ponder_cost self._steps_used = steps_taken.mean().item() return accumulated def retrofit_model_with_crest( model: nn.Module, max_steps: int = 4, target_layers: Optional[list] = None, ) -> Tuple[nn.Module, dict]: """ Retrofit a pretrained transformer with CREST blocks. Original MLP weights are preserved as Step 1. """ layers = None if hasattr(model, 'model') and hasattr(model.model, 'layers'): layers = model.model.layers elif hasattr(model, 'transformer') and hasattr(model.transformer, 'h'): layers = model.transformer.h else: raise ValueError("Cannot find decoder layers") hidden_size = getattr(model.config, 'hidden_size', getattr(model.config, 'd_model', None)) if hidden_size is None: for p in layers[0].parameters(): hidden_size = p.shape[-1] break n_layers = len(layers) if target_layers is None: target_layers = list(range(n_layers)) params_before = sum(p.numel() for p in model.parameters()) converted = 0 skipped = 0 for idx in target_layers: layer = layers[idx] mlp = None mlp_attr = None for attr_name in ['mlp', 'feed_forward', 'ffn']: if hasattr(layer, attr_name): mlp = getattr(layer, attr_name) mlp_attr = attr_name break if mlp is None: skipped += 1 continue crest_block = CRESTBlock( original_mlp=mlp, hidden_size=hidden_size, max_steps=max_steps, ) setattr(layer, mlp_attr, crest_block) converted += 1 params_after = sum(p.numel() for p in model.parameters()) stats = { 'total_layers': n_layers, 'converted': converted, 'skipped': skipped, 'max_steps': max_steps, 'hidden_size': hidden_size, 'params_before': params_before, 'params_after': params_after, 'params_added': params_after - params_before, 'param_overhead_pct': round(100 * (params_after - params_before) / params_before, 1), } return model, stats def collect_ponder_costs(model: nn.Module) -> torch.Tensor: """Collect ponder costs from all CREST blocks.""" total = torch.tensor(0.0, device='cpu') count = 0 for module in model.modules(): if isinstance(module, CRESTBlock): if isinstance(module._ponder_cost, torch.Tensor): total = total.to(module._ponder_cost.device) total = total + module._ponder_cost count += 1 return total / max(count, 1) def get_crest_stats(model: nn.Module) -> dict: """Get monitoring stats from all CREST blocks.""" stats = [] for name, module in model.named_modules(): if isinstance(module, CRESTBlock): stats.append({ 'name': name, 'avg_steps': module._steps_used, 'ponder_cost': module._ponder_cost.item() if isinstance(module._ponder_cost, torch.Tensor) else module._ponder_cost, }) if not stats: return {'avg_steps': 0, 'avg_ponder': 0, 'n_blocks': 0} return { 'avg_steps': sum(s['avg_steps'] for s in stats) / len(stats), 'avg_ponder': sum(s['ponder_cost'] for s in stats) / len(stats), 'n_blocks': len(stats), 'per_block': stats, }