| |
| """ |
| ARC-AGI Task Classifier — Routes tasks to NeuroGolf solvers via DeepSeek API. |
| Output: JSON mapping task_id -> ordered solver list to try first. |
| The LLM call is OFFLINE (model generation time only). Zero ONNX cost. |
| |
| Usage on Kaggle: |
| python -m neurogolf_solver.classify_tasks |
| |
| Usage locally: |
| python -m neurogolf_solver.classify_tasks --data_dir ARC-AGI/data/training/ |
| """ |
|
|
| import json, os, glob, time, argparse |
|
|
| |
| SOLVER_NAMES = [ |
| "identity", "constant", "color_map", "transpose", "flip", "rotate", |
| "shift", "tile", "upscale", "kronecker", "nonuniform_scale", |
| "mirror_h", "mirror_v", "quad_mirror", "concat", "concat_enhanced", |
| "diagonal_tile", "fixed_crop", "spatial_gather", |
| "varshape_spatial_gather", "gravity_unrolled", "edge_detect", |
| "mode_fill", "downsample_stride", "symmetry_complete", |
| "extract_inner", "add_border", "sparse_fill", "channel_filter", |
| ] |
|
|
| COMPOSITION_PATTERNS = [ |
| "transform_then_recolor", |
| "crop_then_transform", |
| "recolor_then_tile", |
| ] |
|
|
| SYSTEM_PROMPT = f"""You are a world-class ARC-AGI pattern classifier. Analyze grid transformations and predict which solver would produce the correct output. |
| |
| Available single solvers: |
| {', '.join(SOLVER_NAMES)} |
| |
| Available composition solvers (two transforms chained): |
| {', '.join(COMPOSITION_PATTERNS)} |
| |
| Solver descriptions: |
| - identity: output = input |
| - constant: output is a fixed grid regardless of input |
| - color_map: per-pixel color remapping |
| - transpose: matrix transpose |
| - flip: horizontal or vertical flip |
| - rotate: 90/180/270 rotation |
| - shift: translate grid by offset |
| - tile: repeat input to fill output |
| - upscale: nearest-neighbor pixel-repeat zoom |
| - kronecker: kron(mask, input) self-similar |
| - nonuniform_scale: non-integer scale |
| - mirror_h/v: mirror and tile horizontally/vertically |
| - quad_mirror: 4-way kaleidoscope |
| - concat: concatenate transformed copies |
| - concat_enhanced: concat with color-dependent selection |
| - diagonal_tile: tile along diagonal |
| - fixed_crop: crop a rectangular region |
| - spatial_gather: arbitrary pixel rearrangement |
| - varshape_spatial_gather: spatial_gather with variable shapes |
| - gravity_unrolled: directional pixel compaction |
| - mode_fill: fill grid with most common color |
| - downsample_stride: subsample at regular stride |
| - symmetry_complete: complete partial symmetry |
| - extract_inner: remove outer border/frame |
| - add_border: add constant-color border |
| - sparse_fill: expand non-zero pixels into blocks |
| - channel_filter: keep only certain color channels |
| - transform_then_recolor: any spatial transform THEN color_map |
| - crop_then_transform: crop THEN apply spatial transform |
| - recolor_then_tile: color_map THEN tile/upscale |
| |
| IMPORTANT: Look at ALL training pairs together. The pattern must be consistent across all pairs. |
| |
| Output a valid JSON object mapping each task ID to: |
| {{ |
| "TASK_ID": {{ |
| "primary_solver": "solver_name", |
| "fallback_solvers": ["solver1", "solver2"], |
| "grid_size_changed": true/false, |
| "confidence": 1-10, |
| "notes": "brief pattern description" |
| }} |
| }} |
| |
| Output ONLY JSON. No other text.""" |
|
|
|
|
| def format_grid(grid): |
| return "\n".join([f"R{i}: {row}" for i, row in enumerate(grid)]) |
|
|
|
|
| def classify_tasks(data_dir, output_file, api_key=None, base_url=None, |
| model="deepseek-chat", batch_size=5): |
| """Classify all ARC tasks using DeepSeek API.""" |
|
|
| |
| if api_key: |
| from openai import OpenAI |
| client = OpenAI(api_key=api_key, base_url=base_url or "https://api.deepseek.com") |
| else: |
| try: |
| from kaggle_secrets import UserSecretsClient |
| from openai import OpenAI |
| user_secrets = UserSecretsClient() |
| client = OpenAI( |
| api_key=user_secrets.get_secret("Deepseek_api_key"), |
| base_url="https://api.deepseek.com" |
| ) |
| except ImportError: |
| raise RuntimeError("No API key provided and not on Kaggle.") |
|
|
| |
| all_files = sorted(glob.glob(os.path.join(data_dir, "task*.json"))) |
| if not all_files: |
| all_files = sorted(glob.glob(os.path.join(data_dir, "*.json"))) |
| print(f"Found {len(all_files)} task files") |
|
|
| classifications = {} |
|
|
| |
| if os.path.exists(output_file): |
| with open(output_file) as f: |
| classifications = json.load(f) |
| print(f"Resuming: {len(classifications)} already classified") |
|
|
| |
| for i in range(0, len(all_files), batch_size): |
| batch_files = all_files[i : i + batch_size] |
| batch_ids = [os.path.basename(f).replace('.json','') for f in batch_files] |
| if all(bid in classifications for bid in batch_ids): |
| continue |
|
|
| prompt = "Classify these ARC tasks:\n" |
| for f in batch_files: |
| tid = os.path.basename(f).replace('.json','') |
| with open(f) as fh: |
| task = json.load(fh) |
| prompt += f"\n### TASK: {tid}\n" |
| for idx, pair in enumerate(task.get('train', [])): |
| prompt += f"--- Example {idx} ---\nIN:\n{format_grid(pair['input'])}\nOUT:\n{format_grid(pair['output'])}\n" |
| for idx, pair in enumerate(task.get('test', [])): |
| prompt += f"--- Test Input {idx} ---\nIN:\n{format_grid(pair['input'])}\n" |
|
|
| for attempt in range(3): |
| try: |
| response = client.chat.completions.create( |
| model=model, |
| messages=[ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": prompt} |
| ], |
| response_format={'type': 'json_object'} |
| ) |
| batch_results = json.loads(response.choices[0].message.content) |
| classifications.update(batch_results) |
| with open(output_file, 'w') as f: |
| json.dump(classifications, f, indent=2) |
| print(f" [{i+1}-{i+len(batch_files)}] Classified: {list(batch_results.keys())}") |
| break |
| except Exception as e: |
| print(f" Retry {attempt+1}: {e}") |
| time.sleep(3) |
|
|
| |
| routing = {} |
| for tid, data in classifications.items(): |
| primary = data.get('primary_solver', '') |
| fallbacks = data.get('fallback_solvers', []) |
| solvers = [primary] + [s for s in fallbacks if s != primary] |
| routing[tid] = { |
| 'solvers': solvers, |
| 'confidence': data.get('confidence', 5), |
| 'grid_changed': data.get('grid_size_changed', False), |
| 'notes': data.get('notes', '') |
| } |
|
|
| routing_file = output_file.replace('.json', '_routing.json') |
| with open(routing_file, 'w') as f: |
| json.dump(routing, f, indent=2) |
|
|
| print(f"\nDone. {len(classifications)} tasks classified.") |
| print(f"Classifications: {output_file}") |
| print(f"Routing table: {routing_file}") |
| return routing |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--data_dir', default='/kaggle/input/competitions/neurogolf-2026/') |
| parser.add_argument('--output_file', default='/kaggle/working/arc_task_routes.json') |
| parser.add_argument('--api_key', default='') |
| parser.add_argument('--base_url', default='') |
| parser.add_argument('--model', default='deepseek-chat') |
| parser.add_argument('--batch_size', type=int, default=5) |
| args = parser.parse_args() |
| classify_tasks(args.data_dir, args.output_file, args.api_key, |
| args.base_url, args.model, args.batch_size) |
|
|