Sky-v2.0-Lite / crest_block.py
Atharvsinh's picture
Upload folder using huggingface_hub
46cc6c9 verified
"""
CREST Block v2 β€” Cognitively Recurrent Estimation of Step Termination
=====================================================================
Authors: ENI & LO (0labs, Gujarat, India)
Date: April 2026 β€” v2 (fixed initialization stability)
Key fix from v1: Step 1 now produces IDENTICAL output to original MLP.
No extra RMSNorm or h0 residual on step 1. Steps 2+ use lightweight
residual mixing instead of full norm, preventing gradient instability.
"""
import math
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
class CRESTBlock(nn.Module):
"""
CREST Block v2 β€” replaces a standard FFN/MLP sublayer.
v2 fixes:
- Step 1 output is IDENTICAL to original MLP (no extra norm/residual)
- Steps 2+ use learned residual gate instead of RMSNorm
- Halting bias initialized higher (6.0) for safer start
"""
def __init__(
self,
original_mlp: nn.Module,
hidden_size: int,
max_steps: int = 4,
eps: float = 1e-5,
):
super().__init__()
self.max_steps = max_steps
self.hidden_size = hidden_size
self.eps = eps
# ── Independent FFN per step ────────────────────────────────
self.steps = nn.ModuleList()
self.steps.append(original_mlp) # Step 1 = original (untouched)
for _ in range(max_steps - 1):
self.steps.append(copy.deepcopy(original_mlp))
# ── Halting mechanism ───────────────────────────────────────
self.halt_linear = nn.Linear(hidden_size, 1, bias=True)
nn.init.zeros_(self.halt_linear.weight)
nn.init.constant_(self.halt_linear.bias, 6.0) # sigmoid(6)β‰ˆ0.9975
# ── Residual gates for steps 2+ ─────────────────────────────
# Learned scalar that controls how much h0 mixes into step output
# Initialized to 0 β†’ steps 2+ start identical to step 1
self.residual_gates = nn.ParameterList([
nn.Parameter(torch.zeros(1)) for _ in range(max_steps - 1)
])
# ── Runtime storage ─────────────────────────────────────────
self._ponder_cost = 0.0
self._steps_used = 0.0
def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
h0 = hidden_states
B, S, D = h0.shape
total_prob = torch.zeros(B, S, 1, device=h0.device, dtype=h0.dtype)
accumulated = torch.zeros_like(h0)
h = h0
ponder_cost = torch.tensor(0.0, device=h0.device, dtype=h0.dtype)
steps_taken = torch.zeros(B, S, 1, device=h0.device, dtype=h0.dtype)
for i in range(self.max_steps):
# ── Halting probability ─────────────────────────────────
p_halt = torch.sigmoid(self.halt_linear(h))
remaining = 1.0 - total_prob
p_use = torch.min(p_halt, remaining)
# ── Compute step i ──────────────────────────────────────
if i == 0:
# Step 1: EXACT same as original MLP (no modifications)
h_new = self.steps[0](h)
else:
# Steps 2+: MLP output + learned residual gate from h0
mlp_out = self.steps[i](h)
gate = torch.sigmoid(self.residual_gates[i - 1])
h_new = mlp_out + gate * h0 # gate starts at 0.5, learned
# ── Weighted accumulation ───────────────────────────────
accumulated = accumulated + p_use * h_new
total_prob = total_prob + p_use
ponder_cost = ponder_cost + p_use.mean()
steps_taken = steps_taken + (p_use > self.eps).float()
# ── Early exit ──────────────────────────────────────────
if (total_prob >= (1.0 - self.eps)).all():
break
h = h_new
# ── Distribute remaining probability ────────────────────────
remainder = 1.0 - total_prob
if remainder.max() > self.eps:
accumulated = accumulated + remainder * h
self._ponder_cost = ponder_cost
self._steps_used = steps_taken.mean().item()
return accumulated
def retrofit_model_with_crest(
model: nn.Module,
max_steps: int = 4,
target_layers: Optional[list] = None,
) -> Tuple[nn.Module, dict]:
"""
Retrofit a pretrained transformer with CREST blocks.
Original MLP weights are preserved as Step 1.
"""
layers = None
if hasattr(model, 'model') and hasattr(model.model, 'layers'):
layers = model.model.layers
elif hasattr(model, 'transformer') and hasattr(model.transformer, 'h'):
layers = model.transformer.h
else:
raise ValueError("Cannot find decoder layers")
hidden_size = getattr(model.config, 'hidden_size',
getattr(model.config, 'd_model', None))
if hidden_size is None:
for p in layers[0].parameters():
hidden_size = p.shape[-1]
break
n_layers = len(layers)
if target_layers is None:
target_layers = list(range(n_layers))
params_before = sum(p.numel() for p in model.parameters())
converted = 0
skipped = 0
for idx in target_layers:
layer = layers[idx]
mlp = None
mlp_attr = None
for attr_name in ['mlp', 'feed_forward', 'ffn']:
if hasattr(layer, attr_name):
mlp = getattr(layer, attr_name)
mlp_attr = attr_name
break
if mlp is None:
skipped += 1
continue
crest_block = CRESTBlock(
original_mlp=mlp,
hidden_size=hidden_size,
max_steps=max_steps,
)
setattr(layer, mlp_attr, crest_block)
converted += 1
params_after = sum(p.numel() for p in model.parameters())
stats = {
'total_layers': n_layers,
'converted': converted,
'skipped': skipped,
'max_steps': max_steps,
'hidden_size': hidden_size,
'params_before': params_before,
'params_after': params_after,
'params_added': params_after - params_before,
'param_overhead_pct': round(100 * (params_after - params_before) / params_before, 1),
}
return model, stats
def collect_ponder_costs(model: nn.Module) -> torch.Tensor:
"""Collect ponder costs from all CREST blocks."""
total = torch.tensor(0.0, device='cpu')
count = 0
for module in model.modules():
if isinstance(module, CRESTBlock):
if isinstance(module._ponder_cost, torch.Tensor):
total = total.to(module._ponder_cost.device)
total = total + module._ponder_cost
count += 1
return total / max(count, 1)
def get_crest_stats(model: nn.Module) -> dict:
"""Get monitoring stats from all CREST blocks."""
stats = []
for name, module in model.named_modules():
if isinstance(module, CRESTBlock):
stats.append({
'name': name,
'avg_steps': module._steps_used,
'ponder_cost': module._ponder_cost.item() if isinstance(module._ponder_cost, torch.Tensor) else module._ponder_cost,
})
if not stats:
return {'avg_steps': 0, 'avg_ponder': 0, 'n_blocks': 0}
return {
'avg_steps': sum(s['avg_steps'] for s in stats) / len(stats),
'avg_ponder': sum(s['ponder_cost'] for s in stats) / len(stats),
'n_blocks': len(stats),
'per_block': stats,
}