File size: 2,824 Bytes
11c11f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""
Chimera 5.2 — Parcae Prelude / Loop / Coda controller.

Same numerics as the previous draft but cleaner:
* Loop count is deterministic during training so gradient checkpointing
  recompute is consistent.
* Backward truncation only retains gradients on the last ``n_loops // 2``
  iterations; earlier iterates are detached, mirroring the original
  intuition while keeping the implementation in pure PyTorch.
* Adaptive early-exit during inference based on residual magnitude.
"""

from __future__ import annotations

import torch
import torch.nn as nn


class ParcaeInjection(nn.Module):
    """ZOH-stable diagonal injection: ``h' = exp(-Δ·A)·h + Δ·B·e``."""

    __constants__ = ["hidden_size"]

    def __init__(self, hidden_size: int):
        super().__init__()
        self.hidden_size = int(hidden_size)
        self.log_A = nn.Parameter(torch.zeros(self.hidden_size))
        self.log_A._no_weight_decay = True
        self.B_raw = nn.Parameter(torch.randn(self.hidden_size) * 0.02)
        self.delta = nn.Parameter(torch.full((self.hidden_size,), 0.5))

    def forward(self, h_prev: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
        A_bar = (-self.delta * self.log_A.exp()).exp()
        B_bar = self.delta * self.B_raw
        return A_bar * h_prev + B_bar * e


class ParcaeLoopController(nn.Module):
    """Iterative refinement controller used by the looped trunk."""

    __constants__ = ["loop_min", "loop_max", "loop_default"]

    def __init__(self, hidden_size: int,
                 loop_range: tuple = (1, 6), loop_default: int = 2,
                 adaptive_exit_threshold: float = 0.01,
                 spectral_radius_bound: float = 1.0):
        super().__init__()
        self.injection = ParcaeInjection(hidden_size)
        self.loop_min, self.loop_max = int(loop_range[0]), int(loop_range[1])
        self.loop_default = int(loop_default)
        self.exit_threshold = float(adaptive_exit_threshold)
        self.e_norm = nn.LayerNorm(hidden_size)

    def forward(self, prelude_output: torch.Tensor, loop_fn,
                num_loops: int = None) -> torch.Tensor:
        e = self.e_norm(prelude_output)
        h = torch.zeros_like(e)
        n_loops = int(num_loops) if num_loops is not None else self.loop_default
        n_loops = max(self.loop_min, min(self.loop_max, n_loops))

        n_bwd = max(1, n_loops // 2) if self.training else n_loops

        for t in range(n_loops):
            h_new = loop_fn(self.injection(h, e))
            backprop = (not self.training) or (t >= n_loops - n_bwd)
            h = h_new if backprop else h_new.detach()
            if not self.training and t > 0:
                if (h_new - h).abs().mean().item() < self.exit_threshold:
                    break
        return h


__all__ = ["ParcaeInjection", "ParcaeLoopController"]