AbstractPhil commited on
Commit
7e37f31
·
verified ·
1 Parent(s): cd017a6

Create cell2_model_v10.py

Browse files
Files changed (1) hide show
  1. cell2_model_v10.py +350 -0
cell2_model_v10.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Superposition Patch Classifier - Two-Tier Gated Transformer
3
+ =============================================================
4
+ Colab Cell 2 of 3 - depends on Cell 1 (generator.py) namespace.
5
+
6
+ Architecture:
7
+ voxels → patch_embed → e₀
8
+
9
+ Stage 0 (local gates): From raw embeddings, no attention
10
+ e₀ → local_dim_head → dim_soft ─┐
11
+ e₀ → local_curv_head → curv_soft ─┤ LOCAL_GATE_DIM = 11
12
+ e₀ → local_bound_head → bound_soft ─┤
13
+ e₀ → local_axis_head → axis_soft ─┘→ local_gates (detached)
14
+
15
+ Stage 1 (bootstrap): Attention sees local gates
16
+ proj([e₀, local_gates]) → bootstrap_block × N → h
17
+
18
+ Stage 1.5 (structural gates): From h, after cross-patch context
19
+ h → struct_topo_head → topo_soft ─┐
20
+ h → struct_neighbor_head → neighbor_soft ─┤ STRUCTURAL_GATE_DIM = 6
21
+ h → struct_role_head → role_soft ─┘→ structural_gates (detached)
22
+
23
+ Stage 2 (geometric routing): Both gate tiers
24
+ (h, local_gates, structural_gates) → geometric_block × N → h'
25
+
26
+ Stage 3 (classification): Gated shape heads
27
+ [h', local_gates, structural_gates] → shape_heads
28
+ """
29
+
30
+ import math
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+
35
+ # Cell 1 provides: all constants including LOCAL_GATE_DIM, STRUCTURAL_GATE_DIM, TOTAL_GATE_DIM
36
+
37
+
38
+ # === Patch Embedding ==========================================================
39
+
40
+ class PatchEmbedding3D(nn.Module):
41
+ def __init__(self, patch_dim=64):
42
+ super().__init__()
43
+ self.proj = nn.Linear(PATCH_VOL, patch_dim)
44
+ pz = torch.arange(MACRO_Z).float() / MACRO_Z
45
+ py = torch.arange(MACRO_Y).float() / MACRO_Y
46
+ px = torch.arange(MACRO_X).float() / MACRO_X
47
+ pos = torch.stack(torch.meshgrid(pz, py, px, indexing='ij'), dim=-1).reshape(MACRO_N, 3)
48
+ self.register_buffer('pos_embed', pos)
49
+ self.pos_proj = nn.Linear(3, patch_dim)
50
+
51
+ def forward(self, x):
52
+ B = x.shape[0]
53
+ patches = x.view(B, MACRO_Z, PATCH_Z, MACRO_Y, PATCH_Y, MACRO_X, PATCH_X)
54
+ patches = patches.permute(0, 1, 3, 5, 2, 4, 6).contiguous().view(B, MACRO_N, PATCH_VOL)
55
+ return self.proj(patches) + self.pos_proj(self.pos_embed)
56
+
57
+
58
+ # === Standard Transformer Block ===============================================
59
+
60
+ class TransformerBlock(nn.Module):
61
+ def __init__(self, dim, n_heads, dropout=0.1):
62
+ super().__init__()
63
+ self.attn = nn.MultiheadAttention(dim, n_heads, dropout=dropout, batch_first=True)
64
+ self.ff = nn.Sequential(
65
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Dropout(dropout),
66
+ nn.Linear(dim * 4, dim), nn.Dropout(dropout)
67
+ )
68
+ self.ln1, self.ln2 = nn.LayerNorm(dim), nn.LayerNorm(dim)
69
+
70
+ def forward(self, x):
71
+ x = x + self.attn(self.ln1(x), self.ln1(x), self.ln1(x))[0]
72
+ return x + self.ff(self.ln2(x))
73
+
74
+
75
+ # === Geometric Gated Attention ================================================
76
+
77
+ class GatedGeometricAttention(nn.Module):
78
+ """
79
+ Multi-head attention with two-tier gate modulation.
80
+ Q, K see both local and structural gates.
81
+ V modulated by combined gate vector.
82
+ Per-head compatibility bias from gate interactions.
83
+ """
84
+
85
+ def __init__(self, embed_dim, gate_dim, n_heads, dropout=0.1):
86
+ super().__init__()
87
+ self.embed_dim = embed_dim
88
+ self.n_heads = n_heads
89
+ self.head_dim = embed_dim // n_heads
90
+
91
+ # Q, K from [h, all_gates]
92
+ self.q_proj = nn.Linear(embed_dim + gate_dim, embed_dim)
93
+ self.k_proj = nn.Linear(embed_dim + gate_dim, embed_dim)
94
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
95
+
96
+ # Per-head gate compatibility
97
+ self.gate_q = nn.Linear(gate_dim, n_heads)
98
+ self.gate_k = nn.Linear(gate_dim, n_heads)
99
+
100
+ # Value modulation by gates
101
+ self.v_gate = nn.Sequential(nn.Linear(gate_dim, embed_dim), nn.Sigmoid())
102
+
103
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
104
+ self.attn_drop = nn.Dropout(dropout)
105
+ self.scale = math.sqrt(self.head_dim)
106
+
107
+ def forward(self, h, gate_features):
108
+ B, N, _ = h.shape
109
+ hg = torch.cat([h, gate_features], dim=-1)
110
+ Q = self.q_proj(hg).view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
111
+ K = self.k_proj(hg).view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
112
+
113
+ V = self.v_proj(h)
114
+ V = (V * self.v_gate(gate_features)).view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
115
+
116
+ content_scores = (Q @ K.transpose(-2, -1)) / self.scale
117
+ gq = self.gate_q(gate_features)
118
+ gk = self.gate_k(gate_features)
119
+ compat = torch.einsum('bih,bjh->bhij', gq, gk)
120
+
121
+ attn = F.softmax(content_scores + compat, dim=-1)
122
+ attn = self.attn_drop(attn)
123
+
124
+ out = (attn @ V).transpose(1, 2).reshape(B, N, self.embed_dim)
125
+ return self.out_proj(out)
126
+
127
+
128
+ class GeometricTransformerBlock(nn.Module):
129
+ def __init__(self, embed_dim, gate_dim, n_heads, dropout=0.1, ff_mult=4):
130
+ super().__init__()
131
+ self.ln1 = nn.LayerNorm(embed_dim)
132
+ self.attn = GatedGeometricAttention(embed_dim, gate_dim, n_heads, dropout)
133
+ self.ln2 = nn.LayerNorm(embed_dim)
134
+ self.ff = nn.Sequential(
135
+ nn.Linear(embed_dim, embed_dim * ff_mult), nn.GELU(), nn.Dropout(dropout),
136
+ nn.Linear(embed_dim * ff_mult, embed_dim), nn.Dropout(dropout)
137
+ )
138
+
139
+ def forward(self, h, gate_features):
140
+ h = h + self.attn(self.ln1(h), gate_features)
141
+ h = h + self.ff(self.ln2(h))
142
+ return h
143
+
144
+
145
+ # === Main Classifier ==========================================================
146
+
147
+ class SuperpositionPatchClassifier(nn.Module):
148
+ """
149
+ Two-tier gated transformer for multi-shape superposition.
150
+
151
+ Tier 1 (local): Gates from raw patch embeddings — what IS in this patch
152
+ Tier 2 (structural): Gates from post-attention h — what ROLE this patch plays
153
+
154
+ Both tiers feed into geometric attention and classification.
155
+ """
156
+
157
+ def __init__(self, embed_dim=128, patch_dim=64, n_bootstrap=2, n_geometric=2,
158
+ n_heads=4, dropout=0.1):
159
+ super().__init__()
160
+ self.embed_dim = embed_dim
161
+
162
+ # Patch embedding
163
+ self.patch_embed = PatchEmbedding3D(patch_dim)
164
+
165
+ # === Stage 0: Local encoder + gate heads (pre-attention) ===
166
+ # Shared MLP gives local heads enough capacity to extract
167
+ # dims/curvature/boundary from 32 voxels without cross-patch info
168
+ local_hidden = patch_dim * 2 # 128
169
+ self.local_encoder = nn.Sequential(
170
+ nn.Linear(patch_dim, local_hidden), nn.GELU(), nn.Dropout(dropout),
171
+ nn.Linear(local_hidden, local_hidden), nn.GELU(), nn.Dropout(dropout),
172
+ )
173
+ self.local_dim_head = nn.Linear(local_hidden, NUM_LOCAL_DIMS)
174
+ self.local_curv_head = nn.Linear(local_hidden, NUM_LOCAL_CURVS)
175
+ self.local_bound_head = nn.Linear(local_hidden, NUM_LOCAL_BOUNDARY)
176
+ self.local_axis_head = nn.Linear(local_hidden, NUM_LOCAL_AXES)
177
+
178
+ # Project [embedding, local_gates] → embed_dim for bootstrap
179
+ self.proj = nn.Linear(patch_dim + LOCAL_GATE_DIM, embed_dim)
180
+
181
+ # === Stage 1: Bootstrap blocks (attention with local gate context) ===
182
+ self.bootstrap_blocks = nn.ModuleList([
183
+ TransformerBlock(embed_dim, n_heads, dropout)
184
+ for _ in range(n_bootstrap)
185
+ ])
186
+
187
+ # === Stage 1.5: Structural gate heads (from h, post-attention) ===
188
+ self.struct_topo_head = nn.Linear(embed_dim, NUM_STRUCT_TOPO)
189
+ self.struct_neighbor_head = nn.Linear(embed_dim, NUM_STRUCT_NEIGHBOR)
190
+ self.struct_role_head = nn.Linear(embed_dim, NUM_STRUCT_ROLE)
191
+
192
+ # === Stage 2: Geometric gated blocks (see both gate tiers) ===
193
+ self.geometric_blocks = nn.ModuleList([
194
+ GeometricTransformerBlock(embed_dim, TOTAL_GATE_DIM, n_heads, dropout)
195
+ for _ in range(n_geometric)
196
+ ])
197
+
198
+ # === Stage 3: Gated classification ===
199
+ gated_dim = embed_dim + TOTAL_GATE_DIM
200
+
201
+ self.patch_shape_head = nn.Sequential(
202
+ nn.Linear(gated_dim, embed_dim), nn.GELU(), nn.Dropout(dropout),
203
+ nn.Linear(embed_dim, NUM_CLASSES)
204
+ )
205
+
206
+ self.global_pool = nn.Sequential(
207
+ nn.Linear(gated_dim, embed_dim), nn.GELU(),
208
+ nn.Linear(embed_dim, embed_dim)
209
+ )
210
+ self.global_gate_head = nn.Linear(embed_dim, NUM_GATES)
211
+ self.global_shape_head = nn.Linear(embed_dim, NUM_CLASSES)
212
+
213
+ def forward(self, x):
214
+ # === Raw patch embedding ===
215
+ e = self.patch_embed(x) # (B, 64, patch_dim)
216
+
217
+ # === Stage 0: Local gates from raw embedding via local encoder ===
218
+ e_local = self.local_encoder(e) # (B, 64, local_hidden)
219
+ local_dim_logits = self.local_dim_head(e_local)
220
+ local_curv_logits = self.local_curv_head(e_local)
221
+ local_bound_logits = self.local_bound_head(e_local)
222
+ local_axis_logits = self.local_axis_head(e_local)
223
+
224
+ local_gates = torch.cat([
225
+ F.softmax(local_dim_logits, dim=-1),
226
+ F.softmax(local_curv_logits, dim=-1),
227
+ torch.sigmoid(local_bound_logits),
228
+ torch.sigmoid(local_axis_logits),
229
+ ], dim=-1) # (B, 64, 11)
230
+
231
+ # === Stage 1: Bootstrap with local gate context ===
232
+ h = self.proj(torch.cat([e, local_gates], dim=-1))
233
+ for blk in self.bootstrap_blocks:
234
+ h = blk(h)
235
+
236
+ # === Stage 1.5: Structural gates from h (after cross-patch context) ===
237
+ struct_topo_logits = self.struct_topo_head(h)
238
+ struct_neighbor_logits = self.struct_neighbor_head(h)
239
+ struct_role_logits = self.struct_role_head(h)
240
+
241
+ structural_gates = torch.cat([
242
+ F.softmax(struct_topo_logits, dim=-1),
243
+ torch.sigmoid(struct_neighbor_logits),
244
+ F.softmax(struct_role_logits, dim=-1),
245
+ ], dim=-1) # (B, 64, 6)
246
+
247
+ # === Combined gate vector ===
248
+ all_gates = torch.cat([local_gates, structural_gates], dim=-1) # (B, 64, 17)
249
+
250
+ # === Stage 2: Geometric gated transformer ===
251
+ for blk in self.geometric_blocks:
252
+ h = blk(h, all_gates)
253
+
254
+ # === Stage 3: Classification from gated representations ===
255
+ h_gated = torch.cat([h, all_gates], dim=-1)
256
+ shape_logits = self.patch_shape_head(h_gated)
257
+ g = self.global_pool(h_gated.mean(dim=1))
258
+
259
+ return {
260
+ # Local gate predictions (Stage 0)
261
+ "local_dim_logits": local_dim_logits,
262
+ "local_curv_logits": local_curv_logits,
263
+ "local_bound_logits": local_bound_logits,
264
+ "local_axis_logits": local_axis_logits,
265
+
266
+ # Structural gate predictions (Stage 1.5)
267
+ "struct_topo_logits": struct_topo_logits,
268
+ "struct_neighbor_logits": struct_neighbor_logits,
269
+ "struct_role_logits": struct_role_logits,
270
+
271
+ # Shape predictions (Stage 3)
272
+ "patch_shape_logits": shape_logits,
273
+ "patch_features": h,
274
+ "global_features": g,
275
+ "global_gates": self.global_gate_head(g),
276
+ "global_shapes": self.global_shape_head(g),
277
+ }
278
+
279
+
280
+ # === Loss =====================================================================
281
+
282
+ class SuperpositionLoss(nn.Module):
283
+ def __init__(self, local_weight=1.0, struct_weight=1.0, shape_weight=1.0, global_weight=0.5):
284
+ super().__init__()
285
+ self.lw, self.sw, self.shw, self.gw = local_weight, struct_weight, shape_weight, global_weight
286
+
287
+ def forward(self, outputs, targets):
288
+ occ_mask = targets["patch_occupancy"] > 0.01
289
+ n_occ = occ_mask.sum().clamp(min=1)
290
+
291
+ # --- Local gate losses ---
292
+ dim_loss = F.cross_entropy(
293
+ outputs["local_dim_logits"].view(-1, NUM_LOCAL_DIMS),
294
+ targets["patch_dims"].clamp(0, NUM_LOCAL_DIMS - 1).view(-1),
295
+ reduction='none').view_as(occ_mask)
296
+ curv_loss = F.cross_entropy(
297
+ outputs["local_curv_logits"].view(-1, NUM_LOCAL_CURVS),
298
+ targets["patch_curvature"].clamp(0, NUM_LOCAL_CURVS - 1).view(-1),
299
+ reduction='none').view_as(occ_mask)
300
+ bound_loss = F.binary_cross_entropy_with_logits(
301
+ outputs["local_bound_logits"].squeeze(-1),
302
+ targets["patch_boundary"],
303
+ reduction='none')
304
+ axis_loss = F.binary_cross_entropy_with_logits(
305
+ outputs["local_axis_logits"],
306
+ targets["patch_axis_active"],
307
+ reduction='none').mean(dim=-1)
308
+
309
+ local_loss = ((dim_loss + curv_loss + bound_loss + axis_loss) * occ_mask.float()).sum() / n_occ
310
+
311
+ # --- Structural gate losses ---
312
+ topo_loss = F.cross_entropy(
313
+ outputs["struct_topo_logits"].view(-1, NUM_STRUCT_TOPO),
314
+ targets["patch_topology"].clamp(0, NUM_STRUCT_TOPO - 1).view(-1),
315
+ reduction='none').view_as(occ_mask)
316
+ neighbor_loss = F.mse_loss(
317
+ torch.sigmoid(outputs["struct_neighbor_logits"].squeeze(-1)),
318
+ targets["patch_neighbor_count"],
319
+ reduction='none')
320
+ role_loss = F.cross_entropy(
321
+ outputs["struct_role_logits"].view(-1, NUM_STRUCT_ROLE),
322
+ targets["patch_surface_role"].clamp(0, NUM_STRUCT_ROLE - 1).view(-1),
323
+ reduction='none').view_as(occ_mask)
324
+
325
+ struct_loss = ((topo_loss + neighbor_loss + role_loss) * occ_mask.float()).sum() / n_occ
326
+
327
+ # --- Shape losses ---
328
+ shape_loss = F.binary_cross_entropy_with_logits(
329
+ outputs["patch_shape_logits"],
330
+ targets["patch_shape_membership"],
331
+ reduction='none').mean(dim=-1)
332
+ shape_loss = (shape_loss * occ_mask.float()).sum() / n_occ
333
+
334
+ # --- Global losses ---
335
+ global_gate_loss = F.binary_cross_entropy_with_logits(outputs["global_gates"], targets["global_gates"])
336
+ global_shape_loss = F.binary_cross_entropy_with_logits(outputs["global_shapes"], targets["global_shapes"])
337
+ global_loss = global_gate_loss + global_shape_loss
338
+
339
+ total = self.lw * local_loss + self.sw * struct_loss + self.shw * shape_loss + self.gw * global_loss
340
+
341
+ return {
342
+ "total": total,
343
+ "local": local_loss,
344
+ "struct": struct_loss,
345
+ "shape": shape_loss,
346
+ "global": global_loss,
347
+ }
348
+
349
+
350
+ print("✓ Model ready (Two-Tier Gated Transformer)")