hellinferno commited on
Commit
2a92b3a
·
1 Parent(s): c98afe9

fix: bulletproof inference.py with Docker + fallback connection methods

Browse files
Files changed (2) hide show
  1. inference.py +268 -91
  2. sql_query_reviewer/client.py +121 -6
inference.py CHANGED
@@ -2,11 +2,12 @@
2
  Inference Script — SQL Query Reviewer
3
  ======================================
4
  MANDATORY environment variables:
5
- API_BASE_URL The API endpoint for the LLM.
6
- MODEL_NAME The model identifier to use for inference.
7
- HF_TOKEN Your Hugging Face / API key.
 
8
 
9
- STDOUT FORMAT:
10
  [START] task=<task_name> env=<benchmark> model=<model_name>
11
  [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
12
  [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
@@ -14,44 +15,45 @@ STDOUT FORMAT:
14
 
15
  from __future__ import annotations
16
 
 
17
  import json
18
  import os
 
 
19
  from typing import Any, List, Optional
20
 
21
  from openai import OpenAI
22
 
23
- from sql_query_reviewer.client import SyncSQLReviewEnv
24
- from sql_query_reviewer.models import SQLReviewAction, SQLReviewObservation
25
-
26
  # ---------------------------------------------------------------------------
27
- # Configuration
28
  # ---------------------------------------------------------------------------
29
 
30
- DEFAULT_TASK_IDS = ("easy_001", "medium_001", "hard_001")
31
- BENCHMARK = "sql-query-reviewer"
32
- SUCCESS_SCORE_THRESHOLD = 0.1
33
-
34
- ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:8000")
35
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
36
  MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
37
- API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
 
 
 
 
38
 
39
  SYSTEM_PROMPT = """You are reviewing a SQL query for correctness, performance, and security.
40
  Return exactly one JSON object with these keys:
41
  - action_type: identify_issue, suggest_fix, approve, or request_more_context
42
- - issue_category: syntax, performance, security, logic, or style when relevant
43
- - issue_description: concise issue statement when relevant
44
- - suggested_fix: corrected SQL or corrected fragment when relevant
45
  - confidence: float between 0.0 and 1.0
46
 
47
  Guidelines:
48
- - Prefer identify_issue until you have high confidence all important issues are covered.
49
- - Use approve only when the query looks acceptable or all issues have already been identified.
50
- - Keep the JSON valid and do not wrap it in prose.
51
  """
52
 
53
  # ---------------------------------------------------------------------------
54
- # Structured stdout logging — MUST match the hackathon spec exactly
55
  # ---------------------------------------------------------------------------
56
 
57
 
@@ -81,146 +83,321 @@ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> No
81
 
82
 
83
  # ---------------------------------------------------------------------------
84
- # LLM interaction
85
  # ---------------------------------------------------------------------------
86
 
87
 
88
- def build_user_prompt(observation: SQLReviewObservation) -> str:
89
- payload = {
90
- "query": observation.query,
91
- "schema_info": observation.schema_info,
92
- "context": observation.context,
93
- "issues_found_so_far": [
94
- issue.model_dump() for issue in observation.issues_found_so_far
95
- ],
96
- "remaining_actions": observation.remaining_actions,
97
- "difficulty": observation.difficulty,
98
- "feedback": observation.feedback,
99
- }
100
- return json.dumps(payload, indent=2)
101
-
102
-
103
  def extract_json(content: str) -> dict[str, Any]:
104
  stripped = content.strip()
105
  if stripped.startswith("```"):
106
- lines = [line for line in stripped.splitlines() if not line.startswith("```")]
107
  stripped = "\n".join(lines).strip()
108
  start = stripped.find("{")
109
  end = stripped.rfind("}")
110
  if start == -1 or end == -1 or end <= start:
111
- raise ValueError(f"Could not find JSON object in model response: {content!r}")
112
  return json.loads(stripped[start : end + 1])
113
 
114
 
115
- def choose_action(
116
- llm_client: OpenAI, model_name: str, observation: SQLReviewObservation
117
- ) -> SQLReviewAction:
118
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  response = llm_client.chat.completions.create(
120
  model=model_name,
121
  temperature=0,
122
  max_tokens=300,
123
  messages=[
124
  {"role": "system", "content": SYSTEM_PROMPT},
125
- {"role": "user", "content": build_user_prompt(observation)},
126
  ],
127
  )
128
  content = response.choices[0].message.content or ""
129
- return SQLReviewAction.model_validate(extract_json(content))
 
 
 
 
130
  except Exception as exc:
131
- print(f"[DEBUG] Model request failed: {exc}", flush=True)
132
- # Fallback: approve to end the episode gracefully
133
- return SQLReviewAction(action_type="approve", confidence=0.1)
134
 
135
 
136
  # ---------------------------------------------------------------------------
137
- # Episode runner
138
  # ---------------------------------------------------------------------------
139
 
140
 
141
- def run_episode(
142
- env: SyncSQLReviewEnv, llm_client: OpenAI, model_name: str, task_id: str
143
  ) -> None:
144
  rewards: List[float] = []
145
  steps_taken = 0
146
  score = 0.0
147
  success = False
148
- last_error: Optional[str] = None
149
 
150
  log_start(task=task_id, env=BENCHMARK, model=model_name)
151
 
152
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  result = env.reset(task_id=task_id)
154
 
155
- step = 0
156
- while not result.done:
157
- step += 1
158
- action = choose_action(
159
- llm_client=llm_client,
160
- model_name=model_name,
161
- observation=result.observation,
162
- )
 
 
163
 
164
  action_str = action.action_type
165
  if action.issue_description:
166
- # Keep action string short and readable
167
  action_str = f"{action.action_type}({action.issue_category})"
168
 
169
- result = env.step(action)
 
 
 
 
 
 
 
170
 
171
- reward = result.reward
172
  rewards.append(reward)
173
- steps_taken = step
174
- last_error = result.info.get("error") if result.info else None
175
-
176
- log_step(
177
- step=step,
178
- action=action_str,
179
- reward=reward,
180
- done=result.done,
181
- error=last_error,
182
- )
183
-
184
- # Get final score from state
185
- state = env.state()
186
- score = state.final_score if state.final_score is not None else 0.0
 
 
 
 
 
 
187
  success = score >= SUCCESS_SCORE_THRESHOLD
188
 
189
  except Exception as exc:
190
  print(f"[DEBUG] Episode error: {exc}", flush=True)
191
- last_error = str(exc)
192
 
193
  finally:
194
  log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
195
 
196
 
197
  # ---------------------------------------------------------------------------
198
- # Main
199
  # ---------------------------------------------------------------------------
200
 
201
 
202
- def main() -> int:
 
203
  if not API_KEY:
204
- raise SystemExit("Set HF_TOKEN or OPENAI_API_KEY before running inference.py")
 
 
 
 
 
 
205
 
206
  task_ids = tuple(
207
  tid.strip()
208
- for tid in os.getenv("TASK_IDS", ",".join(DEFAULT_TASK_IDS)).split(",")
209
  if tid.strip()
210
  )
211
 
212
- llm_client = OpenAI(api_key=API_KEY, base_url=API_BASE_URL)
213
 
214
- with SyncSQLReviewEnv(base_url=ENV_BASE_URL) as env:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  for task_id in task_ids:
216
- run_episode(
217
- env=env,
218
- llm_client=llm_client,
219
- model_name=MODEL_NAME,
220
- task_id=task_id,
221
- )
222
-
223
- return 0
 
 
 
 
 
 
 
 
 
 
 
224
 
225
 
226
  if __name__ == "__main__":
 
2
  Inference Script — SQL Query Reviewer
3
  ======================================
4
  MANDATORY environment variables:
5
+ API_BASE_URL The API endpoint for the LLM.
6
+ MODEL_NAME The model identifier to use for inference.
7
+ HF_TOKEN Your Hugging Face / API key.
8
+ LOCAL_IMAGE_NAME The name of the local Docker image for the environment.
9
 
10
+ STDOUT FORMAT (must match exactly):
11
  [START] task=<task_name> env=<benchmark> model=<model_name>
12
  [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
13
  [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
 
15
 
16
  from __future__ import annotations
17
 
18
+ import asyncio
19
  import json
20
  import os
21
+ import sys
22
+ import traceback
23
  from typing import Any, List, Optional
24
 
25
  from openai import OpenAI
26
 
 
 
 
27
  # ---------------------------------------------------------------------------
28
+ # Environment variables — read ALL names the validator might set
29
  # ---------------------------------------------------------------------------
30
 
31
+ IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME") or os.getenv("IMAGE_NAME")
32
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY")
 
 
 
33
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
34
  MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
35
+ ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:8000")
36
+
37
+ BENCHMARK = "sql-query-reviewer"
38
+ MAX_STEPS = 10
39
+ SUCCESS_SCORE_THRESHOLD = 0.1
40
 
41
  SYSTEM_PROMPT = """You are reviewing a SQL query for correctness, performance, and security.
42
  Return exactly one JSON object with these keys:
43
  - action_type: identify_issue, suggest_fix, approve, or request_more_context
44
+ - issue_category: syntax, performance, security, logic, or style (when relevant)
45
+ - issue_description: concise issue statement (when relevant)
46
+ - suggested_fix: corrected SQL or corrected fragment (when relevant)
47
  - confidence: float between 0.0 and 1.0
48
 
49
  Guidelines:
50
+ - Prefer identify_issue until you believe all important issues are covered.
51
+ - Use approve only when the query looks acceptable or all issues have been identified.
52
+ - Return ONLY valid JSON, no prose, no markdown fences.
53
  """
54
 
55
  # ---------------------------------------------------------------------------
56
+ # Structured stdout logging — matches hackathon spec EXACTLY
57
  # ---------------------------------------------------------------------------
58
 
59
 
 
83
 
84
 
85
  # ---------------------------------------------------------------------------
86
+ # LLM interaction — fully wrapped in try/except
87
  # ---------------------------------------------------------------------------
88
 
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def extract_json(content: str) -> dict[str, Any]:
91
  stripped = content.strip()
92
  if stripped.startswith("```"):
93
+ lines = [l for l in stripped.splitlines() if not l.startswith("```")]
94
  stripped = "\n".join(lines).strip()
95
  start = stripped.find("{")
96
  end = stripped.rfind("}")
97
  if start == -1 or end == -1 or end <= start:
98
+ raise ValueError(f"No JSON object found in: {content[:200]!r}")
99
  return json.loads(stripped[start : end + 1])
100
 
101
 
102
+ def choose_action(llm_client: Any, model_name: str, observation: Any) -> dict[str, Any]:
103
+ """Ask LLM for an action. Returns a raw dict. NEVER raises."""
 
104
  try:
105
+ obs_data = {
106
+ "query": getattr(observation, "query", ""),
107
+ "schema_info": getattr(observation, "schema_info", {}),
108
+ "context": getattr(observation, "context", ""),
109
+ "issues_found_so_far": [],
110
+ "remaining_actions": getattr(observation, "remaining_actions", 0),
111
+ "difficulty": getattr(observation, "difficulty", "unknown"),
112
+ "feedback": getattr(observation, "feedback", ""),
113
+ }
114
+
115
+ # Safely serialize issues
116
+ for i in getattr(observation, "issues_found_so_far", []):
117
+ try:
118
+ obs_data["issues_found_so_far"].append(
119
+ i.model_dump() if hasattr(i, "model_dump") else str(i)
120
+ )
121
+ except Exception:
122
+ pass
123
+
124
  response = llm_client.chat.completions.create(
125
  model=model_name,
126
  temperature=0,
127
  max_tokens=300,
128
  messages=[
129
  {"role": "system", "content": SYSTEM_PROMPT},
130
+ {"role": "user", "content": json.dumps(obs_data, indent=2)},
131
  ],
132
  )
133
  content = response.choices[0].message.content or ""
134
+ parsed = extract_json(content)
135
+ if "action_type" not in parsed:
136
+ parsed["action_type"] = "approve"
137
+ return parsed
138
+
139
  except Exception as exc:
140
+ print(f"[DEBUG] choose_action error: {exc}", flush=True)
141
+ return {"action_type": "approve", "confidence": 0.1}
 
142
 
143
 
144
  # ---------------------------------------------------------------------------
145
+ # Episode runner (async) — for from_docker_image connection
146
  # ---------------------------------------------------------------------------
147
 
148
 
149
+ async def run_episode_async(
150
+ env: Any, llm_client: Any, model_name: str, task_id: str
151
  ) -> None:
152
  rewards: List[float] = []
153
  steps_taken = 0
154
  score = 0.0
155
  success = False
 
156
 
157
  log_start(task=task_id, env=BENCHMARK, model=model_name)
158
 
159
  try:
160
+ from sql_query_reviewer.models import SQLReviewAction
161
+
162
+ result = await env.reset()
163
+
164
+ for step_num in range(1, MAX_STEPS + 1):
165
+ if result.done:
166
+ break
167
+
168
+ action_dict = choose_action(llm_client, model_name, result.observation)
169
+
170
+ try:
171
+ action = SQLReviewAction.model_validate(action_dict)
172
+ except Exception:
173
+ action = SQLReviewAction(action_type="approve", confidence=0.1)
174
+
175
+ action_str = action.action_type
176
+ if action.issue_description:
177
+ action_str = f"{action.action_type}({action.issue_category})"
178
+
179
+ try:
180
+ result = await env.step(action)
181
+ except Exception as step_err:
182
+ print(f"[DEBUG] env.step error: {step_err}", flush=True)
183
+ log_step(step=step_num, action=action_str, reward=0.0, done=True, error=str(step_err))
184
+ rewards.append(0.0)
185
+ steps_taken = step_num
186
+ break
187
+
188
+ reward = result.reward if result.reward is not None else 0.0
189
+ rewards.append(reward)
190
+ steps_taken = step_num
191
+ error_msg = None
192
+ if hasattr(result, "info") and result.info:
193
+ error_msg = result.info.get("error")
194
+
195
+ log_step(step=step_num, action=action_str, reward=reward, done=result.done, error=error_msg)
196
+
197
+ if result.done:
198
+ break
199
+
200
+ # Final score
201
+ try:
202
+ state = await env.state()
203
+ if hasattr(state, "final_score") and state.final_score is not None:
204
+ score = state.final_score
205
+ else:
206
+ score = sum(rewards) / max(len(rewards), 1)
207
+ except Exception:
208
+ score = sum(rewards) / max(len(rewards), 1) if rewards else 0.0
209
+
210
+ score = max(0.0, min(1.0, score))
211
+ success = score >= SUCCESS_SCORE_THRESHOLD
212
+
213
+ except Exception as exc:
214
+ print(f"[DEBUG] Episode error: {exc}", flush=True)
215
+ traceback.print_exc(file=sys.stdout)
216
+
217
+ finally:
218
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
219
+
220
+
221
+ # ---------------------------------------------------------------------------
222
+ # Episode runner (sync) — for direct HTTP connection
223
+ # ---------------------------------------------------------------------------
224
+
225
+
226
+ def run_episode_sync(
227
+ env: Any, llm_client: Any, model_name: str, task_id: str
228
+ ) -> None:
229
+ rewards: List[float] = []
230
+ steps_taken = 0
231
+ score = 0.0
232
+ success = False
233
+
234
+ log_start(task=task_id, env=BENCHMARK, model=model_name)
235
+
236
+ try:
237
+ from sql_query_reviewer.models import SQLReviewAction
238
+
239
  result = env.reset(task_id=task_id)
240
 
241
+ for step_num in range(1, MAX_STEPS + 1):
242
+ if result.done:
243
+ break
244
+
245
+ action_dict = choose_action(llm_client, model_name, result.observation)
246
+
247
+ try:
248
+ action = SQLReviewAction.model_validate(action_dict)
249
+ except Exception:
250
+ action = SQLReviewAction(action_type="approve", confidence=0.1)
251
 
252
  action_str = action.action_type
253
  if action.issue_description:
 
254
  action_str = f"{action.action_type}({action.issue_category})"
255
 
256
+ try:
257
+ result = env.step(action)
258
+ except Exception as step_err:
259
+ print(f"[DEBUG] env.step error: {step_err}", flush=True)
260
+ log_step(step=step_num, action=action_str, reward=0.0, done=True, error=str(step_err))
261
+ rewards.append(0.0)
262
+ steps_taken = step_num
263
+ break
264
 
265
+ reward = result.reward if result.reward is not None else 0.0
266
  rewards.append(reward)
267
+ steps_taken = step_num
268
+ error_msg = None
269
+ if hasattr(result, "info") and result.info:
270
+ error_msg = result.info.get("error")
271
+
272
+ log_step(step=step_num, action=action_str, reward=reward, done=result.done, error=error_msg)
273
+
274
+ if result.done:
275
+ break
276
+
277
+ try:
278
+ state = env.state()
279
+ if hasattr(state, "final_score") and state.final_score is not None:
280
+ score = state.final_score
281
+ else:
282
+ score = sum(rewards) / max(len(rewards), 1)
283
+ except Exception:
284
+ score = sum(rewards) / max(len(rewards), 1) if rewards else 0.0
285
+
286
+ score = max(0.0, min(1.0, score))
287
  success = score >= SUCCESS_SCORE_THRESHOLD
288
 
289
  except Exception as exc:
290
  print(f"[DEBUG] Episode error: {exc}", flush=True)
291
+ traceback.print_exc(file=sys.stdout)
292
 
293
  finally:
294
  log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
295
 
296
 
297
  # ---------------------------------------------------------------------------
298
+ # Main — tries Docker image first (validator), then HTTP (local dev)
299
  # ---------------------------------------------------------------------------
300
 
301
 
302
+ async def async_main() -> int:
303
+ # Build LLM client (even without key, don't crash — emit logs and exit)
304
  if not API_KEY:
305
+ print("[DEBUG] WARNING: No API key found (HF_TOKEN / API_KEY / OPENAI_API_KEY)", flush=True)
306
+ for tid in ["easy_001", "medium_001", "hard_001"]:
307
+ log_start(task=tid, env=BENCHMARK, model=MODEL_NAME)
308
+ log_end(success=False, steps=0, score=0.0, rewards=[])
309
+ return 1
310
+
311
+ llm_client = OpenAI(api_key=API_KEY, base_url=API_BASE_URL)
312
 
313
  task_ids = tuple(
314
  tid.strip()
315
+ for tid in os.getenv("TASK_IDS", "easy_001,medium_001,hard_001").split(",")
316
  if tid.strip()
317
  )
318
 
319
+ env = None
320
 
321
+ try:
322
+ # ------------------------------------------------------------------
323
+ # Method 1: from_docker_image (what the hackathon validator uses)
324
+ # ------------------------------------------------------------------
325
+ if IMAGE_NAME:
326
+ print(f"[DEBUG] Connecting via Docker image: {IMAGE_NAME}", flush=True)
327
+ try:
328
+ from sql_query_reviewer.client import SQLReviewEnv
329
+ env = await SQLReviewEnv.from_docker_image(IMAGE_NAME)
330
+ print("[DEBUG] Docker connection OK", flush=True)
331
+ for task_id in task_ids:
332
+ await run_episode_async(env, llm_client, MODEL_NAME, task_id)
333
+ return 0
334
+ except AttributeError:
335
+ # from_docker_image not implemented in our client — try openenv-core
336
+ print("[DEBUG] from_docker_image not in custom client, trying openenv generic", flush=True)
337
+ try:
338
+ from openenv.core.env_client import GenericEnvClient
339
+ env = await GenericEnvClient.from_docker_image(IMAGE_NAME)
340
+ print("[DEBUG] GenericEnvClient Docker connection OK", flush=True)
341
+ for task_id in task_ids:
342
+ await run_episode_async(env, llm_client, MODEL_NAME, task_id)
343
+ return 0
344
+ except Exception as exc2:
345
+ print(f"[DEBUG] GenericEnvClient also failed: {exc2}", flush=True)
346
+ except Exception as exc:
347
+ print(f"[DEBUG] Docker connection failed: {exc}", flush=True)
348
+
349
+ # ------------------------------------------------------------------
350
+ # Method 2: Async HTTP (fallback for local/URL-based testing)
351
+ # ------------------------------------------------------------------
352
+ print(f"[DEBUG] Connecting via URL: {ENV_BASE_URL}", flush=True)
353
+ try:
354
+ from sql_query_reviewer.client import SQLReviewEnv
355
+ env = SQLReviewEnv(base_url=ENV_BASE_URL)
356
+ await env.__aenter__()
357
+ print("[DEBUG] Async URL connection OK", flush=True)
358
+ for task_id in task_ids:
359
+ await run_episode_async(env, llm_client, MODEL_NAME, task_id)
360
+ return 0
361
+ except Exception as exc:
362
+ print(f"[DEBUG] Async URL failed: {exc}", flush=True)
363
+
364
+ # ------------------------------------------------------------------
365
+ # Method 3: Sync HTTP (last resort)
366
+ # ------------------------------------------------------------------
367
+ try:
368
+ from sql_query_reviewer.client import SyncSQLReviewEnv
369
+ sync_env = SyncSQLReviewEnv(base_url=ENV_BASE_URL)
370
+ sync_env.__enter__()
371
+ print("[DEBUG] Sync HTTP connection OK", flush=True)
372
+ for task_id in task_ids:
373
+ run_episode_sync(sync_env, llm_client, MODEL_NAME, task_id)
374
+ sync_env.close()
375
+ return 0
376
+ except Exception as exc:
377
+ print(f"[DEBUG] Sync HTTP also failed: {exc}", flush=True)
378
+
379
+ # All methods failed — still emit valid log lines
380
+ print("[DEBUG] All connection methods exhausted", flush=True)
381
  for task_id in task_ids:
382
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
383
+ log_end(success=False, steps=0, score=0.0, rewards=[])
384
+ return 1
385
+
386
+ except Exception as exc:
387
+ print(f"[DEBUG] Fatal: {exc}", flush=True)
388
+ traceback.print_exc(file=sys.stdout)
389
+ return 1
390
+
391
+ finally:
392
+ if env is not None:
393
+ try:
394
+ await env.close()
395
+ except Exception as close_err:
396
+ print(f"[DEBUG] env.close() error: {close_err}", flush=True)
397
+
398
+
399
+ def main() -> int:
400
+ return asyncio.run(async_main())
401
 
402
 
403
  if __name__ == "__main__":
sql_query_reviewer/client.py CHANGED
@@ -1,18 +1,92 @@
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
  from typing import Any
4
 
5
  import httpx
6
 
7
- from sql_query_reviewer.models import ResetRequest, SQLReviewAction, SQLReviewState, StepResult
 
 
 
 
 
8
 
9
 
10
  class SQLReviewEnv:
 
 
11
  def __init__(self, base_url: str, timeout: float = 30.0) -> None:
12
  self.base_url = base_url.rstrip("/")
13
  self.timeout = timeout
14
  self._client: httpx.AsyncClient | None = None
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  async def __aenter__(self) -> "SQLReviewEnv":
17
  self._client = httpx.AsyncClient(base_url=self.base_url, timeout=self.timeout)
18
  return self
@@ -24,13 +98,25 @@ class SQLReviewEnv:
24
  if self._client is not None:
25
  await self._client.aclose()
26
  self._client = None
 
 
 
 
 
 
 
 
 
27
 
28
  def sync(self) -> "SyncSQLReviewEnv":
29
  return SyncSQLReviewEnv(base_url=self.base_url, timeout=self.timeout)
30
 
 
 
31
  async def reset(self, task_id: str | None = None) -> StepResult:
32
  client = self._require_client()
33
- response = await client.post("/reset", json=ResetRequest(task_id=task_id).model_dump(exclude_none=True))
 
34
  response.raise_for_status()
35
  return StepResult.model_validate(response.json())
36
 
@@ -48,11 +134,40 @@ class SQLReviewEnv:
48
 
49
  def _require_client(self) -> httpx.AsyncClient:
50
  if self._client is None:
51
- raise RuntimeError("Use SQLReviewEnv as an async context manager before calling it.")
52
  return self._client
53
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  class SyncSQLReviewEnv:
 
 
56
  def __init__(self, base_url: str, timeout: float = 30.0) -> None:
57
  self.base_url = base_url.rstrip("/")
58
  self.timeout = timeout
@@ -72,7 +187,8 @@ class SyncSQLReviewEnv:
72
 
73
  def reset(self, task_id: str | None = None) -> StepResult:
74
  client = self._require_client()
75
- response = client.post("/reset", json=ResetRequest(task_id=task_id).model_dump(exclude_none=True))
 
76
  response.raise_for_status()
77
  return StepResult.model_validate(response.json())
78
 
@@ -90,6 +206,5 @@ class SyncSQLReviewEnv:
90
 
91
  def _require_client(self) -> httpx.Client:
92
  if self._client is None:
93
- raise RuntimeError("Use SyncSQLReviewEnv as a context manager before calling it.")
94
  return self._client
95
-
 
1
+ """
2
+ SQL Query Reviewer — Client
3
+ ============================
4
+ Supports three connection modes:
5
+ 1. from_docker_image() — used by hackathon validator
6
+ 2. Async via SQLReviewEnv(base_url=...)
7
+ 3. Sync via SyncSQLReviewEnv(base_url=...)
8
+ """
9
+
10
  from __future__ import annotations
11
 
12
  from typing import Any
13
 
14
  import httpx
15
 
16
+ from sql_query_reviewer.models import (
17
+ ResetRequest,
18
+ SQLReviewAction,
19
+ SQLReviewState,
20
+ StepResult,
21
+ )
22
 
23
 
24
  class SQLReviewEnv:
25
+ """Async client for the SQL Query Reviewer environment."""
26
+
27
  def __init__(self, base_url: str, timeout: float = 30.0) -> None:
28
  self.base_url = base_url.rstrip("/")
29
  self.timeout = timeout
30
  self._client: httpx.AsyncClient | None = None
31
 
32
+ # --- Docker image support (hackathon validator) -----------------------
33
+
34
+ @classmethod
35
+ async def from_docker_image(cls, image_name: str) -> "SQLReviewEnv":
36
+ """
37
+ Connect to the environment via a Docker image.
38
+ Tries openenv-core's provider first, then falls back to localhost.
39
+ """
40
+ try:
41
+ # Try using openenv-core's built-in Docker provider
42
+ from openenv.core.env_client import EnvClient
43
+
44
+ class _Wrapper(EnvClient):
45
+ pass
46
+
47
+ env = await _Wrapper.from_docker_image(image_name)
48
+ # Wrap the openenv client so our typed models work
49
+ return _DockerEnvWrapper(env)
50
+ except ImportError:
51
+ pass
52
+ except Exception:
53
+ pass
54
+
55
+ # Fallback: assume the Docker container is already running on port 8000
56
+ import subprocess
57
+ import time
58
+
59
+ container_id = None
60
+ try:
61
+ result = subprocess.run(
62
+ ["docker", "run", "-d", "-p", "8000:8000", image_name],
63
+ capture_output=True,
64
+ text=True,
65
+ timeout=120,
66
+ )
67
+ container_id = result.stdout.strip()
68
+ except Exception:
69
+ pass
70
+
71
+ # Wait for container to be ready
72
+ base_url = "http://localhost:8000"
73
+ for _ in range(30):
74
+ try:
75
+ async with httpx.AsyncClient() as c:
76
+ r = await c.get(f"{base_url}/health", timeout=2.0)
77
+ if r.status_code == 200:
78
+ break
79
+ except Exception:
80
+ pass
81
+ time.sleep(1)
82
+
83
+ instance = cls(base_url=base_url)
84
+ instance._container_id = container_id # type: ignore[attr-defined]
85
+ await instance.__aenter__()
86
+ return instance
87
+
88
+ # --- Async context manager --------------------------------------------
89
+
90
  async def __aenter__(self) -> "SQLReviewEnv":
91
  self._client = httpx.AsyncClient(base_url=self.base_url, timeout=self.timeout)
92
  return self
 
98
  if self._client is not None:
99
  await self._client.aclose()
100
  self._client = None
101
+ # Clean up Docker container if we started one
102
+ container_id = getattr(self, "_container_id", None)
103
+ if container_id:
104
+ try:
105
+ import subprocess
106
+ subprocess.run(["docker", "stop", container_id], capture_output=True, timeout=10)
107
+ subprocess.run(["docker", "rm", container_id], capture_output=True, timeout=10)
108
+ except Exception:
109
+ pass
110
 
111
  def sync(self) -> "SyncSQLReviewEnv":
112
  return SyncSQLReviewEnv(base_url=self.base_url, timeout=self.timeout)
113
 
114
+ # --- API methods ------------------------------------------------------
115
+
116
  async def reset(self, task_id: str | None = None) -> StepResult:
117
  client = self._require_client()
118
+ body = ResetRequest(task_id=task_id).model_dump(exclude_none=True)
119
+ response = await client.post("/reset", json=body)
120
  response.raise_for_status()
121
  return StepResult.model_validate(response.json())
122
 
 
134
 
135
  def _require_client(self) -> httpx.AsyncClient:
136
  if self._client is None:
137
+ raise RuntimeError("Use SQLReviewEnv as an async context manager or call from_docker_image().")
138
  return self._client
139
 
140
 
141
+ class _DockerEnvWrapper(SQLReviewEnv):
142
+ """Wraps an openenv-core EnvClient to present our typed interface."""
143
+
144
+ def __init__(self, inner: Any) -> None:
145
+ self._inner = inner
146
+ self._client = None # not used — we delegate to inner
147
+ self.base_url = ""
148
+
149
+ async def reset(self, task_id: str | None = None) -> StepResult:
150
+ result = await self._inner.reset()
151
+ return StepResult.model_validate(result.model_dump() if hasattr(result, "model_dump") else result)
152
+
153
+ async def step(self, action: SQLReviewAction) -> StepResult:
154
+ result = await self._inner.step(action)
155
+ return StepResult.model_validate(result.model_dump() if hasattr(result, "model_dump") else result)
156
+
157
+ async def state(self) -> SQLReviewState:
158
+ result = await self._inner.state()
159
+ return SQLReviewState.model_validate(result.model_dump() if hasattr(result, "model_dump") else result)
160
+
161
+ async def close(self) -> None:
162
+ try:
163
+ await self._inner.close()
164
+ except Exception:
165
+ pass
166
+
167
+
168
  class SyncSQLReviewEnv:
169
+ """Synchronous client for local dev and testing."""
170
+
171
  def __init__(self, base_url: str, timeout: float = 30.0) -> None:
172
  self.base_url = base_url.rstrip("/")
173
  self.timeout = timeout
 
187
 
188
  def reset(self, task_id: str | None = None) -> StepResult:
189
  client = self._require_client()
190
+ body = ResetRequest(task_id=task_id).model_dump(exclude_none=True)
191
+ response = client.post("/reset", json=body)
192
  response.raise_for_status()
193
  return StepResult.model_validate(response.json())
194
 
 
206
 
207
  def _require_client(self) -> httpx.Client:
208
  if self._client is None:
209
+ raise RuntimeError("Use SyncSQLReviewEnv as a context manager.")
210
  return self._client