Roger MT
reorg kilo skills folder
9ae02a4
"""
TRM Solver β€” NN Executor for ARC-AGI NeuroGolf
Takes a parsed transform (from Kilo/DeepSeek) and executes it as a
tiny neural network. Each transform is implemented as a minimal NN
that can be exported to ONNX.
Architecture:
- Each transform is a PyTorch nn.Module with frozen weights
- Weights encode the transform parameters (not learned β€” set directly)
- ONNX export produces a tiny model per task
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
import json
# ─── Data Structures ───────────────────────────────────────────
@dataclass
class TransformSpec:
"""Parsed output from Kilo/DeepSeek."""
name: str
params: Dict
objects: List[Dict] = None
def __post_init__(self):
if self.objects is None:
self.objects = []
# ─── Base NN Transform ─────────────────────────────────────────
class BaseTransformNN(nn.Module):
"""Base class for all transform NNs. Subclasses implement _forward_impl."""
def __init__(self, spec: TransformSpec):
super().__init__()
self.spec = spec
self.max_size = 30 # ARC max grid size
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input grid [B, 1, H, W] or [1, H, W] β€” values 0-9
Returns:
Output grid [B, 1, H_out, W_out] β€” values 0-9
"""
if x.dim() == 3:
x = x.unsqueeze(0)
if x.dim() == 2:
x = x.unsqueeze(0).unsqueeze(0)
return self._forward_impl(x)
def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def count_params(self) -> int:
return sum(p.numel() for p in self.parameters())
# ─── Identity ──────────────────────────────────────────────────
class IdentityNN(BaseTransformNN):
"""Output equals input."""
def _forward_impl(self, x):
return x
# ─── Color Map ─────────────────────────────────────────────────
class ColorMapNN(BaseTransformNN):
"""Per-pixel color remapping. 100 params."""
def __init__(self, spec: TransformSpec):
super().__init__(spec)
lut = spec.params.get("color_map", list(range(10)))
self.lut = nn.Conv2d(10, 10, kernel_size=1, bias=False)
weight = torch.zeros(10, 10, 1, 1)
for i, j in enumerate(lut):
weight[j, i, 0, 0] = 1.0
self.lut.weight = nn.Parameter(weight, requires_grad=False)
def _forward_impl(self, x):
B, _, H, W = x.shape
x_flat = x.long().squeeze(1).clamp(0, 9)
onehot = F.one_hot(x_flat, num_classes=10).permute(0, 3, 1, 2).float()
out = self.lut(onehot)
return out.argmax(dim=1, keepdim=True).float()
# ─── Geometric ─────────────────────────────────────────────────
class FlipNN(BaseTransformNN):
def _forward_impl(self, x):
direction = self.spec.params.get("direction", "horizontal")
dim = 3 if direction == "horizontal" else 2
return torch.flip(x, [dim])
class TransposeNN(BaseTransformNN):
def _forward_impl(self, x):
return x.transpose(2, 3)
class RotateNN(BaseTransformNN):
def _forward_impl(self, x):
k = self.spec.params.get("k", 1)
return torch.rot90(x, k, [2, 3])
# ─── Upscale ───────────────────────────────────────────────────
class UpscaleNN(BaseTransformNN):
"""Nearest-neighbor upscaling. ~scale**2 params."""
def __init__(self, spec: TransformSpec):
super().__init__(spec)
self.scale = spec.params.get("scale", 2)
th = spec.params.get("output_shape", [30, 30])[0]
tw = spec.params.get("output_shape", [30, 30])[1]
self.target_h = th
self.target_w = tw
def _forward_impl(self, x):
x = x.repeat_interleave(self.scale, dim=2).repeat_interleave(self.scale, dim=3)
return x[:, :, :self.target_h, :self.target_w]
# ─── Kronecker Self-Similar ────────────────────────────────────
class KronSelfSimilarNN(BaseTransformNN):
"""output = kron((input != 0), input). 0 learnable params."""
def _forward_impl(self, x):
mask = (x != 0).float()
B, _, H_in, W_in = x.shape
inp_e = x.unsqueeze(2).unsqueeze(2)
mask_e = mask.unsqueeze(4).unsqueeze(4)
result = (mask_e * inp_e).float()
result = result.permute(0, 1, 2, 4, 3, 5).contiguous()
H_out, W_out = H_in * H_in, W_in * W_in
return result.view(B, 1, H_out, W_out)
class TileRepeatNN(BaseTransformNN):
def _forward_impl(self, x):
hr = self.spec.params.get("h_repeat", 2)
wr = self.spec.params.get("w_repeat", 2)
return x.repeat(1, 1, hr, wr)
# ─── Concat Patterns ───────────────────────────────────────────
class ConcatPatternsNN(BaseTransformNN):
"""Concatenate transformed copies horizontally/vertically."""
def _forward_impl(self, x):
axis = self.spec.params.get("axis", "horizontal")
ops = self.spec.params.get("operations", ["identity", "identity"])
pieces = []
for op in ops:
if op == "flip_h": pieces.append(torch.flip(x, [3]))
elif op == "flip_v": pieces.append(torch.flip(x, [2]))
elif op == "transpose": pieces.append(x.transpose(2, 3))
elif op == "rot90": pieces.append(torch.rot90(x, 1, [2, 3]))
elif op == "rot180": pieces.append(torch.rot90(x, 2, [2, 3]))
elif op == "rot270": pieces.append(torch.rot90(x, 3, [2, 3]))
else: pieces.append(x)
dim = 3 if axis == "horizontal" else 2
return torch.cat(pieces, dim=dim)
# ─── Position Color LUT ────────────────────────────────────────
class PositionColorLUTNN(BaseTransformNN):
"""Per-position color lookup. H*W params."""
def __init__(self, spec: TransformSpec):
super().__init__(spec)
lut = spec.params.get("lut", {})
self.h_o = spec.params.get("output_shape", [30, 30])[0]
self.w_o = spec.params.get("output_shape", [30, 30])[1]
self.lut = nn.Parameter(torch.zeros(1, 1, self.h_o, self.w_o), requires_grad=False)
with torch.no_grad():
for k, v in lut.items():
h, w = map(int, k.split(","))
if h < self.h_o and w < self.w_o:
self.lut[0, 0, h, w] = float(v)
def _forward_impl(self, x):
B = x.shape[0]
out = self.lut.expand(B, -1, -1, -1)
mask = (x[:, :, :self.h_o, :self.w_o] != 0).float()
return mask * out
# ─── Spatial Gather ────────────────────────────────────────────
class SpatialGatherNN(BaseTransformNN):
"""Rearrange pixels via gather map. H*W*2 params."""
def __init__(self, spec: TransformSpec):
super().__init__(spec)
gmap = spec.params.get("gather_map", {})
self.h_o = spec.params.get("output_shape", [30, 30])[0]
self.w_o = spec.params.get("output_shape", [30, 30])[1]
self.gh = nn.Parameter(torch.zeros(self.h_o, self.w_o, dtype=torch.long), requires_grad=False)
self.gw = nn.Parameter(torch.zeros(self.h_o, self.w_o, dtype=torch.long), requires_grad=False)
with torch.no_grad():
for k, v in gmap.items():
h, w = map(int, k.split(","))
sh, sw = map(int, v.split(","))
if h < self.h_o and w < self.w_o:
self.gh[h, w] = sh
self.gw[h, w] = sw
def _forward_impl(self, x):
B, C, Hi, Wi = x.shape
gh = self.gh.clamp(0, Hi - 1)
gw = self.gw.clamp(0, Wi - 1)
return x[:, :, gh, gw]
# ─── One-Hot Convolution ───────────────────────────────────────
class OneHotConvNN(BaseTransformNN):
"""One-hot encode, convolve, argmax decode. K^2*100 params."""
def __init__(self, spec: TransformSpec):
super().__init__(spec)
kh = spec.params.get("kernel_h", 3)
kw = spec.params.get("kernel_w", 3)
self.conv = nn.Conv2d(10, 10, kernel_size=(kh, kw), padding='same', bias=False)
if "weights" in spec.params:
w = torch.tensor(spec.params["weights"], dtype=torch.float32)
self.conv.weight = nn.Parameter(w.view(10, 10, kh, kw), requires_grad=False)
def _forward_impl(self, x):
B, _, H, W = x.shape
onehot = F.one_hot(x.long().squeeze(1).clamp(0, 9), 10).permute(0, 3, 1, 2).float()
return self.conv(onehot).argmax(dim=1, keepdim=True).float()
class OneHotLinearNN(BaseTransformNN):
"""One-hot encode, linear, argmax. 100 params."""
def __init__(self, spec: TransformSpec):
super().__init__(spec)
self.linear = nn.Linear(10, 10, bias=False)
if "weights" in spec.params:
self.linear.weight = nn.Parameter(
torch.tensor(spec.params["weights"], dtype=torch.float32), requires_grad=False)
def _forward_impl(self, x):
onehot = F.one_hot(x.long().squeeze(1).clamp(0, 9), 10).float()
return self.linear(onehot).argmax(dim=-1).unsqueeze(1).float()
# ─── Factory & Parser ──────────────────────────────────────────
TRANSFORM_REGISTRY = {
"identity": IdentityNN,
"color_map": ColorMapNN,
"flip": FlipNN,
"transpose": TransposeNN,
"rotate": RotateNN,
"upscale": UpscaleNN,
"kron_self_similar": KronSelfSimilarNN,
"tile_repeat": TileRepeatNN,
"concat_patterns": ConcatPatternsNN,
"pos_color_lut": PositionColorLUTNN,
"spatial_gather": SpatialGatherNN,
"onehot_conv": OneHotConvNN,
"onehot_linear": OneHotLinearNN,
}
def create_transform_nn(spec: TransformSpec) -> BaseTransformNN:
cls = TRANSFORM_REGISTRY.get(spec.name)
if cls is None:
raise ValueError(f"Unknown transform: {spec.name}")
return cls(spec)
def parse_kilo_output(md: str) -> TransformSpec:
"""Parse Kilo markdown into TransformSpec."""
lines = md.strip().split('\n')
name, params, section = None, {}, None
for line in lines:
line = line.strip()
if line.startswith('## '):
section = line[3:].strip().lower()
continue
if section == 'transform' and line.startswith('name:'):
name = line.split(':', 1)[1].strip()
elif section == 'parameters' and line.startswith('- '):
kv = line[2:].split(':', 1)
if len(kv) == 2:
k, v = kv[0].strip(), kv[1].strip()
try:
import ast
params[k] = ast.literal_eval(v)
except (ValueError, SyntaxError):
params[k] = v
if not name:
raise ValueError("No transform name in Kilo output")
return TransformSpec(name=name, params=params)
# ─── ONNX Export ───────────────────────────────────────────────
def export_to_onnx(model: BaseTransformNN, input_shape: Tuple[int, int],
output_path: str, opset: int = 17):
model.eval()
H, W = input_shape
dummy = torch.zeros(1, 1, H, W)
torch.onnx.export(model, dummy, output_path,
input_names=["input"], output_names=["output"],
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
opset_version=opset, do_constant_folding=True)
import os
kb, p = os.path.getsize(output_path) / 1024, model.count_params()
print(f"Exported {output_path}: {kb:.1f} KB, {p} params")
return output_path