""" Parametric Edit Operations Engine — 31 Operations Executes edit operations on palette tensors with pointer-based addressing. Operations preserve scope balance and support region-relative indexing. Three levels: Level 1 — Primitive (8 ops): token-level atomic edits Level 2 — Structural (11 ops): scope-tree mutations Level 3 — Semantic (12 ops): meaning-level transforms Key Features: - Parametric actions with arguments - Pointer arithmetic (region-relative → absolute) - Cross-region operations (MOVE, COPY, EXTRACT, INLINE, MERGE, etc.) - Scope balance verification - Macro pattern transformations - Stateless execution (pure functions) """ import torch from dataclasses import dataclass, field from typing import Tuple, Optional, List, Dict from enum import IntEnum # Import RegionMetadata from scope_pooler from .scope_pooler import RegionMetadata class OpCode(IntEnum): """Operation codes — 31 edit operations across 3 levels""" # === Control (0-99) === NO_OP = 0 MOVE_NEXT = 1 FOCUS_PARENT = 2 DONE = 99 # === Level 1: Primitive — token-level atomic (100-199) === DELETE_RANGE = 100 INSERT_TOKEN = 101 REPLACE_TOKEN = 102 SWAP_TOKENS = 103 MOVE_RANGE = 104 COPY_RANGE = 105 WRAP_SCOPE = 106 UNWRAP_SCOPE = 107 # === Level 2: Structural — scope-tree mutations (200-299) === INDENT = 200 DEDENT = 201 EXTRACT = 202 INLINE = 203 SPLIT_REGION = 204 MERGE_REGIONS = 205 REORDER = 206 NEST_IN_BLOCK = 207 UNNEST_FROM_BLOCK = 208 HOIST = 209 SINK = 210 # === Level 3: Semantic — meaning-level transforms (300-399) === RENAME = 300 RETYPE = 301 CONVERT_CONSTRUCT = 302 SYNC_TO_ASYNC = 303 PARAMETERIZE = 304 SPECIALIZE = 305 GUARD = 306 UNGUARD = 307 SCATTER = 308 GATHER = 309 MIRROR = 310 COMPOSE = 311 # Backward compat aliases for old OpCode values _LEGACY_OPCODES = { 150: OpCode.DELETE_RANGE, 151: OpCode.INSERT_TOKEN, 152: OpCode.REPLACE_TOKEN, 153: OpCode.SWAP_TOKENS, 499: OpCode.DONE, } # ---- Op metadata for classification ---- OP_LEVEL = {} for op in OpCode: v = op.value if v < 100: OP_LEVEL[op] = 'control' elif v < 200: OP_LEVEL[op] = 'primitive' elif v < 300: OP_LEVEL[op] = 'structural' else: OP_LEVEL[op] = 'semantic' # Canonical list of the 31 trainable ops (excludes control) TRAINABLE_OPS: List[OpCode] = [op for op in OpCode if OP_LEVEL[op] != 'control'] NUM_OPS = len(TRAINABLE_OPS) # 31 OP_TO_IDX: Dict[OpCode, int] = {op: i for i, op in enumerate(TRAINABLE_OPS)} IDX_TO_OP: Dict[int, OpCode] = {i: op for i, op in enumerate(TRAINABLE_OPS)} @dataclass class EditAction: """ Parametric edit operation with arguments. Fields: op_id: Operation code from OpCode enum region_id: Which semantic region to operate on [0, R) i_start: Token index within region (relative addressing) i_end: End token index (for range operations, -1 if unused) payload_idx: Palette index to insert/replace (0-4095) confidence: Model confidence in [0, 1] target_region_id: Destination region for cross-region ops (-1 if same region) payload_tokens: Multi-token payload for WRAP, NEST, etc. positions: Multiple target positions for SCATTER """ op_id: int region_id: int i_start: int i_end: int payload_idx: int confidence: float = 1.0 target_region_id: int = -1 payload_tokens: List[int] = field(default_factory=list) positions: List[int] = field(default_factory=list) def __post_init__(self): assert self.op_id >= 0, f"Invalid op_id: {self.op_id}" assert self.region_id >= 0, f"Invalid region_id: {self.region_id}" assert self.i_start >= 0, f"Invalid i_start: {self.i_start}" assert self.i_end >= -1, f"Invalid i_end: {self.i_end}" if self.i_end != -1: assert self.i_end >= self.i_start, f"i_end ({self.i_end}) < i_start ({self.i_start})" assert 0 <= self.payload_idx < 4096, f"Invalid payload_idx: {self.payload_idx}" assert 0 <= self.confidence <= 1, f"Invalid confidence: {self.confidence}" # ---- Exceptions ---- class EditError(Exception): """Base class for edit errors""" pass class ScopeBalanceError(EditError): """Operation would break scope balance""" pass class InvalidPointerError(EditError): """Pointer out of bounds""" pass class RegionNotFoundError(EditError): """region_id invalid""" pass class PatternNotFoundError(EditError): """Macro pattern not found in region""" pass class CrossRegionError(EditError): """Cross-region operation failed""" pass # ---- Main Engine ---- class PaletteEditOps: """ Stateless edit operation executor — 31 operations. All methods are pure functions (no internal state). Thread-safe and deterministic. Constants: START_OF_SCOPE: 0 END_OF_SCOPE: 1 NOOP: 2 """ START_OF_SCOPE = 0 END_OF_SCOPE = 1 NOOP = 2 # Macro pattern definitions for CONVERT_CONSTRUCT MACRO_PATTERNS = { 'py_for_to_js_for': { 'pattern': [20, 220, 220], 'target': [20, 201, 220], 'name': 'Python for → JavaScript for' }, } # ------------------------------------------------------------------ # # Main dispatch # # ------------------------------------------------------------------ # @staticmethod def apply( palette_img: torch.Tensor, action: EditAction, metadata: RegionMetadata ) -> Tuple[torch.Tensor, bool]: """ Apply edit action to palette. Returns (new_palette, success). Guarantees: original unchanged; if success=False, new == original. """ # Resolve legacy OpCode values op_id = _LEGACY_OPCODES.get(action.op_id, action.op_id) if action.region_id >= len(metadata.starts): return palette_img, False palette = palette_img.clone() try: if not PaletteEditOps.verify_scope_balance(palette): raise ScopeBalanceError("Input palette has unbalanced scopes") # --- Control --- if op_id == OpCode.NO_OP: new_palette = palette # --- Level 1: Primitive --- elif op_id == OpCode.DELETE_RANGE: new_palette = PaletteEditOps.delete_range( palette, action.region_id, action.i_start, action.i_end, metadata) elif op_id == OpCode.INSERT_TOKEN: new_palette = PaletteEditOps.insert_token( palette, action.region_id, action.i_start, action.payload_idx, metadata) elif op_id == OpCode.REPLACE_TOKEN: new_palette = PaletteEditOps.replace_token( palette, action.region_id, action.i_start, action.payload_idx, metadata) elif op_id == OpCode.SWAP_TOKENS: new_palette = PaletteEditOps.swap_tokens( palette, action.region_id, action.i_start, action.i_end, metadata) elif op_id == OpCode.MOVE_RANGE: new_palette = PaletteEditOps.move_range( palette, action.region_id, action.i_start, action.i_end, action.target_region_id, action.payload_idx, metadata) elif op_id == OpCode.COPY_RANGE: new_palette = PaletteEditOps.copy_range( palette, action.region_id, action.i_start, action.i_end, action.target_region_id, action.payload_idx, metadata) elif op_id == OpCode.WRAP_SCOPE: new_palette = PaletteEditOps.wrap_scope( palette, action.region_id, action.i_start, action.i_end, metadata) elif op_id == OpCode.UNWRAP_SCOPE: new_palette = PaletteEditOps.unwrap_scope( palette, action.region_id, metadata) # --- Level 2: Structural --- elif op_id == OpCode.INDENT: new_palette = PaletteEditOps.indent( palette, action.region_id, action.i_start, action.i_end, metadata) elif op_id == OpCode.DEDENT: new_palette = PaletteEditOps.dedent( palette, action.region_id, action.i_start, action.i_end, metadata) elif op_id == OpCode.EXTRACT: new_palette = PaletteEditOps.extract( palette, action.region_id, action.i_start, action.i_end, metadata) elif op_id == OpCode.INLINE: new_palette = PaletteEditOps.inline( palette, action.region_id, action.target_region_id, metadata) elif op_id == OpCode.SPLIT_REGION: new_palette = PaletteEditOps.split_region( palette, action.region_id, action.i_start, metadata) elif op_id == OpCode.MERGE_REGIONS: new_palette = PaletteEditOps.merge_regions( palette, action.region_id, action.target_region_id, metadata) elif op_id == OpCode.REORDER: new_palette = PaletteEditOps.reorder( palette, action.region_id, action.i_start, action.i_end, metadata) elif op_id == OpCode.NEST_IN_BLOCK: new_palette = PaletteEditOps.nest_in_block( palette, action.region_id, action.i_start, action.i_end, action.payload_idx, metadata) elif op_id == OpCode.UNNEST_FROM_BLOCK: new_palette = PaletteEditOps.unnest_from_block( palette, action.region_id, metadata) elif op_id == OpCode.HOIST: new_palette = PaletteEditOps.hoist( palette, action.region_id, action.i_start, action.i_end, metadata) elif op_id == OpCode.SINK: new_palette = PaletteEditOps.sink( palette, action.region_id, action.i_start, action.i_end, action.target_region_id, metadata) # --- Level 3: Semantic --- elif op_id == OpCode.RENAME: new_palette = PaletteEditOps.rename( palette, action.region_id, action.i_start, action.payload_idx, metadata) elif op_id == OpCode.RETYPE: new_palette = PaletteEditOps.retype( palette, action.region_id, action.i_start, action.i_end, action.payload_tokens, metadata) elif op_id == OpCode.CONVERT_CONSTRUCT: new_palette = PaletteEditOps.convert_construct( palette, action.region_id, action.payload_tokens, metadata) elif op_id == OpCode.SYNC_TO_ASYNC: new_palette = PaletteEditOps.sync_to_async( palette, action.region_id, metadata) elif op_id == OpCode.PARAMETERIZE: new_palette = PaletteEditOps.parameterize( palette, action.region_id, action.i_start, action.payload_idx, metadata) elif op_id == OpCode.SPECIALIZE: new_palette = PaletteEditOps.specialize( palette, action.region_id, action.i_start, action.i_end, action.payload_tokens, metadata) elif op_id == OpCode.GUARD: new_palette = PaletteEditOps.guard( palette, action.region_id, action.i_start, action.i_end, action.payload_idx, metadata) elif op_id == OpCode.UNGUARD: new_palette = PaletteEditOps.unguard( palette, action.region_id, metadata) elif op_id == OpCode.SCATTER: new_palette = PaletteEditOps.scatter( palette, action.region_id, action.payload_idx, action.positions, metadata) elif op_id == OpCode.GATHER: new_palette = PaletteEditOps.gather( palette, action.positions, action.region_id, action.i_start, metadata) elif op_id == OpCode.MIRROR: new_palette = PaletteEditOps.mirror( palette, action.region_id, action.target_region_id, action.i_start, action.i_end, action.payload_idx, metadata) elif op_id == OpCode.COMPOSE: new_palette = PaletteEditOps.compose( palette, action.region_id, action.i_start, action.i_end, metadata) else: return palette_img, False # Post-check balance if not PaletteEditOps.verify_scope_balance(new_palette): raise ScopeBalanceError("Operation broke scope balance") return new_palette, True except EditError: return palette_img, False # ================================================================== # # LEVEL 1 — Primitive (token-level) # # ================================================================== # @staticmethod def delete_range( palette: torch.Tensor, region_id: int, i_start: int, i_end: int, metadata: RegionMetadata ) -> torch.Tensor: """Delete tokens [i_start, i_end] within region. Shift left, pad NOOP.""" H, W = palette.shape positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) if i_start < 0 or i_start >= len(positions): raise InvalidPointerError(f"i_start={i_start} out of bounds (size={len(positions)})") if i_end < i_start or i_end >= len(positions): raise InvalidPointerError(f"i_end={i_end} out of bounds") abs_positions = [positions[i] for i in range(i_start, i_end + 1)] palette_flat = palette.flatten() for pos in abs_positions: if palette_flat[pos].item() in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE): raise ScopeBalanceError("Cannot delete scope markers") delete_mask = torch.zeros(H * W, dtype=torch.bool) for pos in abs_positions: delete_mask[pos] = True kept = palette_flat[~delete_mask] pad = torch.full((H * W - len(kept),), PaletteEditOps.NOOP, dtype=palette.dtype) return torch.cat([kept, pad]).view(H, W) @staticmethod def insert_token( palette: torch.Tensor, region_id: int, i_start: int, payload_idx: int, metadata: RegionMetadata ) -> torch.Tensor: """Insert payload_idx at position i_start within region. Shift right, drop last.""" H, W = palette.shape positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) if i_start < 0 or i_start > len(positions): raise InvalidPointerError(f"i_start={i_start} out of bounds") if payload_idx in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE): raise ScopeBalanceError("Cannot insert unpaired scope marker") abs_pos = positions[i_start] if i_start < len(positions) else (positions[-1] + 1 if positions else 0) flat = palette.flatten() new_flat = torch.zeros(H * W, dtype=palette.dtype) new_flat[:abs_pos] = flat[:abs_pos] new_flat[abs_pos] = payload_idx if abs_pos < H * W - 1: new_flat[abs_pos + 1:] = flat[abs_pos:H * W - 1] return new_flat.view(H, W) @staticmethod def replace_token( palette: torch.Tensor, region_id: int, i_start: int, payload_idx: int, metadata: RegionMetadata ) -> torch.Tensor: """Replace token at i_start with payload_idx.""" H, W = palette.shape positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) if i_start < 0 or i_start >= len(positions): raise InvalidPointerError(f"i_start={i_start} out of bounds") abs_pos = positions[i_start] h, w = abs_pos // W, abs_pos % W old_value = palette[h, w].item() is_old_scope = old_value in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE) is_new_scope = payload_idx in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE) if is_old_scope and not is_new_scope: raise ScopeBalanceError("Cannot replace scope marker with non-marker") if is_old_scope and is_new_scope and old_value != payload_idx: raise ScopeBalanceError("Cannot replace START with END or vice versa") new_palette = palette.clone() new_palette[h, w] = payload_idx return new_palette @staticmethod def swap_tokens( palette: torch.Tensor, region_id: int, i_start: int, i_end: int, metadata: RegionMetadata ) -> torch.Tensor: """Swap tokens at i_start and i_end within region.""" H, W = palette.shape positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) if i_start < 0 or i_start >= len(positions): raise InvalidPointerError(f"i_start={i_start} out of bounds") if i_end < 0 or i_end >= len(positions): raise InvalidPointerError(f"i_end={i_end} out of bounds") p1, p2 = positions[i_start], positions[i_end] h1, w1 = p1 // W, p1 % W h2, w2 = p2 // W, p2 % W v1, v2 = palette[h1, w1].item(), palette[h2, w2].item() if {v1, v2} == {PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE}: raise ScopeBalanceError("Cannot swap START ↔ END") new_palette = palette.clone() new_palette[h1, w1], new_palette[h2, w2] = palette[h2, w2], palette[h1, w1] return new_palette @staticmethod def move_range( palette: torch.Tensor, src_region: int, i_start: int, i_end: int, dst_region: int, dst_pos: int, metadata: RegionMetadata ) -> torch.Tensor: """Move tokens [i_start,i_end] from src_region to dst_pos in dst_region. = copy + delete source. Cross-region cut-paste.""" if dst_region < 0: dst_region = src_region # First copy, then delete from source result = PaletteEditOps.copy_range( palette, src_region, i_start, i_end, dst_region, dst_pos, metadata) # After copy, source positions shifted — recalculate metadata on new palette # For correctness, we delete from original positions in the post-copy palette. # The copy inserted (i_end - i_start + 1) tokens into dst, which may shift src positions. # Simplification: if same region, account for shift; if different, positions unchanged. n_copied = i_end - i_start + 1 src_positions = PaletteEditOps._get_content_positions(result, metadata, src_region) # Find the original source tokens by value matching orig_positions = PaletteEditOps._get_content_positions(palette, metadata, src_region) orig_flat = palette.flatten() src_values = [orig_flat[orig_positions[i]].item() for i in range(i_start, i_end + 1)] # Delete from result: find matching tokens in src region result_flat = result.flatten() H, W = result.shape delete_mask = torch.zeros(H * W, dtype=torch.bool) deleted = 0 for pos in src_positions: val = result_flat[pos].item() if deleted < n_copied and val == src_values[deleted]: if val not in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE): delete_mask[pos] = True deleted += 1 if deleted == 0: raise InvalidPointerError("Could not locate source tokens for deletion") kept = result_flat[~delete_mask] pad = torch.full((H * W - len(kept),), PaletteEditOps.NOOP, dtype=palette.dtype) return torch.cat([kept, pad]).view(H, W) @staticmethod def copy_range( palette: torch.Tensor, src_region: int, i_start: int, i_end: int, dst_region: int, dst_pos: int, metadata: RegionMetadata ) -> torch.Tensor: """Copy tokens [i_start,i_end] from src_region, insert at dst_pos in dst_region.""" if dst_region < 0: dst_region = src_region src_positions = PaletteEditOps._get_content_positions(palette, metadata, src_region) if i_start < 0 or i_end >= len(src_positions): raise InvalidPointerError(f"Source range [{i_start},{i_end}] out of bounds") flat = palette.flatten() copied_tokens = [flat[src_positions[i]].item() for i in range(i_start, i_end + 1)] # Validate: no scope markers in copy for t in copied_tokens: if t in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE): raise ScopeBalanceError("Cannot copy scope markers without pairing") dst_positions = PaletteEditOps._get_content_positions(palette, metadata, dst_region) if dst_pos < 0 or dst_pos > len(dst_positions): raise InvalidPointerError(f"dst_pos={dst_pos} out of bounds") abs_dst = dst_positions[dst_pos] if dst_pos < len(dst_positions) else ( dst_positions[-1] + 1 if dst_positions else 0) H, W = palette.shape new_flat = torch.full((H * W,), PaletteEditOps.NOOP, dtype=palette.dtype) n = len(copied_tokens) new_flat[:abs_dst] = flat[:abs_dst] for i, t in enumerate(copied_tokens): if abs_dst + i < H * W: new_flat[abs_dst + i] = t remaining = min(H * W - abs_dst - n, H * W - abs_dst) if remaining > 0: new_flat[abs_dst + n:abs_dst + n + remaining] = flat[abs_dst:abs_dst + remaining] return new_flat.view(H, W) @staticmethod def wrap_scope( palette: torch.Tensor, region_id: int, i_start: int, i_end: int, metadata: RegionMetadata ) -> torch.Tensor: """Wrap tokens [i_start,i_end] in new scope markers (START...END). Inserts START before i_start, END after i_end. Balanced by construction.""" H, W = palette.shape positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) if i_start < 0 or i_end >= len(positions) or i_end < i_start: raise InvalidPointerError(f"Range [{i_start},{i_end}] invalid") abs_start = positions[i_start] abs_end = positions[i_end] flat = palette.flatten() # Insert START before abs_start, END after abs_end # Build new sequence: [before_start] + [START] + [start..end] + [END] + [after_end] before = flat[:abs_start].tolist() wrapped = flat[abs_start:abs_end + 1].tolist() after = flat[abs_end + 1:].tolist() new_seq = before + [PaletteEditOps.START_OF_SCOPE] + wrapped + [PaletteEditOps.END_OF_SCOPE] + after # Truncate or pad to H*W if len(new_seq) > H * W: new_seq = new_seq[:H * W] elif len(new_seq) < H * W: new_seq.extend([PaletteEditOps.NOOP] * (H * W - len(new_seq))) return torch.tensor(new_seq, dtype=palette.dtype).view(H, W) @staticmethod def unwrap_scope( palette: torch.Tensor, region_id: int, metadata: RegionMetadata ) -> torch.Tensor: """Remove the outermost scope markers of a region. Content preserved, scope removed. Removes START at region start and END at region end.""" H, W = palette.shape flat = palette.flatten() start_pos = metadata.starts[region_id] end_pos = metadata.ends[region_id] # Verify markers exist if flat[start_pos].item() != PaletteEditOps.START_OF_SCOPE: raise ScopeBalanceError("Region start is not START_OF_SCOPE") if flat[end_pos].item() != PaletteEditOps.END_OF_SCOPE: raise ScopeBalanceError("Region end is not END_OF_SCOPE") # Remove both markers delete_mask = torch.zeros(H * W, dtype=torch.bool) delete_mask[start_pos] = True delete_mask[end_pos] = True kept = flat[~delete_mask] pad = torch.full((H * W - len(kept),), PaletteEditOps.NOOP, dtype=palette.dtype) return torch.cat([kept, pad]).view(H, W) # ================================================================== # # LEVEL 2 — Structural (scope-tree mutations) # # ================================================================== # @staticmethod def indent( palette: torch.Tensor, region_id: int, i_start: int, i_end: int, metadata: RegionMetadata ) -> torch.Tensor: """Increase scope depth: wrap [i_start,i_end] in new scope. Equivalent to wrap_scope — increases nesting by 1.""" return PaletteEditOps.wrap_scope(palette, region_id, i_start, i_end, metadata) @staticmethod def dedent( palette: torch.Tensor, region_id: int, i_start: int, i_end: int, metadata: RegionMetadata ) -> torch.Tensor: """Decrease scope depth: remove innermost scope around [i_start,i_end]. Finds the tightest enclosing scope and removes its markers.""" H, W = palette.shape flat = palette.flatten() positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) if i_start < 0 or i_end >= len(positions): raise InvalidPointerError(f"Range [{i_start},{i_end}] invalid") abs_start = positions[i_start] abs_end = positions[i_end] # Walk outward from abs_start to find enclosing START enclosing_start = -1 for i in range(abs_start - 1, -1, -1): if flat[i].item() == PaletteEditOps.START_OF_SCOPE: enclosing_start = i break if enclosing_start < 0: raise ScopeBalanceError("No enclosing scope to dedent from") # Find matching END depth = 0 enclosing_end = -1 for i in range(enclosing_start, H * W): v = flat[i].item() if v == PaletteEditOps.START_OF_SCOPE: depth += 1 elif v == PaletteEditOps.END_OF_SCOPE: depth -= 1 if depth == 0: enclosing_end = i break if enclosing_end < 0 or enclosing_end < abs_end: raise ScopeBalanceError("Cannot find matching END for enclosing scope") # Remove the enclosing pair delete_mask = torch.zeros(H * W, dtype=torch.bool) delete_mask[enclosing_start] = True delete_mask[enclosing_end] = True kept = flat[~delete_mask] pad = torch.full((H * W - len(kept),), PaletteEditOps.NOOP, dtype=palette.dtype) return torch.cat([kept, pad]).view(H, W) @staticmethod def extract( palette: torch.Tensor, region_id: int, i_start: int, i_end: int, metadata: RegionMetadata ) -> torch.Tensor: """Extract tokens [i_start,i_end] into a new scope appended after current region. Source range replaced with a reference token (payload_idx=3 = EXTRACTED_REF). New scope with extracted content appears after current region's END.""" EXTRACTED_REF = 3 # Sentinel: "content was extracted here" H, W = palette.shape flat = palette.flatten() positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) if i_start < 0 or i_end >= len(positions) or i_end < i_start: raise InvalidPointerError(f"Range [{i_start},{i_end}] invalid") # Grab tokens to extract extracted = [flat[positions[i]].item() for i in range(i_start, i_end + 1)] for t in extracted: if t in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE): raise ScopeBalanceError("Cannot extract scope markers") # Replace source range with single ref token abs_start = positions[i_start] abs_end = positions[i_end] before = flat[:abs_start].tolist() after = flat[abs_end + 1:].tolist() middle = [EXTRACTED_REF] # Find insertion point: after region's END marker region_end = metadata.ends[region_id] # Adjust region_end for removed tokens n_removed = (i_end - i_start + 1) adj_end = region_end - n_removed + 1 # +1 for ref token new_scope = [PaletteEditOps.START_OF_SCOPE] + extracted + [PaletteEditOps.END_OF_SCOPE] seq = before + middle + after # Insert new scope after adjusted region end insert_at = min(adj_end + 1, len(seq)) seq = seq[:insert_at] + new_scope + seq[insert_at:] # Truncate/pad if len(seq) > H * W: seq = seq[:H * W] else: seq.extend([PaletteEditOps.NOOP] * (H * W - len(seq))) return torch.tensor(seq, dtype=palette.dtype).view(H, W) @staticmethod def inline( palette: torch.Tensor, src_region: int, target_region: int, metadata: RegionMetadata ) -> torch.Tensor: """Inline: replace a reference in target_region with contents of src_region. Opposite of extract. Removes src_region scope, inserts content at ref position.""" if target_region < 0: raise CrossRegionError("target_region_id required for INLINE") EXTRACTED_REF = 3 H, W = palette.shape flat = palette.flatten() # Get src region content (excluding scope markers) src_content = PaletteEditOps._get_content_positions(palette, metadata, src_region) content_tokens = [flat[pos].item() for pos in src_content] # Find ref token in target region target_positions = PaletteEditOps._get_content_positions(palette, metadata, target_region) ref_pos = -1 for pos in target_positions: if flat[pos].item() == EXTRACTED_REF: ref_pos = pos break if ref_pos < 0: raise PatternNotFoundError("No EXTRACTED_REF found in target region") # Remove src region entirely (with scope markers) src_start = metadata.starts[src_region] src_end = metadata.ends[src_region] seq = flat.tolist() # Replace ref with content ref_idx = seq.index(EXTRACTED_REF) if EXTRACTED_REF in seq else ref_pos seq = seq[:ref_idx] + content_tokens + seq[ref_idx + 1:] # Remove src scope markers and content # Recalculate positions after insertion # Simpler: remove src region range from original, then insert content at ref # Let's rebuild from scratch flat_list = flat.tolist() # Step 1: mark src region for removal remove = set(range(src_start, src_end + 1)) cleaned = [(i, v) for i, v in enumerate(flat_list) if i not in remove] # Step 2: find ref in cleaned sequence and replace new_seq = [] for _, v in cleaned: if v == EXTRACTED_REF: new_seq.extend(content_tokens) else: new_seq.append(v) if len(new_seq) > H * W: new_seq = new_seq[:H * W] else: new_seq.extend([PaletteEditOps.NOOP] * (H * W - len(new_seq))) return torch.tensor(new_seq, dtype=palette.dtype).view(H, W) @staticmethod def split_region( palette: torch.Tensor, region_id: int, split_at: int, metadata: RegionMetadata ) -> torch.Tensor: """Split region into two at position split_at. Inserts END + START between positions split_at-1 and split_at.""" H, W = palette.shape positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) if split_at <= 0 or split_at >= len(positions): raise InvalidPointerError(f"split_at={split_at} must be in (0, {len(positions)})") abs_split = positions[split_at] flat = palette.flatten() before = flat[:abs_split].tolist() after = flat[abs_split:].tolist() # Insert END then START to create two regions new_seq = before + [PaletteEditOps.END_OF_SCOPE, PaletteEditOps.START_OF_SCOPE] + after if len(new_seq) > H * W: new_seq = new_seq[:H * W] else: new_seq.extend([PaletteEditOps.NOOP] * (H * W - len(new_seq))) return torch.tensor(new_seq, dtype=palette.dtype).view(H, W) @staticmethod def merge_regions( palette: torch.Tensor, region_a: int, region_b: int, metadata: RegionMetadata ) -> torch.Tensor: """Merge two adjacent regions by removing the END of A and START of B. Regions must be adjacent (A's END directly before B's START).""" if region_b < 0: raise CrossRegionError("target_region_id required for MERGE") H, W = palette.shape flat = palette.flatten() end_a = metadata.ends[region_a] start_b = metadata.starts[region_b] # Verify adjacency if flat[end_a].item() != PaletteEditOps.END_OF_SCOPE: raise ScopeBalanceError("Region A end is not END_OF_SCOPE") if flat[start_b].item() != PaletteEditOps.START_OF_SCOPE: raise ScopeBalanceError("Region B start is not START_OF_SCOPE") # Remove both markers delete_mask = torch.zeros(H * W, dtype=torch.bool) delete_mask[end_a] = True delete_mask[start_b] = True kept = flat[~delete_mask] pad = torch.full((H * W - len(kept),), PaletteEditOps.NOOP, dtype=palette.dtype) return torch.cat([kept, pad]).view(H, W) @staticmethod def reorder( palette: torch.Tensor, region_id: int, i_start: int, i_end: int, metadata: RegionMetadata ) -> torch.Tensor: """Reverse the order of tokens [i_start,i_end] within region. Generalizable to arbitrary permutations via payload_tokens.""" H, W = palette.shape positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) if i_start < 0 or i_end >= len(positions) or i_end < i_start: raise InvalidPointerError(f"Range [{i_start},{i_end}] invalid") flat = palette.flatten() new_palette = palette.clone() new_flat = new_palette.flatten() # Reverse the range vals = [flat[positions[i]].item() for i in range(i_start, i_end + 1)] vals.reverse() for i, val in enumerate(vals): pos = positions[i_start + i] new_flat[pos] = val return new_flat.view(H, W) @staticmethod def nest_in_block( palette: torch.Tensor, region_id: int, i_start: int, i_end: int, block_type_hue: int, metadata: RegionMetadata ) -> torch.Tensor: """Wrap [i_start,i_end] in a new control block (if/for/try/function). Inserts: START + block_type_hue + [content] + END. The block_type_hue identifies the construct type (20=for, 24=if, etc.).""" H, W = palette.shape positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) if i_start < 0 or i_end >= len(positions) or i_end < i_start: raise InvalidPointerError(f"Range [{i_start},{i_end}] invalid") abs_start = positions[i_start] abs_end = positions[i_end] flat = palette.flatten() before = flat[:abs_start].tolist() content = flat[abs_start:abs_end + 1].tolist() after = flat[abs_end + 1:].tolist() new_seq = (before + [PaletteEditOps.START_OF_SCOPE, block_type_hue] + content + [PaletteEditOps.END_OF_SCOPE] + after) if len(new_seq) > H * W: new_seq = new_seq[:H * W] else: new_seq.extend([PaletteEditOps.NOOP] * (H * W - len(new_seq))) return torch.tensor(new_seq, dtype=palette.dtype).view(H, W) @staticmethod def unnest_from_block( palette: torch.Tensor, region_id: int, metadata: RegionMetadata ) -> torch.Tensor: """Remove control block scope: remove START, block_type token, and matching END. Content is preserved and lifted to parent scope.""" H, W = palette.shape flat = palette.flatten() start_pos = metadata.starts[region_id] end_pos = metadata.ends[region_id] if flat[start_pos].item() != PaletteEditOps.START_OF_SCOPE: raise ScopeBalanceError("Region start is not START_OF_SCOPE") if flat[end_pos].item() != PaletteEditOps.END_OF_SCOPE: raise ScopeBalanceError("Region end is not END_OF_SCOPE") # Remove START, the token immediately after START (block type), and END delete_mask = torch.zeros(H * W, dtype=torch.bool) delete_mask[start_pos] = True if start_pos + 1 < H * W: delete_mask[start_pos + 1] = True # block type hue delete_mask[end_pos] = True kept = flat[~delete_mask] pad = torch.full((H * W - len(kept),), PaletteEditOps.NOOP, dtype=palette.dtype) return torch.cat([kept, pad]).view(H, W) @staticmethod def hoist( palette: torch.Tensor, region_id: int, i_start: int, i_end: int, metadata: RegionMetadata ) -> torch.Tensor: """Hoist: move tokens [i_start,i_end] from current region to before region's START. Declaration moves to higher scope.""" H, W = palette.shape flat = palette.flatten() positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) if i_start < 0 or i_end >= len(positions) or i_end < i_start: raise InvalidPointerError(f"Range [{i_start},{i_end}] invalid") # Extract tokens hoisted = [flat[positions[i]].item() for i in range(i_start, i_end + 1)] for t in hoisted: if t in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE): raise ScopeBalanceError("Cannot hoist scope markers") # Remove from current positions abs_positions = [positions[i] for i in range(i_start, i_end + 1)] delete_mask = torch.zeros(H * W, dtype=torch.bool) for pos in abs_positions: delete_mask[pos] = True cleaned = flat[~delete_mask].tolist() # Insert before region's START marker region_start = metadata.starts[region_id] # Adjust for deletions before region_start adj = sum(1 for p in abs_positions if p < region_start) insert_at = region_start - adj new_seq = cleaned[:insert_at] + hoisted + cleaned[insert_at:] if len(new_seq) > H * W: new_seq = new_seq[:H * W] else: new_seq.extend([PaletteEditOps.NOOP] * (H * W - len(new_seq))) return torch.tensor(new_seq, dtype=palette.dtype).view(H, W) @staticmethod def sink( palette: torch.Tensor, region_id: int, i_start: int, i_end: int, target_region: int, metadata: RegionMetadata ) -> torch.Tensor: """Sink: move tokens from current region into a deeper (child) region. Opposite of hoist. Tokens move from parent scope to target child scope.""" if target_region < 0: raise CrossRegionError("target_region_id required for SINK") # This is a move from region_id to target_region return PaletteEditOps.move_range( palette, region_id, i_start, i_end, target_region, 0, metadata) # ================================================================== # # LEVEL 3 — Semantic (meaning-level transforms) # # ================================================================== # @staticmethod def rename( palette: torch.Tensor, region_id: int, i_start: int, new_hue: int, metadata: RegionMetadata ) -> torch.Tensor: """Rename: replace identifier hue at i_start with new_hue. Same as REPLACE but semantically constrained to identifiers.""" return PaletteEditOps.replace_token(palette, region_id, i_start, new_hue, metadata) @staticmethod def retype( palette: torch.Tensor, region_id: int, i_start: int, i_end: int, new_type_tokens: List[int], metadata: RegionMetadata ) -> torch.Tensor: """Retype: replace type annotation range [i_start,i_end] with new tokens. Handles type annotations that may change length (int → List[int]).""" H, W = palette.shape positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) if i_start < 0 or i_end >= len(positions) or i_end < i_start: raise InvalidPointerError(f"Range [{i_start},{i_end}] invalid") abs_start = positions[i_start] abs_end = positions[i_end] flat = palette.flatten() before = flat[:abs_start].tolist() after = flat[abs_end + 1:].tolist() new_seq = before + list(new_type_tokens) + after if len(new_seq) > H * W: new_seq = new_seq[:H * W] else: new_seq.extend([PaletteEditOps.NOOP] * (H * W - len(new_seq))) return torch.tensor(new_seq, dtype=palette.dtype).view(H, W) @staticmethod def convert_construct( palette: torch.Tensor, region_id: int, pattern_target: List[int], metadata: RegionMetadata ) -> torch.Tensor: """Convert construct: pattern match and replace within region. pattern_target = [*pattern_tokens, -1, *target_tokens] where -1 is separator. If empty, falls back to built-in MACRO_PATTERNS.""" if not pattern_target: # Use first built-in pattern macro = list(PaletteEditOps.MACRO_PATTERNS.values())[0] pattern = macro['pattern'] target = macro['target'] else: if -1 not in pattern_target: raise PatternNotFoundError("pattern_target must contain -1 separator") sep = pattern_target.index(-1) pattern = pattern_target[:sep] target = pattern_target[sep + 1:] positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) flat = palette.flatten() region_tokens = [flat[pos].item() for pos in positions] plen = len(pattern) found = False new_palette = palette.clone() for i in range(len(region_tokens) - plen + 1): if region_tokens[i:i + plen] == pattern: # Replace with target (may differ in length) if len(target) == plen: # Same length: direct replacement for j, t in enumerate(target): pos = positions[i + j] h, w = pos // palette.shape[1], pos % palette.shape[1] new_palette[h, w] = t else: # Different length: rebuild sequence abs_start = positions[i] abs_end = positions[i + plen - 1] H, W = palette.shape flat_list = flat.tolist() new_seq = flat_list[:abs_start] + target + flat_list[abs_end + 1:] if len(new_seq) > H * W: new_seq = new_seq[:H * W] else: new_seq.extend([PaletteEditOps.NOOP] * (H * W - len(new_seq))) new_palette = torch.tensor(new_seq, dtype=palette.dtype).view(H, W) found = True break if not found: raise PatternNotFoundError(f"Pattern {pattern} not found in region") return new_palette @staticmethod def sync_to_async( palette: torch.Tensor, region_id: int, metadata: RegionMetadata ) -> torch.Tensor: """Add async/await markers to region. Inserts async hue (hue 46) before region's first function-def token (hue 12), and await hue (hue 47) before call tokens (hue 60).""" ASYNC_HUE = 46 AWAIT_HUE = 47 FUNC_DEF_HUE = 12 CALL_HUE = 60 H, W = palette.shape positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) flat = palette.flatten() insertions = [] # (abs_pos, hue_to_insert) for pos in positions: val = flat[pos].item() if val == FUNC_DEF_HUE: insertions.append((pos, ASYNC_HUE)) elif val == CALL_HUE: insertions.append((pos, AWAIT_HUE)) if not insertions: raise PatternNotFoundError("No function defs or calls found to make async") # Build new sequence with insertions seq = flat.tolist() offset = 0 for abs_pos, hue in sorted(insertions): seq.insert(abs_pos + offset, hue) offset += 1 if len(seq) > H * W: seq = seq[:H * W] else: seq.extend([PaletteEditOps.NOOP] * (H * W - len(seq))) return torch.tensor(seq, dtype=palette.dtype).view(H, W) @staticmethod def parameterize( palette: torch.Tensor, region_id: int, i_start: int, param_hue: int, metadata: RegionMetadata ) -> torch.Tensor: """Replace a hardcoded literal at i_start with a parameter reference (param_hue). The literal hue becomes a variable/parameter hue.""" return PaletteEditOps.replace_token(palette, region_id, i_start, param_hue, metadata) @staticmethod def specialize( palette: torch.Tensor, region_id: int, i_start: int, i_end: int, concrete_tokens: List[int], metadata: RegionMetadata ) -> torch.Tensor: """Replace generic type tokens [i_start,i_end] with concrete specialization. Opposite of parameterize for types: List[T] → List[int].""" return PaletteEditOps.retype(palette, region_id, i_start, i_end, concrete_tokens, metadata) @staticmethod def guard( palette: torch.Tensor, region_id: int, i_start: int, i_end: int, guard_hue: int, metadata: RegionMetadata ) -> torch.Tensor: """Wrap [i_start,i_end] in a conditional guard (if/try/etc). Like nest_in_block with a guard-specific hue.""" return PaletteEditOps.nest_in_block( palette, region_id, i_start, i_end, guard_hue, metadata) @staticmethod def unguard( palette: torch.Tensor, region_id: int, metadata: RegionMetadata ) -> torch.Tensor: """Remove conditional guard from region. Content lifted to parent scope. Like unnest_from_block.""" return PaletteEditOps.unnest_from_block(palette, region_id, metadata) @staticmethod def scatter( palette: torch.Tensor, region_id: int, new_hue: int, target_positions: List[int], metadata: RegionMetadata ) -> torch.Tensor: """Replace token at multiple positions with new_hue. Same change applied to N locations (rename-all, update-all-call-sites).""" positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) new_palette = palette.clone() H, W = palette.shape for pos_idx in target_positions: if pos_idx < 0 or pos_idx >= len(positions): raise InvalidPointerError(f"Position {pos_idx} out of bounds") abs_pos = positions[pos_idx] h, w = abs_pos // W, abs_pos % W val = new_palette[h, w].item() if val in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE): raise ScopeBalanceError("Cannot scatter over scope markers") new_palette[h, w] = new_hue return new_palette @staticmethod def gather( palette: torch.Tensor, source_positions: List[int], target_region: int, target_pos: int, metadata: RegionMetadata ) -> torch.Tensor: """Gather: collect tokens from multiple positions into a single location. Tokens at source_positions are removed and concatenated at target_pos in target_region. Opposite of scatter.""" H, W = palette.shape flat = palette.flatten() # Collect values from source positions (these are region-relative in first region) # source_positions are absolute flat indices for simplicity gathered_vals = [] for pos in source_positions: if pos < 0 or pos >= H * W: raise InvalidPointerError(f"Source position {pos} out of bounds") val = flat[pos].item() if val in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE): raise ScopeBalanceError("Cannot gather scope markers") gathered_vals.append(val) # Remove source positions delete_mask = torch.zeros(H * W, dtype=torch.bool) for pos in source_positions: delete_mask[pos] = True cleaned = flat[~delete_mask].tolist() # Insert gathered values at target position in target region target_positions_list = PaletteEditOps._get_content_positions(palette, metadata, target_region) if target_pos < 0 or target_pos > len(target_positions_list): raise InvalidPointerError(f"target_pos={target_pos} out of bounds") # Adjust target pos for deletions before it abs_target = target_positions_list[target_pos] if target_pos < len(target_positions_list) else ( target_positions_list[-1] + 1 if target_positions_list else 0) adj = sum(1 for p in source_positions if p < abs_target) abs_target -= adj new_seq = cleaned[:abs_target] + gathered_vals + cleaned[abs_target:] if len(new_seq) > H * W: new_seq = new_seq[:H * W] else: new_seq.extend([PaletteEditOps.NOOP] * (H * W - len(new_seq))) return torch.tensor(new_seq, dtype=palette.dtype).view(H, W) @staticmethod def mirror( palette: torch.Tensor, region_a: int, region_b: int, i_start: int, i_end: int, new_hue: int, metadata: RegionMetadata ) -> torch.Tensor: """Apply symmetric change to paired regions A and B. Replace tokens at [i_start,i_end] in BOTH regions with new_hue. For getter/setter pairs, request/response symmetry, etc.""" if region_b < 0: raise CrossRegionError("target_region_id required for MIRROR") new_palette = palette.clone() H, W = palette.shape for rid in [region_a, region_b]: positions = PaletteEditOps._get_content_positions(new_palette, metadata, rid) if i_start < 0 or i_end >= len(positions): raise InvalidPointerError(f"Range [{i_start},{i_end}] out of bounds in region {rid}") for i in range(i_start, i_end + 1): abs_pos = positions[i] h, w = abs_pos // W, abs_pos % W val = new_palette[h, w].item() if val in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE): raise ScopeBalanceError("Cannot mirror over scope markers") new_palette[h, w] = new_hue return new_palette @staticmethod def compose( palette: torch.Tensor, region_id: int, i_start: int, i_end: int, metadata: RegionMetadata ) -> torch.Tensor: """Compose: fuse sequential statements [i_start,i_end] into a single expression. Removes intermediate scope boundaries within the range. Tokens are kept, internal START/END pairs are removed.""" H, W = palette.shape flat = palette.flatten() # Get ALL positions in region (including scope markers for this op) mask = metadata.masks[region_id] all_positions = mask.nonzero(as_tuple=False) all_flat = sorted((all_positions[:, 0] * W + all_positions[:, 1]).tolist()) if i_start < 0 or i_end >= len(all_flat) or i_end < i_start: raise InvalidPointerError(f"Range [{i_start},{i_end}] invalid") # Within range, remove internal START/END pairs (not outermost) range_positions = all_flat[i_start:i_end + 1] delete_mask = torch.zeros(H * W, dtype=torch.bool) # Find internal scope markers (not the first START or last END) depth = 0 for pos in range_positions: val = flat[pos].item() if val == PaletteEditOps.START_OF_SCOPE: depth += 1 if depth > 1: # Internal delete_mask[pos] = True elif val == PaletteEditOps.END_OF_SCOPE: if depth > 1: # Internal delete_mask[pos] = True depth -= 1 kept = flat[~delete_mask] pad = torch.full((H * W - len(kept),), PaletteEditOps.NOOP, dtype=palette.dtype) return torch.cat([kept, pad]).view(H, W) # ================================================================== # # Helpers # # ================================================================== # @staticmethod def verify_scope_balance(palette: torch.Tensor) -> bool: """Check START_OF_SCOPE count == END_OF_SCOPE count.""" num_starts = (palette == PaletteEditOps.START_OF_SCOPE).sum().item() num_ends = (palette == PaletteEditOps.END_OF_SCOPE).sum().item() return num_starts == num_ends @staticmethod def _get_content_positions( palette: torch.Tensor, metadata: RegionMetadata, region_id: int ) -> List[int]: """Get flattened positions of content tokens in region (excluding scope markers).""" H, W = palette.shape mask = metadata.masks[region_id] positions = mask.nonzero(as_tuple=False) flat_positions = (positions[:, 0] * W + positions[:, 1]).tolist() filtered = [] for pos in flat_positions: h, w = pos // W, pos % W token = palette[h, w].item() if token not in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE): filtered.append(pos) return sorted(filtered) @staticmethod def _get_region_positions(mask: torch.Tensor, W: int, palette: torch.Tensor = None) -> List[int]: """Legacy helper — use _get_content_positions instead.""" positions = mask.nonzero(as_tuple=False) flat_positions = (positions[:, 0] * W + positions[:, 1]).tolist() if palette is not None: filtered = [] for pos in flat_positions: h, w = pos // W, pos % W token = palette[h, w].item() if token not in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE): filtered.append(pos) return sorted(filtered) return sorted(flat_positions)