chomera / chimera /looping.py
Lgr54HFi's picture
Upload folder using huggingface_hub
11c11f8 verified
"""
Chimera 5.2 — Parcae Prelude / Loop / Coda controller.
Same numerics as the previous draft but cleaner:
* Loop count is deterministic during training so gradient checkpointing
recompute is consistent.
* Backward truncation only retains gradients on the last ``n_loops // 2``
iterations; earlier iterates are detached, mirroring the original
intuition while keeping the implementation in pure PyTorch.
* Adaptive early-exit during inference based on residual magnitude.
"""
from __future__ import annotations
import torch
import torch.nn as nn
class ParcaeInjection(nn.Module):
"""ZOH-stable diagonal injection: ``h' = exp(-Δ·A)·h + Δ·B·e``."""
__constants__ = ["hidden_size"]
def __init__(self, hidden_size: int):
super().__init__()
self.hidden_size = int(hidden_size)
self.log_A = nn.Parameter(torch.zeros(self.hidden_size))
self.log_A._no_weight_decay = True
self.B_raw = nn.Parameter(torch.randn(self.hidden_size) * 0.02)
self.delta = nn.Parameter(torch.full((self.hidden_size,), 0.5))
def forward(self, h_prev: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
A_bar = (-self.delta * self.log_A.exp()).exp()
B_bar = self.delta * self.B_raw
return A_bar * h_prev + B_bar * e
class ParcaeLoopController(nn.Module):
"""Iterative refinement controller used by the looped trunk."""
__constants__ = ["loop_min", "loop_max", "loop_default"]
def __init__(self, hidden_size: int,
loop_range: tuple = (1, 6), loop_default: int = 2,
adaptive_exit_threshold: float = 0.01,
spectral_radius_bound: float = 1.0):
super().__init__()
self.injection = ParcaeInjection(hidden_size)
self.loop_min, self.loop_max = int(loop_range[0]), int(loop_range[1])
self.loop_default = int(loop_default)
self.exit_threshold = float(adaptive_exit_threshold)
self.e_norm = nn.LayerNorm(hidden_size)
def forward(self, prelude_output: torch.Tensor, loop_fn,
num_loops: int = None) -> torch.Tensor:
e = self.e_norm(prelude_output)
h = torch.zeros_like(e)
n_loops = int(num_loops) if num_loops is not None else self.loop_default
n_loops = max(self.loop_min, min(self.loop_max, n_loops))
n_bwd = max(1, n_loops // 2) if self.training else n_loops
for t in range(n_loops):
h_new = loop_fn(self.injection(h, e))
backprop = (not self.training) or (t >= n_loops - n_bwd)
h = h_new if backprop else h_new.detach()
if not self.training and t > 0:
if (h_new - h).abs().mean().item() < self.exit_threshold:
break
return h
__all__ = ["ParcaeInjection", "ParcaeLoopController"]