rogermt commited on
Commit
deda756
·
verified ·
1 Parent(s): e54021c

Add DeepSeek task classifier for LLM-guided solver routing

Browse files
own-solver/neurogolf_solver/classify_tasks.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ ARC-AGI Task Classifier — Routes tasks to NeuroGolf solvers via DeepSeek API.
4
+ Output: JSON mapping task_id -> ordered solver list to try first.
5
+ The LLM call is OFFLINE (model generation time only). Zero ONNX cost.
6
+
7
+ Usage on Kaggle:
8
+ python -m neurogolf_solver.classify_tasks
9
+
10
+ Usage locally:
11
+ python -m neurogolf_solver.classify_tasks --data_dir ARC-AGI/data/training/
12
+ """
13
+
14
+ import json, os, glob, time, argparse
15
+
16
+ # --- Solver names matching solver_registry.py ---
17
+ SOLVER_NAMES = [
18
+ "identity", "constant", "color_map", "transpose", "flip", "rotate",
19
+ "shift", "tile", "upscale", "kronecker", "nonuniform_scale",
20
+ "mirror_h", "mirror_v", "quad_mirror", "concat", "concat_enhanced",
21
+ "diagonal_tile", "fixed_crop", "spatial_gather",
22
+ "varshape_spatial_gather", "gravity_unrolled", "edge_detect",
23
+ "mode_fill", "downsample_stride", "symmetry_complete",
24
+ "extract_inner", "add_border", "sparse_fill", "channel_filter",
25
+ ]
26
+
27
+ COMPOSITION_PATTERNS = [
28
+ "transform_then_recolor",
29
+ "crop_then_transform",
30
+ "recolor_then_tile",
31
+ ]
32
+
33
+ SYSTEM_PROMPT = f"""You are a world-class ARC-AGI pattern classifier. Analyze grid transformations and predict which solver would produce the correct output.
34
+
35
+ Available single solvers:
36
+ {', '.join(SOLVER_NAMES)}
37
+
38
+ Available composition solvers (two transforms chained):
39
+ {', '.join(COMPOSITION_PATTERNS)}
40
+
41
+ Solver descriptions:
42
+ - identity: output = input
43
+ - constant: output is a fixed grid regardless of input
44
+ - color_map: per-pixel color remapping
45
+ - transpose: matrix transpose
46
+ - flip: horizontal or vertical flip
47
+ - rotate: 90/180/270 rotation
48
+ - shift: translate grid by offset
49
+ - tile: repeat input to fill output
50
+ - upscale: nearest-neighbor pixel-repeat zoom
51
+ - kronecker: kron(mask, input) self-similar
52
+ - nonuniform_scale: non-integer scale
53
+ - mirror_h/v: mirror and tile horizontally/vertically
54
+ - quad_mirror: 4-way kaleidoscope
55
+ - concat: concatenate transformed copies
56
+ - concat_enhanced: concat with color-dependent selection
57
+ - diagonal_tile: tile along diagonal
58
+ - fixed_crop: crop a rectangular region
59
+ - spatial_gather: arbitrary pixel rearrangement
60
+ - varshape_spatial_gather: spatial_gather with variable shapes
61
+ - gravity_unrolled: directional pixel compaction
62
+ - mode_fill: fill grid with most common color
63
+ - downsample_stride: subsample at regular stride
64
+ - symmetry_complete: complete partial symmetry
65
+ - extract_inner: remove outer border/frame
66
+ - add_border: add constant-color border
67
+ - sparse_fill: expand non-zero pixels into blocks
68
+ - channel_filter: keep only certain color channels
69
+ - transform_then_recolor: any spatial transform THEN color_map
70
+ - crop_then_transform: crop THEN apply spatial transform
71
+ - recolor_then_tile: color_map THEN tile/upscale
72
+
73
+ IMPORTANT: Look at ALL training pairs together. The pattern must be consistent across all pairs.
74
+
75
+ Output a valid JSON object mapping each task ID to:
76
+ {{
77
+ "TASK_ID": {{
78
+ "primary_solver": "solver_name",
79
+ "fallback_solvers": ["solver1", "solver2"],
80
+ "grid_size_changed": true/false,
81
+ "confidence": 1-10,
82
+ "notes": "brief pattern description"
83
+ }}
84
+ }}
85
+
86
+ Output ONLY JSON. No other text."""
87
+
88
+
89
+ def format_grid(grid):
90
+ return "\n".join([f"R{i}: {row}" for i, row in enumerate(grid)])
91
+
92
+
93
+ def classify_tasks(data_dir, output_file, api_key=None, base_url=None,
94
+ model="deepseek-chat", batch_size=5):
95
+ """Classify all ARC tasks using DeepSeek API."""
96
+
97
+ # --- API Setup ---
98
+ if api_key:
99
+ from openai import OpenAI
100
+ client = OpenAI(api_key=api_key, base_url=base_url or "https://api.deepseek.com")
101
+ else:
102
+ try:
103
+ from kaggle_secrets import UserSecretsClient
104
+ from openai import OpenAI
105
+ user_secrets = UserSecretsClient()
106
+ client = OpenAI(
107
+ api_key=user_secrets.get_secret("Deepseek_api_key"),
108
+ base_url="https://api.deepseek.com"
109
+ )
110
+ except ImportError:
111
+ raise RuntimeError("No API key provided and not on Kaggle.")
112
+
113
+ # --- Load tasks ---
114
+ all_files = sorted(glob.glob(os.path.join(data_dir, "task*.json")))
115
+ if not all_files:
116
+ all_files = sorted(glob.glob(os.path.join(data_dir, "*.json")))
117
+ print(f"Found {len(all_files)} task files")
118
+
119
+ classifications = {}
120
+
121
+ # Resume from previous run
122
+ if os.path.exists(output_file):
123
+ with open(output_file) as f:
124
+ classifications = json.load(f)
125
+ print(f"Resuming: {len(classifications)} already classified")
126
+
127
+ # --- Process in batches ---
128
+ for i in range(0, len(all_files), batch_size):
129
+ batch_files = all_files[i : i + batch_size]
130
+ batch_ids = [os.path.basename(f).replace('.json','') for f in batch_files]
131
+ if all(bid in classifications for bid in batch_ids):
132
+ continue
133
+
134
+ prompt = "Classify these ARC tasks:\n"
135
+ for f in batch_files:
136
+ tid = os.path.basename(f).replace('.json','')
137
+ with open(f) as fh:
138
+ task = json.load(fh)
139
+ prompt += f"\n### TASK: {tid}\n"
140
+ for idx, pair in enumerate(task.get('train', [])):
141
+ prompt += f"--- Example {idx} ---\nIN:\n{format_grid(pair['input'])}\nOUT:\n{format_grid(pair['output'])}\n"
142
+ for idx, pair in enumerate(task.get('test', [])):
143
+ prompt += f"--- Test Input {idx} ---\nIN:\n{format_grid(pair['input'])}\n"
144
+
145
+ for attempt in range(3):
146
+ try:
147
+ response = client.chat.completions.create(
148
+ model=model,
149
+ messages=[
150
+ {"role": "system", "content": SYSTEM_PROMPT},
151
+ {"role": "user", "content": prompt}
152
+ ],
153
+ response_format={'type': 'json_object'}
154
+ )
155
+ batch_results = json.loads(response.choices[0].message.content)
156
+ classifications.update(batch_results)
157
+ with open(output_file, 'w') as f:
158
+ json.dump(classifications, f, indent=2)
159
+ print(f" [{i+1}-{i+len(batch_files)}] Classified: {list(batch_results.keys())}")
160
+ break
161
+ except Exception as e:
162
+ print(f" Retry {attempt+1}: {e}")
163
+ time.sleep(3)
164
+
165
+ # --- Generate routing table ---
166
+ routing = {}
167
+ for tid, data in classifications.items():
168
+ primary = data.get('primary_solver', '')
169
+ fallbacks = data.get('fallback_solvers', [])
170
+ solvers = [primary] + [s for s in fallbacks if s != primary]
171
+ routing[tid] = {
172
+ 'solvers': solvers,
173
+ 'confidence': data.get('confidence', 5),
174
+ 'grid_changed': data.get('grid_size_changed', False),
175
+ 'notes': data.get('notes', '')
176
+ }
177
+
178
+ routing_file = output_file.replace('.json', '_routing.json')
179
+ with open(routing_file, 'w') as f:
180
+ json.dump(routing, f, indent=2)
181
+
182
+ print(f"\nDone. {len(classifications)} tasks classified.")
183
+ print(f"Classifications: {output_file}")
184
+ print(f"Routing table: {routing_file}")
185
+ return routing
186
+
187
+
188
+ if __name__ == "__main__":
189
+ parser = argparse.ArgumentParser()
190
+ parser.add_argument('--data_dir', default='/kaggle/input/competitions/neurogolf-2026/')
191
+ parser.add_argument('--output_file', default='/kaggle/working/arc_task_routes.json')
192
+ parser.add_argument('--api_key', default='')
193
+ parser.add_argument('--base_url', default='')
194
+ parser.add_argument('--model', default='deepseek-chat')
195
+ parser.add_argument('--batch_size', type=int, default=5)
196
+ args = parser.parse_args()
197
+ classify_tasks(args.data_dir, args.output_file, args.api_key,
198
+ args.base_url, args.model, args.batch_size)