rogermt commited on
Commit
833dcfe
·
verified ·
1 Parent(s): 6c42dc4

Add multi-provider LLM solver: Gemini, DeepSeek, GLM, Ollama

Browse files
Files changed (1) hide show
  1. scripts/llm_solver_cloud.py +441 -0
scripts/llm_solver_cloud.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PEMF ARC-AGI — LLM Program Synthesis (Multi-Provider)
3
+ =====================================================
4
+
5
+ Supports:
6
+ - Google Gemini (free tier: 15 RPM, generous limits)
7
+ - DeepSeek V4 (very cheap: $0.07/M input tokens)
8
+ - GLM-4 / ChatGLM (free tier available)
9
+ - Ollama local (any model)
10
+ - Any OpenAI-compatible API
11
+
12
+ Usage:
13
+ # Gemini (free, recommended starting point)
14
+ export LLM_PROVIDER=gemini
15
+ export GEMINI_API_KEY=your_key_here
16
+ python llm_solver_cloud.py
17
+
18
+ # DeepSeek (cheapest cloud option)
19
+ export LLM_PROVIDER=deepseek
20
+ export DEEPSEEK_API_KEY=your_key_here
21
+ python llm_solver_cloud.py
22
+
23
+ # GLM
24
+ export LLM_PROVIDER=glm
25
+ export GLM_API_KEY=your_key_here
26
+ python llm_solver_cloud.py
27
+
28
+ # Ollama local
29
+ export LLM_PROVIDER=ollama
30
+ export OLLAMA_MODEL=qwen2.5-coder:32b
31
+ python llm_solver_cloud.py
32
+ """
33
+
34
+ import os
35
+ import sys
36
+ import json
37
+ import time
38
+ import re
39
+ import glob
40
+ import numpy as np
41
+ from typing import Dict, List, Optional, Tuple
42
+ from collections import Counter
43
+ import urllib.request
44
+
45
+
46
+ # =============================================================================
47
+ # PROVIDER CONFIGS
48
+ # =============================================================================
49
+
50
+ PROVIDERS = {
51
+ "gemini": {
52
+ "name": "Google Gemini",
53
+ "base_url": "https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent",
54
+ "default_model": "gemini-2.0-flash",
55
+ "env_key": "GEMINI_API_KEY",
56
+ "free_tier": "15 RPM, 1M tokens/day",
57
+ "get_key_url": "https://aistudio.google.com/apikey",
58
+ },
59
+ "deepseek": {
60
+ "name": "DeepSeek",
61
+ "base_url": "https://api.deepseek.com/v1/chat/completions",
62
+ "default_model": "deepseek-chat",
63
+ "env_key": "DEEPSEEK_API_KEY",
64
+ "free_tier": "$0.07/M input, $0.27/M output",
65
+ "get_key_url": "https://platform.deepseek.com/api_keys",
66
+ },
67
+ "glm": {
68
+ "name": "GLM (Zhipu AI)",
69
+ "base_url": "https://open.bigmodel.cn/api/paas/v4/chat/completions",
70
+ "default_model": "glm-4-flash",
71
+ "env_key": "GLM_API_KEY",
72
+ "free_tier": "glm-4-flash is free",
73
+ "get_key_url": "https://open.bigmodel.cn/usercenter/apikeys",
74
+ },
75
+ "ollama": {
76
+ "name": "Ollama (local)",
77
+ "base_url": "http://localhost:11434/api/generate",
78
+ "default_model": "qwen2.5-coder:32b",
79
+ "env_key": None,
80
+ },
81
+ }
82
+
83
+
84
+ # =============================================================================
85
+ # API CALLERS
86
+ # =============================================================================
87
+
88
+ def call_gemini(prompt: str, api_key: str, model: str = "gemini-2.0-flash",
89
+ temperature: float = 0.7) -> str:
90
+ """Call Google Gemini API."""
91
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={api_key}"
92
+ payload = {
93
+ "contents": [{"parts": [{"text": prompt}]}],
94
+ "generationConfig": {
95
+ "temperature": temperature,
96
+ "maxOutputTokens": 2048,
97
+ }
98
+ }
99
+ data = json.dumps(payload).encode('utf-8')
100
+ req = urllib.request.Request(url, data=data,
101
+ headers={"Content-Type": "application/json"},
102
+ method='POST')
103
+ try:
104
+ with urllib.request.urlopen(req, timeout=120) as resp:
105
+ result = json.loads(resp.read().decode())
106
+ candidates = result.get('candidates', [])
107
+ if candidates:
108
+ parts = candidates[0].get('content', {}).get('parts', [])
109
+ if parts:
110
+ return parts[0].get('text', '')
111
+ return "ERROR: No response content"
112
+ except Exception as e:
113
+ return f"ERROR: {e}"
114
+
115
+
116
+ def call_deepseek(prompt: str, api_key: str, model: str = "deepseek-chat",
117
+ temperature: float = 0.7) -> str:
118
+ """Call DeepSeek API (OpenAI-compatible)."""
119
+ url = "https://api.deepseek.com/v1/chat/completions"
120
+ payload = {
121
+ "model": model,
122
+ "messages": [{"role": "user", "content": prompt}],
123
+ "max_tokens": 2048,
124
+ "temperature": temperature,
125
+ }
126
+ data = json.dumps(payload).encode('utf-8')
127
+ req = urllib.request.Request(url, data=data,
128
+ headers={"Content-Type": "application/json",
129
+ "Authorization": f"Bearer {api_key}"},
130
+ method='POST')
131
+ try:
132
+ with urllib.request.urlopen(req, timeout=120) as resp:
133
+ result = json.loads(resp.read().decode())
134
+ return result['choices'][0]['message']['content']
135
+ except Exception as e:
136
+ return f"ERROR: {e}"
137
+
138
+
139
+ def call_glm(prompt: str, api_key: str, model: str = "glm-4-flash",
140
+ temperature: float = 0.7) -> str:
141
+ """Call GLM/Zhipu API (OpenAI-compatible)."""
142
+ url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
143
+ payload = {
144
+ "model": model,
145
+ "messages": [{"role": "user", "content": prompt}],
146
+ "max_tokens": 2048,
147
+ "temperature": temperature,
148
+ }
149
+ data = json.dumps(payload).encode('utf-8')
150
+ req = urllib.request.Request(url, data=data,
151
+ headers={"Content-Type": "application/json",
152
+ "Authorization": f"Bearer {api_key}"},
153
+ method='POST')
154
+ try:
155
+ with urllib.request.urlopen(req, timeout=120) as resp:
156
+ result = json.loads(resp.read().decode())
157
+ return result['choices'][0]['message']['content']
158
+ except Exception as e:
159
+ return f"ERROR: {e}"
160
+
161
+
162
+ def call_ollama(prompt: str, model: str = "qwen2.5-coder:32b",
163
+ temperature: float = 0.7) -> str:
164
+ """Call local Ollama."""
165
+ url = "http://localhost:11434/api/generate"
166
+ payload = {
167
+ "model": model,
168
+ "prompt": prompt,
169
+ "stream": False,
170
+ "options": {"temperature": temperature, "num_predict": 2048},
171
+ }
172
+ data = json.dumps(payload).encode('utf-8')
173
+ req = urllib.request.Request(url, data=data,
174
+ headers={"Content-Type": "application/json"},
175
+ method='POST')
176
+ try:
177
+ with urllib.request.urlopen(req, timeout=180) as resp:
178
+ result = json.loads(resp.read().decode())
179
+ return result.get('response', '')
180
+ except Exception as e:
181
+ return f"ERROR: {e}"
182
+
183
+
184
+ def call_llm(prompt: str, provider: str, api_key: str = "",
185
+ model: str = "", temperature: float = 0.7) -> str:
186
+ """Unified LLM caller."""
187
+ if provider == "gemini":
188
+ return call_gemini(prompt, api_key, model or "gemini-2.0-flash", temperature)
189
+ elif provider == "deepseek":
190
+ return call_deepseek(prompt, api_key, model or "deepseek-chat", temperature)
191
+ elif provider == "glm":
192
+ return call_glm(prompt, api_key, model or "glm-4-flash", temperature)
193
+ elif provider == "ollama":
194
+ return call_ollama(prompt, model or "qwen2.5-coder:32b", temperature)
195
+ else:
196
+ return f"ERROR: Unknown provider {provider}"
197
+
198
+
199
+ # =============================================================================
200
+ # PROMPT, EXTRACTION, VERIFICATION (same as before)
201
+ # =============================================================================
202
+
203
+ def build_prompt(task: Dict) -> str:
204
+ train_pairs = task.get('train', [])
205
+ examples = []
206
+ for i, pair in enumerate(train_pairs):
207
+ examples.append(
208
+ f"Example {i+1}:\n"
209
+ f" Input: {json.dumps(pair['input'])}\n"
210
+ f" Output: {json.dumps(pair['output'])}"
211
+ )
212
+ examples_str = "\n".join(examples)
213
+
214
+ inputs = [np.array(p['input']) for p in train_pairs]
215
+ outputs = [np.array(p['output']) for p in train_pairs]
216
+ same_shape = all(i.shape == o.shape for i, o in zip(inputs, outputs))
217
+ in_colors = sorted(set(c for i in inputs for c in np.unique(i).tolist()))
218
+ out_colors = sorted(set(c for o in outputs for c in np.unique(o).tolist()))
219
+
220
+ analysis = f" Same input/output shape: {same_shape}\n"
221
+ analysis += f" Input colors: {in_colors}, Output colors: {out_colors}\n"
222
+ if not same_shape:
223
+ for i, o in zip(inputs[:1], outputs[:1]):
224
+ analysis += f" Shape: {i.shape} -> {o.shape}\n"
225
+
226
+ return f"""Solve this ARC-AGI puzzle. Write ONLY a Python function, no explanations.
227
+
228
+ {examples_str}
229
+
230
+ Analysis:
231
+ {analysis}
232
+ ```python
233
+ import numpy as np
234
+ from collections import Counter, deque
235
+ from scipy.ndimage import label
236
+
237
+ def transform(grid: list[list[int]]) -> list[list[int]]:
238
+ grid = np.array(grid)
239
+ """
240
+
241
+
242
+ def extract_code(response: str) -> Optional[str]:
243
+ for pattern in [r'```python\s*(.*?)```', r'```\s*(.*?)```']:
244
+ matches = re.findall(pattern, response, re.DOTALL)
245
+ for match in matches:
246
+ if 'def transform' in match:
247
+ return match.strip()
248
+ idx = response.find('def transform')
249
+ if idx >= 0:
250
+ before = response[:idx]
251
+ import_start = max(before.rfind('import '), before.rfind('from '))
252
+ start = import_start if import_start >= 0 else idx
253
+ code = response[start:]
254
+ end = code.find('```')
255
+ if end > 0:
256
+ code = code[:end]
257
+ return code.strip()
258
+ stripped = response.strip()
259
+ if stripped.startswith(('import', 'def transform', 'from')):
260
+ return stripped
261
+ return None
262
+
263
+
264
+ def verify_program(code: str, train_pairs: List[Dict]) -> bool:
265
+ namespace = {'np': np, 'numpy': np, 'Counter': Counter,
266
+ 'deque': __import__('collections').deque}
267
+ try:
268
+ # Allow scipy import in generated code
269
+ try:
270
+ import scipy.ndimage
271
+ namespace['scipy'] = __import__('scipy')
272
+ except ImportError:
273
+ pass
274
+ exec(code, namespace)
275
+ except Exception:
276
+ return False
277
+ if 'transform' not in namespace:
278
+ return False
279
+ fn = namespace['transform']
280
+ for pair in train_pairs:
281
+ try:
282
+ result = fn([row[:] for row in pair['input']])
283
+ if result is None:
284
+ return False
285
+ r = np.array(result, dtype=int)
286
+ e = np.array(pair['output'], dtype=int)
287
+ if r.shape != e.shape or not np.array_equal(r, e):
288
+ return False
289
+ except Exception:
290
+ return False
291
+ return True
292
+
293
+
294
+ def apply_program(code: str, test_input):
295
+ namespace = {'np': np, 'numpy': np, 'Counter': Counter,
296
+ 'deque': __import__('collections').deque}
297
+ try:
298
+ import scipy.ndimage
299
+ namespace['scipy'] = __import__('scipy')
300
+ except ImportError:
301
+ pass
302
+ try:
303
+ exec(code, namespace)
304
+ result = namespace['transform']([row[:] for row in test_input])
305
+ if result is not None:
306
+ return np.array(result, dtype=int).tolist()
307
+ except Exception:
308
+ pass
309
+ return None
310
+
311
+
312
+ # =============================================================================
313
+ # SYNTHESIS + MAIN
314
+ # =============================================================================
315
+
316
+ def synthesize_task(task, provider, api_key, model, n_candidates=8, verbose=False):
317
+ prompt = build_prompt(task)
318
+ for i in range(n_candidates):
319
+ temp = 0.1 if i == 0 else min(0.4 + 0.15 * i, 1.2)
320
+ response = call_llm(prompt, provider, api_key, model, temp)
321
+ if response.startswith("ERROR:"):
322
+ if verbose: print(f" C{i+1}: {response[:60]}")
323
+ # Rate limit — wait and retry
324
+ if "429" in response or "rate" in response.lower():
325
+ time.sleep(5)
326
+ continue
327
+ code = extract_code(response)
328
+ if code is None:
329
+ if verbose: print(f" C{i+1}: no code")
330
+ continue
331
+ if verbose: print(f" C{i+1}: {len(code)}ch", end="")
332
+ if verify_program(code, task['train']):
333
+ if verbose: print(" ✅")
334
+ return (f"llm_c{i+1}", code)
335
+ else:
336
+ if verbose: print(" ❌")
337
+ return None
338
+
339
+
340
+ def main():
341
+ PROVIDER = os.environ.get("LLM_PROVIDER", "gemini")
342
+ config = PROVIDERS.get(PROVIDER, {})
343
+ API_KEY = os.environ.get(config.get("env_key", ""), "") if config.get("env_key") else ""
344
+ MODEL = os.environ.get("LLM_MODEL", config.get("default_model", ""))
345
+ N_CANDIDATES = int(os.environ.get("N_CANDIDATES", "8"))
346
+ ARC_DIR = os.environ.get("ARC_DIR", "arc_data/training")
347
+ ALREADY_SOLVED = os.environ.get("ALREADY_SOLVED", "already_solved.json")
348
+ OUTPUT = os.environ.get("OUTPUT_FILE", "llm_results.json")
349
+
350
+ print("=" * 60)
351
+ print(f"PEMF ARC-AGI — LLM Synthesis ({config.get('name', PROVIDER)})")
352
+ print("=" * 60)
353
+ print(f"Provider: {PROVIDER}")
354
+ print(f"Model: {MODEL}")
355
+ print(f"Candidates/task: {N_CANDIDATES}")
356
+ if not API_KEY and PROVIDER != "ollama":
357
+ print(f"\n⚠️ No API key! Set {config.get('env_key', '???')}")
358
+ print(f" Get key: {config.get('get_key_url', '?')}")
359
+ return
360
+ print()
361
+
362
+ # Load already solved
363
+ already_solved = set()
364
+ if os.path.exists(ALREADY_SOLVED):
365
+ with open(ALREADY_SOLVED) as f:
366
+ already_solved = set(json.load(f))
367
+ print(f"Symbolic solved: {len(already_solved)}")
368
+
369
+ # Load tasks
370
+ task_files = sorted(glob.glob(os.path.join(ARC_DIR, "*.json")))
371
+ unsolved = [(os.path.basename(tf).replace('.json',''), tf)
372
+ for tf in task_files
373
+ if os.path.basename(tf).replace('.json','') not in already_solved]
374
+ print(f"Total tasks: {len(task_files)}, unsolved: {len(unsolved)}")
375
+ print()
376
+
377
+ # Run
378
+ results = {}
379
+ solved = 0
380
+ total_time = 0
381
+
382
+ for idx, (tid, tf) in enumerate(unsolved):
383
+ with open(tf) as f:
384
+ task = json.load(f)
385
+ print(f"[{idx+1:3d}/{len(unsolved)}] {tid}:", end=" ", flush=True)
386
+ start = time.time()
387
+ result = synthesize_task(task, PROVIDER, API_KEY, MODEL, N_CANDIDATES, verbose=False)
388
+ elapsed = time.time() - start
389
+ total_time += elapsed
390
+
391
+ if result:
392
+ rule, code = result
393
+ solved += 1
394
+ test_outputs = [apply_program(code, t['input']) for t in task.get('test', [])]
395
+ results[tid] = {'status': 'solved', 'rule': rule, 'code': code,
396
+ 'test_outputs': test_outputs, 'time_s': round(elapsed, 2)}
397
+ print(f"✅ ({elapsed:.1f}s)")
398
+ else:
399
+ results[tid] = {'status': 'failed', 'time_s': round(elapsed, 2)}
400
+ print(f"❌ ({elapsed:.1f}s)")
401
+
402
+ # Rate limit respect
403
+ if PROVIDER == "gemini":
404
+ time.sleep(4) # 15 RPM = 1 every 4s
405
+ elif PROVIDER in ("deepseek", "glm"):
406
+ time.sleep(1)
407
+
408
+ # Save every 10
409
+ if (idx + 1) % 10 == 0:
410
+ _save(OUTPUT, PROVIDER, MODEL, N_CANDIDATES, solved, idx+1,
411
+ total_time, already_solved, len(task_files), results)
412
+ print(f" [Saved: {solved}/{idx+1}, total {len(already_solved)+solved}/{len(task_files)}]")
413
+
414
+ # Final save
415
+ _save(OUTPUT, PROVIDER, MODEL, N_CANDIDATES, solved, len(unsolved),
416
+ total_time, already_solved, len(task_files), results)
417
+
418
+ print(f"\n{'='*60}")
419
+ print(f"LLM solved: {solved}/{len(unsolved)}")
420
+ print(f"Symbolic: {len(already_solved)}")
421
+ print(f"TOTAL: {len(already_solved)+solved}/{len(task_files)} ({100*(len(already_solved)+solved)/len(task_files):.1f}%)")
422
+ print(f"Saved: {OUTPUT}")
423
+
424
+
425
+ def _save(path, provider, model, n_cand, solved, attempted, total_time,
426
+ already_solved, total_tasks, results):
427
+ with open(path, 'w') as f:
428
+ json.dump({
429
+ 'provider': provider, 'model': model, 'n_candidates': n_cand,
430
+ 'llm_solved': solved, 'attempted': attempted,
431
+ 'total_time_s': round(total_time, 1),
432
+ 'symbolic_solved': len(already_solved),
433
+ 'total_solved': len(already_solved) + solved,
434
+ 'total_tasks': total_tasks,
435
+ 'solve_rate': round(100*(len(already_solved)+solved)/total_tasks, 2),
436
+ 'results': results,
437
+ }, f, indent=2)
438
+
439
+
440
+ if __name__ == "__main__":
441
+ main()