Jonttup commited on
Commit
f68c3ae
·
verified ·
1 Parent(s): 14ed7e4

Upload train_edit_classifier.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_edit_classifier.py +656 -0
train_edit_classifier.py ADDED
@@ -0,0 +1,656 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train 31-class Edit Operation Classifier — Neuroswarm Tier 2
3
+
4
+ Pipeline:
5
+ Code → HueAI → HSL (H,W,3)
6
+ → Circular hue encoding (sin/cos) → ViT → HybridRegionPooler (DETR)
7
+ → Delta fusion + profile_delta(33) + oklab_magnitude(1)
8
+ → Hierarchical classifier → 31 ops
9
+
10
+ Usage:
11
+ python train_edit_classifier.py --epochs 50 --batch-size 128 --lr 3e-4
12
+ python train_edit_classifier.py --device cuda --fp16
13
+ """
14
+
15
+ import argparse
16
+ import json
17
+ import math
18
+ import os
19
+ import sys
20
+ import time
21
+ import random
22
+ from pathlib import Path
23
+ from typing import List, Tuple, Dict
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ from torch.utils.data import Dataset, DataLoader
29
+
30
+ sys.path.insert(0, str(Path(__file__).parent))
31
+
32
+ from models.edit_ops import (
33
+ PaletteEditOps, EditAction, OpCode, TRAINABLE_OPS, NUM_OPS,
34
+ OP_TO_IDX, IDX_TO_OP, OP_LEVEL
35
+ )
36
+ from models.edit_classifier import EditOpClassifier, EditOpLoss
37
+ from models.scope_pooler import ScopePooler
38
+
39
+
40
+ # ============================================================
41
+ # Synthetic Dataset Generator
42
+ # ============================================================
43
+
44
+ class EditOpDatasetGenerator:
45
+ """
46
+ Generates (before_palette, after_palette, label) triples by
47
+ applying each of the 31 ops to random palettes.
48
+
49
+ This is the bootstrapping approach — generate synthetic pairs
50
+ to pre-train, then fine-tune on real git diff pairs.
51
+ """
52
+
53
+ START = PaletteEditOps.START_OF_SCOPE
54
+ END = PaletteEditOps.END_OF_SCOPE
55
+ NOOP = PaletteEditOps.NOOP
56
+
57
+ def __init__(self, palette_h: int = 8, palette_w: int = 32, vocab_size: int = 256):
58
+ self.H = palette_h
59
+ self.W = palette_w
60
+ self.vocab_size = vocab_size
61
+ self.ops = PaletteEditOps()
62
+ self.pooler = ScopePooler(hidden_dim=64)
63
+
64
+ def _random_region_tokens(self, min_len: int = 3, max_len: int = 12) -> List[int]:
65
+ """Generate random content tokens (excluding 0, 1, 2)."""
66
+ length = random.randint(min_len, max_len)
67
+ return [random.randint(3, self.vocab_size - 1) for _ in range(length)]
68
+
69
+ def _make_palette(self, tokens: List[int]) -> Tuple[torch.Tensor, object]:
70
+ """Create palette and metadata from flat token list."""
71
+ total = self.H * self.W
72
+ if len(tokens) < total:
73
+ tokens = tokens + [self.NOOP] * (total - len(tokens))
74
+ tokens = tokens[:total]
75
+
76
+ palette = torch.tensor([tokens], dtype=torch.long).view(1, self.H, self.W)
77
+ features = torch.randn(1, self.H, self.W, 64)
78
+ _, metadata = self.pooler(features, palette)
79
+ return palette[0], metadata[0]
80
+
81
+ def _make_single_region(self) -> Tuple[List[int], int]:
82
+ """Create a single-region palette token list."""
83
+ content = self._random_region_tokens(5, 20)
84
+ tokens = [self.START] + content + [self.END]
85
+ # Pad
86
+ total = self.H * self.W
87
+ tokens += [self.NOOP] * (total - len(tokens))
88
+ return tokens[:total], len(content)
89
+
90
+ def _make_two_regions(self) -> List[int]:
91
+ """Create two adjacent region token list."""
92
+ c1 = self._random_region_tokens(3, 10)
93
+ c2 = self._random_region_tokens(3, 10)
94
+ tokens = [self.START] + c1 + [self.END, self.START] + c2 + [self.END]
95
+ total = self.H * self.W
96
+ tokens += [self.NOOP] * (total - len(tokens))
97
+ return tokens[:total]
98
+
99
+ def _make_nested_scope(self) -> List[int]:
100
+ """Create nested scope: outer [inner [content] content]."""
101
+ inner = self._random_region_tokens(3, 8)
102
+ outer = self._random_region_tokens(2, 5)
103
+ block_hue = random.choice([20, 24, 28, 32]) # for/if/while/with hues
104
+ tokens = [self.START] + outer + [self.START, block_hue] + inner + [self.END] + [self.END]
105
+ total = self.H * self.W
106
+ tokens += [self.NOOP] * (total - len(tokens))
107
+ return tokens[:total]
108
+
109
+ def _make_func_palette(self) -> List[int]:
110
+ """Create palette with function def (hue 12) and call (hue 60) for async ops."""
111
+ content = self._random_region_tokens(3, 8)
112
+ tokens = [self.START, 12] + content + [60] + self._random_region_tokens(2, 4) + [self.END]
113
+ total = self.H * self.W
114
+ tokens += [self.NOOP] * (total - len(tokens))
115
+ return tokens[:total]
116
+
117
+ def generate_pair(self, op: OpCode) -> Tuple[torch.Tensor, torch.Tensor, int]:
118
+ """
119
+ Generate a (before, after) palette pair for a specific op.
120
+
121
+ Returns:
122
+ before_hsl: (H, W, 3) float tensor (normalized HSL)
123
+ after_hsl: (H, W, 3) float tensor (normalized HSL)
124
+ label: int in [0, 30]
125
+ """
126
+ label = OP_TO_IDX[op]
127
+ max_attempts = 10
128
+
129
+ for attempt in range(max_attempts):
130
+ try:
131
+ before_palette, action = self._create_op_scenario(op)
132
+ palette, metadata = self._make_palette(before_palette)
133
+
134
+ after_palette, success = self.ops.apply(palette, action, metadata)
135
+ if not success:
136
+ continue
137
+
138
+ # Convert int palettes to fake HSL (for now: map token → hue/sat/light)
139
+ before_hsl = self._palette_to_hsl(palette)
140
+ after_hsl = self._palette_to_hsl(after_palette)
141
+
142
+ return before_hsl, after_hsl, label
143
+
144
+ except Exception:
145
+ continue
146
+
147
+ # Fallback: return identical palettes (will be NO_OP-like, model must learn)
148
+ tokens, _ = self._make_single_region()
149
+ palette, _ = self._make_palette(tokens)
150
+ hsl = self._palette_to_hsl(palette)
151
+ return hsl, hsl, label
152
+
153
+ @staticmethod
154
+ def compute_profile_delta(before_hsl: torch.Tensor, after_hsl: torch.Tensor) -> torch.Tensor:
155
+ """
156
+ Compute a 33-dim structural profile delta from HSL tensors.
157
+
158
+ Mirrors PaletteStructuralProfile dimensions:
159
+ [0:10] Category distribution delta (hue bands)
160
+ [10:19] Color stats delta (mean/std/entropy of H,S,L)
161
+ [19:25] Structural metrics delta (scope, density, etc.)
162
+ [25:33] Spectral alignment delta (placeholder zeros)
163
+
164
+ This is an approximation for synthetic data. Real training
165
+ will use PaletteProfiler.profile_file() on actual source code.
166
+ """
167
+ PROFILE_DIM = 33
168
+ delta = torch.zeros(PROFILE_DIM)
169
+
170
+ # Category distribution via hue bands (10 bins, 36° each)
171
+ before_h = before_hsl[..., 0].flatten()
172
+ after_h = after_hsl[..., 0].flatten()
173
+
174
+ for i in range(10):
175
+ lo, hi = i / 10.0, (i + 1) / 10.0
176
+ before_count = ((before_h >= lo) & (before_h < hi)).float().mean()
177
+ after_count = ((after_h >= lo) & (after_h < hi)).float().mean()
178
+ delta[i] = after_count - before_count
179
+
180
+ # Color stats: mean/std/entropy of H,S,L
181
+ for ch in range(3):
182
+ before_ch = before_hsl[..., ch].flatten()
183
+ after_ch = after_hsl[..., ch].flatten()
184
+ delta[10 + ch * 3] = after_ch.mean() - before_ch.mean()
185
+ delta[11 + ch * 3] = after_ch.std() - before_ch.std()
186
+ # Entropy approximation: histogram entropy
187
+ before_hist = torch.histc(before_ch, bins=16, min=0, max=1) + 1e-8
188
+ after_hist = torch.histc(after_ch, bins=16, min=0, max=1) + 1e-8
189
+ before_ent = -(before_hist / before_hist.sum() * (before_hist / before_hist.sum()).log()).sum()
190
+ after_ent = -(after_hist / after_hist.sum() * (after_hist / after_hist.sum()).log()).sum()
191
+ delta[12 + ch * 3] = after_ent - before_ent
192
+
193
+ # Structural metrics: scope marker changes, density changes
194
+ before_s = before_hsl[..., 1].flatten()
195
+ after_s = after_hsl[..., 1].flatten()
196
+ # Scope markers have S=1.0 — count them
197
+ delta[19] = (after_s > 0.95).float().mean() - (before_s > 0.95).float().mean()
198
+ # Content density (non-zero L)
199
+ delta[20] = (after_hsl[..., 2] > 0.01).float().mean() - (before_hsl[..., 2] > 0.01).float().mean()
200
+ # Mean saturation change
201
+ delta[21] = after_s.mean() - before_s.mean()
202
+ # Mean lightness change
203
+ delta[22] = after_hsl[..., 2].flatten().mean() - before_hsl[..., 2].flatten().mean()
204
+ # Unique hue ratio change
205
+ before_unique = before_h[before_h > 0].unique().numel() / max(1, (before_h > 0).sum().item())
206
+ after_unique = after_h[after_h > 0].unique().numel() / max(1, (after_h > 0).sum().item())
207
+ delta[23] = after_unique - before_unique
208
+ # Token count change (non-NOOP)
209
+ delta[24] = (after_hsl[..., 2] > 0.01).float().sum() - (before_hsl[..., 2] > 0.01).float().sum()
210
+
211
+ # [25:33] spectral alignment — zeros for synthetic, real data fills these
212
+ return delta
213
+
214
+ def _palette_to_hsl(self, palette: torch.Tensor) -> torch.Tensor:
215
+ """Convert integer palette to normalized HSL float tensor (H, W, 3)."""
216
+ H, W = palette.shape
217
+ hsl = torch.zeros(H, W, 3)
218
+ flat = palette.flatten().float()
219
+
220
+ # Map token values to HSL:
221
+ # H = (token_value / vocab_size) * 360 → normalized to [0, 1]
222
+ # S = 0.7 for content, 0.0 for NOOP, 1.0 for scope markers
223
+ # L = 0.5 for content, 0.1 for scope markers, 0.0 for NOOP
224
+ for i in range(H * W):
225
+ h, w = i // W, i % W
226
+ val = flat[i].item()
227
+ if val == self.NOOP:
228
+ hsl[h, w] = torch.tensor([0.0, 0.0, 0.0])
229
+ elif val == self.START:
230
+ hsl[h, w] = torch.tensor([0.0, 1.0, 0.1])
231
+ elif val == self.END:
232
+ hsl[h, w] = torch.tensor([0.5, 1.0, 0.1])
233
+ else:
234
+ hsl[h, w] = torch.tensor([
235
+ val / self.vocab_size,
236
+ 0.7,
237
+ 0.5
238
+ ])
239
+ return hsl
240
+
241
+ def _create_op_scenario(self, op: OpCode) -> Tuple[List[int], EditAction]:
242
+ """Create appropriate palette and EditAction for a given op."""
243
+
244
+ # === LEVEL 1: Primitive ===
245
+ if op == OpCode.DELETE_RANGE:
246
+ tokens, n = self._make_single_region()
247
+ i_end = min(random.randint(0, 2), n - 1)
248
+ return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=i_end, payload_idx=0)
249
+
250
+ elif op == OpCode.INSERT_TOKEN:
251
+ tokens, n = self._make_single_region()
252
+ pos = random.randint(0, n)
253
+ payload = random.randint(3, self.vocab_size - 1)
254
+ return tokens, EditAction(op_id=op, region_id=0, i_start=pos, i_end=-1, payload_idx=payload)
255
+
256
+ elif op == OpCode.REPLACE_TOKEN:
257
+ tokens, n = self._make_single_region()
258
+ pos = random.randint(0, n - 1)
259
+ payload = random.randint(3, self.vocab_size - 1)
260
+ return tokens, EditAction(op_id=op, region_id=0, i_start=pos, i_end=-1, payload_idx=payload)
261
+
262
+ elif op == OpCode.SWAP_TOKENS:
263
+ tokens, n = self._make_single_region()
264
+ i_start = random.randint(0, max(0, n - 2))
265
+ i_end = random.randint(i_start + 1, n - 1) if i_start < n - 1 else i_start
266
+ return tokens, EditAction(op_id=op, region_id=0, i_start=i_start, i_end=i_end, payload_idx=0)
267
+
268
+ elif op == OpCode.MOVE_RANGE:
269
+ tokens = self._make_two_regions()
270
+ return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=0,
271
+ payload_idx=0, target_region_id=1)
272
+
273
+ elif op == OpCode.COPY_RANGE:
274
+ tokens = self._make_two_regions()
275
+ return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=0,
276
+ payload_idx=0, target_region_id=1)
277
+
278
+ elif op == OpCode.WRAP_SCOPE:
279
+ tokens, n = self._make_single_region()
280
+ i_end = min(random.randint(1, 3), n - 1)
281
+ return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=i_end, payload_idx=0)
282
+
283
+ elif op == OpCode.UNWRAP_SCOPE:
284
+ tokens = self._make_nested_scope()
285
+ return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=-1, payload_idx=0)
286
+
287
+ # === LEVEL 2: Structural ===
288
+ elif op == OpCode.INDENT:
289
+ tokens, n = self._make_single_region()
290
+ i_end = min(2, n - 1)
291
+ return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=i_end, payload_idx=0)
292
+
293
+ elif op == OpCode.DEDENT:
294
+ tokens = self._make_nested_scope()
295
+ return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=0, payload_idx=0)
296
+
297
+ elif op == OpCode.EXTRACT:
298
+ tokens, n = self._make_single_region()
299
+ i_end = min(2, n - 1)
300
+ return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=i_end, payload_idx=0)
301
+
302
+ elif op == OpCode.INLINE:
303
+ # Need a palette with ref token and source region
304
+ c1 = self._random_region_tokens(3, 6)
305
+ c2 = self._random_region_tokens(3, 6)
306
+ tokens = [self.START, 3] + c1[1:] + [self.END, self.START] + c2 + [self.END]
307
+ total = self.H * self.W
308
+ tokens += [self.NOOP] * (total - len(tokens))
309
+ tokens = tokens[:total]
310
+ return tokens, EditAction(op_id=op, region_id=1, i_start=0, i_end=-1,
311
+ payload_idx=0, target_region_id=0)
312
+
313
+ elif op == OpCode.SPLIT_REGION:
314
+ tokens, n = self._make_single_region()
315
+ split_at = max(1, min(n // 2, n - 1))
316
+ return tokens, EditAction(op_id=op, region_id=0, i_start=split_at, i_end=-1, payload_idx=0)
317
+
318
+ elif op == OpCode.MERGE_REGIONS:
319
+ tokens = self._make_two_regions()
320
+ return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=-1,
321
+ payload_idx=0, target_region_id=1)
322
+
323
+ elif op == OpCode.REORDER:
324
+ tokens, n = self._make_single_region()
325
+ i_end = min(3, n - 1)
326
+ return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=i_end, payload_idx=0)
327
+
328
+ elif op == OpCode.NEST_IN_BLOCK:
329
+ tokens, n = self._make_single_region()
330
+ i_end = min(2, n - 1)
331
+ block_hue = random.choice([20, 24, 28])
332
+ return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=i_end,
333
+ payload_idx=block_hue)
334
+
335
+ elif op == OpCode.UNNEST_FROM_BLOCK:
336
+ tokens = self._make_nested_scope()
337
+ return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=-1, payload_idx=0)
338
+
339
+ elif op == OpCode.HOIST:
340
+ tokens = self._make_nested_scope()
341
+ return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=0, payload_idx=0)
342
+
343
+ elif op == OpCode.SINK:
344
+ tokens = self._make_two_regions()
345
+ return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=0,
346
+ payload_idx=0, target_region_id=1)
347
+
348
+ # === LEVEL 3: Semantic ===
349
+ elif op == OpCode.RENAME:
350
+ tokens, n = self._make_single_region()
351
+ pos = random.randint(0, n - 1)
352
+ payload = random.randint(3, self.vocab_size - 1)
353
+ return tokens, EditAction(op_id=op, region_id=0, i_start=pos, i_end=-1, payload_idx=payload)
354
+
355
+ elif op == OpCode.RETYPE:
356
+ tokens, n = self._make_single_region()
357
+ i_end = min(1, n - 1)
358
+ new_types = [random.randint(3, self.vocab_size - 1) for _ in range(3)]
359
+ return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=i_end,
360
+ payload_idx=0, payload_tokens=new_types)
361
+
362
+ elif op == OpCode.CONVERT_CONSTRUCT:
363
+ # Use built-in macro pattern
364
+ content = [20, 220, 220] + self._random_region_tokens(2, 5)
365
+ tokens = [self.START] + content + [self.END]
366
+ total = self.H * self.W
367
+ tokens += [self.NOOP] * (total - len(tokens))
368
+ return tokens[:total], EditAction(op_id=op, region_id=0, i_start=0, i_end=-1, payload_idx=0)
369
+
370
+ elif op == OpCode.SYNC_TO_ASYNC:
371
+ tokens = self._make_func_palette()
372
+ return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=-1, payload_idx=0)
373
+
374
+ elif op == OpCode.PARAMETERIZE:
375
+ tokens, n = self._make_single_region()
376
+ pos = random.randint(0, n - 1)
377
+ param_hue = random.randint(3, self.vocab_size - 1)
378
+ return tokens, EditAction(op_id=op, region_id=0, i_start=pos, i_end=-1, payload_idx=param_hue)
379
+
380
+ elif op == OpCode.SPECIALIZE:
381
+ tokens, n = self._make_single_region()
382
+ i_end = min(1, n - 1)
383
+ concrete = [random.randint(3, self.vocab_size - 1) for _ in range(3)]
384
+ return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=i_end,
385
+ payload_idx=0, payload_tokens=concrete)
386
+
387
+ elif op == OpCode.GUARD:
388
+ tokens, n = self._make_single_region()
389
+ i_end = min(2, n - 1)
390
+ guard_hue = random.choice([24, 28, 32])
391
+ return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=i_end,
392
+ payload_idx=guard_hue)
393
+
394
+ elif op == OpCode.UNGUARD:
395
+ tokens = self._make_nested_scope()
396
+ return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=-1, payload_idx=0)
397
+
398
+ elif op == OpCode.SCATTER:
399
+ tokens, n = self._make_single_region()
400
+ # Pick 2-3 positions to scatter to
401
+ positions = random.sample(range(n), min(3, n))
402
+ payload = random.randint(3, self.vocab_size - 1)
403
+ return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=-1,
404
+ payload_idx=payload, positions=positions)
405
+
406
+ elif op == OpCode.GATHER:
407
+ tokens, n = self._make_single_region()
408
+ palette, metadata = self._make_palette(tokens)
409
+ positions = PaletteEditOps._get_content_positions(palette, metadata, 0)
410
+ abs_positions = positions[:min(3, len(positions))]
411
+ return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=-1,
412
+ payload_idx=0, positions=abs_positions)
413
+
414
+ elif op == OpCode.MIRROR:
415
+ tokens = self._make_two_regions()
416
+ return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=0,
417
+ payload_idx=random.randint(3, self.vocab_size - 1),
418
+ target_region_id=1)
419
+
420
+ elif op == OpCode.COMPOSE:
421
+ tokens = self._make_nested_scope()
422
+ palette, metadata = self._make_palette(tokens)
423
+ mask = metadata.masks[0]
424
+ n_positions = mask.sum().item()
425
+ return tokens, EditAction(op_id=op, region_id=0, i_start=0,
426
+ i_end=max(0, int(n_positions) - 1), payload_idx=0)
427
+
428
+ raise ValueError(f"Unknown op: {op}")
429
+
430
+
431
+ class EditOpDataset(Dataset):
432
+ """PyTorch Dataset for edit op classification training."""
433
+
434
+ def __init__(self, num_samples: int = 10000, palette_h: int = 8, palette_w: int = 32):
435
+ self.generator = EditOpDatasetGenerator(palette_h, palette_w)
436
+ self.num_samples = num_samples
437
+ self.samples_per_op = num_samples // NUM_OPS
438
+
439
+ # Pre-generate balanced dataset with profile deltas
440
+ self.data: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]] = []
441
+ print(f"Generating {num_samples} training pairs ({self.samples_per_op} per op)...")
442
+ for op in TRAINABLE_OPS:
443
+ for _ in range(self.samples_per_op):
444
+ before, after, label = self.generator.generate_pair(op)
445
+ profile_delta = self.generator.compute_profile_delta(before, after)
446
+ self.data.append((before, after, profile_delta, label))
447
+
448
+ # Shuffle
449
+ random.shuffle(self.data)
450
+ print(f"Generated {len(self.data)} pairs across {NUM_OPS} ops")
451
+
452
+ def __len__(self):
453
+ return len(self.data)
454
+
455
+ def __getitem__(self, idx):
456
+ before, after, profile_delta, label = self.data[idx]
457
+ return before, after, profile_delta, torch.tensor(label, dtype=torch.long)
458
+
459
+
460
+ # ============================================================
461
+ # Training Loop
462
+ # ============================================================
463
+
464
+ def train(args):
465
+ device = torch.device(args.device)
466
+ print(f"Device: {device}")
467
+ print(f"Training {NUM_OPS}-class edit op classifier")
468
+ print(f"Ops: {[op.name for op in TRAINABLE_OPS]}")
469
+
470
+ # Create datasets
471
+ train_dataset = EditOpDataset(args.train_samples, args.palette_h, args.palette_w)
472
+ val_dataset = EditOpDataset(args.val_samples, args.palette_h, args.palette_w)
473
+
474
+ train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
475
+ num_workers=0, pin_memory=True)
476
+ val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
477
+ num_workers=0, pin_memory=True)
478
+
479
+ # Model
480
+ model = EditOpClassifier(
481
+ hidden_dim=args.hidden_dim,
482
+ vit_layers=args.vit_layers,
483
+ vit_heads=args.vit_heads,
484
+ num_regions=args.num_regions,
485
+ patch_size=args.patch_size,
486
+ dropout=args.dropout,
487
+ ).to(device)
488
+
489
+ param_count = sum(p.numel() for p in model.parameters())
490
+ print(f"Model parameters: {param_count:,}")
491
+
492
+ # Loss
493
+ criterion = EditOpLoss().to(device)
494
+
495
+ # Optimizer
496
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01)
497
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
498
+
499
+ # FP16 support
500
+ scaler = torch.amp.GradScaler('cuda') if args.fp16 and device.type == 'cuda' else None
501
+
502
+ best_val_acc = 0.0
503
+ save_dir = Path("trained_models")
504
+ save_dir.mkdir(exist_ok=True)
505
+
506
+ for epoch in range(args.epochs):
507
+ model.train()
508
+ epoch_metrics = {'loss': 0, 'op_acc': 0, 'level_acc': 0, 'batches': 0}
509
+ t0 = time.time()
510
+
511
+ for before, after, profile_delta, labels in train_loader:
512
+ before = before.to(device)
513
+ after = after.to(device)
514
+ profile_delta = profile_delta.to(device)
515
+ labels = labels.to(device)
516
+
517
+ optimizer.zero_grad()
518
+
519
+ if scaler:
520
+ with torch.amp.autocast('cuda'):
521
+ op_logits, level_logits, _ = model(before, after, profile_delta)
522
+ loss, metrics = criterion(op_logits, level_logits, labels)
523
+ scaler.scale(loss).backward()
524
+ scaler.unscale_(optimizer)
525
+ nn.utils.clip_grad_norm_(model.parameters(), 1.0)
526
+ scaler.step(optimizer)
527
+ scaler.update()
528
+ else:
529
+ op_logits, level_logits, _ = model(before, after, profile_delta)
530
+ loss, metrics = criterion(op_logits, level_logits, labels)
531
+ loss.backward()
532
+ nn.utils.clip_grad_norm_(model.parameters(), 1.0)
533
+ optimizer.step()
534
+
535
+ epoch_metrics['loss'] += metrics['loss']
536
+ epoch_metrics['op_acc'] += metrics['op_acc']
537
+ epoch_metrics['level_acc'] += metrics['level_acc']
538
+ epoch_metrics['batches'] += 1
539
+
540
+ scheduler.step()
541
+
542
+ n = epoch_metrics['batches']
543
+ train_loss = epoch_metrics['loss'] / n
544
+ train_op_acc = epoch_metrics['op_acc'] / n
545
+ train_level_acc = epoch_metrics['level_acc'] / n
546
+ elapsed = time.time() - t0
547
+
548
+ # Validation
549
+ model.eval()
550
+ val_metrics = {'loss': 0, 'op_acc': 0, 'level_acc': 0, 'consistency': 0, 'batches': 0}
551
+ per_op_correct = {i: 0 for i in range(NUM_OPS)}
552
+ per_op_total = {i: 0 for i in range(NUM_OPS)}
553
+
554
+ with torch.no_grad():
555
+ for before, after, profile_delta, labels in val_loader:
556
+ before = before.to(device)
557
+ after = after.to(device)
558
+ profile_delta = profile_delta.to(device)
559
+ labels = labels.to(device)
560
+
561
+ op_logits, level_logits, _ = model(before, after, profile_delta)
562
+ _, metrics = criterion(op_logits, level_logits, labels)
563
+
564
+ preds = op_logits.argmax(dim=-1)
565
+ for pred, label in zip(preds, labels):
566
+ l = label.item()
567
+ per_op_total[l] += 1
568
+ if pred.item() == l:
569
+ per_op_correct[l] += 1
570
+
571
+ val_metrics['loss'] += metrics['loss']
572
+ val_metrics['op_acc'] += metrics['op_acc']
573
+ val_metrics['level_acc'] += metrics['level_acc']
574
+ val_metrics['consistency'] += metrics['consistency']
575
+ val_metrics['batches'] += 1
576
+
577
+ vn = val_metrics['batches']
578
+ val_loss = val_metrics['loss'] / vn
579
+ val_op_acc = val_metrics['op_acc'] / vn
580
+ val_level_acc = val_metrics['level_acc'] / vn
581
+ val_consistency = val_metrics['consistency'] / vn
582
+
583
+ print(f"Epoch {epoch+1:3d}/{args.epochs} "
584
+ f"[{elapsed:.1f}s] "
585
+ f"train: loss={train_loss:.4f} op={train_op_acc:.1%} level={train_level_acc:.1%} | "
586
+ f"val: loss={val_loss:.4f} op={val_op_acc:.1%} level={val_level_acc:.1%} "
587
+ f"consist={val_consistency:.1%}")
588
+
589
+ # Per-op breakdown every 10 epochs
590
+ if (epoch + 1) % 10 == 0 or epoch == args.epochs - 1:
591
+ print(" Per-op accuracy:")
592
+ for level in ['primitive', 'structural', 'semantic']:
593
+ ops_in_level = [op for op in TRAINABLE_OPS if OP_LEVEL[op] == level]
594
+ print(f" {level.upper()}:")
595
+ for op in ops_in_level:
596
+ idx = OP_TO_IDX[op]
597
+ total = per_op_total[idx]
598
+ correct = per_op_correct[idx]
599
+ acc = correct / total if total > 0 else 0
600
+ print(f" {op.name:25s} {correct:3d}/{total:3d} = {acc:.1%}")
601
+
602
+ # Save best
603
+ if val_op_acc > best_val_acc:
604
+ best_val_acc = val_op_acc
605
+ checkpoint = {
606
+ 'epoch': epoch + 1,
607
+ 'model_state': model.state_dict(),
608
+ 'optimizer_state': optimizer.state_dict(),
609
+ 'val_op_acc': val_op_acc,
610
+ 'val_level_acc': val_level_acc,
611
+ 'val_consistency': val_consistency,
612
+ 'args': vars(args),
613
+ 'num_ops': NUM_OPS,
614
+ 'op_names': [op.name for op in TRAINABLE_OPS],
615
+ }
616
+ torch.save(checkpoint, save_dir / 'edit_classifier_best.pt')
617
+ print(f" -> Saved best model (op_acc={val_op_acc:.1%})")
618
+
619
+ # Save final
620
+ torch.save({
621
+ 'epoch': args.epochs,
622
+ 'model_state': model.state_dict(),
623
+ 'val_op_acc': val_op_acc,
624
+ 'best_val_acc': best_val_acc,
625
+ 'args': vars(args),
626
+ 'num_ops': NUM_OPS,
627
+ }, save_dir / 'edit_classifier_final.pt')
628
+
629
+ print(f"\nTraining complete. Best val accuracy: {best_val_acc:.1%}")
630
+ return best_val_acc
631
+
632
+
633
+ def main():
634
+ parser = argparse.ArgumentParser(description="Train 31-class Edit Op Classifier")
635
+ parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu')
636
+ parser.add_argument('--epochs', type=int, default=50)
637
+ parser.add_argument('--batch-size', type=int, default=128)
638
+ parser.add_argument('--lr', type=float, default=3e-4)
639
+ parser.add_argument('--hidden-dim', type=int, default=256)
640
+ parser.add_argument('--vit-layers', type=int, default=4)
641
+ parser.add_argument('--vit-heads', type=int, default=8)
642
+ parser.add_argument('--num-regions', type=int, default=8)
643
+ parser.add_argument('--patch-size', type=int, default=4)
644
+ parser.add_argument('--dropout', type=float, default=0.1)
645
+ parser.add_argument('--train-samples', type=int, default=31000)
646
+ parser.add_argument('--val-samples', type=int, default=6200)
647
+ parser.add_argument('--fp16', action='store_true')
648
+ parser.add_argument('--palette-h', type=int, default=8)
649
+ parser.add_argument('--palette-w', type=int, default=32)
650
+ args = parser.parse_args()
651
+
652
+ train(args)
653
+
654
+
655
+ if __name__ == '__main__':
656
+ main()