AbstractPhil commited on
Commit
e50ad6c
Β·
verified Β·
1 Parent(s): cc4f091

Update constellation.py

Browse files
Files changed (1) hide show
  1. constellation.py +467 -292
constellation.py CHANGED
@@ -1,39 +1,32 @@
1
  """
2
- Constellation β€” Unified Geometric Observer + Interpreter
3
- ==========================================================
4
- Configurable implementation covering all validated constellation forms.
5
-
6
- PROVEN RESULTS:
7
- Form 1 (Core): 91.5% CIFAR-10 @ 1.6M params, CV=0.2045
8
- Form 5 (Relay): cos_to_orig=0.994 @ depth 16, 8.4Γ— faster than attn @ 131K
9
- Hybrid: 88.0% CIFAR-10 @ 23.5M (conv encoder + constellation)
10
- Scattering v1: 81.9% CIFAR-10 @ 17M (frozen scattering + constellation)
11
-
12
- UNIVERSAL RULES (empirically validated):
13
- - SquaredReLU in all constellation paths, never GELU
14
- - Patchwork: Linear(in, in*2) β†’ SquaredReLU β†’ LN β†’ Linear(in*2, out)
15
- - Gate init: -3.0 (sigmoid β‰ˆ 0.047) for relay/residual forms
16
- - SLERP: acos in fp32, everything else in compute dtype
17
- - Adam, NO weight decay β€” geometry IS regularization
18
- - InfoNCE is alignment FORCE, Procrustes is REGULARIZER
19
- - CV loss on the BOTTLENECK, weight 0.001 or below
20
- - Anchor dropout (30%) prevents collapse in high-anchor configs
21
-
22
- FORMS:
23
- Constellation β€” observation + interpretation, configurable
24
- ConstellationRelay β€” per-token geometric layer with gated residual
25
 
26
  Usage:
27
- from constellation import Constellation, ConstellationRelay
28
 
29
- # Form 1 (Core): single vector per image
30
- c = Constellation(n_anchors=16, dim=16, n_directions=8,
31
- d_comp=64, n_phases=3)
32
- output = c(directions) # (B, 8, 16) β†’ ConstellationOutput
33
-
34
- # Form 5 (Relay): per-token processing
35
- r = ConstellationRelay(dim=256, patch_dim=16, n_anchors=16)
36
- out = r(tokens) # (B, S, 256) β†’ (B, S, 256)
37
  """
38
 
39
  import torch
@@ -41,32 +34,58 @@ import torch.nn as nn
41
  import torch.nn.functional as F
42
  import math
43
  from dataclasses import dataclass
44
- from typing import Optional
45
 
46
 
47
  # ══════════════════════════════════════════════════════════════════
48
- # ACTIVATION
49
  # ══════════════════════════════════════════════════════════════════
50
 
51
  class SquaredReLU(nn.Module):
52
- """x β†’ ReLU(x)Β². Proven superior to GELU in all constellation paths."""
53
  def forward(self, x):
54
  return F.relu(x) ** 2
55
 
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # ══════════════════════════════════════════════════════════════════
58
  # ANCHOR INITIALIZATION
59
  # ═══════════════════════════════════════════════��══════════════════
60
 
61
  def init_anchors_xavier(n, d):
62
- """Xavier normal β†’ normalize. Near-orthogonal in high-d. Used in Core."""
63
  w = torch.empty(n, d)
64
  nn.init.xavier_normal_(w)
65
  return F.normalize(w, dim=-1)
66
 
67
 
68
  def init_anchors_orthogonal(n, d):
69
- """QR decomposition β†’ exact orthonormal basis. Used when n <= d."""
70
  if n <= d:
71
  M = torch.randn(d, n)
72
  Q, _ = torch.linalg.qr(M)
@@ -80,7 +99,7 @@ def init_anchors_orthogonal(n, d):
80
 
81
 
82
  def init_anchors_repulsion(n, d, iters=200, lr=0.05):
83
- """QR + iterative repulsion for even coverage beyond d anchors."""
84
  vecs = init_anchors_orthogonal(n, d)
85
  vecs = F.normalize(vecs, dim=-1)
86
  for _ in range(iters):
@@ -99,213 +118,330 @@ INIT_METHODS = {
99
 
100
 
101
  # ══════════════════════════════════════════════════════════════════
102
- # OUTPUT
103
  # ══════════════════════════════════════════════════════════════════
104
 
105
- @dataclass
106
- class ConstellationOutput:
107
- """Full output from constellation forward pass."""
108
- embedding: torch.Tensor # (B, pw_dim) β€” interpreted observation
109
- cosines: torch.Tensor # (B, N, A) or (B, N, A*phases)
110
- distances: torch.Tensor # (B, N, A) or (B, N, A*phases)
111
- nearest: torch.Tensor # (B, N) β€” collapsed anchor assignment
112
- directions: torch.Tensor # (B, N, D) β€” input directions on S^(D-1)
113
- tri_flat: torch.Tensor # (B, tri_dim) β€” flattened triangulation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
 
116
  # ══════════════════════════════════════════════════════════════════
117
- # CONSTELLATION β€” observation + interpretation
118
  # ══════════════════════════════════════════════════════════════════
119
 
120
- class Constellation(nn.Module):
121
- """Geometric observer with anchor-aligned interpretation.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
- Anchors on S^(D-1) observe input directions via triangulation.
124
- Compartments interpret per-anchor observations.
125
- SLERP phases provide multi-scale angular measurement.
126
- All coupled through gradient flow.
 
 
 
 
 
 
127
 
128
  Args:
129
- n_anchors: reference directions on S^(D-1)
130
- dim: anchor/direction dimensionality
131
- n_directions: input directions per sample
 
132
  d_comp: hidden dim per compartment
133
- n_phases: SLERP interpolation phases (1=static, 3=proven default)
134
- anchor_init: 'xavier', 'orthogonal', or 'repulsion'
135
- anchor_dropout: fraction of anchors to drop during training (0.3 for soup)
136
- compartment: 'aligned' (one per anchor) or 'flat' (single patchwork)
 
137
  """
138
 
139
  def __init__(
140
  self,
141
- n_anchors: int,
142
- dim: int,
143
- n_directions: int,
144
- d_comp: int = 64,
145
- n_phases: int = 3,
146
- anchor_init: str = 'xavier',
147
- anchor_dropout: float = 0.0,
148
- compartment: str = 'aligned',
 
 
149
  ):
150
  super().__init__()
151
- self.n_anchors = n_anchors
152
  self.dim = dim
153
- self.n_directions = n_directions
154
- self.d_comp = d_comp
155
- self.n_phases = n_phases
156
- self.anchor_dropout = anchor_dropout
157
- self.compartment_type = compartment
158
 
159
- # Anchors: home (frozen) + current (learned)
160
- init_fn = INIT_METHODS[anchor_init]
161
- home = init_fn(n_anchors, dim)
162
- self.register_buffer('home', home)
163
- self.anchors = nn.Parameter(home.clone())
164
-
165
- # Triangulation dimensions
166
- if compartment == 'aligned':
167
- # tri: (B, N, A * phases) β†’ each compartment reads its anchor's column
168
- self.tri_dim = n_directions * n_anchors * n_phases
169
- self.embedding_dim = n_anchors * d_comp
170
-
171
- # One compartment per anchor β€” reads tri[:, :, k] across all phases
172
- # Input: n_directions * n_phases values per anchor
173
- comp_in = n_directions * n_phases
174
- self.compartments = nn.ModuleList([
175
- nn.Sequential(
176
- nn.Linear(comp_in, d_comp * 2),
177
- SquaredReLU(),
178
- nn.Linear(d_comp * 2, d_comp),
179
- nn.LayerNorm(d_comp),
180
- ) for _ in range(n_anchors)
181
- ])
182
- elif compartment == 'flat':
183
- # tri: (B, tri_dim) β†’ single patchwork MLP
184
- self.tri_dim = n_directions * n_anchors * n_phases
185
- self.embedding_dim = dim
186
-
187
- self.patchwork = nn.Sequential(
188
- nn.Linear(self.tri_dim, self.tri_dim * 2),
189
- SquaredReLU(),
190
- nn.LayerNorm(self.tri_dim * 2),
191
- nn.Linear(self.tri_dim * 2, dim),
192
- )
193
- else:
194
- raise ValueError(f"Unknown compartment type: {compartment}")
195
-
196
- self._init_weights()
197
-
198
- def _init_weights(self):
199
- for m in self.modules():
200
- if isinstance(m, nn.Linear):
201
- nn.init.trunc_normal_(m.weight, std=0.02)
202
- if m.bias is not None:
203
- nn.init.zeros_(m.bias)
204
- elif isinstance(m, nn.LayerNorm):
205
- nn.init.ones_(m.weight)
206
- nn.init.zeros_(m.bias)
207
-
208
- def drift(self):
209
- """Geodesic distance between home and learned anchor positions."""
210
- h = F.normalize(self.home.float(), dim=-1)
211
- c = F.normalize(self.anchors.float(), dim=-1)
212
- return torch.acos((h * c).sum(-1).clamp(-1 + 1e-6, 1 - 1e-6))
213
-
214
- def at_phase(self, t):
215
- """SLERP between home and learned positions at phase t ∈ [0, 1]."""
216
- h = F.normalize(self.home.float(), dim=-1)
217
- c = F.normalize(self.anchors.float(), dim=-1)
218
- omega = self.drift().unsqueeze(-1) # (A, 1)
219
- so = omega.sin().clamp(min=1e-6)
220
- return torch.sin((1 - t) * omega) / so * h + torch.sin(t * omega) / so * c
221
-
222
- def _triangulate(self, directions, anchors):
223
- """(B, N, D) Γ— (A, D) β†’ (B, N, A) cosines and distances."""
224
- cos = torch.einsum('bnd,ad->bna', directions, anchors)
225
- return cos, 1.0 - cos
226
-
227
- def forward(self, directions: torch.Tensor) -> ConstellationOutput:
228
- """Observe and interpret.
229
 
230
  Args:
231
- directions: (B, N, D) β€” L2-normalized to S^(D-1)
232
 
233
  Returns:
234
- ConstellationOutput
235
  """
236
- B, N, D = directions.shape
237
-
238
- # Multi-phase triangulation
239
- phases = torch.linspace(0, 1, self.n_phases, device=directions.device).tolist()
240
- all_cos = []
241
- all_dist = []
242
- for t in phases:
243
- anchors_t = F.normalize(self.at_phase(t), dim=-1).to(directions.dtype)
244
-
245
- # Anchor dropout during training
246
- if self.training and self.anchor_dropout > 0:
247
- mask = torch.rand(anchors_t.shape[0], device=anchors_t.device) > self.anchor_dropout
248
- if mask.sum() < 2:
249
- mask[:2] = True
250
- anchors_t = anchors_t[mask]
251
-
252
- cos, dist = self._triangulate(directions, anchors_t)
253
- all_cos.append(cos)
254
- all_dist.append(dist)
255
-
256
- # Stack phases: (B, N, A*phases) if no dropout, variable if dropout
257
- cos_cat = torch.cat(all_cos, dim=-1)
258
- dist_cat = torch.cat(all_dist, dim=-1)
259
-
260
- # Nearest anchor (from phase 0, no dropout)
261
- anchors_0 = F.normalize(self.at_phase(0.0), dim=-1).to(directions.dtype)
262
- cos_0 = torch.einsum('bnd,ad->bna', directions, anchors_0)
263
- nearest = cos_0.max(dim=-1).indices
264
-
265
- # Interpret
266
- if self.compartment_type == 'aligned' and not (self.training and self.anchor_dropout > 0):
267
- # dist_cat: (B, N, A * n_phases)
268
- # Reshape to (B, N, n_phases, A) then (B, A, N * n_phases)
269
- A = self.n_anchors
270
- dist_reshape = dist_cat.reshape(B, N, self.n_phases, A)
271
- # For compartment k: gather distances to anchor k across all directions and phases
272
- # dist_reshape[:, :, :, k] β†’ (B, N, n_phases) β†’ flatten β†’ (B, N*n_phases)
273
- parts = []
274
- for k in range(A):
275
- comp_input = dist_reshape[:, :, :, k].reshape(B, N * self.n_phases)
276
- parts.append(self.compartments[k](comp_input))
277
- embedding = torch.cat(parts, dim=-1) # (B, A * d_comp)
278
- elif self.compartment_type == 'flat' or (self.training and self.anchor_dropout > 0):
279
- tri_flat = dist_cat.reshape(B, -1)
280
- if self.compartment_type == 'flat':
281
- embedding = self.patchwork(tri_flat)
282
- else:
283
- # Fallback for aligned + dropout: pad and use compartments
284
- # This is a training-only path
285
- embedding = torch.zeros(B, self.embedding_dim,
286
- device=directions.device, dtype=directions.dtype)
287
- # Use flat mean as fallback during dropout
288
- for k in range(self.n_anchors):
289
- comp_in_size = self.n_directions * self.n_phases
290
- if tri_flat.shape[1] >= comp_in_size:
291
- chunk = tri_flat[:, :comp_in_size]
292
- else:
293
- chunk = F.pad(tri_flat, (0, comp_in_size - tri_flat.shape[1]))
294
- embedding[:, k * self.d_comp:(k + 1) * self.d_comp] = self.compartments[k](chunk)
295
- else:
296
- tri_flat = dist_cat.reshape(B, -1)
297
- embedding = self.patchwork(tri_flat)
298
-
299
- tri_flat = dist_cat.reshape(B, -1)
300
-
301
- return ConstellationOutput(
302
- embedding=embedding,
303
- cosines=cos_cat,
304
- distances=dist_cat,
305
- nearest=nearest,
306
- directions=directions,
307
- tri_flat=tri_flat,
308
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
 
310
 
311
  # ══════════════════════════════════════════════════════════════════
@@ -313,67 +449,51 @@ class Constellation(nn.Module):
313
  # ══════════════════════════════════════════════════════════════════
314
 
315
  class ConstellationRelay(nn.Module):
316
- """Per-token geometric processing layer with gated residual.
317
 
318
- Replaces attention as a per-token processing layer.
319
- O(S) complexity. No cross-token interaction.
320
- Preserves 99.4% cosine similarity to input at depth 16.
321
 
322
  Pipeline:
323
- LayerNorm β†’ chunk D into patches β†’ L2 norm per patch
324
- β†’ Constellation observation + interpretation
325
- β†’ Project back to D β†’ gated residual
326
 
327
  Args:
328
- dim: token dimension (must be divisible by patch_dim)
329
- patch_dim: dimension per patch subspace (default 16)
330
- n_anchors: anchors per patch subspace
331
  d_comp: hidden dim per compartment
332
- n_phases: SLERP phases
333
- gate_init: initial gate bias (default -3.0 β†’ sigmoid β‰ˆ 0.047)
334
  anchor_init: initialization method
 
335
  """
336
 
337
  def __init__(
338
  self,
339
- dim: int,
340
- patch_dim: int = 16,
341
- n_anchors: int = 16,
342
- d_comp: int = 64,
343
- n_phases: int = 3,
344
- gate_init: float = -3.0,
345
- anchor_init: str = 'xavier',
346
  ):
347
  super().__init__()
348
- assert dim % patch_dim == 0
349
  self.dim = dim
350
- self.patch_dim = patch_dim
351
- self.n_patches = dim // patch_dim
352
-
353
  self.norm = nn.LayerNorm(dim)
354
 
355
- # Constellation operates on (B*S, n_patches, patch_dim)
356
  self.constellation = Constellation(
357
- n_anchors=n_anchors,
358
- dim=patch_dim,
359
- n_directions=self.n_patches,
360
- d_comp=d_comp,
361
- n_phases=n_phases,
362
- anchor_init=anchor_init,
363
- compartment='aligned',
364
- )
365
 
366
- # Project constellation embedding back to token dim
367
- self.proj = nn.Linear(self.constellation.embedding_dim, dim)
368
 
369
- # Gated residual β€” init at -3.0 so gate starts near 0
370
  self.gate = nn.Parameter(torch.full((dim,), gate_init))
371
 
372
- def forward(self, x: torch.Tensor) -> torch.Tensor:
373
- """
374
- x: (B, S, D) or (B, D)
375
- Returns: same shape as input
376
- """
377
  squeeze = False
378
  if x.dim() == 2:
379
  x = x.unsqueeze(1)
@@ -382,21 +502,14 @@ class ConstellationRelay(nn.Module):
382
  B, S, D = x.shape
383
  residual = x
384
 
385
- # Normalize
386
  h = self.norm(x)
387
-
388
- # Chunk into patches and normalize to S^(patch_dim-1)
389
- h_flat = h.reshape(B * S, self.n_patches, self.patch_dim)
390
  h_flat = F.normalize(h_flat, dim=-1)
391
 
392
- # Constellation: observe + interpret
393
- output = self.constellation(h_flat)
394
-
395
- # Project back to token dim
396
- update = self.proj(output.embedding) # (B*S, D)
397
- update = update.reshape(B, S, D)
398
 
399
- # Gated residual
400
  g = torch.sigmoid(self.gate)
401
  out = residual + g * update
402
 
@@ -406,11 +519,11 @@ class ConstellationRelay(nn.Module):
406
 
407
 
408
  # ══════════════════════════════════════════════════════════════════
409
- # GEOMETRIC OPS β€” measurement tools
410
  # ══════════════════════════════════════════════════════════════════
411
 
412
  class GeometricOps:
413
- """Static geometric utilities for constellation monitoring and loss."""
414
 
415
  @staticmethod
416
  def cayley_menger_vol2(points):
@@ -428,6 +541,7 @@ class GeometricOps:
428
  return sign * torch.linalg.det(cm.float()).to(points.dtype) / ((2 ** k) * (fact ** 2))
429
 
430
  @staticmethod
 
431
  def cv_metric(emb, n_samples=200, n_points=5):
432
  """Non-differentiable CV for monitoring. Target band: 0.20–0.23."""
433
  vols = []
@@ -442,35 +556,96 @@ class GeometricOps:
442
  return (vols_t.std() / (vols_t.mean() + 1e-8)).item()
443
 
444
  @staticmethod
445
- def cv_loss(emb, target=0.22, n_samples=100, n_points=5):
446
- """Differentiable CV loss. Weight: 0.001 or below."""
 
 
 
447
  vols = []
448
  for _ in range(n_samples):
449
- idx = torch.randperm(min(emb.shape[0], 512))[:n_points]
450
- v2 = GeometricOps.cayley_menger_vol2(emb[idx].unsqueeze(0))
451
- if v2[0] > 1e-20:
452
- vols.append(v2[0].sqrt())
 
 
 
 
 
 
 
 
 
 
453
  if len(vols) < 5:
454
  return torch.tensor(0.0, device=emb.device)
455
- vols_t = torch.stack(vols)
456
- cv = vols_t.std() / (vols_t.mean() + 1e-8)
457
  return (cv - target).pow(2)
458
 
459
  @staticmethod
460
  def anchor_spread_loss(anchors, target_cos=0.0):
461
- """Repulsion loss keeping anchors spread on the sphere."""
462
  a = F.normalize(anchors, dim=-1)
463
  sim = a @ a.T
464
  mask = ~torch.eye(a.shape[0], dtype=torch.bool, device=a.device)
465
  return F.relu(sim[mask] - target_cos).mean()
466
 
467
  @staticmethod
468
- def diagnostics(output: ConstellationOutput, n_anchors: int) -> dict:
469
- """Compute diagnostic metrics."""
470
- diag = {}
471
- diag['n_active'] = output.nearest.flatten().unique().numel()
472
- counts = torch.bincount(output.nearest.flatten(), minlength=n_anchors).float()
473
- diag['anchor_util_std'] = counts.std().item()
474
- diag['nearest_cos'] = output.cosines[:, :, :n_anchors].max(dim=-1).values.mean().item()
475
- diag['mean_tri'] = output.distances.mean().item()
476
- return diag
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ Constellation β€” Geometric Observer + Interpreter
3
+ ===================================================
4
+ Aligned to the proven GeoLIP Core trainer (91.2% CIFAR-10 @ 1.65M params).
5
+
6
+ Architecture:
7
+ emb @ anchors.T β†’ 64 distances β†’ 8 round-robin compartments β†’ cat(pw, emb) β†’ classifier
8
+
9
+ Key mechanisms:
10
+ - Round-robin compartments: 8 groups of 8 anchors, diverse measurements per group
11
+ - cat(patchwork, embedding): classifier sees both interpreted distances AND raw position
12
+ - Anchor push: direct centroid placement every N batches (self-distillation across time)
13
+ - Attraction loss: pulls embeddings toward nearest anchor
14
+ - InfoNCE on two views: alignment force
15
+ - Simple triangulation: emb @ anchors.T, no SLERP, no phases
16
+
17
+ Classes:
18
+ Constellation β€” triangulation against anchors on S^(d-1)
19
+ Patchwork β€” round-robin compartmentalized interpretation
20
+ ConstellationCore β€” full pipeline: constellation + patchwork + classifier
21
+ GeometricOps β€” CV, spread, Cayley-Menger utilities
22
+ GeometricAutograd β€” Form 12 manifold-aware gradient correction
 
 
23
 
24
  Usage:
25
+ from constellation import ConstellationCore
26
 
27
+ model = ConstellationCore(num_classes=10, dim=192, n_anchors=64)
28
+ out = model(images) # dict: logits, embedding, triangulation, nearest, patchwork
29
+ loss, ld = model.compute_loss(out, targets, output_aug=out2)
 
 
 
 
 
30
  """
31
 
32
  import torch
 
34
  import torch.nn.functional as F
35
  import math
36
  from dataclasses import dataclass
37
+ from typing import Optional, Dict, Any
38
 
39
 
40
  # ══════════════════════════════════════════════════════════════════
41
+ # ACTIVATIONS
42
  # ══════════════════════════════════════════════════════════════════
43
 
44
  class SquaredReLU(nn.Module):
45
+ """x β†’ ReLU(x)Β². Proven #1 in bulk activation tests."""
46
  def forward(self, x):
47
  return F.relu(x) ** 2
48
 
49
 
50
+ class StarReLU(nn.Module):
51
+ """x β†’ (ReLU(x))Β² * scale + bias. Runner-up in bulk tests."""
52
+ def __init__(self):
53
+ super().__init__()
54
+ self.scale = nn.Parameter(torch.ones(1) * 0.8944)
55
+ self.bias = nn.Parameter(torch.zeros(1) - 0.4472)
56
+ def forward(self, x):
57
+ return F.relu(x) ** 2 * self.scale + self.bias
58
+
59
+
60
+ ACTIVATIONS = {
61
+ 'squared_relu': SquaredReLU,
62
+ 'star_relu': StarReLU,
63
+ 'gelu': lambda: nn.GELU(),
64
+ 'relu': lambda: nn.ReLU(),
65
+ 'sigmoid': lambda: nn.Sigmoid(),
66
+ }
67
+
68
+
69
+ def make_activation(name='squared_relu'):
70
+ """Create activation by name."""
71
+ if name not in ACTIVATIONS:
72
+ raise ValueError(f"Unknown activation '{name}'. Choose from: {list(ACTIVATIONS.keys())}")
73
+ return ACTIVATIONS[name]()
74
+
75
+
76
  # ══════════════════════════════════════════════════════════════════
77
  # ANCHOR INITIALIZATION
78
  # ═══════════════════════════════════════════════��══════════════════
79
 
80
  def init_anchors_xavier(n, d):
81
+ """Xavier normal β†’ normalize. Near-orthogonal in high-d."""
82
  w = torch.empty(n, d)
83
  nn.init.xavier_normal_(w)
84
  return F.normalize(w, dim=-1)
85
 
86
 
87
  def init_anchors_orthogonal(n, d):
88
+ """QR decomposition β†’ exact orthonormal basis when n <= d."""
89
  if n <= d:
90
  M = torch.randn(d, n)
91
  Q, _ = torch.linalg.qr(M)
 
99
 
100
 
101
  def init_anchors_repulsion(n, d, iters=200, lr=0.05):
102
+ """QR + iterative repulsion for even coverage. Used in proven Core."""
103
  vecs = init_anchors_orthogonal(n, d)
104
  vecs = F.normalize(vecs, dim=-1)
105
  for _ in range(iters):
 
118
 
119
 
120
  # ══════════════════════════════════════════════════════════════════
121
+ # CONSTELLATION β€” triangulation on S^(d-1)
122
  # ══════════════════════════════════════════════════════════════════
123
 
124
+ class Constellation(nn.Module):
125
+ """Anchors on S^(d-1). Triangulates input embeddings.
126
+
127
+ Simple: emb @ anchors.T β†’ cosines β†’ distances.
128
+ No SLERP, no phases, no home/learned split.
129
+
130
+ Args:
131
+ n_anchors: number of reference points on S^(d-1)
132
+ dim: dimensionality of the sphere
133
+ anchor_drop: fraction to drop during training (0.15 proven)
134
+ anchor_init: 'repulsion', 'xavier', or 'orthogonal'
135
+ """
136
+
137
+ def __init__(self, n_anchors, dim, anchor_drop=0.0, anchor_init='repulsion'):
138
+ super().__init__()
139
+ init_fn = INIT_METHODS[anchor_init]
140
+ self.anchors = nn.Parameter(init_fn(n_anchors, dim))
141
+ self.anchor_drop = anchor_drop
142
+ self.n_anchors = n_anchors
143
+ self.dim = dim
144
+
145
+ def triangulate(self, emb, training=False):
146
+ """emb: (B, D) L2-normalized β†’ (tri, nearest).
147
+
148
+ tri: (B, A) angular distances to all anchors
149
+ nearest: (B,) index of closest anchor
150
+ """
151
+ anchors = F.normalize(self.anchors, dim=-1)
152
+
153
+ if training and self.anchor_drop > 0:
154
+ mask = torch.rand(anchors.shape[0], device=anchors.device) > self.anchor_drop
155
+ if mask.sum() < 2:
156
+ mask[:2] = True
157
+ anchors_drop = anchors[mask]
158
+ cos = emb @ anchors_drop.T
159
+ tri = 1.0 - cos
160
+ _, nearest_local = cos.max(dim=-1)
161
+ nearest = mask.nonzero(as_tuple=True)[0][nearest_local]
162
+ else:
163
+ cos = emb @ anchors.T
164
+ tri = 1.0 - cos
165
+ _, nearest = cos.max(dim=-1)
166
+
167
+ return tri, nearest
168
+
169
+ def forward(self, emb, training=False):
170
+ return self.triangulate(emb, training=training)
171
 
172
 
173
  # ══════════════════════════════════════════════════════════════════
174
+ # PATCHWORK β€” round-robin compartmentalized interpretation
175
  # ══════════════════════════════════════════════════════════════════
176
 
177
+ class Patchwork(nn.Module):
178
+ """Round-robin compartments reading diverse anchor subsets.
179
+
180
+ 64 anchors, 8 compartments β†’ each reads 8 anchors.
181
+ Assignment: anchor k goes to compartment (k % n_comp).
182
+ Each compartment: Linear(anchors_per, d_comp*2) β†’ act β†’ Linear β†’ LN β†’ d_comp
183
+
184
+ Args:
185
+ n_anchors: total anchors (must be divisible by n_comp)
186
+ n_comp: number of compartments
187
+ d_comp: output dim per compartment
188
+ activation: activation function name
189
+ """
190
+
191
+ def __init__(self, n_anchors, n_comp=8, d_comp=64, activation='squared_relu'):
192
+ super().__init__()
193
+ self.n_comp = n_comp
194
+ self.d_comp = d_comp
195
+ self.output_dim = n_comp * d_comp
196
+
197
+ # Round-robin assignment: anchor k β†’ compartment (k % n_comp)
198
+ self.register_buffer('asgn', torch.arange(n_anchors) % n_comp)
199
+ anchors_per = n_anchors // n_comp
200
+
201
+ self.comps = nn.ModuleList([
202
+ nn.Sequential(
203
+ nn.Linear(anchors_per, d_comp * 2),
204
+ make_activation(activation),
205
+ nn.Linear(d_comp * 2, d_comp),
206
+ nn.LayerNorm(d_comp),
207
+ ) for _ in range(n_comp)
208
+ ])
209
+
210
+ def forward(self, tri):
211
+ """tri: (B, n_anchors) β†’ (B, n_comp * d_comp)"""
212
+ return torch.cat([
213
+ self.comps[k](tri[:, self.asgn == k])
214
+ for k in range(self.n_comp)
215
+ ], dim=-1)
216
 
217
+
218
+ # ══════════════════════════════════════════════════════════════════
219
+ # CONSTELLATION CORE β€” full pipeline
220
+ # ══════════════════════════════════════════════════════════════════
221
+
222
+ class ConstellationCore(nn.Module):
223
+ """Constellation + Patchwork + Classifier.
224
+
225
+ Forward returns dict with all outputs for downstream consumers.
226
+ Classifier reads cat(patchwork, embedding).
227
 
228
  Args:
229
+ num_classes: classification targets
230
+ dim: embedding dimension (encoder output)
231
+ n_anchors: anchors on S^(dim-1)
232
+ n_comp: patchwork compartments
233
  d_comp: hidden dim per compartment
234
+ anchor_drop: training dropout rate for anchors
235
+ anchor_init: initialization method
236
+ activation: activation for patchwork compartments
237
+ cv_target: target CV for geometric loss
238
+ infonce_temp: temperature for InfoNCE
239
  """
240
 
241
  def __init__(
242
  self,
243
+ num_classes=10,
244
+ dim=192,
245
+ n_anchors=64,
246
+ n_comp=8,
247
+ d_comp=64,
248
+ anchor_drop=0.15,
249
+ anchor_init='repulsion',
250
+ activation='squared_relu',
251
+ cv_target=0.22,
252
+ infonce_temp=0.07,
253
  ):
254
  super().__init__()
255
+ self.num_classes = num_classes
256
  self.dim = dim
257
+ self.cv_target = cv_target
258
+ self.infonce_temp = infonce_temp
 
 
 
259
 
260
+ self.config = {k: v for k, v in locals().items()
261
+ if k != 'self' and not k.startswith('_')}
262
+
263
+ self.constellation = Constellation(
264
+ n_anchors, dim, anchor_drop, anchor_init)
265
+
266
+ self.patchwork = Patchwork(
267
+ n_anchors, n_comp, d_comp, activation)
268
+
269
+ pw_dim = self.patchwork.output_dim
270
+
271
+ # Classifier reads cat(patchwork, embedding)
272
+ self.classifier = nn.Sequential(
273
+ nn.Linear(pw_dim + dim, pw_dim),
274
+ make_activation(activation),
275
+ nn.LayerNorm(pw_dim),
276
+ nn.Dropout(0.1),
277
+ nn.Linear(pw_dim, num_classes),
278
+ )
279
+
280
+ def forward(self, emb_normalized):
281
+ """Forward pass on L2-normalized embeddings.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
  Args:
284
+ emb_normalized: (B, D) already on S^(d-1)
285
 
286
  Returns:
287
+ dict with: logits, embedding, triangulation, nearest, patchwork
288
  """
289
+ emb = emb_normalized
290
+
291
+ # Full triangulation for patchwork
292
+ tri, nearest = self.constellation.triangulate(emb, training=False)
293
+ pw = self.patchwork(tri)
294
+
295
+ # Dropout version for nearest tracking only
296
+ if self.training:
297
+ _, nearest = self.constellation.triangulate(emb, training=True)
298
+
299
+ # Classifier sees BOTH patchwork interpretation AND raw position
300
+ logits = self.classifier(torch.cat([pw, emb], dim=-1))
301
+
302
+ return {
303
+ 'logits': logits,
304
+ 'embedding': emb,
305
+ 'triangulation': tri,
306
+ 'nearest': nearest,
307
+ 'patchwork': pw,
308
+ }
309
+
310
+ def compute_loss(self, output, targets, output_aug=None):
311
+ """Compute all losses.
312
+
313
+ Args:
314
+ output: dict from forward()
315
+ targets: (B,) class indices
316
+ output_aug: optional dict from forward() on second view
317
+
318
+ Returns:
319
+ (total_loss, loss_dict)
320
+ """
321
+ ld = {}
322
+ emb = output['embedding']
323
+ B = emb.shape[0]
324
+
325
+ # CE classification
326
+ l_ce = F.cross_entropy(output['logits'], targets)
327
+ ld['ce'] = l_ce
328
+ ld['acc'] = (output['logits'].argmax(-1) == targets).float().mean().item()
329
+
330
+ # InfoNCE between augmented views
331
+ if output_aug is not None:
332
+ emb_aug = output_aug['embedding']
333
+ labels_nce = torch.arange(B, device=emb.device)
334
+ sim = emb @ emb_aug.T / self.infonce_temp
335
+ l_nce = F.cross_entropy(sim, labels_nce)
336
+ nce_acc = (sim.argmax(1) == labels_nce).float().mean().item()
337
+ ld['nce'] = l_nce
338
+ ld['nce_acc'] = nce_acc
339
+
340
+ # Anchor attraction: pull embeddings toward nearest anchor
341
+ anchors_n = F.normalize(self.constellation.anchors, dim=-1)
342
+ cos_to_anchors = emb @ anchors_n.T
343
+ nearest_cos = cos_to_anchors.max(dim=1).values
344
+ l_attract = (1.0 - nearest_cos).mean()
345
+ ld['attract'] = l_attract
346
+ ld['nearest_cos'] = nearest_cos.mean().item()
347
+
348
+ # CV on embeddings
349
+ l_cv = GeometricOps.cv_loss(emb, target=self.cv_target)
350
+ ld['cv'] = l_cv
351
+
352
+ # Anchor spread
353
+ l_spread = GeometricOps.anchor_spread_loss(self.constellation.anchors)
354
+ ld['spread'] = l_spread
355
+
356
+ # Total
357
+ loss = (l_ce
358
+ + ld.get('nce', 0.0) * 1.0
359
+ + l_attract * 0.5
360
+ + l_cv * 0.01
361
+ + l_spread * 0.001)
362
+ ld['total'] = loss
363
+ return loss, ld
364
+
365
+ @torch.no_grad()
366
+ def push_anchors_to_centroids(self, emb_buffer, label_buffer, lr=0.1):
367
+ """Push anchors toward class centroids β€” self-distillation across time.
368
+
369
+ Phase 1: Compute class centroids from labels
370
+ Phase 2: Greedy-assign anchors to classes (round-robin capacity)
371
+ Phase 3: SLERP each anchor toward its class centroid with perpendicular
372
+ perturbation so co-class anchors don't collapse
373
+
374
+ Args:
375
+ emb_buffer: (N, D) accumulated embeddings
376
+ label_buffer: (N,) class labels
377
+ lr: blend rate toward centroid
378
+
379
+ Returns:
380
+ number of anchors moved
381
+ """
382
+ anchors = self.constellation.anchors.data
383
+ n_a = anchors.shape[0]
384
+ emb_n = F.normalize(emb_buffer, dim=-1)
385
+ device = anchors.device
386
+
387
+ # Phase 1: class centroids
388
+ classes = label_buffer.unique()
389
+ n_cls = classes.shape[0]
390
+ centroids = []
391
+ for c in classes:
392
+ mask = label_buffer == c
393
+ if mask.sum() > 0:
394
+ centroids.append(
395
+ F.normalize(emb_n[mask].mean(0, keepdim=True), dim=-1))
396
+ if len(centroids) == 0:
397
+ return 0
398
+ centroids = torch.cat(centroids, dim=0)
399
+
400
+ # Phase 2: greedy anchor-to-class assignment
401
+ anchors_n = F.normalize(anchors, dim=-1)
402
+ cos = anchors_n @ centroids.T
403
+ anchors_per_class = n_a // n_cls
404
+ assigned_class = torch.full((n_a,), -1, dtype=torch.long, device=device)
405
+ class_count = torch.zeros(n_cls, dtype=torch.long, device=device)
406
+
407
+ _, flat_idx = cos.flatten().sort(descending=True)
408
+ for idx in flat_idx:
409
+ a = (idx // n_cls).item()
410
+ c = (idx % n_cls).item()
411
+ if assigned_class[a] >= 0:
412
+ continue
413
+ if class_count[c] >= anchors_per_class + 1:
414
+ continue
415
+ assigned_class[a] = c
416
+ class_count[c] += 1
417
+ if (assigned_class >= 0).all():
418
+ break
419
+
420
+ # Unassigned leftovers
421
+ unassigned = (assigned_class < 0).nonzero(as_tuple=True)[0]
422
+ if len(unassigned) > 0:
423
+ leftover_cos = anchors_n[unassigned] @ centroids.T
424
+ assigned_class[unassigned] = leftover_cos.argmax(dim=1)
425
+
426
+ # Phase 3: push with perpendicular perturbation
427
+ moved = 0
428
+ for a in range(n_a):
429
+ c = assigned_class[a].item()
430
+ target = centroids[c]
431
+
432
+ rank_in_class = (assigned_class[:a] == c).sum().item()
433
+ if anchors_per_class > 1 and rank_in_class > 0:
434
+ noise = torch.randn_like(target) * 0.05
435
+ noise = noise - (noise * target).sum() * target
436
+ target = F.normalize(
437
+ (target + noise).unsqueeze(0), dim=-1).squeeze(0)
438
+
439
+ anchors[a] = F.normalize(
440
+ (anchors_n[a] + lr * (target - anchors_n[a])).unsqueeze(0),
441
+ dim=-1).squeeze(0)
442
+ moved += 1
443
+
444
+ return moved
445
 
446
 
447
  # ══════════════════════════════════════════════════════════════════
 
449
  # ══════════════════════════════════════════════════════════════════
450
 
451
  class ConstellationRelay(nn.Module):
452
+ """Per-token geometric processing with gated residual.
453
 
454
+ O(S) complexity. Preserves 99.4% cos similarity at depth 16.
 
 
455
 
456
  Pipeline:
457
+ LayerNorm β†’ L2 normalize β†’ triangulate β†’ patchwork β†’ project β†’ gated residual
 
 
458
 
459
  Args:
460
+ dim: token dimension
461
+ n_anchors: anchors on S^(dim-1)
462
+ n_comp: patchwork compartments
463
  d_comp: hidden dim per compartment
464
+ gate_init: initial gate bias (-3.0 β†’ sigmoid β‰ˆ 0.047)
 
465
  anchor_init: initialization method
466
+ activation: activation function name
467
  """
468
 
469
  def __init__(
470
  self,
471
+ dim,
472
+ n_anchors=16,
473
+ n_comp=8,
474
+ d_comp=64,
475
+ gate_init=-3.0,
476
+ anchor_init='repulsion',
477
+ activation='squared_relu',
478
  ):
479
  super().__init__()
 
480
  self.dim = dim
 
 
 
481
  self.norm = nn.LayerNorm(dim)
482
 
 
483
  self.constellation = Constellation(
484
+ n_anchors, dim, anchor_init=anchor_init)
485
+
486
+ self.patchwork = Patchwork(
487
+ n_anchors, n_comp, d_comp, activation)
 
 
 
 
488
 
489
+ # Project patchwork back to token dim
490
+ self.proj = nn.Linear(self.patchwork.output_dim, dim)
491
 
492
+ # Gated residual
493
  self.gate = nn.Parameter(torch.full((dim,), gate_init))
494
 
495
+ def forward(self, x):
496
+ """x: (B, S, D) or (B, D) β†’ same shape."""
 
 
 
497
  squeeze = False
498
  if x.dim() == 2:
499
  x = x.unsqueeze(1)
 
502
  B, S, D = x.shape
503
  residual = x
504
 
 
505
  h = self.norm(x)
506
+ h_flat = h.reshape(B * S, D)
 
 
507
  h_flat = F.normalize(h_flat, dim=-1)
508
 
509
+ tri, _ = self.constellation.triangulate(h_flat)
510
+ pw = self.patchwork(tri)
511
+ update = self.proj(pw).reshape(B, S, D)
 
 
 
512
 
 
513
  g = torch.sigmoid(self.gate)
514
  out = residual + g * update
515
 
 
519
 
520
 
521
  # ══════════════════════════════════════════════════════════════════
522
+ # GEOMETRIC OPS
523
  # ══════════════════════════════════════════════════════════════════
524
 
525
  class GeometricOps:
526
+ """Static geometric utilities."""
527
 
528
  @staticmethod
529
  def cayley_menger_vol2(points):
 
541
  return sign * torch.linalg.det(cm.float()).to(points.dtype) / ((2 ** k) * (fact ** 2))
542
 
543
  @staticmethod
544
+ @torch.no_grad()
545
  def cv_metric(emb, n_samples=200, n_points=5):
546
  """Non-differentiable CV for monitoring. Target band: 0.20–0.23."""
547
  vols = []
 
556
  return (vols_t.std() / (vols_t.mean() + 1e-8)).item()
557
 
558
  @staticmethod
559
+ def cv_loss(emb, target=0.22, n_samples=64, n_points=5):
560
+ """Differentiable CV loss. Weight: 0.01 or below."""
561
+ B = emb.shape[0]
562
+ if B < n_points:
563
+ return torch.tensor(0.0, device=emb.device)
564
  vols = []
565
  for _ in range(n_samples):
566
+ idx = torch.randperm(min(B, 512), device=emb.device)[:n_points]
567
+ pts = emb[idx].unsqueeze(0)
568
+ gram = torch.bmm(pts, pts.transpose(1, 2))
569
+ norms = torch.diagonal(gram, dim1=1, dim2=2)
570
+ d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
571
+ d2 = F.relu(d2)
572
+ N = n_points
573
+ cm = torch.zeros(1, N + 1, N + 1, device=emb.device, dtype=emb.dtype)
574
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
575
+ k = N - 1
576
+ pf = ((-1.0) ** (k + 1)) / ((2.0 ** k) * (math.factorial(k) ** 2))
577
+ v2 = pf * torch.linalg.det(cm.float())
578
+ if v2[0].item() > 1e-20:
579
+ vols.append(v2[0].to(emb.dtype).sqrt())
580
  if len(vols) < 5:
581
  return torch.tensor(0.0, device=emb.device)
582
+ vt = torch.stack(vols)
583
+ cv = vt.std() / (vt.mean() + 1e-8)
584
  return (cv - target).pow(2)
585
 
586
  @staticmethod
587
  def anchor_spread_loss(anchors, target_cos=0.0):
588
+ """Repulsion loss keeping anchors spread."""
589
  a = F.normalize(anchors, dim=-1)
590
  sim = a @ a.T
591
  mask = ~torch.eye(a.shape[0], dtype=torch.bool, device=a.device)
592
  return F.relu(sim[mask] - target_cos).mean()
593
 
594
  @staticmethod
595
+ def diagnostics(constellation, emb):
596
+ """Compute health metrics from a constellation and embeddings."""
597
+ tri, nearest = constellation.triangulate(emb, training=False)
598
+ n_active = nearest.unique().numel()
599
+ anchors_n = F.normalize(constellation.anchors, dim=-1)
600
+ cos_to_anchors = emb @ anchors_n.T
601
+ nearest_cos = cos_to_anchors.max(dim=1).values.mean().item()
602
+ counts = torch.bincount(nearest, minlength=constellation.n_anchors).float()
603
+ return {
604
+ 'n_active': n_active,
605
+ 'nearest_cos': nearest_cos,
606
+ 'anchor_util_std': counts.std().item(),
607
+ 'anchor_util_min': counts.min().item(),
608
+ 'anchor_util_max': counts.max().item(),
609
+ }
610
+
611
+
612
+ # ══════════════════════════════════════════════════════════════════
613
+ # GEOMETRIC AUTOGRAD β€” Form 12
614
+ # ══════════════════════════════════════════════════════════════════
615
+
616
+ class GeometricAutograd(torch.autograd.Function):
617
+ """Manifold-aware gradient correction on S^(D-1).
618
+
619
+ Forward: identity.
620
+ Backward: tangential projection + separation from nearest anchor.
621
+
622
+ Proven settings: tang=0.01, sep=1.0
623
+ """
624
+
625
+ @staticmethod
626
+ def forward(ctx, emb, anchors, tang_strength, sep_strength):
627
+ ctx.save_for_backward(emb, anchors)
628
+ ctx.tang = tang_strength
629
+ ctx.sep = sep_strength
630
+ return emb
631
+
632
+ @staticmethod
633
+ def backward(ctx, grad):
634
+ emb, anchors = ctx.saved_tensors
635
+ tang = ctx.tang
636
+ sep = ctx.sep
637
+
638
+ dot = (grad * emb).sum(dim=-1, keepdim=True)
639
+ radial = dot * emb
640
+ tangential = grad - radial
641
+ corrected = tangential + (1.0 - tang) * radial
642
+
643
+ if sep > 0:
644
+ anchors_n = F.normalize(anchors.detach(), dim=-1)
645
+ cos_to_anchors = emb @ anchors_n.T
646
+ nearest_idx = cos_to_anchors.argmax(dim=-1)
647
+ nearest = anchors_n[nearest_idx]
648
+ toward = (corrected * nearest).sum(dim=-1, keepdim=True)
649
+ corrected = corrected - sep * F.relu(toward) * nearest
650
+
651
+ return corrected, None, None, None