""" TRM Solver — NN Executor for ARC-AGI NeuroGolf Takes a parsed transform (from Kilo/DeepSeek) and executes it as a tiny neural network. Each transform is implemented as a minimal NN that can be exported to ONNX. Architecture: - Each transform is a PyTorch nn.Module with frozen weights - Weights encode the transform parameters (not learned — set directly) - ONNX export produces a tiny model per task """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from typing import Dict, List, Tuple, Optional from dataclasses import dataclass import json # ─── Data Structures ─────────────────────────────────────────── @dataclass class TransformSpec: """Parsed output from Kilo/DeepSeek.""" name: str params: Dict objects: List[Dict] = None def __post_init__(self): if self.objects is None: self.objects = [] # ─── Base NN Transform ───────────────────────────────────────── class BaseTransformNN(nn.Module): """Base class for all transform NNs. Subclasses implement _forward_impl.""" def __init__(self, spec: TransformSpec): super().__init__() self.spec = spec self.max_size = 30 # ARC max grid size def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: Input grid [B, 1, H, W] or [1, H, W] — values 0-9 Returns: Output grid [B, 1, H_out, W_out] — values 0-9 """ if x.dim() == 3: x = x.unsqueeze(0) if x.dim() == 2: x = x.unsqueeze(0).unsqueeze(0) return self._forward_impl(x) def _forward_impl(self, x: torch.Tensor) -> torch.Tensor: raise NotImplementedError def count_params(self) -> int: return sum(p.numel() for p in self.parameters()) # ─── Identity ────────────────────────────────────────────────── class IdentityNN(BaseTransformNN): """Output equals input.""" def _forward_impl(self, x): return x # ─── Color Map ───────────────────────────────────────────────── class ColorMapNN(BaseTransformNN): """Per-pixel color remapping. 100 params.""" def __init__(self, spec: TransformSpec): super().__init__(spec) lut = spec.params.get("color_map", list(range(10))) self.lut = nn.Conv2d(10, 10, kernel_size=1, bias=False) weight = torch.zeros(10, 10, 1, 1) for i, j in enumerate(lut): weight[j, i, 0, 0] = 1.0 self.lut.weight = nn.Parameter(weight, requires_grad=False) def _forward_impl(self, x): B, _, H, W = x.shape x_flat = x.long().squeeze(1).clamp(0, 9) onehot = F.one_hot(x_flat, num_classes=10).permute(0, 3, 1, 2).float() out = self.lut(onehot) return out.argmax(dim=1, keepdim=True).float() # ─── Geometric ───────────────────────────────────────────────── class FlipNN(BaseTransformNN): def _forward_impl(self, x): direction = self.spec.params.get("direction", "horizontal") dim = 3 if direction == "horizontal" else 2 return torch.flip(x, [dim]) class TransposeNN(BaseTransformNN): def _forward_impl(self, x): return x.transpose(2, 3) class RotateNN(BaseTransformNN): def _forward_impl(self, x): k = self.spec.params.get("k", 1) return torch.rot90(x, k, [2, 3]) # ─── Upscale ─────────────────────────────────────────────────── class UpscaleNN(BaseTransformNN): """Nearest-neighbor upscaling. ~scale**2 params.""" def __init__(self, spec: TransformSpec): super().__init__(spec) self.scale = spec.params.get("scale", 2) th = spec.params.get("output_shape", [30, 30])[0] tw = spec.params.get("output_shape", [30, 30])[1] self.target_h = th self.target_w = tw def _forward_impl(self, x): x = x.repeat_interleave(self.scale, dim=2).repeat_interleave(self.scale, dim=3) return x[:, :, :self.target_h, :self.target_w] # ─── Kronecker Self-Similar ──────────────────────────────────── class KronSelfSimilarNN(BaseTransformNN): """output = kron((input != 0), input). 0 learnable params.""" def _forward_impl(self, x): mask = (x != 0).float() B, _, H_in, W_in = x.shape inp_e = x.unsqueeze(2).unsqueeze(2) mask_e = mask.unsqueeze(4).unsqueeze(4) result = (mask_e * inp_e).float() result = result.permute(0, 1, 2, 4, 3, 5).contiguous() H_out, W_out = H_in * H_in, W_in * W_in return result.view(B, 1, H_out, W_out) class TileRepeatNN(BaseTransformNN): def _forward_impl(self, x): hr = self.spec.params.get("h_repeat", 2) wr = self.spec.params.get("w_repeat", 2) return x.repeat(1, 1, hr, wr) # ─── Concat Patterns ─────────────────────────────────────────── class ConcatPatternsNN(BaseTransformNN): """Concatenate transformed copies horizontally/vertically.""" def _forward_impl(self, x): axis = self.spec.params.get("axis", "horizontal") ops = self.spec.params.get("operations", ["identity", "identity"]) pieces = [] for op in ops: if op == "flip_h": pieces.append(torch.flip(x, [3])) elif op == "flip_v": pieces.append(torch.flip(x, [2])) elif op == "transpose": pieces.append(x.transpose(2, 3)) elif op == "rot90": pieces.append(torch.rot90(x, 1, [2, 3])) elif op == "rot180": pieces.append(torch.rot90(x, 2, [2, 3])) elif op == "rot270": pieces.append(torch.rot90(x, 3, [2, 3])) else: pieces.append(x) dim = 3 if axis == "horizontal" else 2 return torch.cat(pieces, dim=dim) # ─── Position Color LUT ──────────────────────────────────────── class PositionColorLUTNN(BaseTransformNN): """Per-position color lookup. H*W params.""" def __init__(self, spec: TransformSpec): super().__init__(spec) lut = spec.params.get("lut", {}) self.h_o = spec.params.get("output_shape", [30, 30])[0] self.w_o = spec.params.get("output_shape", [30, 30])[1] self.lut = nn.Parameter(torch.zeros(1, 1, self.h_o, self.w_o), requires_grad=False) with torch.no_grad(): for k, v in lut.items(): h, w = map(int, k.split(",")) if h < self.h_o and w < self.w_o: self.lut[0, 0, h, w] = float(v) def _forward_impl(self, x): B = x.shape[0] out = self.lut.expand(B, -1, -1, -1) mask = (x[:, :, :self.h_o, :self.w_o] != 0).float() return mask * out # ─── Spatial Gather ──────────────────────────────────────────── class SpatialGatherNN(BaseTransformNN): """Rearrange pixels via gather map. H*W*2 params.""" def __init__(self, spec: TransformSpec): super().__init__(spec) gmap = spec.params.get("gather_map", {}) self.h_o = spec.params.get("output_shape", [30, 30])[0] self.w_o = spec.params.get("output_shape", [30, 30])[1] self.gh = nn.Parameter(torch.zeros(self.h_o, self.w_o, dtype=torch.long), requires_grad=False) self.gw = nn.Parameter(torch.zeros(self.h_o, self.w_o, dtype=torch.long), requires_grad=False) with torch.no_grad(): for k, v in gmap.items(): h, w = map(int, k.split(",")) sh, sw = map(int, v.split(",")) if h < self.h_o and w < self.w_o: self.gh[h, w] = sh self.gw[h, w] = sw def _forward_impl(self, x): B, C, Hi, Wi = x.shape gh = self.gh.clamp(0, Hi - 1) gw = self.gw.clamp(0, Wi - 1) return x[:, :, gh, gw] # ─── One-Hot Convolution ─────────────────────────────────────── class OneHotConvNN(BaseTransformNN): """One-hot encode, convolve, argmax decode. K^2*100 params.""" def __init__(self, spec: TransformSpec): super().__init__(spec) kh = spec.params.get("kernel_h", 3) kw = spec.params.get("kernel_w", 3) self.conv = nn.Conv2d(10, 10, kernel_size=(kh, kw), padding='same', bias=False) if "weights" in spec.params: w = torch.tensor(spec.params["weights"], dtype=torch.float32) self.conv.weight = nn.Parameter(w.view(10, 10, kh, kw), requires_grad=False) def _forward_impl(self, x): B, _, H, W = x.shape onehot = F.one_hot(x.long().squeeze(1).clamp(0, 9), 10).permute(0, 3, 1, 2).float() return self.conv(onehot).argmax(dim=1, keepdim=True).float() class OneHotLinearNN(BaseTransformNN): """One-hot encode, linear, argmax. 100 params.""" def __init__(self, spec: TransformSpec): super().__init__(spec) self.linear = nn.Linear(10, 10, bias=False) if "weights" in spec.params: self.linear.weight = nn.Parameter( torch.tensor(spec.params["weights"], dtype=torch.float32), requires_grad=False) def _forward_impl(self, x): onehot = F.one_hot(x.long().squeeze(1).clamp(0, 9), 10).float() return self.linear(onehot).argmax(dim=-1).unsqueeze(1).float() # ─── Factory & Parser ────────────────────────────────────────── TRANSFORM_REGISTRY = { "identity": IdentityNN, "color_map": ColorMapNN, "flip": FlipNN, "transpose": TransposeNN, "rotate": RotateNN, "upscale": UpscaleNN, "kron_self_similar": KronSelfSimilarNN, "tile_repeat": TileRepeatNN, "concat_patterns": ConcatPatternsNN, "pos_color_lut": PositionColorLUTNN, "spatial_gather": SpatialGatherNN, "onehot_conv": OneHotConvNN, "onehot_linear": OneHotLinearNN, } def create_transform_nn(spec: TransformSpec) -> BaseTransformNN: cls = TRANSFORM_REGISTRY.get(spec.name) if cls is None: raise ValueError(f"Unknown transform: {spec.name}") return cls(spec) def parse_kilo_output(md: str) -> TransformSpec: """Parse Kilo markdown into TransformSpec.""" lines = md.strip().split('\n') name, params, section = None, {}, None for line in lines: line = line.strip() if line.startswith('## '): section = line[3:].strip().lower() continue if section == 'transform' and line.startswith('name:'): name = line.split(':', 1)[1].strip() elif section == 'parameters' and line.startswith('- '): kv = line[2:].split(':', 1) if len(kv) == 2: k, v = kv[0].strip(), kv[1].strip() try: import ast params[k] = ast.literal_eval(v) except (ValueError, SyntaxError): params[k] = v if not name: raise ValueError("No transform name in Kilo output") return TransformSpec(name=name, params=params) # ─── ONNX Export ─────────────────────────────────────────────── def export_to_onnx(model: BaseTransformNN, input_shape: Tuple[int, int], output_path: str, opset: int = 17): model.eval() H, W = input_shape dummy = torch.zeros(1, 1, H, W) torch.onnx.export(model, dummy, output_path, input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}, opset_version=opset, do_constant_folding=True) import os kb, p = os.path.getsize(output_path) / 1024, model.count_params() print(f"Exported {output_path}: {kb:.1f} KB, {p} params") return output_path