rogermt commited on
Commit
b9f62c4
·
verified ·
1 Parent(s): a57fa41

Add scripts/kaggle_llm_solver.py

Browse files
Files changed (1) hide show
  1. scripts/kaggle_llm_solver.py +452 -0
scripts/kaggle_llm_solver.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PEMF ARC-AGI — LLM Program Synthesis via Ollama (Kaggle Edition)
3
+ ================================================================
4
+
5
+ Self-contained script for Kaggle GPU notebooks.
6
+ Pulls a model via Ollama, runs LLM synthesis on unsolved ARC tasks.
7
+
8
+ Usage on Kaggle:
9
+ 1. Enable GPU (T4 x2 or P100)
10
+ 2. Enable internet access
11
+ 3. Upload this file + arc_data/ + already_solved.json
12
+ 4. Run all cells
13
+
14
+ The script:
15
+ - Installs Ollama
16
+ - Pulls the model (qwen2.5-coder:32b or smaller)
17
+ - Loads ARC tasks
18
+ - For each unsolved task: generates Python transform(), verifies against training pairs
19
+ - Saves results to llm_results.json
20
+ """
21
+
22
+ import subprocess
23
+ import sys
24
+ import os
25
+ import json
26
+ import time
27
+ import re
28
+ import signal
29
+ import numpy as np
30
+ from typing import Dict, List, Optional, Tuple
31
+ from collections import Counter
32
+ from pathlib import Path
33
+
34
+
35
+ # =============================================================================
36
+ # 1. OLLAMA SETUP
37
+ # =============================================================================
38
+
39
+ def install_ollama():
40
+ """Install Ollama on Kaggle/Linux."""
41
+ print("Installing Ollama...")
42
+ subprocess.run("curl -fsSL https://ollama.com/install.sh | sh",
43
+ shell=True, check=True, capture_output=True)
44
+ print("Ollama installed.")
45
+
46
+
47
+ def start_ollama():
48
+ """Start Ollama server in background."""
49
+ print("Starting Ollama server...")
50
+ proc = subprocess.Popen(
51
+ ["ollama", "serve"],
52
+ stdout=subprocess.DEVNULL,
53
+ stderr=subprocess.DEVNULL,
54
+ )
55
+ time.sleep(3) # Wait for server to start
56
+ print(f"Ollama server started (PID {proc.pid})")
57
+ return proc
58
+
59
+
60
+ def pull_model(model_name: str):
61
+ """Pull a model via Ollama."""
62
+ print(f"Pulling model {model_name}... (this may take several minutes)")
63
+ result = subprocess.run(
64
+ ["ollama", "pull", model_name],
65
+ capture_output=True, text=True, timeout=1800
66
+ )
67
+ if result.returncode != 0:
68
+ print(f"Pull failed: {result.stderr}")
69
+ raise RuntimeError(f"Failed to pull {model_name}")
70
+ print(f"Model {model_name} ready.")
71
+
72
+
73
+ def call_ollama(prompt: str, model: str = "qwen2.5-coder:32b",
74
+ temperature: float = 0.7, timeout_s: int = 120) -> str:
75
+ """Call Ollama API and return response text."""
76
+ import urllib.request
77
+
78
+ payload = {
79
+ "model": model,
80
+ "prompt": prompt,
81
+ "stream": False,
82
+ "options": {
83
+ "temperature": temperature,
84
+ "num_predict": 2048,
85
+ }
86
+ }
87
+
88
+ data = json.dumps(payload).encode('utf-8')
89
+ req = urllib.request.Request(
90
+ "http://localhost:11434/api/generate",
91
+ data=data,
92
+ headers={"Content-Type": "application/json"},
93
+ method='POST'
94
+ )
95
+
96
+ try:
97
+ with urllib.request.urlopen(req, timeout=timeout_s) as resp:
98
+ result = json.loads(resp.read().decode())
99
+ return result.get('response', '')
100
+ except Exception as e:
101
+ return f"ERROR: {e}"
102
+
103
+
104
+ # =============================================================================
105
+ # 2. PROMPT BUILDING
106
+ # =============================================================================
107
+
108
+ def build_prompt(task: Dict) -> str:
109
+ """Build prompt for ARC task."""
110
+ train_pairs = task.get('train', [])
111
+
112
+ examples = []
113
+ for i, pair in enumerate(train_pairs):
114
+ examples.append(
115
+ f"Example {i+1}:\n"
116
+ f" Input: {json.dumps(pair['input'])}\n"
117
+ f" Output: {json.dumps(pair['output'])}"
118
+ )
119
+ examples_str = "\n".join(examples)
120
+
121
+ # Basic analysis
122
+ inputs = [np.array(p['input']) for p in train_pairs]
123
+ outputs = [np.array(p['output']) for p in train_pairs]
124
+ same_shape = all(i.shape == o.shape for i, o in zip(inputs, outputs))
125
+ in_colors = sorted(set(c for i in inputs for c in np.unique(i).tolist()))
126
+ out_colors = sorted(set(c for o in outputs for c in np.unique(o).tolist()))
127
+
128
+ analysis = f" Same input/output shape: {same_shape}\n"
129
+ analysis += f" Input colors: {in_colors}\n"
130
+ analysis += f" Output colors: {out_colors}\n"
131
+ if not same_shape:
132
+ ratios = [(o.shape[0]/i.shape[0], o.shape[1]/i.shape[1])
133
+ for i, o in zip(inputs, outputs)]
134
+ analysis += f" Shape ratios (h,w): {ratios}\n"
135
+
136
+ prompt = f"""Solve this ARC-AGI puzzle. Write ONLY a Python function, no explanations.
137
+
138
+ {examples_str}
139
+
140
+ Analysis:
141
+ {analysis}
142
+ Write a complete Python function that transforms any input grid to its output.
143
+ The function MUST work correctly for ALL examples above.
144
+
145
+ ```python
146
+ import numpy as np
147
+ from collections import Counter
148
+
149
+ def transform(grid: list[list[int]]) -> list[list[int]]:
150
+ grid = np.array(grid)
151
+ """
152
+ return prompt
153
+
154
+
155
+ # =============================================================================
156
+ # 3. CODE EXTRACTION AND VERIFICATION
157
+ # =============================================================================
158
+
159
+ def extract_code(response: str) -> Optional[str]:
160
+ """Extract Python function from LLM response."""
161
+ # Try ```python blocks
162
+ for pattern in [r'```python\s*(.*?)```', r'```\s*(.*?)```']:
163
+ matches = re.findall(pattern, response, re.DOTALL)
164
+ for match in matches:
165
+ if 'def transform' in match:
166
+ return match.strip()
167
+
168
+ # Try finding def transform directly
169
+ idx = response.find('def transform')
170
+ if idx >= 0:
171
+ # Look backwards for imports
172
+ before = response[:idx]
173
+ import_start = before.rfind('import ')
174
+ if import_start >= 0:
175
+ code = response[import_start:]
176
+ else:
177
+ code = response[idx:]
178
+ # Trim at next ``` or double newline after function ends
179
+ end = code.find('```')
180
+ if end > 0:
181
+ code = code[:end]
182
+ return code.strip()
183
+
184
+ # If response itself looks like code (starts with import or def)
185
+ stripped = response.strip()
186
+ if stripped.startswith('import') or stripped.startswith('def transform'):
187
+ return stripped
188
+
189
+ return None
190
+
191
+
192
+ def verify_program(code: str, train_pairs: List[Dict]) -> bool:
193
+ """Execute program and verify against all training pairs."""
194
+ namespace = {'np': np, 'numpy': np, 'Counter': Counter,
195
+ 'collections': __import__('collections')}
196
+
197
+ try:
198
+ exec(code, namespace)
199
+ except Exception:
200
+ return False
201
+
202
+ if 'transform' not in namespace:
203
+ return False
204
+
205
+ transform_fn = namespace['transform']
206
+
207
+ for pair in train_pairs:
208
+ try:
209
+ inp = [row[:] for row in pair['input']] # deep copy
210
+ result = transform_fn(inp)
211
+ if result is None:
212
+ return False
213
+ result_arr = np.array(result, dtype=int)
214
+ expected_arr = np.array(pair['output'], dtype=int)
215
+ if result_arr.shape != expected_arr.shape:
216
+ return False
217
+ if not np.array_equal(result_arr, expected_arr):
218
+ return False
219
+ except Exception:
220
+ return False
221
+
222
+ return True
223
+
224
+
225
+ def apply_program(code: str, test_input: List[List[int]]) -> Optional[List[List[int]]]:
226
+ """Apply verified program to test input."""
227
+ namespace = {'np': np, 'numpy': np, 'Counter': Counter,
228
+ 'collections': __import__('collections')}
229
+ try:
230
+ exec(code, namespace)
231
+ result = namespace['transform']([row[:] for row in test_input])
232
+ if result is not None:
233
+ return [list(row) for row in np.array(result, dtype=int).tolist()]
234
+ except Exception:
235
+ pass
236
+ return None
237
+
238
+
239
+ # =============================================================================
240
+ # 4. SYNTHESIS ENGINE
241
+ # =============================================================================
242
+
243
+ def synthesize_task(task: Dict, model: str = "qwen2.5-coder:32b",
244
+ n_candidates: int = 8, verbose: bool = False) -> Optional[Tuple[str, str]]:
245
+ """
246
+ Try to solve a task via LLM.
247
+ Returns (rule_name, code) if successful, None otherwise.
248
+ """
249
+ train_pairs = task.get('train', [])
250
+ if not train_pairs:
251
+ return None
252
+
253
+ prompt = build_prompt(task)
254
+
255
+ for i in range(n_candidates):
256
+ temp = 0.1 if i == 0 else 0.5 + 0.1 * i # first try low temp, then increase
257
+ response = call_ollama(prompt, model=model, temperature=min(temp, 1.0))
258
+
259
+ if response.startswith("ERROR:"):
260
+ if verbose:
261
+ print(f" Candidate {i+1}: API error")
262
+ continue
263
+
264
+ code = extract_code(response)
265
+ if code is None:
266
+ if verbose:
267
+ print(f" Candidate {i+1}: No code extracted")
268
+ continue
269
+
270
+ if verbose:
271
+ print(f" Candidate {i+1}: {len(code)} chars", end="")
272
+
273
+ if verify_program(code, train_pairs):
274
+ if verbose:
275
+ print(f" ✅")
276
+ return (f"llm_c{i+1}_t{temp:.1f}", code)
277
+ else:
278
+ if verbose:
279
+ print(f" ❌")
280
+
281
+ return None
282
+
283
+
284
+ # =============================================================================
285
+ # 5. MAIN RUNNER
286
+ # =============================================================================
287
+
288
+ def main():
289
+ # --- Configuration ---
290
+ MODEL = os.environ.get("OLLAMA_MODEL", "qwen2.5-coder:32b")
291
+ # For smaller GPUs, use:
292
+ # MODEL = "qwen2.5-coder:14b" (fits T4 16GB)
293
+ # MODEL = "qwen2.5-coder:7b" (fits any GPU)
294
+
295
+ N_CANDIDATES = int(os.environ.get("N_CANDIDATES", "8"))
296
+ ARC_DIR = os.environ.get("ARC_DIR", "arc_data/training")
297
+ ALREADY_SOLVED_FILE = os.environ.get("ALREADY_SOLVED", "already_solved.json")
298
+ OUTPUT_FILE = os.environ.get("OUTPUT_FILE", "llm_results.json")
299
+
300
+ print("=" * 60)
301
+ print("PEMF ARC-AGI — LLM Program Synthesis (Kaggle/Ollama)")
302
+ print("=" * 60)
303
+ print(f"Model: {MODEL}")
304
+ print(f"Candidates per task: {N_CANDIDATES}")
305
+ print(f"ARC data: {ARC_DIR}")
306
+ print()
307
+
308
+ # --- Install & start Ollama ---
309
+ try:
310
+ subprocess.run(["ollama", "--version"], capture_output=True, check=True)
311
+ print("Ollama already installed.")
312
+ except (FileNotFoundError, subprocess.CalledProcessError):
313
+ install_ollama()
314
+
315
+ server = start_ollama()
316
+
317
+ try:
318
+ pull_model(MODEL)
319
+ except Exception as e:
320
+ print(f"Failed to pull {MODEL}: {e}")
321
+ print("Trying smaller model...")
322
+ MODEL = "qwen2.5-coder:7b"
323
+ pull_model(MODEL)
324
+
325
+ # --- Load already solved tasks ---
326
+ already_solved = set()
327
+ if os.path.exists(ALREADY_SOLVED_FILE):
328
+ with open(ALREADY_SOLVED_FILE) as f:
329
+ already_solved = set(json.load(f))
330
+ print(f"Already solved (symbolic): {len(already_solved)} tasks")
331
+
332
+ # --- Load ARC tasks ---
333
+ import glob
334
+ task_files = sorted(glob.glob(os.path.join(ARC_DIR, "*.json")))
335
+ print(f"Total ARC tasks: {len(task_files)}")
336
+
337
+ unsolved_files = []
338
+ for tf in task_files:
339
+ tid = os.path.basename(tf).replace('.json', '')
340
+ if tid not in already_solved:
341
+ unsolved_files.append((tid, tf))
342
+ print(f"Unsolved tasks to try: {len(unsolved_files)}")
343
+ print()
344
+
345
+ # --- Run synthesis ---
346
+ results = {}
347
+ solved = 0
348
+ total_time = 0
349
+
350
+ for idx, (tid, tf) in enumerate(unsolved_files):
351
+ with open(tf) as f:
352
+ task = json.load(f)
353
+
354
+ print(f"[{idx+1:3d}/{len(unsolved_files)}] {tid}:", end=" ", flush=True)
355
+ start = time.time()
356
+
357
+ result = synthesize_task(task, model=MODEL, n_candidates=N_CANDIDATES, verbose=False)
358
+ elapsed = time.time() - start
359
+ total_time += elapsed
360
+
361
+ if result:
362
+ rule_name, code = result
363
+ solved += 1
364
+
365
+ # Apply to test pairs
366
+ test_outputs = []
367
+ for test in task.get('test', []):
368
+ out = apply_program(code, test['input'])
369
+ test_outputs.append(out)
370
+
371
+ results[tid] = {
372
+ 'status': 'solved',
373
+ 'rule': rule_name,
374
+ 'code': code,
375
+ 'test_outputs': test_outputs,
376
+ 'time_s': round(elapsed, 2),
377
+ }
378
+ print(f"✅ {rule_name} ({elapsed:.1f}s)")
379
+ else:
380
+ results[tid] = {
381
+ 'status': 'failed',
382
+ 'time_s': round(elapsed, 2),
383
+ }
384
+ print(f"❌ ({elapsed:.1f}s)")
385
+
386
+ # Save progress periodically
387
+ if (idx + 1) % 10 == 0:
388
+ with open(OUTPUT_FILE, 'w') as f:
389
+ json.dump({
390
+ 'model': MODEL,
391
+ 'n_candidates': N_CANDIDATES,
392
+ 'solved': solved,
393
+ 'attempted': idx + 1,
394
+ 'total_time_s': round(total_time, 1),
395
+ 'results': results,
396
+ }, f, indent=2)
397
+ print(f" [Progress saved: {solved}/{idx+1} solved]")
398
+
399
+ # --- Final save ---
400
+ with open(OUTPUT_FILE, 'w') as f:
401
+ json.dump({
402
+ 'model': MODEL,
403
+ 'n_candidates': N_CANDIDATES,
404
+ 'solved': solved,
405
+ 'attempted': len(unsolved_files),
406
+ 'total_time_s': round(total_time, 1),
407
+ 'already_solved_symbolic': len(already_solved),
408
+ 'total_solved': len(already_solved) + solved,
409
+ 'total_tasks': len(task_files),
410
+ 'solve_rate': round(100 * (len(already_solved) + solved) / len(task_files), 2),
411
+ 'results': results,
412
+ }, f, indent=2)
413
+
414
+ # --- Summary ---
415
+ print()
416
+ print("=" * 60)
417
+ print("FINAL RESULTS")
418
+ print("=" * 60)
419
+ print(f"LLM solved: {solved}/{len(unsolved_files)} unsolved tasks")
420
+ print(f"Symbolic solved: {len(already_solved)}")
421
+ print(f"TOTAL SOLVED: {len(already_solved) + solved}/{len(task_files)} ({100*(len(already_solved)+solved)/len(task_files):.1f}%)")
422
+ print(f"Total LLM time: {total_time:.0f}s ({total_time/max(1,len(unsolved_files)):.1f}s/task)")
423
+ print(f"Results saved to: {OUTPUT_FILE}")
424
+
425
+ # Cleanup
426
+ server.terminate()
427
+
428
+
429
+ # =============================================================================
430
+ # 6. GENERATE already_solved.json FROM SYMBOLIC RESULTS
431
+ # =============================================================================
432
+
433
+ def generate_already_solved(summary_file: str, output_file: str = "already_solved.json"):
434
+ """
435
+ Generate already_solved.json from a v4 summary file.
436
+ Run this BEFORE running on Kaggle.
437
+ """
438
+ with open(summary_file) as f:
439
+ data = json.load(f)
440
+ solved = [r['task_id'] for r in data['results'] if r.get('all_train_solved')]
441
+ with open(output_file, 'w') as f:
442
+ json.dump(solved, f)
443
+ print(f"Wrote {len(solved)} solved task IDs to {output_file}")
444
+
445
+
446
+ if __name__ == "__main__":
447
+ # If run with --generate-solved, create the already_solved.json
448
+ if len(sys.argv) > 1 and sys.argv[1] == "--generate-solved":
449
+ summary = sys.argv[2] if len(sys.argv) > 2 else "arc_results/summary_v4.json"
450
+ generate_already_solved(summary)
451
+ else:
452
+ main()