rogermt commited on
Commit
eea0011
Β·
verified Β·
1 Parent(s): 3d9221a

Add TRM solver NN executor with 14 transform types and ONNX export

Browse files
Files changed (1) hide show
  1. trm_solver/executor.py +330 -0
trm_solver/executor.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TRM Solver β€” NN Executor for ARC-AGI NeuroGolf
3
+
4
+ Takes a parsed transform (from Kilo/DeepSeek) and executes it as a
5
+ tiny neural network. Each transform is implemented as a minimal NN
6
+ that can be exported to ONNX.
7
+
8
+ Architecture:
9
+ - Each transform is a PyTorch nn.Module with frozen weights
10
+ - Weights encode the transform parameters (not learned β€” set directly)
11
+ - ONNX export produces a tiny model per task
12
+ """
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import numpy as np
18
+ from typing import Dict, List, Tuple, Optional
19
+ from dataclasses import dataclass
20
+ import json
21
+
22
+
23
+ # ─── Data Structures ───────────────────────────────────────────
24
+
25
+ @dataclass
26
+ class TransformSpec:
27
+ """Parsed output from Kilo/DeepSeek."""
28
+ name: str
29
+ params: Dict
30
+ objects: List[Dict] = None
31
+
32
+ def __post_init__(self):
33
+ if self.objects is None:
34
+ self.objects = []
35
+
36
+
37
+ # ─── Base NN Transform ─────────────────────────────────────────
38
+
39
+ class BaseTransformNN(nn.Module):
40
+ """Base class for all transform NNs. Subclasses implement _forward_impl."""
41
+
42
+ def __init__(self, spec: TransformSpec):
43
+ super().__init__()
44
+ self.spec = spec
45
+ self.max_size = 30 # ARC max grid size
46
+
47
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
48
+ """
49
+ Args:
50
+ x: Input grid [B, 1, H, W] or [1, H, W] β€” values 0-9
51
+ Returns:
52
+ Output grid [B, 1, H_out, W_out] β€” values 0-9
53
+ """
54
+ if x.dim() == 3:
55
+ x = x.unsqueeze(0)
56
+ if x.dim() == 2:
57
+ x = x.unsqueeze(0).unsqueeze(0)
58
+ return self._forward_impl(x)
59
+
60
+ def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
61
+ raise NotImplementedError
62
+
63
+ def count_params(self) -> int:
64
+ return sum(p.numel() for p in self.parameters())
65
+
66
+
67
+ # ─── Identity ──────────────────────────────────────────────────
68
+
69
+ class IdentityNN(BaseTransformNN):
70
+ """Output equals input."""
71
+ def _forward_impl(self, x):
72
+ return x
73
+
74
+
75
+ # ─── Color Map ─────────────────────────────────────────────────
76
+
77
+ class ColorMapNN(BaseTransformNN):
78
+ """Per-pixel color remapping. 100 params."""
79
+
80
+ def __init__(self, spec: TransformSpec):
81
+ super().__init__(spec)
82
+ lut = spec.params.get("color_map", list(range(10)))
83
+ self.lut = nn.Conv2d(10, 10, kernel_size=1, bias=False)
84
+ weight = torch.zeros(10, 10, 1, 1)
85
+ for i, j in enumerate(lut):
86
+ weight[j, i, 0, 0] = 1.0
87
+ self.lut.weight = nn.Parameter(weight, requires_grad=False)
88
+
89
+ def _forward_impl(self, x):
90
+ B, _, H, W = x.shape
91
+ x_flat = x.long().squeeze(1).clamp(0, 9)
92
+ onehot = F.one_hot(x_flat, num_classes=10).permute(0, 3, 1, 2).float()
93
+ out = self.lut(onehot)
94
+ return out.argmax(dim=1, keepdim=True).float()
95
+
96
+
97
+ # ─── Geometric ─────────────────────────────────────────────────
98
+
99
+ class FlipNN(BaseTransformNN):
100
+ def _forward_impl(self, x):
101
+ direction = self.spec.params.get("direction", "horizontal")
102
+ dim = 3 if direction == "horizontal" else 2
103
+ return torch.flip(x, [dim])
104
+
105
+
106
+ class TransposeNN(BaseTransformNN):
107
+ def _forward_impl(self, x):
108
+ return x.transpose(2, 3)
109
+
110
+
111
+ class RotateNN(BaseTransformNN):
112
+ def _forward_impl(self, x):
113
+ k = self.spec.params.get("k", 1)
114
+ return torch.rot90(x, k, [2, 3])
115
+
116
+
117
+ # ─── Upscale ───────────────────────────────────────────────────
118
+
119
+ class UpscaleNN(BaseTransformNN):
120
+ """Nearest-neighbor upscaling. ~scale**2 params."""
121
+
122
+ def __init__(self, spec: TransformSpec):
123
+ super().__init__(spec)
124
+ self.scale = spec.params.get("scale", 2)
125
+ th = spec.params.get("output_shape", [30, 30])[0]
126
+ tw = spec.params.get("output_shape", [30, 30])[1]
127
+ self.target_h = th
128
+ self.target_w = tw
129
+
130
+ def _forward_impl(self, x):
131
+ x = x.repeat_interleave(self.scale, dim=2).repeat_interleave(self.scale, dim=3)
132
+ return x[:, :, :self.target_h, :self.target_w]
133
+
134
+
135
+ # ─── Kronecker Self-Similar ────────────────────────────────────
136
+
137
+ class KronSelfSimilarNN(BaseTransformNN):
138
+ """output = kron((input != 0), input). 0 learnable params."""
139
+
140
+ def _forward_impl(self, x):
141
+ mask = (x != 0).float()
142
+ B, _, H_in, W_in = x.shape
143
+ inp_e = x.unsqueeze(2).unsqueeze(2)
144
+ mask_e = mask.unsqueeze(4).unsqueeze(4)
145
+ result = (mask_e * inp_e).float()
146
+ result = result.permute(0, 1, 2, 4, 3, 5).contiguous()
147
+ H_out, W_out = H_in * H_in, W_in * W_in
148
+ return result.view(B, 1, H_out, W_out)
149
+
150
+
151
+ class TileRepeatNN(BaseTransformNN):
152
+ def _forward_impl(self, x):
153
+ hr = self.spec.params.get("h_repeat", 2)
154
+ wr = self.spec.params.get("w_repeat", 2)
155
+ return x.repeat(1, 1, hr, wr)
156
+
157
+
158
+ # ─── Concat Patterns ───────────────────────────────────────────
159
+
160
+ class ConcatPatternsNN(BaseTransformNN):
161
+ """Concatenate transformed copies horizontally/vertically."""
162
+
163
+ def _forward_impl(self, x):
164
+ axis = self.spec.params.get("axis", "horizontal")
165
+ ops = self.spec.params.get("operations", ["identity", "identity"])
166
+ pieces = []
167
+ for op in ops:
168
+ if op == "flip_h": pieces.append(torch.flip(x, [3]))
169
+ elif op == "flip_v": pieces.append(torch.flip(x, [2]))
170
+ elif op == "transpose": pieces.append(x.transpose(2, 3))
171
+ elif op == "rot90": pieces.append(torch.rot90(x, 1, [2, 3]))
172
+ elif op == "rot180": pieces.append(torch.rot90(x, 2, [2, 3]))
173
+ elif op == "rot270": pieces.append(torch.rot90(x, 3, [2, 3]))
174
+ else: pieces.append(x)
175
+ dim = 3 if axis == "horizontal" else 2
176
+ return torch.cat(pieces, dim=dim)
177
+
178
+
179
+ # ─── Position Color LUT ────────────────────────────────────────
180
+
181
+ class PositionColorLUTNN(BaseTransformNN):
182
+ """Per-position color lookup. H*W params."""
183
+
184
+ def __init__(self, spec: TransformSpec):
185
+ super().__init__(spec)
186
+ lut = spec.params.get("lut", {})
187
+ self.h_o = spec.params.get("output_shape", [30, 30])[0]
188
+ self.w_o = spec.params.get("output_shape", [30, 30])[1]
189
+ self.lut = nn.Parameter(torch.zeros(1, 1, self.h_o, self.w_o), requires_grad=False)
190
+ with torch.no_grad():
191
+ for k, v in lut.items():
192
+ h, w = map(int, k.split(","))
193
+ if h < self.h_o and w < self.w_o:
194
+ self.lut[0, 0, h, w] = float(v)
195
+
196
+ def _forward_impl(self, x):
197
+ B = x.shape[0]
198
+ out = self.lut.expand(B, -1, -1, -1)
199
+ mask = (x[:, :, :self.h_o, :self.w_o] != 0).float()
200
+ return mask * out
201
+
202
+
203
+ # ─── Spatial Gather ────────────────────────────────────────────
204
+
205
+ class SpatialGatherNN(BaseTransformNN):
206
+ """Rearrange pixels via gather map. H*W*2 params."""
207
+
208
+ def __init__(self, spec: TransformSpec):
209
+ super().__init__(spec)
210
+ gmap = spec.params.get("gather_map", {})
211
+ self.h_o = spec.params.get("output_shape", [30, 30])[0]
212
+ self.w_o = spec.params.get("output_shape", [30, 30])[1]
213
+ self.gh = nn.Parameter(torch.zeros(self.h_o, self.w_o, dtype=torch.long), requires_grad=False)
214
+ self.gw = nn.Parameter(torch.zeros(self.h_o, self.w_o, dtype=torch.long), requires_grad=False)
215
+ with torch.no_grad():
216
+ for k, v in gmap.items():
217
+ h, w = map(int, k.split(","))
218
+ sh, sw = map(int, v.split(","))
219
+ if h < self.h_o and w < self.w_o:
220
+ self.gh[h, w] = sh
221
+ self.gw[h, w] = sw
222
+
223
+ def _forward_impl(self, x):
224
+ B, C, Hi, Wi = x.shape
225
+ gh = self.gh.clamp(0, Hi - 1)
226
+ gw = self.gw.clamp(0, Wi - 1)
227
+ return x[:, :, gh, gw]
228
+
229
+
230
+ # ─── One-Hot Convolution ───────────────────────────────────────
231
+
232
+ class OneHotConvNN(BaseTransformNN):
233
+ """One-hot encode, convolve, argmax decode. K^2*100 params."""
234
+
235
+ def __init__(self, spec: TransformSpec):
236
+ super().__init__(spec)
237
+ kh = spec.params.get("kernel_h", 3)
238
+ kw = spec.params.get("kernel_w", 3)
239
+ self.conv = nn.Conv2d(10, 10, kernel_size=(kh, kw), padding='same', bias=False)
240
+ if "weights" in spec.params:
241
+ w = torch.tensor(spec.params["weights"], dtype=torch.float32)
242
+ self.conv.weight = nn.Parameter(w.view(10, 10, kh, kw), requires_grad=False)
243
+
244
+ def _forward_impl(self, x):
245
+ B, _, H, W = x.shape
246
+ onehot = F.one_hot(x.long().squeeze(1).clamp(0, 9), 10).permute(0, 3, 1, 2).float()
247
+ return self.conv(onehot).argmax(dim=1, keepdim=True).float()
248
+
249
+
250
+ class OneHotLinearNN(BaseTransformNN):
251
+ """One-hot encode, linear, argmax. 100 params."""
252
+
253
+ def __init__(self, spec: TransformSpec):
254
+ super().__init__(spec)
255
+ self.linear = nn.Linear(10, 10, bias=False)
256
+ if "weights" in spec.params:
257
+ self.linear.weight = nn.Parameter(
258
+ torch.tensor(spec.params["weights"], dtype=torch.float32), requires_grad=False)
259
+
260
+ def _forward_impl(self, x):
261
+ onehot = F.one_hot(x.long().squeeze(1).clamp(0, 9), 10).float()
262
+ return self.linear(onehot).argmax(dim=-1).unsqueeze(1).float()
263
+
264
+
265
+ # ─── Factory & Parser ──────────────────────────────────────────
266
+
267
+ TRANSFORM_REGISTRY = {
268
+ "identity": IdentityNN,
269
+ "color_map": ColorMapNN,
270
+ "flip": FlipNN,
271
+ "transpose": TransposeNN,
272
+ "rotate": RotateNN,
273
+ "upscale": UpscaleNN,
274
+ "kron_self_similar": KronSelfSimilarNN,
275
+ "tile_repeat": TileRepeatNN,
276
+ "concat_patterns": ConcatPatternsNN,
277
+ "pos_color_lut": PositionColorLUTNN,
278
+ "spatial_gather": SpatialGatherNN,
279
+ "onehot_conv": OneHotConvNN,
280
+ "onehot_linear": OneHotLinearNN,
281
+ }
282
+
283
+
284
+ def create_transform_nn(spec: TransformSpec) -> BaseTransformNN:
285
+ cls = TRANSFORM_REGISTRY.get(spec.name)
286
+ if cls is None:
287
+ raise ValueError(f"Unknown transform: {spec.name}")
288
+ return cls(spec)
289
+
290
+
291
+ def parse_kilo_output(md: str) -> TransformSpec:
292
+ """Parse Kilo markdown into TransformSpec."""
293
+ lines = md.strip().split('\n')
294
+ name, params, section = None, {}, None
295
+ for line in lines:
296
+ line = line.strip()
297
+ if line.startswith('## '):
298
+ section = line[3:].strip().lower()
299
+ continue
300
+ if section == 'transform' and line.startswith('name:'):
301
+ name = line.split(':', 1)[1].strip()
302
+ elif section == 'parameters' and line.startswith('- '):
303
+ kv = line[2:].split(':', 1)
304
+ if len(kv) == 2:
305
+ k, v = kv[0].strip(), kv[1].strip()
306
+ try:
307
+ import ast
308
+ params[k] = ast.literal_eval(v)
309
+ except (ValueError, SyntaxError):
310
+ params[k] = v
311
+ if not name:
312
+ raise ValueError("No transform name in Kilo output")
313
+ return TransformSpec(name=name, params=params)
314
+
315
+
316
+ # ─── ONNX Export ───────────────────────────────────────────────
317
+
318
+ def export_to_onnx(model: BaseTransformNN, input_shape: Tuple[int, int],
319
+ output_path: str, opset: int = 17):
320
+ model.eval()
321
+ H, W = input_shape
322
+ dummy = torch.zeros(1, 1, H, W)
323
+ torch.onnx.export(model, dummy, output_path,
324
+ input_names=["input"], output_names=["output"],
325
+ dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
326
+ opset_version=opset, do_constant_folding=True)
327
+ import os
328
+ kb, p = os.path.getsize(output_path) / 1024, model.count_params()
329
+ print(f"Exported {output_path}: {kb:.1f} KB, {p} params")
330
+ return output_path