rogermt commited on
Commit
d565d28
Β·
verified Β·
1 Parent(s): eea0011

Add Kilo Bridge for LLM-driven ARC task analysis and ONNX export

Browse files
Files changed (1) hide show
  1. trm_solver/kilo_bridge.py +363 -0
trm_solver/kilo_bridge.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Kilo Bridge β€” call DeepSeek headless to analyze ARC tasks,
3
+ then drive the NN executor to produce ONNX models.
4
+
5
+ Usage:
6
+ python trm_solver/kilo_bridge.py --task 007bbfb7 --render
7
+
8
+ Pipeline:
9
+ 1. Render ARC task as image (or pass raw grid)
10
+ 2. Call `kilo run` with the image + prompt
11
+ 3. Parse markdown output β†’ TransformSpec
12
+ 4. Create NN executor β†’ export ONNX
13
+ """
14
+
15
+ import subprocess
16
+ import json
17
+ import os
18
+ import sys
19
+ import argparse
20
+ import tempfile
21
+ from typing import Optional, Dict, List, Tuple
22
+ import numpy as np
23
+
24
+ # Add parent to path
25
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
26
+
27
+ from trm_solver.executor import (
28
+ TransformSpec, create_transform_nn, export_to_onnx, parse_kilo_output
29
+ )
30
+
31
+
32
+ # ─── ARC Task Loading ──────────────────────────────────────────
33
+
34
+ def load_arc_task(task_file: str) -> Dict:
35
+ """Load an ARC task from JSON file."""
36
+ with open(task_file) as f:
37
+ return json.load(f)
38
+
39
+
40
+ def render_grid(grid: List[List[int]], cell_size: int = 20) -> np.ndarray:
41
+ """Render an ARC grid as an RGB image."""
42
+ # ARC color palette (0-9)
43
+ palette = {
44
+ 0: [0, 0, 0], # black
45
+ 1: [0, 116, 217], # blue
46
+ 2: [255, 65, 54], # red
47
+ 3: [46, 204, 64], # green
48
+ 4: [255, 220, 0], # yellow
49
+ 5: [170, 170, 170], # gray
50
+ 6: [240, 18, 190], # magenta
51
+ 7: [255, 133, 27], # orange
52
+ 8: [127, 219, 255], # light blue
53
+ 9: [135, 86, 52], # brown
54
+ }
55
+
56
+ h, w = len(grid), len(grid[0])
57
+ img = np.zeros((h * cell_size, w * cell_size, 3), dtype=np.uint8)
58
+
59
+ for r in range(h):
60
+ for c in range(w):
61
+ color = palette.get(grid[r][c], [0, 0, 0])
62
+ img[r*cell_size:(r+1)*cell_size, c*cell_size:(c+1)*cell_size] = color
63
+
64
+ # Add grid lines
65
+ img[::cell_size, :] = [64, 64, 64]
66
+ img[:, ::cell_size] = [64, 64, 64]
67
+
68
+ return img
69
+
70
+
71
+ def render_task_grids(task: Dict, cell_size: int = 20) -> np.ndarray:
72
+ """
73
+ Render all train pairs of an ARC task as a single image.
74
+ Layout: train pairs side by side, input above output.
75
+ """
76
+ train_pairs = task.get("train", [])
77
+ n_pairs = len(train_pairs)
78
+
79
+ if n_pairs == 0:
80
+ return None
81
+
82
+ # Find max dimensions
83
+ max_h = max(len(p["input"]) for p in train_pairs + [{"input": task.get("test", [{}])[0].get("input", [[0]])}])
84
+ max_w = max(len(p["input"][0]) for p in train_pairs)
85
+
86
+ pair_imgs = []
87
+ for pair in train_pairs:
88
+ inp = render_grid(pair["input"], cell_size)
89
+ out = render_grid(pair["output"], cell_size)
90
+ # Stack input above output with small gap
91
+ gap = np.zeros((cell_size//2, inp.shape[1], 3), dtype=np.uint8)
92
+ pair_img = np.vstack([inp, gap, out])
93
+ pair_imgs.append(pair_img)
94
+
95
+ # Pad to same height
96
+ max_pair_h = max(p.shape[0] for p in pair_imgs)
97
+ padded = []
98
+ for p in pair_imgs:
99
+ if p.shape[0] < max_pair_h:
100
+ pad = np.zeros((max_pair_h - p.shape[0], p.shape[1], 3), dtype=np.uint8)
101
+ p = np.vstack([p, pad])
102
+ # Add separator
103
+ sep = np.zeros((max_pair_h, cell_size//2, 3), dtype=np.uint8)
104
+ padded.append(p)
105
+ padded.append(sep)
106
+
107
+ # Remove last separator
108
+ if padded:
109
+ padded = padded[:-1]
110
+
111
+ result = np.hstack(padded)
112
+ return result
113
+
114
+
115
+ # ─── Kilo Interface ────────────────────────────────────────────
116
+
117
+ KILO_PROMPT_TEMPLATE = """Analyze this ARC-AGI task. The image shows training examples: input grids (above) and their output grids (below) for each example pair.
118
+
119
+ Identify the transformation rule that maps each input to its output.
120
+
121
+ Output your analysis in this EXACT format:
122
+
123
+ ## Transform
124
+ name: <transform_name>
125
+
126
+ ## Parameters
127
+ - param1: value1
128
+ - param2: value2
129
+
130
+ Available transform names:
131
+ - identity: output equals input
132
+ - color_map: per-pixel color remapping (params: color_map=[0,2,1,3,...])
133
+ - flip: horizontal or vertical flip (params: direction="horizontal"|"vertical")
134
+ - transpose: matrix transpose
135
+ - rotate: 90/180/270 rotation (params: k=1|2|3)
136
+ - upscale: nearest-neighbor upscale (params: scale=2|3, output_shape=[H,W])
137
+ - kron_self_similar: Kronecker product with own mask
138
+ - tile_repeat: tile input (params: h_repeat=N, w_repeat=N)
139
+ - concat_patterns: concatenate transformed copies (params: axis="horizontal"|"vertical", operations=["identity","flip_h"])
140
+ - pos_color_lut: position-based color lookup (params: lut={"0,0":4,"1,2":3})
141
+ - spatial_gather: pixel rearrangement (params: gather_map={"0,0":"1,2"})
142
+ - onehot_conv: one-hot convolution (params: kernel_h=3, kernel_w=3)
143
+ - onehot_linear: one-hot linear transform (params: weights=[[...]])
144
+
145
+ Be precise. Output ONLY the structured format above, no extra text."""
146
+
147
+
148
+ def call_kilo_headless(image_path: str, prompt: str = None) -> str:
149
+ """
150
+ Call DeepSeek via Kilo headless CLI.
151
+
152
+ Args:
153
+ image_path: Path to rendered ARC task image
154
+ prompt: Override default prompt
155
+
156
+ Returns:
157
+ Markdown output from Kilo/DeepSeek
158
+ """
159
+ if prompt is None:
160
+ prompt = KILO_PROMPT_TEMPLATE
161
+
162
+ cmd = [
163
+ "kilo", "run",
164
+ prompt,
165
+ "--image", image_path,
166
+ "--format", "default"
167
+ ]
168
+
169
+ print(f"Running: {' '.join(cmd[:2])} ... [prompt + image]")
170
+
171
+ result = subprocess.run(
172
+ cmd,
173
+ capture_output=True,
174
+ text=True,
175
+ timeout=120 # 2 min timeout per task
176
+ )
177
+
178
+ if result.returncode != 0:
179
+ print(f"Kilo error (stderr): {result.stderr[:500]}")
180
+ raise RuntimeError(f"Kilo failed with code {result.returncode}")
181
+
182
+ return result.stdout.strip()
183
+
184
+
185
+ def call_kilo_via_sdk(image_path: str, prompt: str = None,
186
+ server_url: str = "http://localhost:8765") -> str:
187
+ """
188
+ Call DeepSeek via Kilo SDK (tunnel to local server).
189
+ Use this from Kaggle when connected via tunnel.
190
+ """
191
+ try:
192
+ from kilo_sdk import KiloClient
193
+ except ImportError:
194
+ raise ImportError("Install kilo_sdk: pip install kilo-sdk")
195
+
196
+ client = KiloClient(server_url=server_url)
197
+
198
+ if prompt is None:
199
+ prompt = KILO_PROMPT_TEMPLATE
200
+
201
+ response = client.run(
202
+ prompt=prompt,
203
+ image=image_path,
204
+ format="default"
205
+ )
206
+
207
+ return response
208
+
209
+
210
+ # ─── Main Pipeline ─────────────────────────────────────────────
211
+
212
+ def process_task(task_file: str, output_dir: str,
213
+ use_sdk: bool = False,
214
+ server_url: str = "http://localhost:8765",
215
+ render: bool = True) -> Tuple[str, TransformSpec, str]:
216
+ """
217
+ Full pipeline for one ARC task:
218
+ 1. Render β†’ 2. Kilo β†’ 3. Parse β†’ 4. NN β†’ 5. ONNX
219
+
220
+ Returns: (task_id, TransformSpec, onnx_path)
221
+ """
222
+ task = load_arc_task(task_file)
223
+ task_id = os.path.splitext(os.path.basename(task_file))[0]
224
+
225
+ # Render task as image
226
+ if render:
227
+ img = render_task_grids(task)
228
+ if img is None:
229
+ raise ValueError(f"No train pairs in {task_file}")
230
+ from PIL import Image
231
+ img_path = os.path.join(output_dir, f"{task_id}_render.png")
232
+ Image.fromarray(img).save(img_path)
233
+ else:
234
+ img_path = task_file # Assume already rendered
235
+
236
+ # Call Kilo
237
+ print(f"\n{'='*60}")
238
+ print(f"Task: {task_id}")
239
+ print(f"{'='*60}")
240
+
241
+ if use_sdk:
242
+ md_output = call_kilo_via_sdk(img_path, server_url=server_url)
243
+ else:
244
+ md_output = call_kilo_headless(img_path)
245
+
246
+ print(f"\nKilo output:\n{md_output[:500]}...\n")
247
+
248
+ # Parse
249
+ spec = parse_kilo_output(md_output)
250
+ print(f"Parsed: transform={spec.name}, params={spec.params}")
251
+
252
+ # Create NN
253
+ model = create_transform_nn(spec)
254
+
255
+ # Get test input shape
256
+ test_input = task.get("test", [{}])[0].get("input", [[0]])
257
+ test_h, test_w = len(test_input), len(test_input[0])
258
+
259
+ # Export ONNX
260
+ os.makedirs(output_dir, exist_ok=True)
261
+ onnx_path = os.path.join(output_dir, f"{task_id}.onnx")
262
+ export_to_onnx(model, (test_h, test_w), onnx_path)
263
+
264
+ # Save spec for reference
265
+ spec_path = os.path.join(output_dir, f"{task_id}_spec.json")
266
+ with open(spec_path, 'w') as f:
267
+ json.dump({"name": spec.name, "params": spec.params}, f, indent=2)
268
+
269
+ return task_id, spec, onnx_path
270
+
271
+
272
+ def batch_process(data_dir: str, output_dir: str,
273
+ task_ids: Optional[List[str]] = None,
274
+ use_sdk: bool = False,
275
+ server_url: str = "http://localhost:8765",
276
+ max_tasks: int = None) -> List[Tuple[str, TransformSpec, str]]:
277
+ """
278
+ Process multiple ARC tasks.
279
+
280
+ Args:
281
+ data_dir: Path to directory containing task JSON files
282
+ output_dir: Where to save ONNX files
283
+ task_ids: Specific task IDs to process (None = all)
284
+ use_sdk: Use Kilo SDK instead of CLI
285
+ server_url: SDK server URL
286
+ max_tasks: Limit number of tasks
287
+ """
288
+ results = []
289
+ task_files = []
290
+
291
+ for f in sorted(os.listdir(data_dir)):
292
+ if f.endswith('.json'):
293
+ tid = f.replace('.json', '')
294
+ if task_ids is None or tid in task_ids:
295
+ task_files.append(os.path.join(data_dir, f))
296
+
297
+ if max_tasks:
298
+ task_files = task_files[:max_tasks]
299
+
300
+ print(f"Processing {len(task_files)} tasks...")
301
+
302
+ for i, tf in enumerate(task_files):
303
+ try:
304
+ tid, spec, onnx = process_task(tf, output_dir, use_sdk, server_url)
305
+ results.append((tid, spec, onnx))
306
+ print(f" [{i+1}/{len(task_files)}] βœ“ {tid} β†’ {spec.name}")
307
+ except Exception as e:
308
+ print(f" [{i+1}/{len(task_files)}] βœ— {os.path.basename(tf)}: {e}")
309
+
310
+ # Summary
311
+ print(f"\n{'='*60}")
312
+ print(f"SUMMARY: {len(results)}/{len(task_files)} tasks processed")
313
+ print(f"{'='*60}")
314
+
315
+ # Save manifest
316
+ manifest = {
317
+ "total": len(results),
318
+ "tasks": {tid: {"transform": spec.name, "params": spec.params}
319
+ for tid, spec, _ in results}
320
+ }
321
+ with open(os.path.join(output_dir, "manifest.json"), 'w') as f:
322
+ json.dump(manifest, f, indent=2)
323
+
324
+ return results
325
+
326
+
327
+ # ─── CLI ───────────────────────────────────────────────────────
328
+
329
+ if __name__ == "__main__":
330
+ parser = argparse.ArgumentParser(description="Kilo Bridge for ARC-AGI")
331
+ parser.add_argument("--task", help="Task ID or path to task JSON")
332
+ parser.add_argument("--data-dir", help="Directory of ARC task JSONs")
333
+ parser.add_argument("--output-dir", default="onnx_models", help="Output directory")
334
+ parser.add_argument("--use-sdk", action="store_true", help="Use Kilo SDK (tunnel)")
335
+ parser.add_argument("--server-url", default="http://localhost:8765")
336
+ parser.add_argument("--max-tasks", type=int, help="Max tasks to process")
337
+ parser.add_argument("--task-ids", nargs="*", help="Specific task IDs")
338
+ parser.add_argument("--no-render", action="store_true", help="Skip rendering")
339
+ parser.add_argument("--dry-run", action="store_true",
340
+ help="Test without calling Kilo (use dummy output)")
341
+
342
+ args = parser.parse_args()
343
+
344
+ if args.dry_run:
345
+ # Test with dummy Kilo output
346
+ spec = TransformSpec(name="kron_self_similar", params={"scale": 3})
347
+ model = create_transform_nn(spec)
348
+ os.makedirs(args.output_dir, exist_ok=True)
349
+ onnx_path = os.path.join(args.output_dir, "test_dry_run.onnx")
350
+ export_to_onnx(model, (3, 3), onnx_path)
351
+ print(f"Dry run complete: {onnx_path}")
352
+ sys.exit(0)
353
+
354
+ if args.task:
355
+ process_task(args.task, args.output_dir,
356
+ use_sdk=args.use_sdk, server_url=args.server_url,
357
+ render=not args.no_render)
358
+ elif args.data_dir:
359
+ batch_process(args.data_dir, args.output_dir,
360
+ task_ids=args.task_ids, use_sdk=args.use_sdk,
361
+ server_url=args.server_url, max_tasks=args.max_tasks)
362
+ else:
363
+ parser.print_help()