AbstractPhil commited on
Commit
cc3eb9a
Β·
verified Β·
1 Parent(s): 461d412

Create prototype.py

Browse files
Files changed (1) hide show
  1. prototype.py +348 -0
prototype.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FLEighConduit β€” Evidence-Emitting Eigendecomposition
3
+ ======================================================
4
+ Extends FLEigh with judicial conduit telemetry.
5
+ Read-only observation of the solver's internal states.
6
+
7
+ Council-ratified specification (11 rounds, 3 AI participants):
8
+ - Theorem 1: Lens Preservation (shared arithmetic path)
9
+ - Theorem 2: Dynamic Non-Reconstructibility (friction, settle, order)
10
+ - Theorem 4: Continuity (static continuous, dynamic piecewise)
11
+ - Theorem 5: Gauge-Safe Directional Observation (sign canonicalization)
12
+
13
+ Classes:
14
+ ConduitPacket β€” fixed-shape tensor bundle, batch-first, dimension-agnostic
15
+ FLEighConduit β€” extends FLEigh, emits ConduitPacket
16
+
17
+ Usage:
18
+ from geolip_core.linalg.conduit import FLEighConduit
19
+
20
+ solver = FLEighConduit()
21
+ packet = solver(A) # A: (B, n, n) symmetric
22
+ packet.eigenvalues # (B, n)
23
+ packet.friction # (B, n) β€” per-root solver struggle
24
+ packet.settle # (B, n) β€” iterations to convergence
25
+ packet.extraction_order # (B, n) β€” root extraction sequence
26
+
27
+ # Standard eigenpairs (identical to FLEigh)
28
+ evals, evecs = packet.eigenpairs()
29
+
30
+ License: MIT
31
+
32
+ Author: AbstractPhil + Claude 4.6 Opus Extended
33
+ Assistants: Gemini Pro, GPT 5.4 Extended Thinking
34
+ """
35
+
36
+ import math
37
+ import torch
38
+ import torch.nn as nn
39
+ from torch import Tensor
40
+ from typing import Tuple, Optional
41
+ from dataclasses import dataclass
42
+
43
+
44
+ @dataclass
45
+ class ConduitPacket:
46
+ """Fixed-shape telemetry from FLEighConduit.
47
+
48
+ All tensors batch-first, dimension-agnostic.
49
+ Production packet: scalar-dominant, bounded overhead.
50
+ Research fields (Mstore, trajectories) populated only when requested.
51
+ """
52
+
53
+ # ── Spectral evidence (static, deterministic) ──
54
+ eigenvalues: Tensor # (B, n) sorted ascending
55
+ eigenvectors: Tensor # (B, n, n) sign-canonicalized
56
+ char_coeffs: Tensor # (B, n) elementary symmetric polys, monic 1 omitted
57
+
58
+ # ── Adjudication evidence (dynamic, non-reconstructible) ──
59
+ friction: Tensor # (B, n) per-root Ξ£ 1/(|p'(z_t)| + Ξ΄)
60
+ settle: Tensor # (B, n) iterations to convergence per root
61
+ extraction_order: Tensor # (B, n) which root found first (0-indexed)
62
+ refinement_residual: Tensor # (B,) ||V^T V - I||_F after Newton-Schulz
63
+
64
+ # ── Release fidelity ──
65
+ # Note: release_residual is owned by SVDConduit layer in full architecture.
66
+ # Included here for v1 convenience when enc_out matrix M is provided.
67
+ release_residual: Optional[Tensor] = None # (B,) ||M - U diag(S) Vt||Β²
68
+
69
+ # ── Research mode (populated only with research=True) ──
70
+ mstore: Optional[Tensor] = None # (n+1, B, n, n) FL matrix states
71
+ z_trajectory: Optional[Tensor] = None # (B, n, laguerre_iters) root guesses
72
+ dp_trajectory: Optional[Tensor] = None # (B, n, laguerre_iters) p' at each step
73
+
74
+ def eigenpairs(self) -> Tuple[Tensor, Tensor]:
75
+ """Standard output matching FLEigh contract."""
76
+ return self.eigenvalues, self.eigenvectors
77
+
78
+
79
+ def canonicalize_eigenvectors(V: Tensor) -> Tensor:
80
+ """Force deterministic sign convention on eigenvector columns.
81
+
82
+ For each column (eigenvector), flip sign so the entry with
83
+ largest absolute value is positive. Resolves the gauge ambiguity
84
+ that otherwise causes identical matrices to produce different
85
+ embeddings on S^(nΒ²-1).
86
+
87
+ Args:
88
+ V: (B, n, n) eigenvector matrix (columns are eigenvectors)
89
+ Returns:
90
+ V with deterministic signs
91
+ """
92
+ # Find index of max absolute value per column
93
+ max_idx = V.abs().argmax(dim=-2, keepdim=True) # (B, 1, n)
94
+ sign = V.gather(-2, max_idx).sign() # (B, 1, n)
95
+ return V * sign
96
+
97
+
98
+ class FLEighConduit(nn.Module):
99
+ """Evidence-emitting eigendecomposition.
100
+
101
+ Identical arithmetic to FLEigh. Captures telemetry at phase
102
+ boundaries without altering the numerical path.
103
+
104
+ Phases (shared with FLEigh):
105
+ 1. FL characteristic polynomial (fp64, n bmm)
106
+ 2. Laguerre root-finding + Newton polish (with telemetry capture)
107
+ 3. FL adjugate eigenvectors (fp64 Horner + max-col)
108
+ 4. Newton-Schulz orthogonalization (fp32, 2 iters)
109
+ 5. Rayleigh quotient refinement (fp32, 2 bmm)
110
+
111
+ Args:
112
+ laguerre_iters: Root-finding iterations per eigenvalue (default 5)
113
+ polish_iters: Newton refinement iterations (default 3)
114
+ ns_iters: Newton-Schulz orthogonalization iterations (default 2)
115
+ friction_delta: stability constant for friction computation (default 1e-8)
116
+ settle_threshold: convergence threshold for settle count (default 1e-6)
117
+ research: if True, populate full trajectory and Mstore fields
118
+ """
119
+
120
+ def __init__(self, laguerre_iters: int = 5, polish_iters: int = 3,
121
+ ns_iters: int = 2, friction_delta: float = 1e-8,
122
+ settle_threshold: float = 1e-6, research: bool = False):
123
+ super().__init__()
124
+ self.laguerre_iters = laguerre_iters
125
+ self.polish_iters = polish_iters
126
+ self.ns_iters = ns_iters
127
+ self.friction_delta = friction_delta
128
+ self.settle_threshold = settle_threshold
129
+ self.research = research
130
+
131
+ def forward(self, A: Tensor) -> ConduitPacket:
132
+ """Evidence-emitting eigendecomposition.
133
+
134
+ Args:
135
+ A: (B, n, n) symmetric matrix batch
136
+
137
+ Returns:
138
+ ConduitPacket with eigenpairs + judicial telemetry
139
+ """
140
+ B, n, _ = A.shape
141
+ device = A.device
142
+
143
+ # ════════════════════════════════════════════
144
+ # Phase 1: Faddeev-LeVerrier (fp64)
145
+ # ════════════════════════════════════════════
146
+ scale = (torch.linalg.norm(A.reshape(B, -1), dim=-1) / math.sqrt(n)).clamp(min=1e-12)
147
+ As = A / scale[:, None, None]
148
+
149
+ Ad = As.double()
150
+ eye_d = torch.eye(n, device=device, dtype=torch.float64).unsqueeze(0).expand(B, -1, -1)
151
+ c = torch.zeros(B, n + 1, device=device, dtype=torch.float64)
152
+ c[:, n] = 1.0
153
+ Mstore = torch.zeros(n + 1, B, n, n, device=device, dtype=torch.float64)
154
+ Mk = torch.zeros(B, n, n, device=device, dtype=torch.float64)
155
+ for k in range(1, n + 1):
156
+ Mk = torch.bmm(Ad, Mk) + c[:, n - k + 1, None, None] * eye_d
157
+ Mstore[k] = Mk
158
+ c[:, n - k] = -(Ad * Mk).sum((-2, -1)) / k
159
+
160
+ # Capture characteristic coefficients (omit monic leading 1)
161
+ char_coeffs = c[:, :n].float() # (B, n)
162
+
163
+ # ════════════════════════════════════════════
164
+ # Phase 2: Laguerre + deflation + Newton polish
165
+ # WITH TELEMETRY CAPTURE
166
+ # ════════════════════════════════════════════
167
+ use_f64 = n > 6
168
+ dt = torch.float64 if use_f64 else torch.float32
169
+ cl = c.to(dt).clone().detach()
170
+ roots = torch.zeros(B, n, device=device, dtype=dt)
171
+ zi = As.to(dt).diagonal(dim1=-2, dim2=-1).sort(dim=-1).values.detach()
172
+ zi = zi + torch.linspace(-1e-4, 1e-4, n, device=device, dtype=dt).unsqueeze(0)
173
+
174
+ # Telemetry buffers
175
+ friction = torch.zeros(B, n, device=device, dtype=torch.float32)
176
+ settle = torch.full((B, n), float(self.laguerre_iters),
177
+ device=device, dtype=torch.float32)
178
+ extraction_order = torch.zeros(B, n, device=device, dtype=torch.float32)
179
+
180
+ # Research buffers
181
+ if self.research:
182
+ z_traj = torch.zeros(B, n, self.laguerre_iters,
183
+ device=device, dtype=torch.float32)
184
+ dp_traj = torch.zeros(B, n, self.laguerre_iters,
185
+ device=device, dtype=torch.float32)
186
+
187
+ for ri in range(n):
188
+ deg = n - ri
189
+ z = zi[:, ri]
190
+
191
+ for lag_iter in range(self.laguerre_iters):
192
+ # Horner evaluation
193
+ pv = cl[:, deg]
194
+ dp = torch.zeros(B, device=device, dtype=dt)
195
+ d2 = torch.zeros(B, device=device, dtype=dt)
196
+ for j in range(deg - 1, -1, -1):
197
+ d2 = d2 * z + dp
198
+ dp = dp * z + pv
199
+ pv = pv * z + cl[:, j]
200
+
201
+ # ── Telemetry capture ──
202
+ dp_abs = dp.abs().float()
203
+ friction[:, ri] += 1.0 / (dp_abs + self.friction_delta)
204
+
205
+ # Settle detection
206
+ pv_abs = pv.abs().float()
207
+ just_settled = (pv_abs < self.settle_threshold) & \
208
+ (settle[:, ri] == float(self.laguerre_iters))
209
+ settle[:, ri] = torch.where(just_settled,
210
+ torch.full_like(settle[:, ri], float(lag_iter)),
211
+ settle[:, ri])
212
+
213
+ if self.research:
214
+ z_traj[:, ri, lag_iter] = z.float()
215
+ dp_traj[:, ri, lag_iter] = dp_abs
216
+
217
+ # Laguerre step (unchanged arithmetic)
218
+ ok = pv.abs() > 1e-30
219
+ ps = torch.where(ok, pv, torch.ones_like(pv))
220
+ G = torch.where(ok, dp / ps, torch.zeros_like(dp))
221
+ H = G * G - torch.where(ok, 2.0 * d2 / ps, torch.zeros_like(d2))
222
+ disc = ((deg - 1.0) * (deg * H - G * G)).clamp(min=0.0)
223
+ sq = torch.sqrt(disc)
224
+ gp = G + sq
225
+ gm = G - sq
226
+ den = torch.where(gp.abs() >= gm.abs(), gp, gm)
227
+ dok = den.abs() > 1e-20
228
+ ds = torch.where(dok, den, torch.ones_like(den))
229
+ z = z - torch.where(dok, float(deg) / ds, torch.zeros_like(den))
230
+
231
+ roots[:, ri] = z
232
+ extraction_order[:, ri] = float(ri)
233
+
234
+ # Synthetic division (deflation)
235
+ b = cl[:, deg]
236
+ for j in range(deg - 1, 0, -1):
237
+ bn = cl[:, j] + z * b
238
+ cl[:, j] = b
239
+ b = bn
240
+ cl[:, 0] = b
241
+
242
+ # Newton polish on original polynomial
243
+ roots = roots.double()
244
+ for _ in range(self.polish_iters):
245
+ pv = torch.ones(B, n, device=device, dtype=torch.float64)
246
+ dp = torch.zeros(B, n, device=device, dtype=torch.float64)
247
+ for j in range(n - 1, -1, -1):
248
+ dp = dp * roots + pv
249
+ pv = pv * roots + c[:, j:j + 1]
250
+ ok = dp.abs() > 1e-30
251
+ dps = torch.where(ok, dp, torch.ones_like(dp))
252
+ roots = roots - torch.where(ok, pv / dps, torch.zeros_like(pv))
253
+
254
+ # ════════════════════════════════════════════
255
+ # Phase 3: FL adjugate eigenvectors (fp64)
256
+ # ════════════════════════════════════════════
257
+ lam = roots
258
+ R = Mstore[1].unsqueeze(1).expand(-1, n, -1, -1).clone()
259
+ for k in range(2, n + 1):
260
+ R = R * lam[:, :, None, None] + Mstore[k].unsqueeze(1)
261
+ cnorms = R.norm(dim=-2)
262
+ best = cnorms.argmax(dim=-1)
263
+ idx = best.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, n, 1)
264
+ vec = R.gather(-1, idx).squeeze(-1)
265
+ vec = vec / (vec.norm(dim=-1, keepdim=True) + 1e-30)
266
+ V = vec.float().transpose(-2, -1)
267
+
268
+ # ════════════════════════════════════════════
269
+ # Phase 4: Newton-Schulz orthogonalization
270
+ # ════════════════════════════════════════════
271
+ eye_f = torch.eye(n, device=device, dtype=torch.float32).unsqueeze(0).expand(B, -1, -1)
272
+ Y = torch.bmm(V.transpose(-2, -1), V)
273
+ X = eye_f.clone()
274
+ for _ in range(self.ns_iters):
275
+ T = 3.0 * eye_f - Y
276
+ X = 0.5 * torch.bmm(X, T)
277
+ Y = 0.5 * torch.bmm(T, Y)
278
+ V = torch.bmm(V, X)
279
+
280
+ # Telemetry: orthogonality residual after NS
281
+ VtV = torch.bmm(V.transpose(-2, -1), V)
282
+ refinement_residual = (VtV - eye_f).pow(2).sum((-2, -1)).sqrt() # (B,)
283
+
284
+ # ════════════════════════════════════════════
285
+ # Phase 5: Rayleigh quotient refinement
286
+ # ════════════════════════════════════════════
287
+ AV = torch.bmm(A, V)
288
+ evals = (V * AV).sum(dim=-2)
289
+
290
+ se, perm = evals.sort(dim=-1)
291
+ sv = V.gather(-1, perm.unsqueeze(-2).expand_as(V))
292
+
293
+ # Reorder telemetry to match sorted eigenvalue order
294
+ friction_sorted = friction.gather(-1, perm)
295
+ settle_sorted = settle.gather(-1, perm)
296
+ # extraction_order stays as-is β€” it records the ORIGINAL extraction sequence
297
+
298
+ # ════════════════════════════════════════════
299
+ # Gauge canonicalization
300
+ # ════════════════════════════════════════════
301
+ sv = canonicalize_eigenvectors(sv)
302
+
303
+ # ════════════════════════════════════════════
304
+ # Build packet
305
+ # ════════════════════════════════════════════
306
+ packet = ConduitPacket(
307
+ eigenvalues=se,
308
+ eigenvectors=sv,
309
+ char_coeffs=char_coeffs,
310
+ friction=friction_sorted,
311
+ settle=settle_sorted,
312
+ extraction_order=extraction_order,
313
+ refinement_residual=refinement_residual,
314
+ )
315
+
316
+ if self.research:
317
+ packet.mstore = Mstore
318
+ packet.z_trajectory = z_traj
319
+ packet.dp_trajectory = dp_traj
320
+
321
+ return packet
322
+
323
+
324
+ # ── Regression parity test ──
325
+
326
+ def verify_parity(A: Tensor, atol: float = 1e-5) -> bool:
327
+ """Verify FLEighConduit produces identical eigenpairs to FLEigh.
328
+
329
+ Args:
330
+ A: (B, n, n) symmetric test matrices
331
+ atol: absolute tolerance
332
+
333
+ Returns:
334
+ True if eigenpairs match within tolerance
335
+ """
336
+ from geolip_core.linalg.eigh import FLEigh
337
+
338
+ ref_evals, ref_evecs = FLEigh()(A)
339
+ packet = FLEighConduit()(A)
340
+ cond_evals, cond_evecs = packet.eigenpairs()
341
+
342
+ evals_match = torch.allclose(ref_evals, cond_evals, atol=atol)
343
+
344
+ # Eigenvectors may differ by sign β€” compare via absolute inner products
345
+ dots = (ref_evecs * cond_evecs).sum(dim=-2).abs()
346
+ evecs_match = (dots > 1.0 - atol).all()
347
+
348
+ return evals_match and evecs_match