ARC-AGI / trm_solver /classify_tasks.py
rogermt's picture
Update classifier to support Kilo local server
ccffc8f verified
#!/usr/bin/env python3
"""
ARC-AGI Task Classifier via Kilo Code server (local DeepSeek, free tier).
Also supports DeepSeek API as fallback.
Kilo server mode (preferred, faster on free tier):
1. Start Kilo server: kilo serve --port 8765
2. Run: python classify_tasks.py --mode kilo --kilo_url http://127.0.0.1:8765
API mode (fallback):
python classify_tasks.py --mode api --data_dir /kaggle/input/competitions/neurogolf-2026/
"""
import json, os, glob, time, argparse, requests
SOLVER_NAMES = [
"identity", "constant", "color_map", "transpose", "flip", "rotate",
"shift", "tile", "upscale", "kronecker", "nonuniform_scale",
"mirror_h", "mirror_v", "quad_mirror", "concat", "concat_enhanced",
"diagonal_tile", "fixed_crop", "spatial_gather",
"varshape_spatial_gather", "gravity_unrolled", "edge_detect",
"mode_fill", "downsample_stride", "symmetry_complete",
"extract_inner", "add_border", "sparse_fill", "channel_filter",
]
COMPOSITION_PATTERNS = [
"transform_then_recolor",
"crop_then_transform",
"recolor_then_tile",
]
SYSTEM_PROMPT = f"""You are a world-class ARC-AGI pattern classifier. Analyze grid transformations and predict which solver would produce the correct output.
Available single solvers: {', '.join(SOLVER_NAMES)}
Available composition solvers: {', '.join(COMPOSITION_PATTERNS)}
Solver descriptions:
- identity: output = input
- constant: output is a fixed grid regardless of input
- color_map: per-pixel color remapping
- transpose: matrix transpose
- flip: horizontal or vertical flip
- rotate: 90/180/270 rotation
- shift: translate grid by offset
- tile: repeat input to fill output
- upscale: nearest-neighbor pixel-repeat zoom
- kronecker: kron(mask, input) self-similar
- nonuniform_scale: non-integer scale
- mirror_h/v: mirror and tile horizontally/vertically
- quad_mirror: 4-way kaleidoscope
- concat: concatenate transformed copies
- concat_enhanced: concat with color-dependent selection
- diagonal_tile: tile along diagonal
- fixed_crop: crop a rectangular region
- spatial_gather: arbitrary pixel rearrangement
- varshape_spatial_gather: spatial_gather with variable shapes
- gravity_unrolled: directional pixel compaction
- mode_fill: fill grid with most common color
- downsample_stride: subsample at regular stride
- symmetry_complete: complete partial symmetry
- extract_inner: remove outer border/frame
- add_border: add constant-color border
- sparse_fill: expand non-zero pixels into blocks
- channel_filter: keep only certain color channels
- transform_then_recolor: any spatial transform THEN color_map
- crop_then_transform: crop THEN apply spatial transform
- recolor_then_tile: color_map THEN tile/upscale
IMPORTANT: Look at ALL training pairs together.
Output a valid JSON object mapping each task ID to:
{{"TASK_ID": {{"primary_solver": "solver_name", "fallback_solvers": ["solver1", "solver2"], "grid_size_changed": true/false, "confidence": 1-10, "notes": "brief description"}}}}
Output ONLY JSON."""
def format_grid(grid):
return "\n".join([f"R{i}: {row}" for i, row in enumerate(grid)])
def call_kilo(prompt, kilo_url, model="deepseek-ai/deepseek-chat", timeout=120):
"""Call Kilo local server (OpenAI-compatible API)."""
payload = {
"model": model,
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt}
],
"temperature": 0.3,
}
resp = requests.post(f"{kilo_url}/v1/chat/completions", json=payload, timeout=timeout)
data = resp.json()
return data['choices'][0]['message']['content']
def call_api(prompt, api_key, base_url="https://api.deepseek.com", model="deepseek-chat"):
"""Call DeepSeek API (fallback, slower)."""
from openai import OpenAI
client = OpenAI(api_key=api_key, base_url=base_url)
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt}
],
response_format={'type': 'json_object'},
temperature=0.3,
)
return response.choices[0].message.content
def classify_tasks(data_dir, output_file, mode="kilo", kilo_url="http://127.0.0.1:8765",
api_key=None, model="deepseek-ai/deepseek-chat", batch_size=5):
all_files = sorted(glob.glob(os.path.join(data_dir, "task*.json")))
if not all_files:
all_files = sorted(glob.glob(os.path.join(data_dir, "*.json")))
print(f"Found {len(all_files)} task files. Mode: {mode}")
classifications = {}
if os.path.exists(output_file):
with open(output_file) as f:
classifications = json.load(f)
print(f"Resuming: {len(classifications)} already classified")
for i in range(0, len(all_files), batch_size):
batch_files = all_files[i : i + batch_size]
batch_ids = [os.path.basename(f).replace('.json','') for f in batch_files]
if all(bid in classifications for bid in batch_ids):
continue
prompt = "Classify these ARC tasks:\n"
for f in batch_files:
tid = os.path.basename(f).replace('.json','')
with open(f) as fh:
task = json.load(fh)
prompt += f"\n### TASK: {tid}\n"
for idx, pair in enumerate(task.get('train', [])):
prompt += f"--- Example {idx} ---\nIN:\n{format_grid(pair['input'])}\nOUT:\n{format_grid(pair['output'])}\n"
for idx, pair in enumerate(task.get('test', [])):
prompt += f"--- Test Input {idx} ---\nIN:\n{format_grid(pair['input'])}\n"
for attempt in range(3):
try:
if mode == "kilo":
content = call_kilo(prompt, kilo_url, model)
else:
content = call_api(prompt, api_key, model=model)
# Parse JSON from response
batch_results = json.loads(content)
classifications.update(batch_results)
with open(output_file, 'w') as f:
json.dump(classifications, f, indent=2)
print(f" [{i+1}-{i+len(batch_files)}] OK: {list(batch_results.keys())}")
break
except Exception as e:
print(f" Retry {attempt+1}: {e}")
time.sleep(3)
# Generate routing table
routing = {}
for tid, data in classifications.items():
primary = data.get('primary_solver', '')
fallbacks = data.get('fallback_solvers', [])
solvers = [primary] + [s for s in fallbacks if s != primary]
routing[tid] = {
'solvers': solvers,
'confidence': data.get('confidence', 5),
'grid_changed': data.get('grid_size_changed', False),
'notes': data.get('notes', '')
}
routing_file = output_file.replace('.json', '_routing.json')
with open(routing_file, 'w') as f:
json.dump(routing, f, indent=2)
print(f"\nDone. {len(classifications)} classified. Routing: {routing_file}")
return routing
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument('--mode', default='kilo', choices=['kilo', 'api'])
p.add_argument('--data_dir', default='/kaggle/input/competitions/neurogolf-2026/')
p.add_argument('--output_file', default='arc_task_routes.json')
p.add_argument('--kilo_url', default='http://127.0.0.1:8765')
p.add_argument('--api_key', default='')
p.add_argument('--model', default='deepseek-ai/deepseek-chat')
p.add_argument('--batch_size', type=int, default=5)
args = p.parse_args()
classify_tasks(args.data_dir, args.output_file, args.mode,
args.kilo_url, args.api_key, args.model, args.batch_size)