""" Chimera 5.2 — Parcae Prelude / Loop / Coda controller. Same numerics as the previous draft but cleaner: * Loop count is deterministic during training so gradient checkpointing recompute is consistent. * Backward truncation only retains gradients on the last ``n_loops // 2`` iterations; earlier iterates are detached, mirroring the original intuition while keeping the implementation in pure PyTorch. * Adaptive early-exit during inference based on residual magnitude. """ from __future__ import annotations import torch import torch.nn as nn class ParcaeInjection(nn.Module): """ZOH-stable diagonal injection: ``h' = exp(-Δ·A)·h + Δ·B·e``.""" __constants__ = ["hidden_size"] def __init__(self, hidden_size: int): super().__init__() self.hidden_size = int(hidden_size) self.log_A = nn.Parameter(torch.zeros(self.hidden_size)) self.log_A._no_weight_decay = True self.B_raw = nn.Parameter(torch.randn(self.hidden_size) * 0.02) self.delta = nn.Parameter(torch.full((self.hidden_size,), 0.5)) def forward(self, h_prev: torch.Tensor, e: torch.Tensor) -> torch.Tensor: A_bar = (-self.delta * self.log_A.exp()).exp() B_bar = self.delta * self.B_raw return A_bar * h_prev + B_bar * e class ParcaeLoopController(nn.Module): """Iterative refinement controller used by the looped trunk.""" __constants__ = ["loop_min", "loop_max", "loop_default"] def __init__(self, hidden_size: int, loop_range: tuple = (1, 6), loop_default: int = 2, adaptive_exit_threshold: float = 0.01, spectral_radius_bound: float = 1.0): super().__init__() self.injection = ParcaeInjection(hidden_size) self.loop_min, self.loop_max = int(loop_range[0]), int(loop_range[1]) self.loop_default = int(loop_default) self.exit_threshold = float(adaptive_exit_threshold) self.e_norm = nn.LayerNorm(hidden_size) def forward(self, prelude_output: torch.Tensor, loop_fn, num_loops: int = None) -> torch.Tensor: e = self.e_norm(prelude_output) h = torch.zeros_like(e) n_loops = int(num_loops) if num_loops is not None else self.loop_default n_loops = max(self.loop_min, min(self.loop_max, n_loops)) n_bwd = max(1, n_loops // 2) if self.training else n_loops for t in range(n_loops): h_new = loop_fn(self.injection(h, e)) backprop = (not self.training) or (t >= n_loops - n_bwd) h = h_new if backprop else h_new.detach() if not self.training and t > 0: if (h_new - h).abs().mean().item() < self.exit_threshold: break return h __all__ = ["ParcaeInjection", "ParcaeLoopController"]