| """ |
| Train 31-class Edit Operation Classifier β Neuroswarm Tier 2 |
| |
| Pipeline: |
| Code β HueAI β HSL (H,W,3) |
| β Circular hue encoding (sin/cos) β ViT β HybridRegionPooler (DETR) |
| β Delta fusion + profile_delta(33) + oklab_magnitude(1) |
| β Hierarchical classifier β 31 ops |
| |
| Usage: |
| python train_edit_classifier.py --epochs 50 --batch-size 128 --lr 3e-4 |
| python train_edit_classifier.py --device cuda --fp16 |
| """ |
|
|
| import argparse |
| import json |
| import math |
| import os |
| import sys |
| import time |
| import random |
| from pathlib import Path |
| from typing import List, Tuple, Dict |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
|
|
| sys.path.insert(0, str(Path(__file__).parent)) |
|
|
| from models.edit_ops import ( |
| PaletteEditOps, EditAction, OpCode, TRAINABLE_OPS, NUM_OPS, |
| OP_TO_IDX, IDX_TO_OP, OP_LEVEL |
| ) |
| from models.edit_classifier import EditOpClassifier, EditOpLoss |
| from models.scope_pooler import ScopePooler |
|
|
|
|
| |
| |
| |
|
|
| class EditOpDatasetGenerator: |
| """ |
| Generates (before_palette, after_palette, label) triples by |
| applying each of the 31 ops to random palettes. |
| |
| This is the bootstrapping approach β generate synthetic pairs |
| to pre-train, then fine-tune on real git diff pairs. |
| """ |
|
|
| START = PaletteEditOps.START_OF_SCOPE |
| END = PaletteEditOps.END_OF_SCOPE |
| NOOP = PaletteEditOps.NOOP |
|
|
| def __init__(self, palette_h: int = 8, palette_w: int = 32, vocab_size: int = 256): |
| self.H = palette_h |
| self.W = palette_w |
| self.vocab_size = vocab_size |
| self.ops = PaletteEditOps() |
| self.pooler = ScopePooler(hidden_dim=64) |
|
|
| def _random_region_tokens(self, min_len: int = 3, max_len: int = 12) -> List[int]: |
| """Generate random content tokens (excluding 0, 1, 2).""" |
| length = random.randint(min_len, max_len) |
| return [random.randint(3, self.vocab_size - 1) for _ in range(length)] |
|
|
| def _make_palette(self, tokens: List[int]) -> Tuple[torch.Tensor, object]: |
| """Create palette and metadata from flat token list.""" |
| total = self.H * self.W |
| if len(tokens) < total: |
| tokens = tokens + [self.NOOP] * (total - len(tokens)) |
| tokens = tokens[:total] |
|
|
| palette = torch.tensor([tokens], dtype=torch.long).view(1, self.H, self.W) |
| features = torch.randn(1, self.H, self.W, 64) |
| _, metadata = self.pooler(features, palette) |
| return palette[0], metadata[0] |
|
|
| def _make_single_region(self) -> Tuple[List[int], int]: |
| """Create a single-region palette token list.""" |
| content = self._random_region_tokens(5, 20) |
| tokens = [self.START] + content + [self.END] |
| |
| total = self.H * self.W |
| tokens += [self.NOOP] * (total - len(tokens)) |
| return tokens[:total], len(content) |
|
|
| def _make_two_regions(self) -> List[int]: |
| """Create two adjacent region token list.""" |
| c1 = self._random_region_tokens(3, 10) |
| c2 = self._random_region_tokens(3, 10) |
| tokens = [self.START] + c1 + [self.END, self.START] + c2 + [self.END] |
| total = self.H * self.W |
| tokens += [self.NOOP] * (total - len(tokens)) |
| return tokens[:total] |
|
|
| def _make_nested_scope(self) -> List[int]: |
| """Create nested scope: outer [inner [content] content].""" |
| inner = self._random_region_tokens(3, 8) |
| outer = self._random_region_tokens(2, 5) |
| block_hue = random.choice([20, 24, 28, 32]) |
| tokens = [self.START] + outer + [self.START, block_hue] + inner + [self.END] + [self.END] |
| total = self.H * self.W |
| tokens += [self.NOOP] * (total - len(tokens)) |
| return tokens[:total] |
|
|
| def _make_func_palette(self) -> List[int]: |
| """Create palette with function def (hue 12) and call (hue 60) for async ops.""" |
| content = self._random_region_tokens(3, 8) |
| tokens = [self.START, 12] + content + [60] + self._random_region_tokens(2, 4) + [self.END] |
| total = self.H * self.W |
| tokens += [self.NOOP] * (total - len(tokens)) |
| return tokens[:total] |
|
|
| def generate_pair(self, op: OpCode) -> Tuple[torch.Tensor, torch.Tensor, int]: |
| """ |
| Generate a (before, after) palette pair for a specific op. |
| |
| Returns: |
| before_hsl: (H, W, 3) float tensor (normalized HSL) |
| after_hsl: (H, W, 3) float tensor (normalized HSL) |
| label: int in [0, 30] |
| """ |
| label = OP_TO_IDX[op] |
| max_attempts = 10 |
|
|
| for attempt in range(max_attempts): |
| try: |
| before_palette, action = self._create_op_scenario(op) |
| palette, metadata = self._make_palette(before_palette) |
|
|
| after_palette, success = self.ops.apply(palette, action, metadata) |
| if not success: |
| continue |
|
|
| |
| before_hsl = self._palette_to_hsl(palette) |
| after_hsl = self._palette_to_hsl(after_palette) |
|
|
| return before_hsl, after_hsl, label |
|
|
| except Exception: |
| continue |
|
|
| |
| tokens, _ = self._make_single_region() |
| palette, _ = self._make_palette(tokens) |
| hsl = self._palette_to_hsl(palette) |
| return hsl, hsl, label |
|
|
| @staticmethod |
| def compute_profile_delta(before_hsl: torch.Tensor, after_hsl: torch.Tensor) -> torch.Tensor: |
| """ |
| Compute a 33-dim structural profile delta from HSL tensors. |
| |
| Mirrors PaletteStructuralProfile dimensions: |
| [0:10] Category distribution delta (hue bands) |
| [10:19] Color stats delta (mean/std/entropy of H,S,L) |
| [19:25] Structural metrics delta (scope, density, etc.) |
| [25:33] Spectral alignment delta (placeholder zeros) |
| |
| This is an approximation for synthetic data. Real training |
| will use PaletteProfiler.profile_file() on actual source code. |
| """ |
| PROFILE_DIM = 33 |
| delta = torch.zeros(PROFILE_DIM) |
|
|
| |
| before_h = before_hsl[..., 0].flatten() |
| after_h = after_hsl[..., 0].flatten() |
|
|
| for i in range(10): |
| lo, hi = i / 10.0, (i + 1) / 10.0 |
| before_count = ((before_h >= lo) & (before_h < hi)).float().mean() |
| after_count = ((after_h >= lo) & (after_h < hi)).float().mean() |
| delta[i] = after_count - before_count |
|
|
| |
| for ch in range(3): |
| before_ch = before_hsl[..., ch].flatten() |
| after_ch = after_hsl[..., ch].flatten() |
| delta[10 + ch * 3] = after_ch.mean() - before_ch.mean() |
| delta[11 + ch * 3] = after_ch.std() - before_ch.std() |
| |
| before_hist = torch.histc(before_ch, bins=16, min=0, max=1) + 1e-8 |
| after_hist = torch.histc(after_ch, bins=16, min=0, max=1) + 1e-8 |
| before_ent = -(before_hist / before_hist.sum() * (before_hist / before_hist.sum()).log()).sum() |
| after_ent = -(after_hist / after_hist.sum() * (after_hist / after_hist.sum()).log()).sum() |
| delta[12 + ch * 3] = after_ent - before_ent |
|
|
| |
| before_s = before_hsl[..., 1].flatten() |
| after_s = after_hsl[..., 1].flatten() |
| |
| delta[19] = (after_s > 0.95).float().mean() - (before_s > 0.95).float().mean() |
| |
| delta[20] = (after_hsl[..., 2] > 0.01).float().mean() - (before_hsl[..., 2] > 0.01).float().mean() |
| |
| delta[21] = after_s.mean() - before_s.mean() |
| |
| delta[22] = after_hsl[..., 2].flatten().mean() - before_hsl[..., 2].flatten().mean() |
| |
| before_unique = before_h[before_h > 0].unique().numel() / max(1, (before_h > 0).sum().item()) |
| after_unique = after_h[after_h > 0].unique().numel() / max(1, (after_h > 0).sum().item()) |
| delta[23] = after_unique - before_unique |
| |
| delta[24] = (after_hsl[..., 2] > 0.01).float().sum() - (before_hsl[..., 2] > 0.01).float().sum() |
|
|
| |
| return delta |
|
|
| def _palette_to_hsl(self, palette: torch.Tensor) -> torch.Tensor: |
| """Convert integer palette to normalized HSL float tensor (H, W, 3).""" |
| H, W = palette.shape |
| hsl = torch.zeros(H, W, 3) |
| flat = palette.flatten().float() |
|
|
| |
| |
| |
| |
| for i in range(H * W): |
| h, w = i // W, i % W |
| val = flat[i].item() |
| if val == self.NOOP: |
| hsl[h, w] = torch.tensor([0.0, 0.0, 0.0]) |
| elif val == self.START: |
| hsl[h, w] = torch.tensor([0.0, 1.0, 0.1]) |
| elif val == self.END: |
| hsl[h, w] = torch.tensor([0.5, 1.0, 0.1]) |
| else: |
| hsl[h, w] = torch.tensor([ |
| val / self.vocab_size, |
| 0.7, |
| 0.5 |
| ]) |
| return hsl |
|
|
| def _create_op_scenario(self, op: OpCode) -> Tuple[List[int], EditAction]: |
| """Create appropriate palette and EditAction for a given op.""" |
|
|
| |
| if op == OpCode.DELETE_RANGE: |
| tokens, n = self._make_single_region() |
| i_end = min(random.randint(0, 2), n - 1) |
| return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=i_end, payload_idx=0) |
|
|
| elif op == OpCode.INSERT_TOKEN: |
| tokens, n = self._make_single_region() |
| pos = random.randint(0, n) |
| payload = random.randint(3, self.vocab_size - 1) |
| return tokens, EditAction(op_id=op, region_id=0, i_start=pos, i_end=-1, payload_idx=payload) |
|
|
| elif op == OpCode.REPLACE_TOKEN: |
| tokens, n = self._make_single_region() |
| pos = random.randint(0, n - 1) |
| payload = random.randint(3, self.vocab_size - 1) |
| return tokens, EditAction(op_id=op, region_id=0, i_start=pos, i_end=-1, payload_idx=payload) |
|
|
| elif op == OpCode.SWAP_TOKENS: |
| tokens, n = self._make_single_region() |
| i_start = random.randint(0, max(0, n - 2)) |
| i_end = random.randint(i_start + 1, n - 1) if i_start < n - 1 else i_start |
| return tokens, EditAction(op_id=op, region_id=0, i_start=i_start, i_end=i_end, payload_idx=0) |
|
|
| elif op == OpCode.MOVE_RANGE: |
| tokens = self._make_two_regions() |
| return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=0, |
| payload_idx=0, target_region_id=1) |
|
|
| elif op == OpCode.COPY_RANGE: |
| tokens = self._make_two_regions() |
| return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=0, |
| payload_idx=0, target_region_id=1) |
|
|
| elif op == OpCode.WRAP_SCOPE: |
| tokens, n = self._make_single_region() |
| i_end = min(random.randint(1, 3), n - 1) |
| return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=i_end, payload_idx=0) |
|
|
| elif op == OpCode.UNWRAP_SCOPE: |
| tokens = self._make_nested_scope() |
| return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=-1, payload_idx=0) |
|
|
| |
| elif op == OpCode.INDENT: |
| tokens, n = self._make_single_region() |
| i_end = min(2, n - 1) |
| return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=i_end, payload_idx=0) |
|
|
| elif op == OpCode.DEDENT: |
| tokens = self._make_nested_scope() |
| return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=0, payload_idx=0) |
|
|
| elif op == OpCode.EXTRACT: |
| tokens, n = self._make_single_region() |
| i_end = min(2, n - 1) |
| return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=i_end, payload_idx=0) |
|
|
| elif op == OpCode.INLINE: |
| |
| c1 = self._random_region_tokens(3, 6) |
| c2 = self._random_region_tokens(3, 6) |
| tokens = [self.START, 3] + c1[1:] + [self.END, self.START] + c2 + [self.END] |
| total = self.H * self.W |
| tokens += [self.NOOP] * (total - len(tokens)) |
| tokens = tokens[:total] |
| return tokens, EditAction(op_id=op, region_id=1, i_start=0, i_end=-1, |
| payload_idx=0, target_region_id=0) |
|
|
| elif op == OpCode.SPLIT_REGION: |
| tokens, n = self._make_single_region() |
| split_at = max(1, min(n // 2, n - 1)) |
| return tokens, EditAction(op_id=op, region_id=0, i_start=split_at, i_end=-1, payload_idx=0) |
|
|
| elif op == OpCode.MERGE_REGIONS: |
| tokens = self._make_two_regions() |
| return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=-1, |
| payload_idx=0, target_region_id=1) |
|
|
| elif op == OpCode.REORDER: |
| tokens, n = self._make_single_region() |
| i_end = min(3, n - 1) |
| return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=i_end, payload_idx=0) |
|
|
| elif op == OpCode.NEST_IN_BLOCK: |
| tokens, n = self._make_single_region() |
| i_end = min(2, n - 1) |
| block_hue = random.choice([20, 24, 28]) |
| return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=i_end, |
| payload_idx=block_hue) |
|
|
| elif op == OpCode.UNNEST_FROM_BLOCK: |
| tokens = self._make_nested_scope() |
| return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=-1, payload_idx=0) |
|
|
| elif op == OpCode.HOIST: |
| tokens = self._make_nested_scope() |
| return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=0, payload_idx=0) |
|
|
| elif op == OpCode.SINK: |
| tokens = self._make_two_regions() |
| return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=0, |
| payload_idx=0, target_region_id=1) |
|
|
| |
| elif op == OpCode.RENAME: |
| tokens, n = self._make_single_region() |
| pos = random.randint(0, n - 1) |
| payload = random.randint(3, self.vocab_size - 1) |
| return tokens, EditAction(op_id=op, region_id=0, i_start=pos, i_end=-1, payload_idx=payload) |
|
|
| elif op == OpCode.RETYPE: |
| tokens, n = self._make_single_region() |
| i_end = min(1, n - 1) |
| new_types = [random.randint(3, self.vocab_size - 1) for _ in range(3)] |
| return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=i_end, |
| payload_idx=0, payload_tokens=new_types) |
|
|
| elif op == OpCode.CONVERT_CONSTRUCT: |
| |
| content = [20, 220, 220] + self._random_region_tokens(2, 5) |
| tokens = [self.START] + content + [self.END] |
| total = self.H * self.W |
| tokens += [self.NOOP] * (total - len(tokens)) |
| return tokens[:total], EditAction(op_id=op, region_id=0, i_start=0, i_end=-1, payload_idx=0) |
|
|
| elif op == OpCode.SYNC_TO_ASYNC: |
| tokens = self._make_func_palette() |
| return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=-1, payload_idx=0) |
|
|
| elif op == OpCode.PARAMETERIZE: |
| tokens, n = self._make_single_region() |
| pos = random.randint(0, n - 1) |
| param_hue = random.randint(3, self.vocab_size - 1) |
| return tokens, EditAction(op_id=op, region_id=0, i_start=pos, i_end=-1, payload_idx=param_hue) |
|
|
| elif op == OpCode.SPECIALIZE: |
| tokens, n = self._make_single_region() |
| i_end = min(1, n - 1) |
| concrete = [random.randint(3, self.vocab_size - 1) for _ in range(3)] |
| return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=i_end, |
| payload_idx=0, payload_tokens=concrete) |
|
|
| elif op == OpCode.GUARD: |
| tokens, n = self._make_single_region() |
| i_end = min(2, n - 1) |
| guard_hue = random.choice([24, 28, 32]) |
| return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=i_end, |
| payload_idx=guard_hue) |
|
|
| elif op == OpCode.UNGUARD: |
| tokens = self._make_nested_scope() |
| return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=-1, payload_idx=0) |
|
|
| elif op == OpCode.SCATTER: |
| tokens, n = self._make_single_region() |
| |
| positions = random.sample(range(n), min(3, n)) |
| payload = random.randint(3, self.vocab_size - 1) |
| return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=-1, |
| payload_idx=payload, positions=positions) |
|
|
| elif op == OpCode.GATHER: |
| tokens, n = self._make_single_region() |
| palette, metadata = self._make_palette(tokens) |
| positions = PaletteEditOps._get_content_positions(palette, metadata, 0) |
| abs_positions = positions[:min(3, len(positions))] |
| return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=-1, |
| payload_idx=0, positions=abs_positions) |
|
|
| elif op == OpCode.MIRROR: |
| tokens = self._make_two_regions() |
| return tokens, EditAction(op_id=op, region_id=0, i_start=0, i_end=0, |
| payload_idx=random.randint(3, self.vocab_size - 1), |
| target_region_id=1) |
|
|
| elif op == OpCode.COMPOSE: |
| tokens = self._make_nested_scope() |
| palette, metadata = self._make_palette(tokens) |
| mask = metadata.masks[0] |
| n_positions = mask.sum().item() |
| return tokens, EditAction(op_id=op, region_id=0, i_start=0, |
| i_end=max(0, int(n_positions) - 1), payload_idx=0) |
|
|
| raise ValueError(f"Unknown op: {op}") |
|
|
|
|
| class EditOpDataset(Dataset): |
| """PyTorch Dataset for edit op classification training.""" |
|
|
| def __init__(self, num_samples: int = 10000, palette_h: int = 8, palette_w: int = 32): |
| self.generator = EditOpDatasetGenerator(palette_h, palette_w) |
| self.num_samples = num_samples |
| self.samples_per_op = num_samples // NUM_OPS |
|
|
| |
| self.data: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]] = [] |
| print(f"Generating {num_samples} training pairs ({self.samples_per_op} per op)...") |
| for op in TRAINABLE_OPS: |
| for _ in range(self.samples_per_op): |
| before, after, label = self.generator.generate_pair(op) |
| profile_delta = self.generator.compute_profile_delta(before, after) |
| self.data.append((before, after, profile_delta, label)) |
|
|
| |
| random.shuffle(self.data) |
| print(f"Generated {len(self.data)} pairs across {NUM_OPS} ops") |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| before, after, profile_delta, label = self.data[idx] |
| return before, after, profile_delta, torch.tensor(label, dtype=torch.long) |
|
|
|
|
| |
| |
| |
|
|
| def train(args): |
| device = torch.device(args.device) |
| print(f"Device: {device}") |
| print(f"Training {NUM_OPS}-class edit op classifier") |
| print(f"Ops: {[op.name for op in TRAINABLE_OPS]}") |
|
|
| |
| train_dataset = EditOpDataset(args.train_samples, args.palette_h, args.palette_w) |
| val_dataset = EditOpDataset(args.val_samples, args.palette_h, args.palette_w) |
|
|
| train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, |
| num_workers=0, pin_memory=True) |
| val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, |
| num_workers=0, pin_memory=True) |
|
|
| |
| model = EditOpClassifier( |
| hidden_dim=args.hidden_dim, |
| vit_layers=args.vit_layers, |
| vit_heads=args.vit_heads, |
| num_regions=args.num_regions, |
| patch_size=args.patch_size, |
| dropout=args.dropout, |
| ).to(device) |
|
|
| param_count = sum(p.numel() for p in model.parameters()) |
| print(f"Model parameters: {param_count:,}") |
|
|
| |
| criterion = EditOpLoss().to(device) |
|
|
| |
| optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) |
|
|
| |
| scaler = torch.amp.GradScaler('cuda') if args.fp16 and device.type == 'cuda' else None |
|
|
| best_val_acc = 0.0 |
| save_dir = Path("trained_models") |
| save_dir.mkdir(exist_ok=True) |
|
|
| for epoch in range(args.epochs): |
| model.train() |
| epoch_metrics = {'loss': 0, 'op_acc': 0, 'level_acc': 0, 'batches': 0} |
| t0 = time.time() |
|
|
| for before, after, profile_delta, labels in train_loader: |
| before = before.to(device) |
| after = after.to(device) |
| profile_delta = profile_delta.to(device) |
| labels = labels.to(device) |
|
|
| optimizer.zero_grad() |
|
|
| if scaler: |
| with torch.amp.autocast('cuda'): |
| op_logits, level_logits, _ = model(before, after, profile_delta) |
| loss, metrics = criterion(op_logits, level_logits, labels) |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| op_logits, level_logits, _ = model(before, after, profile_delta) |
| loss, metrics = criterion(op_logits, level_logits, labels) |
| loss.backward() |
| nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
|
|
| epoch_metrics['loss'] += metrics['loss'] |
| epoch_metrics['op_acc'] += metrics['op_acc'] |
| epoch_metrics['level_acc'] += metrics['level_acc'] |
| epoch_metrics['batches'] += 1 |
|
|
| scheduler.step() |
|
|
| n = epoch_metrics['batches'] |
| train_loss = epoch_metrics['loss'] / n |
| train_op_acc = epoch_metrics['op_acc'] / n |
| train_level_acc = epoch_metrics['level_acc'] / n |
| elapsed = time.time() - t0 |
|
|
| |
| model.eval() |
| val_metrics = {'loss': 0, 'op_acc': 0, 'level_acc': 0, 'consistency': 0, 'batches': 0} |
| per_op_correct = {i: 0 for i in range(NUM_OPS)} |
| per_op_total = {i: 0 for i in range(NUM_OPS)} |
|
|
| with torch.no_grad(): |
| for before, after, profile_delta, labels in val_loader: |
| before = before.to(device) |
| after = after.to(device) |
| profile_delta = profile_delta.to(device) |
| labels = labels.to(device) |
|
|
| op_logits, level_logits, _ = model(before, after, profile_delta) |
| _, metrics = criterion(op_logits, level_logits, labels) |
|
|
| preds = op_logits.argmax(dim=-1) |
| for pred, label in zip(preds, labels): |
| l = label.item() |
| per_op_total[l] += 1 |
| if pred.item() == l: |
| per_op_correct[l] += 1 |
|
|
| val_metrics['loss'] += metrics['loss'] |
| val_metrics['op_acc'] += metrics['op_acc'] |
| val_metrics['level_acc'] += metrics['level_acc'] |
| val_metrics['consistency'] += metrics['consistency'] |
| val_metrics['batches'] += 1 |
|
|
| vn = val_metrics['batches'] |
| val_loss = val_metrics['loss'] / vn |
| val_op_acc = val_metrics['op_acc'] / vn |
| val_level_acc = val_metrics['level_acc'] / vn |
| val_consistency = val_metrics['consistency'] / vn |
|
|
| print(f"Epoch {epoch+1:3d}/{args.epochs} " |
| f"[{elapsed:.1f}s] " |
| f"train: loss={train_loss:.4f} op={train_op_acc:.1%} level={train_level_acc:.1%} | " |
| f"val: loss={val_loss:.4f} op={val_op_acc:.1%} level={val_level_acc:.1%} " |
| f"consist={val_consistency:.1%}") |
|
|
| |
| if (epoch + 1) % 10 == 0 or epoch == args.epochs - 1: |
| print(" Per-op accuracy:") |
| for level in ['primitive', 'structural', 'semantic']: |
| ops_in_level = [op for op in TRAINABLE_OPS if OP_LEVEL[op] == level] |
| print(f" {level.upper()}:") |
| for op in ops_in_level: |
| idx = OP_TO_IDX[op] |
| total = per_op_total[idx] |
| correct = per_op_correct[idx] |
| acc = correct / total if total > 0 else 0 |
| print(f" {op.name:25s} {correct:3d}/{total:3d} = {acc:.1%}") |
|
|
| |
| if val_op_acc > best_val_acc: |
| best_val_acc = val_op_acc |
| checkpoint = { |
| 'epoch': epoch + 1, |
| 'model_state': model.state_dict(), |
| 'optimizer_state': optimizer.state_dict(), |
| 'val_op_acc': val_op_acc, |
| 'val_level_acc': val_level_acc, |
| 'val_consistency': val_consistency, |
| 'args': vars(args), |
| 'num_ops': NUM_OPS, |
| 'op_names': [op.name for op in TRAINABLE_OPS], |
| } |
| torch.save(checkpoint, save_dir / 'edit_classifier_best.pt') |
| print(f" -> Saved best model (op_acc={val_op_acc:.1%})") |
|
|
| |
| torch.save({ |
| 'epoch': args.epochs, |
| 'model_state': model.state_dict(), |
| 'val_op_acc': val_op_acc, |
| 'best_val_acc': best_val_acc, |
| 'args': vars(args), |
| 'num_ops': NUM_OPS, |
| }, save_dir / 'edit_classifier_final.pt') |
|
|
| print(f"\nTraining complete. Best val accuracy: {best_val_acc:.1%}") |
| return best_val_acc |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Train 31-class Edit Op Classifier") |
| parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu') |
| parser.add_argument('--epochs', type=int, default=50) |
| parser.add_argument('--batch-size', type=int, default=128) |
| parser.add_argument('--lr', type=float, default=3e-4) |
| parser.add_argument('--hidden-dim', type=int, default=256) |
| parser.add_argument('--vit-layers', type=int, default=4) |
| parser.add_argument('--vit-heads', type=int, default=8) |
| parser.add_argument('--num-regions', type=int, default=8) |
| parser.add_argument('--patch-size', type=int, default=4) |
| parser.add_argument('--dropout', type=float, default=0.1) |
| parser.add_argument('--train-samples', type=int, default=31000) |
| parser.add_argument('--val-samples', type=int, default=6200) |
| parser.add_argument('--fp16', action='store_true') |
| parser.add_argument('--palette-h', type=int, default=8) |
| parser.add_argument('--palette-w', type=int, default=32) |
| args = parser.parse_args() |
|
|
| train(args) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|