rogermt commited on
Commit
ccffc8f
·
verified ·
1 Parent(s): 826a998

Update classifier to support Kilo local server

Browse files
Files changed (1) hide show
  1. trm_solver/classify_tasks.py +68 -68
trm_solver/classify_tasks.py CHANGED
@@ -1,17 +1,17 @@
1
  #!/usr/bin/env python3
2
  """
3
- ARC-AGI Task Classifier Routes tasks to NeuroGolf solvers via DeepSeek API.
4
- Output: JSON mapping task_id -> ordered solver list to try first.
5
- The LLM call is OFFLINE (model generation time only). Zero ONNX cost.
6
 
7
- Usage on Kaggle:
8
- python classify_tasks.py
 
9
 
10
- Usage locally:
11
- python classify_tasks.py --data_dir ARC-AGI/data/training/
12
  """
13
 
14
- import json, os, glob, time, argparse
15
 
16
  SOLVER_NAMES = [
17
  "identity", "constant", "color_map", "transpose", "flip", "rotate",
@@ -31,11 +31,9 @@ COMPOSITION_PATTERNS = [
31
 
32
  SYSTEM_PROMPT = f"""You are a world-class ARC-AGI pattern classifier. Analyze grid transformations and predict which solver would produce the correct output.
33
 
34
- Available single solvers:
35
- {', '.join(SOLVER_NAMES)}
36
 
37
- Available composition solvers (two transforms chained):
38
- {', '.join(COMPOSITION_PATTERNS)}
39
 
40
  Solver descriptions:
41
  - identity: output = input
@@ -69,49 +67,55 @@ Solver descriptions:
69
  - crop_then_transform: crop THEN apply spatial transform
70
  - recolor_then_tile: color_map THEN tile/upscale
71
 
72
- IMPORTANT: Look at ALL training pairs together. The pattern must be consistent across all pairs.
73
 
74
  Output a valid JSON object mapping each task ID to:
75
- {{
76
- "TASK_ID": {{
77
- "primary_solver": "solver_name",
78
- "fallback_solvers": ["solver1", "solver2"],
79
- "grid_size_changed": true/false,
80
- "confidence": 1-10,
81
- "notes": "brief pattern description"
82
- }}
83
- }}
84
 
85
- Output ONLY JSON. No other text."""
86
 
87
 
88
  def format_grid(grid):
89
  return "\n".join([f"R{i}: {row}" for i, row in enumerate(grid)])
90
 
91
 
92
- def classify_tasks(data_dir, output_file, api_key=None, base_url=None,
93
- model="deepseek-chat", batch_size=5):
94
- """Classify all ARC tasks using DeepSeek API."""
95
-
96
- if api_key:
97
- from openai import OpenAI
98
- client = OpenAI(api_key=api_key, base_url=base_url or "https://api.deepseek.com")
99
- else:
100
- try:
101
- from kaggle_secrets import UserSecretsClient
102
- from openai import OpenAI
103
- user_secrets = UserSecretsClient()
104
- client = OpenAI(
105
- api_key=user_secrets.get_secret("Deepseek_api_key"),
106
- base_url="https://api.deepseek.com"
107
- )
108
- except ImportError:
109
- raise RuntimeError("No API key provided and not on Kaggle.")
110
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  all_files = sorted(glob.glob(os.path.join(data_dir, "task*.json")))
112
  if not all_files:
113
  all_files = sorted(glob.glob(os.path.join(data_dir, "*.json")))
114
- print(f"Found {len(all_files)} task files")
115
 
116
  classifications = {}
117
  if os.path.exists(output_file):
@@ -138,24 +142,23 @@ def classify_tasks(data_dir, output_file, api_key=None, base_url=None,
138
 
139
  for attempt in range(3):
140
  try:
141
- response = client.chat.completions.create(
142
- model=model,
143
- messages=[
144
- {"role": "system", "content": SYSTEM_PROMPT},
145
- {"role": "user", "content": prompt}
146
- ],
147
- response_format={'type': 'json_object'}
148
- )
149
- batch_results = json.loads(response.choices[0].message.content)
150
  classifications.update(batch_results)
151
  with open(output_file, 'w') as f:
152
  json.dump(classifications, f, indent=2)
153
- print(f" [{i+1}-{i+len(batch_files)}] Classified: {list(batch_results.keys())}")
154
  break
155
  except Exception as e:
156
  print(f" Retry {attempt+1}: {e}")
157
  time.sleep(3)
158
 
 
159
  routing = {}
160
  for tid, data in classifications.items():
161
  primary = data.get('primary_solver', '')
@@ -167,25 +170,22 @@ def classify_tasks(data_dir, output_file, api_key=None, base_url=None,
167
  'grid_changed': data.get('grid_size_changed', False),
168
  'notes': data.get('notes', '')
169
  }
170
-
171
  routing_file = output_file.replace('.json', '_routing.json')
172
  with open(routing_file, 'w') as f:
173
  json.dump(routing, f, indent=2)
174
-
175
- print(f"\nDone. {len(classifications)} tasks classified.")
176
- print(f"Classifications: {output_file}")
177
- print(f"Routing table: {routing_file}")
178
  return routing
179
 
180
 
181
  if __name__ == "__main__":
182
- parser = argparse.ArgumentParser()
183
- parser.add_argument('--data_dir', default='/kaggle/input/competitions/neurogolf-2026/')
184
- parser.add_argument('--output_file', default='/kaggle/working/arc_task_routes.json')
185
- parser.add_argument('--api_key', default='')
186
- parser.add_argument('--base_url', default='')
187
- parser.add_argument('--model', default='deepseek-chat')
188
- parser.add_argument('--batch_size', type=int, default=5)
189
- args = parser.parse_args()
190
- classify_tasks(args.data_dir, args.output_file, args.api_key,
191
- args.base_url, args.model, args.batch_size)
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ ARC-AGI Task Classifier via Kilo Code server (local DeepSeek, free tier).
4
+ Also supports DeepSeek API as fallback.
 
5
 
6
+ Kilo server mode (preferred, faster on free tier):
7
+ 1. Start Kilo server: kilo serve --port 8765
8
+ 2. Run: python classify_tasks.py --mode kilo --kilo_url http://127.0.0.1:8765
9
 
10
+ API mode (fallback):
11
+ python classify_tasks.py --mode api --data_dir /kaggle/input/competitions/neurogolf-2026/
12
  """
13
 
14
+ import json, os, glob, time, argparse, requests
15
 
16
  SOLVER_NAMES = [
17
  "identity", "constant", "color_map", "transpose", "flip", "rotate",
 
31
 
32
  SYSTEM_PROMPT = f"""You are a world-class ARC-AGI pattern classifier. Analyze grid transformations and predict which solver would produce the correct output.
33
 
34
+ Available single solvers: {', '.join(SOLVER_NAMES)}
 
35
 
36
+ Available composition solvers: {', '.join(COMPOSITION_PATTERNS)}
 
37
 
38
  Solver descriptions:
39
  - identity: output = input
 
67
  - crop_then_transform: crop THEN apply spatial transform
68
  - recolor_then_tile: color_map THEN tile/upscale
69
 
70
+ IMPORTANT: Look at ALL training pairs together.
71
 
72
  Output a valid JSON object mapping each task ID to:
73
+ {{"TASK_ID": {{"primary_solver": "solver_name", "fallback_solvers": ["solver1", "solver2"], "grid_size_changed": true/false, "confidence": 1-10, "notes": "brief description"}}}}
 
 
 
 
 
 
 
 
74
 
75
+ Output ONLY JSON."""
76
 
77
 
78
  def format_grid(grid):
79
  return "\n".join([f"R{i}: {row}" for i, row in enumerate(grid)])
80
 
81
 
82
+ def call_kilo(prompt, kilo_url, model="deepseek-ai/deepseek-chat", timeout=120):
83
+ """Call Kilo local server (OpenAI-compatible API)."""
84
+ payload = {
85
+ "model": model,
86
+ "messages": [
87
+ {"role": "system", "content": SYSTEM_PROMPT},
88
+ {"role": "user", "content": prompt}
89
+ ],
90
+ "temperature": 0.3,
91
+ }
92
+ resp = requests.post(f"{kilo_url}/v1/chat/completions", json=payload, timeout=timeout)
93
+ data = resp.json()
94
+ return data['choices'][0]['message']['content']
95
+
96
+
97
+ def call_api(prompt, api_key, base_url="https://api.deepseek.com", model="deepseek-chat"):
98
+ """Call DeepSeek API (fallback, slower)."""
99
+ from openai import OpenAI
100
+ client = OpenAI(api_key=api_key, base_url=base_url)
101
+ response = client.chat.completions.create(
102
+ model=model,
103
+ messages=[
104
+ {"role": "system", "content": SYSTEM_PROMPT},
105
+ {"role": "user", "content": prompt}
106
+ ],
107
+ response_format={'type': 'json_object'},
108
+ temperature=0.3,
109
+ )
110
+ return response.choices[0].message.content
111
+
112
+
113
+ def classify_tasks(data_dir, output_file, mode="kilo", kilo_url="http://127.0.0.1:8765",
114
+ api_key=None, model="deepseek-ai/deepseek-chat", batch_size=5):
115
  all_files = sorted(glob.glob(os.path.join(data_dir, "task*.json")))
116
  if not all_files:
117
  all_files = sorted(glob.glob(os.path.join(data_dir, "*.json")))
118
+ print(f"Found {len(all_files)} task files. Mode: {mode}")
119
 
120
  classifications = {}
121
  if os.path.exists(output_file):
 
142
 
143
  for attempt in range(3):
144
  try:
145
+ if mode == "kilo":
146
+ content = call_kilo(prompt, kilo_url, model)
147
+ else:
148
+ content = call_api(prompt, api_key, model=model)
149
+
150
+ # Parse JSON from response
151
+ batch_results = json.loads(content)
 
 
152
  classifications.update(batch_results)
153
  with open(output_file, 'w') as f:
154
  json.dump(classifications, f, indent=2)
155
+ print(f" [{i+1}-{i+len(batch_files)}] OK: {list(batch_results.keys())}")
156
  break
157
  except Exception as e:
158
  print(f" Retry {attempt+1}: {e}")
159
  time.sleep(3)
160
 
161
+ # Generate routing table
162
  routing = {}
163
  for tid, data in classifications.items():
164
  primary = data.get('primary_solver', '')
 
170
  'grid_changed': data.get('grid_size_changed', False),
171
  'notes': data.get('notes', '')
172
  }
 
173
  routing_file = output_file.replace('.json', '_routing.json')
174
  with open(routing_file, 'w') as f:
175
  json.dump(routing, f, indent=2)
176
+ print(f"\nDone. {len(classifications)} classified. Routing: {routing_file}")
 
 
 
177
  return routing
178
 
179
 
180
  if __name__ == "__main__":
181
+ p = argparse.ArgumentParser()
182
+ p.add_argument('--mode', default='kilo', choices=['kilo', 'api'])
183
+ p.add_argument('--data_dir', default='/kaggle/input/competitions/neurogolf-2026/')
184
+ p.add_argument('--output_file', default='arc_task_routes.json')
185
+ p.add_argument('--kilo_url', default='http://127.0.0.1:8765')
186
+ p.add_argument('--api_key', default='')
187
+ p.add_argument('--model', default='deepseek-ai/deepseek-chat')
188
+ p.add_argument('--batch_size', type=int, default=5)
189
+ args = p.parse_args()
190
+ classify_tasks(args.data_dir, args.output_file, args.mode,
191
+ args.kilo_url, args.api_key, args.model, args.batch_size)