Jonttup commited on
Commit
2f9ad67
·
verified ·
1 Parent(s): 40c6c49

Upload models/edit_classifier.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/edit_classifier.py +460 -0
models/edit_classifier.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 31-Class Edit Operation Classifier — Neuroswarm Tier 2 Verification Engine
3
+
4
+ Verification stack:
5
+ Tier 1: 33-dim profile cosine similarity (nanoseconds, GPU)
6
+ Tier 2: THIS — edit classifier inference (milliseconds, GPU)
7
+ Tier 3: LLM review (seconds, API call, costs tokens)
8
+
9
+ Pipeline:
10
+ (before_hsl, after_hsl) each (B, H, W, 3)
11
+ → Circular hue encoding: h → (sin(2πh), cos(2πh)), stack with S,L → 4D
12
+ → HSLFeatureExtractor (ViT spatial features)
13
+ → HybridRegionPooler (DETR-style learned queries, no scope markers)
14
+ → Delta computation + fusion
15
+ → Concat: [global_feat, profile_delta_33, oklab_magnitude_1]
16
+ → Hierarchical classifier: level (3) → op (31)
17
+
18
+ Fixes over v1:
19
+ 1. Circular hue encoding (HSLFeatureExtractor) — hue wraparound correct
20
+ 2. HybridRegionPooler — DETR learned queries with iterative refinement
21
+ 3. 33-dim profile delta conditioning — structural direction signal
22
+ 4. OKLab delta magnitude — perceptual change size signal
23
+ """
24
+
25
+ import math
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ from typing import Optional, Tuple, Dict, List
30
+
31
+ from .edit_ops import TRAINABLE_OPS, NUM_OPS, OP_TO_IDX, IDX_TO_OP, OpCode, OP_LEVEL
32
+ from .hsl_feature_extractor import HSLFeatureExtractor
33
+ from .hybrid_pooler import HybridRegionPooler
34
+ from .oklab_utils import hsl_to_oklab_batch
35
+
36
+
37
+ class EditOpClassifier(nn.Module):
38
+ """
39
+ Neuroswarm Tier 2: Classifies edit ops from before/after palette pairs.
40
+
41
+ Managers call this thousands of times per cycle to verify sub-agent work
42
+ without spending tokens on LLM review. ~1ms inference on GPU.
43
+
44
+ Input: (before_hsl, after_hsl) each (B, H, W, 3) normalized HSL [0,1]
45
+ Output: (op_logits_31, level_logits_3, global_features)
46
+ """
47
+
48
+ PROFILE_DIM = 33 # Structural profile vector dimensionality
49
+ OKLAB_DIM = 1 # Perceptual delta magnitude (scalar)
50
+
51
+ def __init__(
52
+ self,
53
+ hidden_dim: int = 256,
54
+ vit_layers: int = 4,
55
+ vit_heads: int = 8,
56
+ num_regions: int = 8,
57
+ patch_size: int = 4,
58
+ num_refinement_iters: int = 2,
59
+ dropout: float = 0.1,
60
+ ):
61
+ super().__init__()
62
+ self.hidden_dim = hidden_dim
63
+
64
+ # Fix 1: HSLFeatureExtractor with circular hue encoding
65
+ # h → (sin(2πh), cos(2πh)) handles hue wraparound correctly
66
+ # 359° and 1° are adjacent, not 358 apart
67
+ self.feature_extractor = HSLFeatureExtractor(
68
+ hidden_dim=hidden_dim,
69
+ num_layers=vit_layers,
70
+ num_heads=vit_heads,
71
+ patch_size=patch_size,
72
+ dropout=dropout,
73
+ )
74
+
75
+ # Fix 2: HybridRegionPooler — DETR-style learned queries
76
+ # use_structure=False because HSL palettes have NO scope markers
77
+ # Iterative refinement (Slot Attention style)
78
+ self.region_pooler = HybridRegionPooler(
79
+ hidden_dim=hidden_dim,
80
+ num_learned_queries=num_regions,
81
+ num_heads=vit_heads,
82
+ use_structure=False,
83
+ dropout=dropout,
84
+ num_refinement_iters=num_refinement_iters,
85
+ )
86
+
87
+ # Delta fusion: (before_regions, after_regions, delta) → fused
88
+ self.delta_fusion = nn.Sequential(
89
+ nn.Linear(hidden_dim * 3, hidden_dim * 2),
90
+ nn.GELU(),
91
+ nn.Dropout(dropout),
92
+ nn.Linear(hidden_dim * 2, hidden_dim),
93
+ nn.LayerNorm(hidden_dim),
94
+ )
95
+
96
+ # Global pooling via attention
97
+ self.global_query = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02)
98
+ self.global_attn = nn.MultiheadAttention(
99
+ hidden_dim, vit_heads, dropout=dropout, batch_first=True
100
+ )
101
+
102
+ # Fix 3: 33-dim profile delta projection
103
+ # Structural profile captures category distribution, color stats,
104
+ # scope depth, spectral alignment — compressed direction signal
105
+ self.profile_proj = nn.Sequential(
106
+ nn.Linear(self.PROFILE_DIM, hidden_dim // 4),
107
+ nn.GELU(),
108
+ nn.LayerNorm(hidden_dim // 4),
109
+ )
110
+
111
+ # Fix 4: OKLab delta magnitude projection
112
+ # Single scalar — "how big was this change" in perceptual space
113
+ self.oklab_proj = nn.Sequential(
114
+ nn.Linear(self.OKLAB_DIM, hidden_dim // 8),
115
+ nn.GELU(),
116
+ )
117
+
118
+ # Conditioning input size: hidden_dim + profile_proj + oklab_proj
119
+ cond_dim = hidden_dim + hidden_dim // 4 + hidden_dim // 8
120
+
121
+ # Level classifier (primitive / structural / semantic)
122
+ self.level_head = nn.Sequential(
123
+ nn.Linear(cond_dim, hidden_dim // 2),
124
+ nn.GELU(),
125
+ nn.Dropout(dropout),
126
+ nn.Linear(hidden_dim // 2, 3),
127
+ )
128
+
129
+ # Fine-grained op classifier (31 classes)
130
+ # Conditioned on level logits (hierarchical)
131
+ self.op_head = nn.Sequential(
132
+ nn.Linear(cond_dim + 3, hidden_dim), # +3 for level logits
133
+ nn.GELU(),
134
+ nn.Dropout(dropout),
135
+ nn.Linear(hidden_dim, hidden_dim // 2),
136
+ nn.GELU(),
137
+ nn.Dropout(dropout),
138
+ nn.Linear(hidden_dim // 2, NUM_OPS),
139
+ )
140
+
141
+ self._init_weights()
142
+
143
+ def _init_weights(self):
144
+ for m in self.modules():
145
+ if isinstance(m, nn.Linear):
146
+ nn.init.trunc_normal_(m.weight, std=0.02)
147
+ if m.bias is not None:
148
+ nn.init.zeros_(m.bias)
149
+
150
+ def encode_palette(self, hsl: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
151
+ """
152
+ Encode HSL palette → region embeddings + importance scores.
153
+
154
+ Args:
155
+ hsl: (B, H, W, 3) normalized HSL [0,1]
156
+
157
+ Returns:
158
+ regions: (B, R, hidden_dim) region embeddings
159
+ importance: (B, R) importance scores
160
+ """
161
+ # HSLFeatureExtractor: circular hue → ViT spatial features
162
+ features = self.feature_extractor(hsl) # (B, H, W, D)
163
+
164
+ # HybridRegionPooler: DETR queries → region embeddings
165
+ regions, importance = self.region_pooler(features) # (B, R, D), (B, R)
166
+
167
+ return regions, importance
168
+
169
+ @staticmethod
170
+ def compute_oklab_delta(before_hsl: torch.Tensor, after_hsl: torch.Tensor) -> torch.Tensor:
171
+ """
172
+ Compute perceptual change magnitude in OKLab space.
173
+
174
+ Returns:
175
+ (B, 1) scalar — mean DeltaE across all spatial positions
176
+ """
177
+ # Convert to OKLab
178
+ before_oklab = hsl_to_oklab_batch(before_hsl) # (B, H, W, 3)
179
+ after_oklab = hsl_to_oklab_batch(after_hsl) # (B, H, W, 3)
180
+
181
+ # Per-pixel DeltaE
182
+ delta_e = (before_oklab - after_oklab).pow(2).sum(dim=-1).sqrt() # (B, H, W)
183
+
184
+ # Mean across spatial dimensions
185
+ mean_delta_e = delta_e.mean(dim=(1, 2), keepdim=False) # (B,)
186
+
187
+ return mean_delta_e.unsqueeze(-1) # (B, 1)
188
+
189
+ def forward(
190
+ self,
191
+ before_hsl: torch.Tensor,
192
+ after_hsl: torch.Tensor,
193
+ profile_delta: Optional[torch.Tensor] = None,
194
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
195
+ """
196
+ Classify edit operation from before/after palette pair.
197
+
198
+ Args:
199
+ before_hsl: (B, H, W, 3) palette before edit, HSL [0,1]
200
+ after_hsl: (B, H, W, 3) palette after edit, HSL [0,1]
201
+ profile_delta: (B, 33) optional structural profile delta (after - before)
202
+ If None, zeros are used (graceful degradation)
203
+
204
+ Returns:
205
+ op_logits: (B, 31) logits over edit operations
206
+ level_logits: (B, 3) logits over levels
207
+ global_feat: (B, hidden_dim) fused delta representation
208
+ """
209
+ B = before_hsl.shape[0]
210
+ device = before_hsl.device
211
+
212
+ # Encode both palettes through shared feature extractor + pooler
213
+ before_regions, before_imp = self.encode_palette(before_hsl) # (B, R, D)
214
+ after_regions, after_imp = self.encode_palette(after_hsl) # (B, R, D)
215
+
216
+ # Compute delta (importance-weighted)
217
+ imp = (before_imp + after_imp) / 2 # (B, R)
218
+ imp_w = imp.unsqueeze(-1) # (B, R, 1)
219
+ delta = (after_regions - before_regions) * imp_w
220
+
221
+ # Fuse: [before, after, delta] → fused features
222
+ fused = torch.cat([before_regions, after_regions, delta], dim=-1) # (B, R, 3*D)
223
+ fused = self.delta_fusion(fused) # (B, R, D)
224
+
225
+ # Global pool via attention
226
+ query = self.global_query.expand(B, -1, -1)
227
+ global_feat, _ = self.global_attn(query, fused, fused)
228
+ global_feat = global_feat.squeeze(1) # (B, D)
229
+
230
+ # Fix 3: Profile delta conditioning
231
+ if profile_delta is None:
232
+ profile_delta = torch.zeros(B, self.PROFILE_DIM, device=device)
233
+ profile_feat = self.profile_proj(profile_delta) # (B, D//4)
234
+
235
+ # Fix 4: OKLab delta magnitude
236
+ oklab_delta = self.compute_oklab_delta(before_hsl, after_hsl) # (B, 1)
237
+ oklab_feat = self.oklab_proj(oklab_delta) # (B, D//8)
238
+
239
+ # Concatenate all conditioning signals
240
+ conditioned = torch.cat([global_feat, profile_feat, oklab_feat], dim=-1) # (B, D + D//4 + D//8)
241
+
242
+ # Level classification
243
+ level_logits = self.level_head(conditioned) # (B, 3)
244
+
245
+ # Fine op classification (conditioned on level)
246
+ op_input = torch.cat([conditioned, level_logits], dim=-1)
247
+ op_logits = self.op_head(op_input) # (B, 31)
248
+
249
+ return op_logits, level_logits, global_feat
250
+
251
+
252
+ # ====================================================================
253
+ # Tier 1: Profile cosine similarity (nanoseconds)
254
+ # ====================================================================
255
+
256
+ class Tier1ProfileVerifier:
257
+ """
258
+ Neuroswarm Tier 1: Nanosecond verification via 33-dim profile cosine similarity.
259
+
260
+ Usage:
261
+ verifier = Tier1ProfileVerifier()
262
+ result = verifier.verify(expected_delta, actual_delta)
263
+ if result.tier == 'pass': ...
264
+ elif result.tier == 'escalate': ... # → Tier 2
265
+ elif result.tier == 'reject': ... # → retry agent
266
+ """
267
+
268
+ def __init__(
269
+ self,
270
+ pass_threshold: float = 0.7,
271
+ reject_threshold: float = 0.3,
272
+ ):
273
+ self.pass_threshold = pass_threshold
274
+ self.reject_threshold = reject_threshold
275
+
276
+ def verify(
277
+ self,
278
+ expected_delta: torch.Tensor,
279
+ actual_delta: torch.Tensor,
280
+ ) -> dict:
281
+ """
282
+ Compare expected vs actual structural profile delta.
283
+
284
+ Args:
285
+ expected_delta: (33,) or (B, 33) expected profile change
286
+ actual_delta: (33,) or (B, 33) actual profile change
287
+
288
+ Returns:
289
+ dict with 'alignment', 'tier' ('pass'/'escalate'/'reject')
290
+ """
291
+ if expected_delta.dim() == 1:
292
+ expected_delta = expected_delta.unsqueeze(0)
293
+ actual_delta = actual_delta.unsqueeze(0)
294
+
295
+ # Cosine similarity
296
+ alignment = F.cosine_similarity(expected_delta, actual_delta, dim=-1) # (B,)
297
+
298
+ tiers = []
299
+ for a in alignment:
300
+ a_val = a.item()
301
+ if a_val >= self.pass_threshold:
302
+ tiers.append('pass')
303
+ elif a_val >= self.reject_threshold:
304
+ tiers.append('escalate')
305
+ else:
306
+ tiers.append('reject')
307
+
308
+ return {
309
+ 'alignment': alignment,
310
+ 'tiers': tiers,
311
+ 'mean_alignment': alignment.mean().item(),
312
+ }
313
+
314
+
315
+ # ====================================================================
316
+ # Tier 2: Edit classifier inference wrapper
317
+ # ====================================================================
318
+
319
+ class Tier2EditVerifier:
320
+ """
321
+ Neuroswarm Tier 2: Millisecond verification via edit classifier.
322
+
323
+ Usage:
324
+ verifier = Tier2EditVerifier(model, device='cuda')
325
+ result = verifier.verify(before_hsl, after_hsl, expected_op, profile_delta)
326
+ if result['match']: ... # agent did the right thing
327
+ else: ... # escalate to Tier 3
328
+ """
329
+
330
+ def __init__(
331
+ self,
332
+ model: EditOpClassifier,
333
+ device: str = 'cpu',
334
+ confidence_threshold: float = 0.8,
335
+ ):
336
+ self.model = model.to(device).eval()
337
+ self.device = device
338
+ self.confidence_threshold = confidence_threshold
339
+
340
+ @torch.no_grad()
341
+ def verify(
342
+ self,
343
+ before_hsl: torch.Tensor,
344
+ after_hsl: torch.Tensor,
345
+ expected_op: OpCode,
346
+ profile_delta: Optional[torch.Tensor] = None,
347
+ ) -> dict:
348
+ """
349
+ Verify that an agent performed the expected edit operation.
350
+
351
+ Returns:
352
+ dict with 'match', 'predicted_op', 'confidence', 'escalate'
353
+ """
354
+ before = before_hsl.unsqueeze(0).to(self.device) if before_hsl.dim() == 3 else before_hsl.to(self.device)
355
+ after = after_hsl.unsqueeze(0).to(self.device) if after_hsl.dim() == 3 else after_hsl.to(self.device)
356
+ if profile_delta is not None:
357
+ profile_delta = profile_delta.unsqueeze(0).to(self.device) if profile_delta.dim() == 1 else profile_delta.to(self.device)
358
+
359
+ op_logits, level_logits, _ = self.model(before, after, profile_delta)
360
+
361
+ probs = F.softmax(op_logits, dim=-1)
362
+ pred_idx = probs.argmax(dim=-1).item()
363
+ confidence = probs[0, pred_idx].item()
364
+ predicted_op = IDX_TO_OP[pred_idx]
365
+
366
+ expected_idx = OP_TO_IDX[expected_op]
367
+ match = (pred_idx == expected_idx) and (confidence >= self.confidence_threshold)
368
+ escalate = not match
369
+
370
+ return {
371
+ 'match': match,
372
+ 'predicted_op': predicted_op,
373
+ 'predicted_op_name': predicted_op.name,
374
+ 'expected_op_name': expected_op.name,
375
+ 'confidence': confidence,
376
+ 'escalate': escalate,
377
+ 'op_probs': probs[0].cpu(),
378
+ }
379
+
380
+
381
+ # ====================================================================
382
+ # Loss
383
+ # ====================================================================
384
+
385
+ class EditOpLoss(nn.Module):
386
+ """
387
+ Combined loss for edit op classification.
388
+
389
+ Components:
390
+ - Cross-entropy on 31-class op prediction
391
+ - Cross-entropy on 3-class level prediction (auxiliary)
392
+ - Level-op consistency penalty
393
+ """
394
+
395
+ def __init__(self, level_weight: float = 0.3, consistency_weight: float = 0.1):
396
+ super().__init__()
397
+ self.level_weight = level_weight
398
+ self.consistency_weight = consistency_weight
399
+ self.op_loss_fn = nn.CrossEntropyLoss(label_smoothing=0.05)
400
+ self.level_loss_fn = nn.CrossEntropyLoss(label_smoothing=0.05)
401
+
402
+ # Build op → level mapping
403
+ self._op_to_level = {}
404
+ level_names = ['primitive', 'structural', 'semantic']
405
+ for op in TRAINABLE_OPS:
406
+ level = OP_LEVEL[op]
407
+ self._op_to_level[OP_TO_IDX[op]] = level_names.index(level)
408
+
409
+ def forward(
410
+ self,
411
+ op_logits: torch.Tensor,
412
+ level_logits: torch.Tensor,
413
+ op_labels: torch.Tensor,
414
+ ) -> Tuple[torch.Tensor, Dict[str, float]]:
415
+ """
416
+ Args:
417
+ op_logits: (B, 31) predicted op logits
418
+ level_logits: (B, 3) predicted level logits
419
+ op_labels: (B,) integer labels in [0, 30]
420
+
421
+ Returns:
422
+ total_loss, metrics_dict
423
+ """
424
+ op_loss = self.op_loss_fn(op_logits, op_labels)
425
+
426
+ level_labels = torch.tensor(
427
+ [self._op_to_level[l.item()] for l in op_labels],
428
+ device=op_labels.device, dtype=torch.long
429
+ )
430
+ level_loss = self.level_loss_fn(level_logits, level_labels)
431
+
432
+ pred_ops = op_logits.argmax(dim=-1)
433
+ pred_levels = level_logits.argmax(dim=-1)
434
+ expected_levels = torch.tensor(
435
+ [self._op_to_level[p.item()] for p in pred_ops],
436
+ device=op_labels.device, dtype=torch.long
437
+ )
438
+ consistency = (pred_levels == expected_levels).float().mean()
439
+ consistency_loss = 1.0 - consistency
440
+
441
+ total = op_loss + self.level_weight * level_loss + self.consistency_weight * consistency_loss
442
+
443
+ metrics = {
444
+ 'loss': total.item(),
445
+ 'op_loss': op_loss.item(),
446
+ 'level_loss': level_loss.item(),
447
+ 'consistency': consistency.item(),
448
+ 'op_acc': (pred_ops == op_labels).float().mean().item(),
449
+ 'level_acc': (pred_levels == level_labels).float().mean().item(),
450
+ }
451
+
452
+ return total, metrics
453
+
454
+ @staticmethod
455
+ def op_label_from_opcode(opcode: OpCode) -> int:
456
+ return OP_TO_IDX[opcode]
457
+
458
+ @staticmethod
459
+ def opcode_from_label(label: int) -> OpCode:
460
+ return IDX_TO_OP[label]