| """ |
| Chimera 5.2 — Parcae Prelude / Loop / Coda controller. |
| |
| Same numerics as the previous draft but cleaner: |
| * Loop count is deterministic during training so gradient checkpointing |
| recompute is consistent. |
| * Backward truncation only retains gradients on the last ``n_loops // 2`` |
| iterations; earlier iterates are detached, mirroring the original |
| intuition while keeping the implementation in pure PyTorch. |
| * Adaptive early-exit during inference based on residual magnitude. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class ParcaeInjection(nn.Module): |
| """ZOH-stable diagonal injection: ``h' = exp(-Δ·A)·h + Δ·B·e``.""" |
|
|
| __constants__ = ["hidden_size"] |
|
|
| def __init__(self, hidden_size: int): |
| super().__init__() |
| self.hidden_size = int(hidden_size) |
| self.log_A = nn.Parameter(torch.zeros(self.hidden_size)) |
| self.log_A._no_weight_decay = True |
| self.B_raw = nn.Parameter(torch.randn(self.hidden_size) * 0.02) |
| self.delta = nn.Parameter(torch.full((self.hidden_size,), 0.5)) |
|
|
| def forward(self, h_prev: torch.Tensor, e: torch.Tensor) -> torch.Tensor: |
| A_bar = (-self.delta * self.log_A.exp()).exp() |
| B_bar = self.delta * self.B_raw |
| return A_bar * h_prev + B_bar * e |
|
|
|
|
| class ParcaeLoopController(nn.Module): |
| """Iterative refinement controller used by the looped trunk.""" |
|
|
| __constants__ = ["loop_min", "loop_max", "loop_default"] |
|
|
| def __init__(self, hidden_size: int, |
| loop_range: tuple = (1, 6), loop_default: int = 2, |
| adaptive_exit_threshold: float = 0.01, |
| spectral_radius_bound: float = 1.0): |
| super().__init__() |
| self.injection = ParcaeInjection(hidden_size) |
| self.loop_min, self.loop_max = int(loop_range[0]), int(loop_range[1]) |
| self.loop_default = int(loop_default) |
| self.exit_threshold = float(adaptive_exit_threshold) |
| self.e_norm = nn.LayerNorm(hidden_size) |
|
|
| def forward(self, prelude_output: torch.Tensor, loop_fn, |
| num_loops: int = None) -> torch.Tensor: |
| e = self.e_norm(prelude_output) |
| h = torch.zeros_like(e) |
| n_loops = int(num_loops) if num_loops is not None else self.loop_default |
| n_loops = max(self.loop_min, min(self.loop_max, n_loops)) |
|
|
| n_bwd = max(1, n_loops // 2) if self.training else n_loops |
|
|
| for t in range(n_loops): |
| h_new = loop_fn(self.injection(h, e)) |
| backprop = (not self.training) or (t >= n_loops - n_bwd) |
| h = h_new if backprop else h_new.detach() |
| if not self.training and t > 0: |
| if (h_new - h).abs().mean().item() < self.exit_threshold: |
| break |
| return h |
|
|
|
|
| __all__ = ["ParcaeInjection", "ParcaeLoopController"] |
|
|