""" 31-Class Edit Operation Classifier — Neuroswarm Tier 2 Verification Engine Verification stack: Tier 1: 33-dim profile cosine similarity (nanoseconds, GPU) Tier 2: THIS — edit classifier inference (milliseconds, GPU) Tier 3: LLM review (seconds, API call, costs tokens) Pipeline: (before_hsl, after_hsl) each (B, H, W, 3) → Circular hue encoding: h → (sin(2πh), cos(2πh)), stack with S,L → 4D → HSLFeatureExtractor (ViT spatial features) → HybridRegionPooler (DETR-style learned queries, no scope markers) → Delta computation + fusion → Concat: [global_feat, profile_delta_33, oklab_magnitude_1] → Hierarchical classifier: level (3) → op (31) Fixes over v1: 1. Circular hue encoding (HSLFeatureExtractor) — hue wraparound correct 2. HybridRegionPooler — DETR learned queries with iterative refinement 3. 33-dim profile delta conditioning — structural direction signal 4. OKLab delta magnitude — perceptual change size signal """ import math import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, Dict, List from .edit_ops import TRAINABLE_OPS, NUM_OPS, OP_TO_IDX, IDX_TO_OP, OpCode, OP_LEVEL from .hsl_feature_extractor import HSLFeatureExtractor from .hybrid_pooler import HybridRegionPooler from .oklab_utils import hsl_to_oklab_batch class EditOpClassifier(nn.Module): """ Neuroswarm Tier 2: Classifies edit ops from before/after palette pairs. Managers call this thousands of times per cycle to verify sub-agent work without spending tokens on LLM review. ~1ms inference on GPU. Input: (before_hsl, after_hsl) each (B, H, W, 3) normalized HSL [0,1] Output: (op_logits_31, level_logits_3, global_features) """ PROFILE_DIM = 33 # Structural profile vector dimensionality OKLAB_DIM = 1 # Perceptual delta magnitude (scalar) def __init__( self, hidden_dim: int = 256, vit_layers: int = 4, vit_heads: int = 8, num_regions: int = 8, patch_size: int = 4, num_refinement_iters: int = 2, dropout: float = 0.1, ): super().__init__() self.hidden_dim = hidden_dim # Fix 1: HSLFeatureExtractor with circular hue encoding # h → (sin(2πh), cos(2πh)) handles hue wraparound correctly # 359° and 1° are adjacent, not 358 apart self.feature_extractor = HSLFeatureExtractor( hidden_dim=hidden_dim, num_layers=vit_layers, num_heads=vit_heads, patch_size=patch_size, dropout=dropout, ) # Fix 2: HybridRegionPooler — DETR-style learned queries # use_structure=False because HSL palettes have NO scope markers # Iterative refinement (Slot Attention style) self.region_pooler = HybridRegionPooler( hidden_dim=hidden_dim, num_learned_queries=num_regions, num_heads=vit_heads, use_structure=False, dropout=dropout, num_refinement_iters=num_refinement_iters, ) # Delta fusion: (before_regions, after_regions, delta) → fused self.delta_fusion = nn.Sequential( nn.Linear(hidden_dim * 3, hidden_dim * 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim * 2, hidden_dim), nn.LayerNorm(hidden_dim), ) # Global pooling via attention self.global_query = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02) self.global_attn = nn.MultiheadAttention( hidden_dim, vit_heads, dropout=dropout, batch_first=True ) # Fix 3: 33-dim profile delta projection # Structural profile captures category distribution, color stats, # scope depth, spectral alignment — compressed direction signal self.profile_proj = nn.Sequential( nn.Linear(self.PROFILE_DIM, hidden_dim // 4), nn.GELU(), nn.LayerNorm(hidden_dim // 4), ) # Fix 4: OKLab delta magnitude projection # Single scalar — "how big was this change" in perceptual space self.oklab_proj = nn.Sequential( nn.Linear(self.OKLAB_DIM, hidden_dim // 8), nn.GELU(), ) # Conditioning input size: hidden_dim + profile_proj + oklab_proj cond_dim = hidden_dim + hidden_dim // 4 + hidden_dim // 8 # Level classifier (primitive / structural / semantic) self.level_head = nn.Sequential( nn.Linear(cond_dim, hidden_dim // 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim // 2, 3), ) # Fine-grained op classifier (31 classes) # Conditioned on level logits (hierarchical) self.op_head = nn.Sequential( nn.Linear(cond_dim + 3, hidden_dim), # +3 for level logits nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim // 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim // 2, NUM_OPS), ) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) def encode_palette(self, hsl: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Encode HSL palette → region embeddings + importance scores. Args: hsl: (B, H, W, 3) normalized HSL [0,1] Returns: regions: (B, R, hidden_dim) region embeddings importance: (B, R) importance scores """ # HSLFeatureExtractor: circular hue → ViT spatial features features = self.feature_extractor(hsl) # (B, H, W, D) # HybridRegionPooler: DETR queries → region embeddings regions, importance = self.region_pooler(features) # (B, R, D), (B, R) return regions, importance @staticmethod def compute_oklab_delta(before_hsl: torch.Tensor, after_hsl: torch.Tensor) -> torch.Tensor: """ Compute perceptual change magnitude in OKLab space. Returns: (B, 1) scalar — mean DeltaE across all spatial positions """ # Convert to OKLab before_oklab = hsl_to_oklab_batch(before_hsl) # (B, H, W, 3) after_oklab = hsl_to_oklab_batch(after_hsl) # (B, H, W, 3) # Per-pixel DeltaE delta_e = (before_oklab - after_oklab).pow(2).sum(dim=-1).sqrt() # (B, H, W) # Mean across spatial dimensions mean_delta_e = delta_e.mean(dim=(1, 2), keepdim=False) # (B,) return mean_delta_e.unsqueeze(-1) # (B, 1) def forward( self, before_hsl: torch.Tensor, after_hsl: torch.Tensor, profile_delta: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Classify edit operation from before/after palette pair. Args: before_hsl: (B, H, W, 3) palette before edit, HSL [0,1] after_hsl: (B, H, W, 3) palette after edit, HSL [0,1] profile_delta: (B, 33) optional structural profile delta (after - before) If None, zeros are used (graceful degradation) Returns: op_logits: (B, 31) logits over edit operations level_logits: (B, 3) logits over levels global_feat: (B, hidden_dim) fused delta representation """ B = before_hsl.shape[0] device = before_hsl.device # Encode both palettes through shared feature extractor + pooler before_regions, before_imp = self.encode_palette(before_hsl) # (B, R, D) after_regions, after_imp = self.encode_palette(after_hsl) # (B, R, D) # Compute delta (importance-weighted) imp = (before_imp + after_imp) / 2 # (B, R) imp_w = imp.unsqueeze(-1) # (B, R, 1) delta = (after_regions - before_regions) * imp_w # Fuse: [before, after, delta] → fused features fused = torch.cat([before_regions, after_regions, delta], dim=-1) # (B, R, 3*D) fused = self.delta_fusion(fused) # (B, R, D) # Global pool via attention query = self.global_query.expand(B, -1, -1) global_feat, _ = self.global_attn(query, fused, fused) global_feat = global_feat.squeeze(1) # (B, D) # Fix 3: Profile delta conditioning if profile_delta is None: profile_delta = torch.zeros(B, self.PROFILE_DIM, device=device) profile_feat = self.profile_proj(profile_delta) # (B, D//4) # Fix 4: OKLab delta magnitude oklab_delta = self.compute_oklab_delta(before_hsl, after_hsl) # (B, 1) oklab_feat = self.oklab_proj(oklab_delta) # (B, D//8) # Concatenate all conditioning signals conditioned = torch.cat([global_feat, profile_feat, oklab_feat], dim=-1) # (B, D + D//4 + D//8) # Level classification level_logits = self.level_head(conditioned) # (B, 3) # Fine op classification (conditioned on level) op_input = torch.cat([conditioned, level_logits], dim=-1) op_logits = self.op_head(op_input) # (B, 31) return op_logits, level_logits, global_feat # ==================================================================== # Tier 1: Profile cosine similarity (nanoseconds) # ==================================================================== class Tier1ProfileVerifier: """ Neuroswarm Tier 1: Nanosecond verification via 33-dim profile cosine similarity. Usage: verifier = Tier1ProfileVerifier() result = verifier.verify(expected_delta, actual_delta) if result.tier == 'pass': ... elif result.tier == 'escalate': ... # → Tier 2 elif result.tier == 'reject': ... # → retry agent """ def __init__( self, pass_threshold: float = 0.7, reject_threshold: float = 0.3, ): self.pass_threshold = pass_threshold self.reject_threshold = reject_threshold def verify( self, expected_delta: torch.Tensor, actual_delta: torch.Tensor, ) -> dict: """ Compare expected vs actual structural profile delta. Args: expected_delta: (33,) or (B, 33) expected profile change actual_delta: (33,) or (B, 33) actual profile change Returns: dict with 'alignment', 'tier' ('pass'/'escalate'/'reject') """ if expected_delta.dim() == 1: expected_delta = expected_delta.unsqueeze(0) actual_delta = actual_delta.unsqueeze(0) # Cosine similarity alignment = F.cosine_similarity(expected_delta, actual_delta, dim=-1) # (B,) tiers = [] for a in alignment: a_val = a.item() if a_val >= self.pass_threshold: tiers.append('pass') elif a_val >= self.reject_threshold: tiers.append('escalate') else: tiers.append('reject') return { 'alignment': alignment, 'tiers': tiers, 'mean_alignment': alignment.mean().item(), } # ==================================================================== # Tier 2: Edit classifier inference wrapper # ==================================================================== class Tier2EditVerifier: """ Neuroswarm Tier 2: Millisecond verification via edit classifier. Usage: verifier = Tier2EditVerifier(model, device='cuda') result = verifier.verify(before_hsl, after_hsl, expected_op, profile_delta) if result['match']: ... # agent did the right thing else: ... # escalate to Tier 3 """ def __init__( self, model: EditOpClassifier, device: str = 'cpu', confidence_threshold: float = 0.8, ): self.model = model.to(device).eval() self.device = device self.confidence_threshold = confidence_threshold @torch.no_grad() def verify( self, before_hsl: torch.Tensor, after_hsl: torch.Tensor, expected_op: OpCode, profile_delta: Optional[torch.Tensor] = None, ) -> dict: """ Verify that an agent performed the expected edit operation. Returns: dict with 'match', 'predicted_op', 'confidence', 'escalate' """ before = before_hsl.unsqueeze(0).to(self.device) if before_hsl.dim() == 3 else before_hsl.to(self.device) after = after_hsl.unsqueeze(0).to(self.device) if after_hsl.dim() == 3 else after_hsl.to(self.device) if profile_delta is not None: profile_delta = profile_delta.unsqueeze(0).to(self.device) if profile_delta.dim() == 1 else profile_delta.to(self.device) op_logits, level_logits, _ = self.model(before, after, profile_delta) probs = F.softmax(op_logits, dim=-1) pred_idx = probs.argmax(dim=-1).item() confidence = probs[0, pred_idx].item() predicted_op = IDX_TO_OP[pred_idx] expected_idx = OP_TO_IDX[expected_op] match = (pred_idx == expected_idx) and (confidence >= self.confidence_threshold) escalate = not match return { 'match': match, 'predicted_op': predicted_op, 'predicted_op_name': predicted_op.name, 'expected_op_name': expected_op.name, 'confidence': confidence, 'escalate': escalate, 'op_probs': probs[0].cpu(), } # ==================================================================== # Loss # ==================================================================== class EditOpLoss(nn.Module): """ Combined loss for edit op classification. Components: - Cross-entropy on 31-class op prediction - Cross-entropy on 3-class level prediction (auxiliary) - Level-op consistency penalty """ def __init__(self, level_weight: float = 0.3, consistency_weight: float = 0.1): super().__init__() self.level_weight = level_weight self.consistency_weight = consistency_weight self.op_loss_fn = nn.CrossEntropyLoss(label_smoothing=0.05) self.level_loss_fn = nn.CrossEntropyLoss(label_smoothing=0.05) # Build op → level mapping self._op_to_level = {} level_names = ['primitive', 'structural', 'semantic'] for op in TRAINABLE_OPS: level = OP_LEVEL[op] self._op_to_level[OP_TO_IDX[op]] = level_names.index(level) def forward( self, op_logits: torch.Tensor, level_logits: torch.Tensor, op_labels: torch.Tensor, ) -> Tuple[torch.Tensor, Dict[str, float]]: """ Args: op_logits: (B, 31) predicted op logits level_logits: (B, 3) predicted level logits op_labels: (B,) integer labels in [0, 30] Returns: total_loss, metrics_dict """ op_loss = self.op_loss_fn(op_logits, op_labels) level_labels = torch.tensor( [self._op_to_level[l.item()] for l in op_labels], device=op_labels.device, dtype=torch.long ) level_loss = self.level_loss_fn(level_logits, level_labels) pred_ops = op_logits.argmax(dim=-1) pred_levels = level_logits.argmax(dim=-1) expected_levels = torch.tensor( [self._op_to_level[p.item()] for p in pred_ops], device=op_labels.device, dtype=torch.long ) consistency = (pred_levels == expected_levels).float().mean() consistency_loss = 1.0 - consistency total = op_loss + self.level_weight * level_loss + self.consistency_weight * consistency_loss metrics = { 'loss': total.item(), 'op_loss': op_loss.item(), 'level_loss': level_loss.item(), 'consistency': consistency.item(), 'op_acc': (pred_ops == op_labels).float().mean().item(), 'level_acc': (pred_levels == level_labels).float().mean().item(), } return total, metrics @staticmethod def op_label_from_opcode(opcode: OpCode) -> int: return OP_TO_IDX[opcode] @staticmethod def opcode_from_label(label: int) -> OpCode: return IDX_TO_OP[label]