| """ |
| TRM Solver β NN Executor for ARC-AGI NeuroGolf |
| |
| Takes a parsed transform (from Kilo/DeepSeek) and executes it as a |
| tiny neural network. Each transform is implemented as a minimal NN |
| that can be exported to ONNX. |
| |
| Architecture: |
| - Each transform is a PyTorch nn.Module with frozen weights |
| - Weights encode the transform parameters (not learned β set directly) |
| - ONNX export produces a tiny model per task |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from typing import Dict, List, Tuple, Optional |
| from dataclasses import dataclass |
| import json |
|
|
|
|
| |
|
|
| @dataclass |
| class TransformSpec: |
| """Parsed output from Kilo/DeepSeek.""" |
| name: str |
| params: Dict |
| objects: List[Dict] = None |
|
|
| def __post_init__(self): |
| if self.objects is None: |
| self.objects = [] |
|
|
|
|
| |
|
|
| class BaseTransformNN(nn.Module): |
| """Base class for all transform NNs. Subclasses implement _forward_impl.""" |
| |
| def __init__(self, spec: TransformSpec): |
| super().__init__() |
| self.spec = spec |
| self.max_size = 30 |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| x: Input grid [B, 1, H, W] or [1, H, W] β values 0-9 |
| Returns: |
| Output grid [B, 1, H_out, W_out] β values 0-9 |
| """ |
| if x.dim() == 3: |
| x = x.unsqueeze(0) |
| if x.dim() == 2: |
| x = x.unsqueeze(0).unsqueeze(0) |
| return self._forward_impl(x) |
| |
| def _forward_impl(self, x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError |
| |
| def count_params(self) -> int: |
| return sum(p.numel() for p in self.parameters()) |
|
|
|
|
| |
|
|
| class IdentityNN(BaseTransformNN): |
| """Output equals input.""" |
| def _forward_impl(self, x): |
| return x |
|
|
|
|
| |
|
|
| class ColorMapNN(BaseTransformNN): |
| """Per-pixel color remapping. 100 params.""" |
| |
| def __init__(self, spec: TransformSpec): |
| super().__init__(spec) |
| lut = spec.params.get("color_map", list(range(10))) |
| self.lut = nn.Conv2d(10, 10, kernel_size=1, bias=False) |
| weight = torch.zeros(10, 10, 1, 1) |
| for i, j in enumerate(lut): |
| weight[j, i, 0, 0] = 1.0 |
| self.lut.weight = nn.Parameter(weight, requires_grad=False) |
| |
| def _forward_impl(self, x): |
| B, _, H, W = x.shape |
| x_flat = x.long().squeeze(1).clamp(0, 9) |
| onehot = F.one_hot(x_flat, num_classes=10).permute(0, 3, 1, 2).float() |
| out = self.lut(onehot) |
| return out.argmax(dim=1, keepdim=True).float() |
|
|
|
|
| |
|
|
| class FlipNN(BaseTransformNN): |
| def _forward_impl(self, x): |
| direction = self.spec.params.get("direction", "horizontal") |
| dim = 3 if direction == "horizontal" else 2 |
| return torch.flip(x, [dim]) |
|
|
|
|
| class TransposeNN(BaseTransformNN): |
| def _forward_impl(self, x): |
| return x.transpose(2, 3) |
|
|
|
|
| class RotateNN(BaseTransformNN): |
| def _forward_impl(self, x): |
| k = self.spec.params.get("k", 1) |
| return torch.rot90(x, k, [2, 3]) |
|
|
|
|
| |
|
|
| class UpscaleNN(BaseTransformNN): |
| """Nearest-neighbor upscaling. ~scale**2 params.""" |
| |
| def __init__(self, spec: TransformSpec): |
| super().__init__(spec) |
| self.scale = spec.params.get("scale", 2) |
| th = spec.params.get("output_shape", [30, 30])[0] |
| tw = spec.params.get("output_shape", [30, 30])[1] |
| self.target_h = th |
| self.target_w = tw |
| |
| def _forward_impl(self, x): |
| x = x.repeat_interleave(self.scale, dim=2).repeat_interleave(self.scale, dim=3) |
| return x[:, :, :self.target_h, :self.target_w] |
|
|
|
|
| |
|
|
| class KronSelfSimilarNN(BaseTransformNN): |
| """output = kron((input != 0), input). 0 learnable params.""" |
| |
| def _forward_impl(self, x): |
| mask = (x != 0).float() |
| B, _, H_in, W_in = x.shape |
| inp_e = x.unsqueeze(2).unsqueeze(2) |
| mask_e = mask.unsqueeze(4).unsqueeze(4) |
| result = (mask_e * inp_e).float() |
| result = result.permute(0, 1, 2, 4, 3, 5).contiguous() |
| H_out, W_out = H_in * H_in, W_in * W_in |
| return result.view(B, 1, H_out, W_out) |
|
|
|
|
| class TileRepeatNN(BaseTransformNN): |
| def _forward_impl(self, x): |
| hr = self.spec.params.get("h_repeat", 2) |
| wr = self.spec.params.get("w_repeat", 2) |
| return x.repeat(1, 1, hr, wr) |
|
|
|
|
| |
|
|
| class ConcatPatternsNN(BaseTransformNN): |
| """Concatenate transformed copies horizontally/vertically.""" |
| |
| def _forward_impl(self, x): |
| axis = self.spec.params.get("axis", "horizontal") |
| ops = self.spec.params.get("operations", ["identity", "identity"]) |
| pieces = [] |
| for op in ops: |
| if op == "flip_h": pieces.append(torch.flip(x, [3])) |
| elif op == "flip_v": pieces.append(torch.flip(x, [2])) |
| elif op == "transpose": pieces.append(x.transpose(2, 3)) |
| elif op == "rot90": pieces.append(torch.rot90(x, 1, [2, 3])) |
| elif op == "rot180": pieces.append(torch.rot90(x, 2, [2, 3])) |
| elif op == "rot270": pieces.append(torch.rot90(x, 3, [2, 3])) |
| else: pieces.append(x) |
| dim = 3 if axis == "horizontal" else 2 |
| return torch.cat(pieces, dim=dim) |
|
|
|
|
| |
|
|
| class PositionColorLUTNN(BaseTransformNN): |
| """Per-position color lookup. H*W params.""" |
| |
| def __init__(self, spec: TransformSpec): |
| super().__init__(spec) |
| lut = spec.params.get("lut", {}) |
| self.h_o = spec.params.get("output_shape", [30, 30])[0] |
| self.w_o = spec.params.get("output_shape", [30, 30])[1] |
| self.lut = nn.Parameter(torch.zeros(1, 1, self.h_o, self.w_o), requires_grad=False) |
| with torch.no_grad(): |
| for k, v in lut.items(): |
| h, w = map(int, k.split(",")) |
| if h < self.h_o and w < self.w_o: |
| self.lut[0, 0, h, w] = float(v) |
| |
| def _forward_impl(self, x): |
| B = x.shape[0] |
| out = self.lut.expand(B, -1, -1, -1) |
| mask = (x[:, :, :self.h_o, :self.w_o] != 0).float() |
| return mask * out |
|
|
|
|
| |
|
|
| class SpatialGatherNN(BaseTransformNN): |
| """Rearrange pixels via gather map. H*W*2 params.""" |
| |
| def __init__(self, spec: TransformSpec): |
| super().__init__(spec) |
| gmap = spec.params.get("gather_map", {}) |
| self.h_o = spec.params.get("output_shape", [30, 30])[0] |
| self.w_o = spec.params.get("output_shape", [30, 30])[1] |
| self.gh = nn.Parameter(torch.zeros(self.h_o, self.w_o, dtype=torch.long), requires_grad=False) |
| self.gw = nn.Parameter(torch.zeros(self.h_o, self.w_o, dtype=torch.long), requires_grad=False) |
| with torch.no_grad(): |
| for k, v in gmap.items(): |
| h, w = map(int, k.split(",")) |
| sh, sw = map(int, v.split(",")) |
| if h < self.h_o and w < self.w_o: |
| self.gh[h, w] = sh |
| self.gw[h, w] = sw |
| |
| def _forward_impl(self, x): |
| B, C, Hi, Wi = x.shape |
| gh = self.gh.clamp(0, Hi - 1) |
| gw = self.gw.clamp(0, Wi - 1) |
| return x[:, :, gh, gw] |
|
|
|
|
| |
|
|
| class OneHotConvNN(BaseTransformNN): |
| """One-hot encode, convolve, argmax decode. K^2*100 params.""" |
| |
| def __init__(self, spec: TransformSpec): |
| super().__init__(spec) |
| kh = spec.params.get("kernel_h", 3) |
| kw = spec.params.get("kernel_w", 3) |
| self.conv = nn.Conv2d(10, 10, kernel_size=(kh, kw), padding='same', bias=False) |
| if "weights" in spec.params: |
| w = torch.tensor(spec.params["weights"], dtype=torch.float32) |
| self.conv.weight = nn.Parameter(w.view(10, 10, kh, kw), requires_grad=False) |
| |
| def _forward_impl(self, x): |
| B, _, H, W = x.shape |
| onehot = F.one_hot(x.long().squeeze(1).clamp(0, 9), 10).permute(0, 3, 1, 2).float() |
| return self.conv(onehot).argmax(dim=1, keepdim=True).float() |
|
|
|
|
| class OneHotLinearNN(BaseTransformNN): |
| """One-hot encode, linear, argmax. 100 params.""" |
| |
| def __init__(self, spec: TransformSpec): |
| super().__init__(spec) |
| self.linear = nn.Linear(10, 10, bias=False) |
| if "weights" in spec.params: |
| self.linear.weight = nn.Parameter( |
| torch.tensor(spec.params["weights"], dtype=torch.float32), requires_grad=False) |
| |
| def _forward_impl(self, x): |
| onehot = F.one_hot(x.long().squeeze(1).clamp(0, 9), 10).float() |
| return self.linear(onehot).argmax(dim=-1).unsqueeze(1).float() |
|
|
|
|
| |
|
|
| TRANSFORM_REGISTRY = { |
| "identity": IdentityNN, |
| "color_map": ColorMapNN, |
| "flip": FlipNN, |
| "transpose": TransposeNN, |
| "rotate": RotateNN, |
| "upscale": UpscaleNN, |
| "kron_self_similar": KronSelfSimilarNN, |
| "tile_repeat": TileRepeatNN, |
| "concat_patterns": ConcatPatternsNN, |
| "pos_color_lut": PositionColorLUTNN, |
| "spatial_gather": SpatialGatherNN, |
| "onehot_conv": OneHotConvNN, |
| "onehot_linear": OneHotLinearNN, |
| } |
|
|
|
|
| def create_transform_nn(spec: TransformSpec) -> BaseTransformNN: |
| cls = TRANSFORM_REGISTRY.get(spec.name) |
| if cls is None: |
| raise ValueError(f"Unknown transform: {spec.name}") |
| return cls(spec) |
|
|
|
|
| def parse_kilo_output(md: str) -> TransformSpec: |
| """Parse Kilo markdown into TransformSpec.""" |
| lines = md.strip().split('\n') |
| name, params, section = None, {}, None |
| for line in lines: |
| line = line.strip() |
| if line.startswith('## '): |
| section = line[3:].strip().lower() |
| continue |
| if section == 'transform' and line.startswith('name:'): |
| name = line.split(':', 1)[1].strip() |
| elif section == 'parameters' and line.startswith('- '): |
| kv = line[2:].split(':', 1) |
| if len(kv) == 2: |
| k, v = kv[0].strip(), kv[1].strip() |
| try: |
| import ast |
| params[k] = ast.literal_eval(v) |
| except (ValueError, SyntaxError): |
| params[k] = v |
| if not name: |
| raise ValueError("No transform name in Kilo output") |
| return TransformSpec(name=name, params=params) |
|
|
|
|
| |
|
|
| def export_to_onnx(model: BaseTransformNN, input_shape: Tuple[int, int], |
| output_path: str, opset: int = 17): |
| model.eval() |
| H, W = input_shape |
| dummy = torch.zeros(1, 1, H, W) |
| torch.onnx.export(model, dummy, output_path, |
| input_names=["input"], output_names=["output"], |
| dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}, |
| opset_version=opset, do_constant_folding=True) |
| import os |
| kb, p = os.path.getsize(output_path) / 1024, model.count_params() |
| print(f"Exported {output_path}: {kb:.1f} KB, {p} params") |
| return output_path |