Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- README.md +4 -2
- client.py +23 -35
- graders/__init__.py +0 -0
- graders/grader_easy.py +110 -0
- graders/grader_hard.py +140 -0
- graders/grader_medium.py +135 -0
- inference.py +157 -92
- models.py +2 -19
- openenv_sql_debug.egg-info/PKG-INFO +1 -0
- openenv_sql_debug.egg-info/SOURCES.txt +6 -1
- openenv_sql_debug.egg-info/requires.txt +1 -0
- pyproject.toml +1 -0
- server/app.py +1 -1
- server/requirements.txt +2 -1
- server/sql_debug_environment.py +52 -52
- tasks/task_hard.py +2 -2
- uv.lock +2 -0
README.md
CHANGED
|
@@ -195,7 +195,6 @@ The inference script defaults to `syntax_fix_001`, logs each step, and stops whe
|
|
| 195 |
```text
|
| 196 |
sql_exp/
|
| 197 |
βββ client.py # OpenEnv client wrapper
|
| 198 |
-
βββ grader.py # Reward computation
|
| 199 |
βββ inference.py # LLM-driven inference loop
|
| 200 |
βββ models.py # Action and observation models
|
| 201 |
βββ openenv.yaml # OpenEnv manifest
|
|
@@ -209,7 +208,10 @@ sql_exp/
|
|
| 209 |
β βββ task_easy.py # Syntax-fix task
|
| 210 |
β βββ task_medium.py # Join logic task
|
| 211 |
β βββ task_hard.py # Query optimization task
|
| 212 |
-
βββ
|
|
|
|
|
|
|
|
|
|
| 213 |
βββ README.md # Project overview
|
| 214 |
```
|
| 215 |
|
|
|
|
| 195 |
```text
|
| 196 |
sql_exp/
|
| 197 |
βββ client.py # OpenEnv client wrapper
|
|
|
|
| 198 |
βββ inference.py # LLM-driven inference loop
|
| 199 |
βββ models.py # Action and observation models
|
| 200 |
βββ openenv.yaml # OpenEnv manifest
|
|
|
|
| 208 |
β βββ task_easy.py # Syntax-fix task
|
| 209 |
β βββ task_medium.py # Join logic task
|
| 210 |
β βββ task_hard.py # Query optimization task
|
| 211 |
+
βββ graders/
|
| 212 |
+
β βββ grader_easy.py # Syntax-fix task
|
| 213 |
+
β βββ grader_medium.py # Join logic task
|
| 214 |
+
β βββ grader_hard.py # Query optimization task
|
| 215 |
βββ README.md # Project overview
|
| 216 |
```
|
| 217 |
|
client.py
CHANGED
|
@@ -1,16 +1,6 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the BSD-style license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
# client.py
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
This is what inference.py uses to talk to the running server.
|
| 11 |
-
"""
|
| 12 |
-
|
| 13 |
-
from typing import Dict
|
| 14 |
|
| 15 |
from openenv.core import EnvClient
|
| 16 |
from openenv.core.client_types import StepResult
|
|
@@ -20,35 +10,33 @@ from models import SQLDebugAction, SQLDebugObservation
|
|
| 20 |
|
| 21 |
|
| 22 |
class SQLDebugEnv(EnvClient[SQLDebugAction, SQLDebugObservation, State]):
|
| 23 |
-
""
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
Maintains a persistent WebSocket connection to the server.
|
| 27 |
-
Each instance gets its own dedicated environment session.
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
print(result.reward)
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
"""
|
| 44 |
|
|
|
|
| 45 |
def _step_payload(self, action: SQLDebugAction) -> Dict:
|
| 46 |
-
"""Convert SQLDebugAction to JSON payload."""
|
| 47 |
return {"query": action.query}
|
| 48 |
|
|
|
|
|
|
|
| 49 |
def _parse_result(self, payload: Dict) -> StepResult[SQLDebugObservation]:
|
| 50 |
-
"""Parse server JSON response into a typed StepResult."""
|
| 51 |
obs_data = payload.get("observation", {})
|
|
|
|
| 52 |
|
| 53 |
observation = SQLDebugObservation(
|
| 54 |
task_id=obs_data.get("task_id", ""),
|
|
@@ -63,6 +51,7 @@ class SQLDebugEnv(EnvClient[SQLDebugAction, SQLDebugObservation, State]):
|
|
| 63 |
available_tasks=obs_data.get("available_tasks", []),
|
| 64 |
done=payload.get("done", False),
|
| 65 |
reward=payload.get("reward", 0.0),
|
|
|
|
| 66 |
)
|
| 67 |
|
| 68 |
return StepResult(
|
|
@@ -72,8 +61,7 @@ class SQLDebugEnv(EnvClient[SQLDebugAction, SQLDebugObservation, State]):
|
|
| 72 |
)
|
| 73 |
|
| 74 |
def _parse_state(self, payload: Dict) -> State:
|
| 75 |
-
"""Parse server JSON response into a State object."""
|
| 76 |
return State(
|
| 77 |
episode_id=payload.get("episode_id"),
|
| 78 |
step_count=payload.get("step_count", 0),
|
| 79 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# client.py
|
| 2 |
+
from typing import Dict, Optional
|
| 3 |
+
import httpx
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
from openenv.core import EnvClient
|
| 6 |
from openenv.core.client_types import StepResult
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
class SQLDebugEnv(EnvClient[SQLDebugAction, SQLDebugObservation, State]):
|
| 13 |
+
def __init__(self, base_url: str = "http://localhost:8000", **kwargs):
|
| 14 |
+
super().__init__(base_url=base_url, **kwargs)
|
| 15 |
+
self._base_url = base_url.rstrip("/")
|
|
|
|
|
|
|
| 16 |
|
| 17 |
+
# ββ Override reset to send task_id in body ββββββββββββββββββββββββββββββββ
|
| 18 |
+
async def reset(self, task_id: Optional[str] = None, **kwargs) -> StepResult:
|
| 19 |
+
payload = {}
|
| 20 |
+
if task_id:
|
| 21 |
+
payload["task_id"] = task_id
|
|
|
|
| 22 |
|
| 23 |
+
async with httpx.AsyncClient(timeout=30) as http:
|
| 24 |
+
response = await http.post(
|
| 25 |
+
f"{self._base_url}/reset",
|
| 26 |
+
json=payload,
|
| 27 |
+
)
|
| 28 |
+
response.raise_for_status()
|
| 29 |
+
return self._parse_result(response.json())
|
|
|
|
| 30 |
|
| 31 |
+
# ββ step payload ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 32 |
def _step_payload(self, action: SQLDebugAction) -> Dict:
|
|
|
|
| 33 |
return {"query": action.query}
|
| 34 |
|
| 35 |
+
# β update _parse_result only
|
| 36 |
+
|
| 37 |
def _parse_result(self, payload: Dict) -> StepResult[SQLDebugObservation]:
|
|
|
|
| 38 |
obs_data = payload.get("observation", {})
|
| 39 |
+
meta = obs_data.get("metadata", {}) # β feedback lives here now
|
| 40 |
|
| 41 |
observation = SQLDebugObservation(
|
| 42 |
task_id=obs_data.get("task_id", ""),
|
|
|
|
| 51 |
available_tasks=obs_data.get("available_tasks", []),
|
| 52 |
done=payload.get("done", False),
|
| 53 |
reward=payload.get("reward", 0.0),
|
| 54 |
+
metadata=meta,
|
| 55 |
)
|
| 56 |
|
| 57 |
return StepResult(
|
|
|
|
| 61 |
)
|
| 62 |
|
| 63 |
def _parse_state(self, payload: Dict) -> State:
|
|
|
|
| 64 |
return State(
|
| 65 |
episode_id=payload.get("episode_id"),
|
| 66 |
step_count=payload.get("step_count", 0),
|
| 67 |
+
)
|
graders/__init__.py
ADDED
|
File without changes
|
graders/grader_easy.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# tasks/grader_easy.py
|
| 2 |
+
"""
|
| 3 |
+
Grader for syntax_fix_001 β fix typos in SQL keywords.
|
| 4 |
+
Reward is shaped on: syntax correctness + F1 row match + step efficiency.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def grade(
|
| 9 |
+
task: dict,
|
| 10 |
+
agent_query: str,
|
| 11 |
+
run_result: dict,
|
| 12 |
+
prev_absolute_score: float = 0.0,
|
| 13 |
+
step_count: int = 1,
|
| 14 |
+
max_steps: int = 5,
|
| 15 |
+
) -> dict:
|
| 16 |
+
"""
|
| 17 |
+
Easy task grader. Pure row-match scoring β no plan check needed.
|
| 18 |
+
|
| 19 |
+
Reward components:
|
| 20 |
+
syntax_score : 0.0 or 1.0 β did the query run at all?
|
| 21 |
+
result_score : 0.0β1.0 β F1 of returned vs expected rows
|
| 22 |
+
efficiency_bonus: 0.0β0.05 β small bonus for solving early
|
| 23 |
+
delta : absolute_score - prev_absolute_score
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
syntax_ok = run_result["error"] is None
|
| 27 |
+
|
| 28 |
+
# ββ Syntax ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 29 |
+
if not syntax_ok:
|
| 30 |
+
absolute_score = 0.05 # tiny gradient so agent knows to fix syntax first
|
| 31 |
+
delta = absolute_score - prev_absolute_score
|
| 32 |
+
delta = max(-0.3, min(0.5, delta))
|
| 33 |
+
return {
|
| 34 |
+
"value": delta,
|
| 35 |
+
"absolute_score": absolute_score,
|
| 36 |
+
"syntax_ok": False,
|
| 37 |
+
"result_score": 0.0,
|
| 38 |
+
"plan_score": 0.0,
|
| 39 |
+
"delta": delta,
|
| 40 |
+
"status": "syntax_error",
|
| 41 |
+
"feedback": f"syntax_error: {run_result['error'][:100]}",
|
| 42 |
+
"message": f"syntax_error | abs=0.050 | delta={delta:+.3f}",
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
# ββ Row matching (F1) βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 46 |
+
expected = task["expected_rows"]
|
| 47 |
+
got = run_result["rows"]
|
| 48 |
+
|
| 49 |
+
if not got:
|
| 50 |
+
result_score = 0.0
|
| 51 |
+
else:
|
| 52 |
+
correct_returned = sum(1 for row in got if row in expected)
|
| 53 |
+
correct_expected = sum(1 for row in expected if row in got)
|
| 54 |
+
|
| 55 |
+
precision = correct_returned / max(len(got), 1)
|
| 56 |
+
recall = correct_expected / max(len(expected), 1)
|
| 57 |
+
|
| 58 |
+
if precision + recall > 0:
|
| 59 |
+
result_score = 2 * precision * recall / (precision + recall)
|
| 60 |
+
else:
|
| 61 |
+
result_score = 0.0
|
| 62 |
+
|
| 63 |
+
# ββ Efficiency bonus ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 64 |
+
steps_remaining = max_steps - step_count
|
| 65 |
+
efficiency_bonus = 0.0
|
| 66 |
+
if result_score >= 0.99:
|
| 67 |
+
efficiency_bonus = round(0.05 * (steps_remaining / max_steps), 4)
|
| 68 |
+
|
| 69 |
+
# ββ Absolute score β easy: syntax 15% + correctness 80% + bonus 5% βββββββ
|
| 70 |
+
absolute_score = round(
|
| 71 |
+
min(0.99, 0.15 * 1.0 + 0.80 * result_score + efficiency_bonus), 4
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# ββ Delta reward β the RL signal ββββββββββββββββββββββββββββββββββββββββββ
|
| 75 |
+
delta = absolute_score - prev_absolute_score
|
| 76 |
+
if abs(delta) < 0.001 and step_count > 1:
|
| 77 |
+
delta -= 0.02 # stall penalty β discourages repeating same query
|
| 78 |
+
delta = round(max(-0.3, min(0.5, delta)), 4)
|
| 79 |
+
|
| 80 |
+
# ββ Feedback for agent ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 81 |
+
issues = []
|
| 82 |
+
if result_score < 0.5:
|
| 83 |
+
issues.append("result_rows: returned rows do not match expected β check your WHERE clause")
|
| 84 |
+
elif result_score < 0.99:
|
| 85 |
+
issues.append(f"result_rows: partial match ({result_score:.0%}) β some rows still wrong")
|
| 86 |
+
if len(got) > len(expected):
|
| 87 |
+
issues.append(f"extra_rows: returned {len(got)} rows but expected {len(expected)}")
|
| 88 |
+
feedback = "; ".join(issues) if issues else "rows match β looking good"
|
| 89 |
+
|
| 90 |
+
status = (
|
| 91 |
+
"solved" if absolute_score >= 0.99
|
| 92 |
+
else "improving" if delta > 0.01
|
| 93 |
+
else "regression" if delta < -0.01
|
| 94 |
+
else "stalled"
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
return {
|
| 98 |
+
"value": delta,
|
| 99 |
+
"absolute_score": absolute_score,
|
| 100 |
+
"syntax_ok": True,
|
| 101 |
+
"result_score": result_score,
|
| 102 |
+
"plan_score": 0.0,
|
| 103 |
+
"delta": delta,
|
| 104 |
+
"status": status,
|
| 105 |
+
"feedback": feedback,
|
| 106 |
+
"message": (
|
| 107 |
+
f"{status} | abs={absolute_score:.3f} | delta={delta:+.3f} | "
|
| 108 |
+
f"result={result_score:.0%}"
|
| 109 |
+
),
|
| 110 |
+
}
|
graders/grader_hard.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# tasks/grader_hard.py
|
| 2 |
+
"""
|
| 3 |
+
Grader for optimize_001 β replace correlated subquery with CTE.
|
| 4 |
+
|
| 5 |
+
Unlike easy/medium, there are no fixed expected_rows.
|
| 6 |
+
Score is entirely driven by query plan quality:
|
| 7 |
+
- uses WITH (CTE)
|
| 8 |
+
- uses GROUP BY
|
| 9 |
+
- uses AVG(
|
| 10 |
+
- does NOT use correlated subquery pattern
|
| 11 |
+
- executes without error
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def grade(
|
| 16 |
+
task: dict,
|
| 17 |
+
agent_query: str,
|
| 18 |
+
run_result: dict,
|
| 19 |
+
prev_absolute_score: float = 0.0,
|
| 20 |
+
step_count: int = 1,
|
| 21 |
+
max_steps: int = 10,
|
| 22 |
+
) -> dict:
|
| 23 |
+
|
| 24 |
+
syntax_ok = run_result["error"] is None
|
| 25 |
+
|
| 26 |
+
# ββ Syntax ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 27 |
+
if not syntax_ok:
|
| 28 |
+
absolute_score = 0.05
|
| 29 |
+
delta = round(
|
| 30 |
+
max(-0.3, min(0.5, absolute_score - prev_absolute_score)), 4
|
| 31 |
+
)
|
| 32 |
+
return {
|
| 33 |
+
"value": delta,
|
| 34 |
+
"absolute_score": absolute_score,
|
| 35 |
+
"syntax_ok": False,
|
| 36 |
+
"result_score": 0.0,
|
| 37 |
+
"plan_score": 0.0,
|
| 38 |
+
"delta": delta,
|
| 39 |
+
"status": "syntax_error",
|
| 40 |
+
"feedback": f"syntax_error: {run_result['error'][:100]}",
|
| 41 |
+
"message": f"syntax_error | abs=0.050 | delta={delta:+.3f}",
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
query_upper = agent_query.upper()
|
| 45 |
+
# good_patterns = task.get("good_patterns", ["WITH", "GROUP BY", "AVG("])
|
| 46 |
+
|
| 47 |
+
# ββ Plan component scores βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 48 |
+
|
| 49 |
+
# 1. Uses CTE (WITH keyword) β most important signal
|
| 50 |
+
has_cte = "WITH" in query_upper
|
| 51 |
+
cte_score = 1.0 if has_cte else 0.0
|
| 52 |
+
|
| 53 |
+
# 2. Uses GROUP BY β required for computing per-user average
|
| 54 |
+
has_group_by = "GROUP BY" in query_upper
|
| 55 |
+
group_score = 1.0 if has_group_by else 0.0
|
| 56 |
+
|
| 57 |
+
# 3. Uses AVG β must be aggregating correctly
|
| 58 |
+
has_avg = "AVG(" in query_upper
|
| 59 |
+
avg_score = 1.0 if has_avg else 0.0
|
| 60 |
+
|
| 61 |
+
# 4. Correlated subquery penalty β still using the slow pattern
|
| 62 |
+
still_correlated = (
|
| 63 |
+
"SELECT AVG" in query_upper
|
| 64 |
+
and "WHERE" in query_upper
|
| 65 |
+
and not has_cte # WITH overrides this penalty
|
| 66 |
+
)
|
| 67 |
+
correlation_penalty = 0.4 if still_correlated else 0.0
|
| 68 |
+
|
| 69 |
+
# 5. Execution quality β did the query actually return rows?
|
| 70 |
+
rows_returned = len(run_result["rows"])
|
| 71 |
+
execution_score = 1.0 if rows_returned > 0 else 0.3
|
| 72 |
+
# 0.3 credit for running without error even if empty result
|
| 73 |
+
|
| 74 |
+
# ββ Plan score weighted combination βββββββββββββββββββββββββββββββββββββββ
|
| 75 |
+
# CTE 40% + GROUP BY 25% + AVG 20% + execution 15%
|
| 76 |
+
plan_score = round(
|
| 77 |
+
max(
|
| 78 |
+
0.0,
|
| 79 |
+
0.40 * cte_score
|
| 80 |
+
+ 0.25 * group_score
|
| 81 |
+
+ 0.20 * avg_score
|
| 82 |
+
+ 0.15 * execution_score
|
| 83 |
+
- correlation_penalty,
|
| 84 |
+
),
|
| 85 |
+
4,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# ββ Efficiency bonus ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 89 |
+
steps_remaining = max_steps - step_count
|
| 90 |
+
efficiency_bonus = 0.0
|
| 91 |
+
if plan_score >= 0.85:
|
| 92 |
+
efficiency_bonus = round(0.05 * (steps_remaining / max_steps), 4)
|
| 93 |
+
|
| 94 |
+
# ββ Absolute score β hard: syntax 10% + plan 85% + bonus 5% βββββββββββββ
|
| 95 |
+
absolute_score = round(
|
| 96 |
+
min(0.99, 0.10 * 1.0 + 0.85 * plan_score + efficiency_bonus), 4
|
| 97 |
+
)
|
| 98 |
+
absolute_score = max(0.05, absolute_score)
|
| 99 |
+
|
| 100 |
+
# ββ Delta βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 101 |
+
delta = absolute_score - prev_absolute_score
|
| 102 |
+
if abs(delta) < 0.001 and step_count > 1:
|
| 103 |
+
delta -= 0.02
|
| 104 |
+
delta = round(max(-0.3, min(0.5, delta)), 4)
|
| 105 |
+
|
| 106 |
+
# ββ Feedback βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 107 |
+
issues = []
|
| 108 |
+
if not has_cte:
|
| 109 |
+
issues.append("missing_cte: query needs WITH clause to precompute averages")
|
| 110 |
+
if not has_group_by:
|
| 111 |
+
issues.append("missing_group_by: need GROUP BY user_id to compute per-user avg")
|
| 112 |
+
if not has_avg:
|
| 113 |
+
issues.append("missing_avg: need AVG(amount) in the CTE")
|
| 114 |
+
if still_correlated:
|
| 115 |
+
issues.append("still_correlated: subquery in WHERE runs per-row β move to CTE")
|
| 116 |
+
if rows_returned == 0 and syntax_ok:
|
| 117 |
+
issues.append("empty_result: query runs but returns no rows β check JOIN and WHERE")
|
| 118 |
+
feedback = "; ".join(issues) if issues else "plan looks optimized"
|
| 119 |
+
|
| 120 |
+
status = (
|
| 121 |
+
"solved" if absolute_score >= 0.99
|
| 122 |
+
else "improving" if delta > 0.01
|
| 123 |
+
else "regression" if delta < -0.01
|
| 124 |
+
else "stalled"
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
return {
|
| 128 |
+
"value": delta,
|
| 129 |
+
"absolute_score": absolute_score,
|
| 130 |
+
"syntax_ok": True,
|
| 131 |
+
"result_score": execution_score,
|
| 132 |
+
"plan_score": plan_score,
|
| 133 |
+
"delta": delta,
|
| 134 |
+
"status": status,
|
| 135 |
+
"feedback": feedback,
|
| 136 |
+
"message": (
|
| 137 |
+
f"{status} | abs={absolute_score:.3f} | delta={delta:+.3f} | "
|
| 138 |
+
f"plan={plan_score:.0%} | cte={has_cte} | group={has_group_by}"
|
| 139 |
+
),
|
| 140 |
+
}
|
graders/grader_medium.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# tasks/grader_medium.py
|
| 2 |
+
"""
|
| 3 |
+
Grader for logic_fix_001 β fix wrong JOIN type / WHERE logic.
|
| 4 |
+
|
| 5 |
+
Harder than easy: agent must get BOTH precision and recall right.
|
| 6 |
+
Extra penalty for wrong row count (catches SELECT * with no WHERE).
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def grade(
|
| 11 |
+
task: dict,
|
| 12 |
+
agent_query: str,
|
| 13 |
+
run_result: dict,
|
| 14 |
+
prev_absolute_score: float = 0.0,
|
| 15 |
+
step_count: int = 1,
|
| 16 |
+
max_steps: int = 8,
|
| 17 |
+
) -> dict:
|
| 18 |
+
|
| 19 |
+
syntax_ok = run_result["error"] is None
|
| 20 |
+
|
| 21 |
+
# ββ Syntax ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 22 |
+
if not syntax_ok:
|
| 23 |
+
absolute_score = 0.05
|
| 24 |
+
delta = round(
|
| 25 |
+
max(-0.3, min(0.5, absolute_score - prev_absolute_score)), 4
|
| 26 |
+
)
|
| 27 |
+
return {
|
| 28 |
+
"value": delta,
|
| 29 |
+
"absolute_score": absolute_score,
|
| 30 |
+
"syntax_ok": False,
|
| 31 |
+
"result_score": 0.0,
|
| 32 |
+
"plan_score": 0.0,
|
| 33 |
+
"delta": delta,
|
| 34 |
+
"status": "syntax_error",
|
| 35 |
+
"feedback": f"syntax_error: {run_result['error'][:100]}",
|
| 36 |
+
"message": f"syntax_error | abs=0.050 | delta={delta:+.3f}",
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
expected = task["expected_rows"]
|
| 40 |
+
got = run_result["rows"]
|
| 41 |
+
|
| 42 |
+
# ββ F1 row score ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 43 |
+
if not got:
|
| 44 |
+
result_score = 0.0
|
| 45 |
+
else:
|
| 46 |
+
correct_returned = sum(1 for row in got if row in expected)
|
| 47 |
+
correct_expected = sum(1 for row in expected if row in got)
|
| 48 |
+
|
| 49 |
+
precision = correct_returned / max(len(got), 1)
|
| 50 |
+
recall = correct_expected / max(len(expected), 1)
|
| 51 |
+
|
| 52 |
+
if precision + recall > 0:
|
| 53 |
+
result_score = 2 * precision * recall / (precision + recall)
|
| 54 |
+
else:
|
| 55 |
+
result_score = 0.0
|
| 56 |
+
|
| 57 |
+
# ββ Extra penalty for wrong row count βββββββββββββββββββββββββββββββββββββ
|
| 58 |
+
# Logic bugs typically show up as too many rows (LEFT JOIN returns NULLs)
|
| 59 |
+
# Penalize harder than easy task to encourage precise reasoning
|
| 60 |
+
row_count_penalty = 0.0
|
| 61 |
+
if len(got) > len(expected):
|
| 62 |
+
extra = len(got) - len(expected)
|
| 63 |
+
row_count_penalty = min(0.25, extra * 0.08)
|
| 64 |
+
|
| 65 |
+
# ββ JOIN type hint score ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 66 |
+
# Gives partial credit for using the right JOIN type even if rows are off
|
| 67 |
+
# Avoids zero-reward cliff for agents that fix JOIN but have minor issues
|
| 68 |
+
query_upper = agent_query.upper()
|
| 69 |
+
join_score = 0.0
|
| 70 |
+
if "INNER JOIN" in query_upper:
|
| 71 |
+
join_score = 0.15 # using INNER JOIN is the right direction
|
| 72 |
+
elif "LEFT JOIN" in query_upper:
|
| 73 |
+
join_score = 0.0 # LEFT JOIN is the bug β no credit
|
| 74 |
+
elif "JOIN" in query_upper:
|
| 75 |
+
join_score = 0.05 # some join exists β small credit
|
| 76 |
+
|
| 77 |
+
# ββ Efficiency bonus ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 78 |
+
steps_remaining = max_steps - step_count
|
| 79 |
+
efficiency_bonus = 0.0
|
| 80 |
+
if result_score >= 0.99:
|
| 81 |
+
efficiency_bonus = round(0.05 * (steps_remaining / max_steps), 4)
|
| 82 |
+
|
| 83 |
+
# ββ Absolute score β medium: syntax 10% + correctness 70% + join 15% + bonus 5% ββ
|
| 84 |
+
absolute_score = round(
|
| 85 |
+
min(
|
| 86 |
+
0.99,
|
| 87 |
+
0.10 * 1.0
|
| 88 |
+
+ 0.70 * result_score
|
| 89 |
+
+ 0.15 * join_score
|
| 90 |
+
+ efficiency_bonus
|
| 91 |
+
- row_count_penalty,
|
| 92 |
+
),
|
| 93 |
+
4,
|
| 94 |
+
)
|
| 95 |
+
absolute_score = max(0.05, absolute_score) # floor at 0.05
|
| 96 |
+
|
| 97 |
+
# ββ Delta βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 98 |
+
delta = absolute_score - prev_absolute_score
|
| 99 |
+
if abs(delta) < 0.001 and step_count > 1:
|
| 100 |
+
delta -= 0.02
|
| 101 |
+
delta = round(max(-0.3, min(0.5, delta)), 4)
|
| 102 |
+
|
| 103 |
+
# ββ Feedback βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 104 |
+
issues = []
|
| 105 |
+
if "LEFT JOIN" in query_upper:
|
| 106 |
+
issues.append("join_type: using LEFT JOIN includes rows with no matching department")
|
| 107 |
+
if len(got) > len(expected):
|
| 108 |
+
issues.append(f"extra_rows: got {len(got)} rows, expected {len(expected)} β filter too loose")
|
| 109 |
+
if len(got) < len(expected) and len(got) > 0:
|
| 110 |
+
issues.append(f"missing_rows: got {len(got)} rows, expected {len(expected)} β filter too strict")
|
| 111 |
+
if result_score < 0.5:
|
| 112 |
+
issues.append("result_rows: output does not match expected β check JOIN and WHERE")
|
| 113 |
+
feedback = "; ".join(issues) if issues else "rows and join look correct"
|
| 114 |
+
|
| 115 |
+
status = (
|
| 116 |
+
"solved" if absolute_score >= 0.99
|
| 117 |
+
else "improving" if delta > 0.01
|
| 118 |
+
else "regression" if delta < -0.01
|
| 119 |
+
else "stalled"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
return {
|
| 123 |
+
"value": delta,
|
| 124 |
+
"absolute_score": absolute_score,
|
| 125 |
+
"syntax_ok": True,
|
| 126 |
+
"result_score": result_score,
|
| 127 |
+
"plan_score": join_score,
|
| 128 |
+
"delta": delta,
|
| 129 |
+
"status": status,
|
| 130 |
+
"feedback": feedback,
|
| 131 |
+
"message": (
|
| 132 |
+
f"{status} | abs={absolute_score:.3f} | delta={delta:+.3f} | "
|
| 133 |
+
f"result={result_score:.0%} | join={join_score:.2f}"
|
| 134 |
+
),
|
| 135 |
+
}
|
inference.py
CHANGED
|
@@ -1,50 +1,38 @@
|
|
| 1 |
ο»Ώ# inference.py
|
| 2 |
-
"""
|
| 3 |
-
SQL Debug & Optimizer β OpenEnv Inference Script
|
| 4 |
-
|
| 5 |
-
Mandatory stdout format:
|
| 6 |
-
[START] task=<task_name> env=<benchmark> model=<model_name>
|
| 7 |
-
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
|
| 8 |
-
[END] success=<true|false> steps=<n> score=<0.00> rewards=<r1,r2,...>
|
| 9 |
-
"""
|
| 10 |
-
|
| 11 |
import asyncio
|
| 12 |
import os
|
| 13 |
import textwrap
|
| 14 |
-
from typing import List, Optional
|
| 15 |
|
| 16 |
from openai import OpenAI
|
| 17 |
from client import SQLDebugEnv, SQLDebugAction
|
| 18 |
|
| 19 |
-
# ββ
|
| 20 |
IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
|
| 21 |
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 22 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
|
| 23 |
-
MODEL_NAME = os.getenv("MODEL_NAME",
|
|
|
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
SUCCESS_THRESHOLD = 0.5 # reward >= 0.5 = success
|
| 32 |
|
| 33 |
|
| 34 |
-
# ββ
|
| 35 |
|
| 36 |
def log_start(task: str, env: str, model: str) -> None:
|
| 37 |
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 38 |
|
| 39 |
|
| 40 |
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
|
| 41 |
-
# action must be single-line β newlines break log parsing
|
| 42 |
action_clean = action.replace("\n", " ").replace("\r", "").strip()
|
| 43 |
-
error_val = error if error else "null"
|
| 44 |
-
done_val = str(done).lower()
|
| 45 |
print(
|
| 46 |
f"[STEP] step={step} action={action_clean} reward={reward:.2f} "
|
| 47 |
-
f"done={
|
| 48 |
flush=True,
|
| 49 |
)
|
| 50 |
|
|
@@ -58,124 +46,201 @@ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> No
|
|
| 58 |
)
|
| 59 |
|
| 60 |
|
| 61 |
-
# ββ
|
| 62 |
-
|
| 63 |
-
SYSTEM_PROMPT =
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
"""
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
-
def build_prompt(obs) -> str:
|
| 78 |
-
"""Build the user prompt from the current observation."""
|
| 79 |
-
result_preview = str(obs.query_result[:3]) if obs.query_result else "empty / error"
|
| 80 |
return textwrap.dedent(f"""
|
| 81 |
-
TASK: {
|
|
|
|
| 82 |
|
| 83 |
-
|
| 84 |
-
{obs.schema_sql.strip()[:
|
| 85 |
|
| 86 |
-
CURRENT QUERY
|
| 87 |
{obs.current_query.strip()}
|
| 88 |
|
| 89 |
ERROR: {obs.error_message or "none"}
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
| 92 |
|
| 93 |
-
Write the corrected SQL
|
| 94 |
""").strip()
|
| 95 |
|
| 96 |
|
| 97 |
-
def call_llm(
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
try:
|
| 100 |
completion = client.chat.completions.create(
|
| 101 |
model=MODEL_NAME,
|
| 102 |
-
messages=
|
| 103 |
-
{"role": "system", "content": SYSTEM_PROMPT},
|
| 104 |
-
{"role": "user", "content": build_prompt(obs)},
|
| 105 |
-
],
|
| 106 |
temperature=TEMPERATURE,
|
| 107 |
max_tokens=MAX_TOKENS,
|
| 108 |
stream=False,
|
| 109 |
)
|
| 110 |
raw = (completion.choices[0].message.content or "").strip()
|
| 111 |
-
|
| 112 |
-
# Strip markdown code fences if model adds them despite instructions
|
| 113 |
if "```" in raw:
|
| 114 |
-
lines = raw.split("\n")
|
| 115 |
raw = "\n".join(
|
| 116 |
-
|
|
|
|
| 117 |
).strip()
|
| 118 |
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
except Exception as exc:
|
| 122 |
print(f"[DEBUG] LLM call failed: {exc}", flush=True)
|
| 123 |
return "SELECT 1"
|
| 124 |
|
| 125 |
|
| 126 |
-
# ββ Main
|
| 127 |
|
| 128 |
async def main() -> None:
|
| 129 |
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
|
|
|
| 130 |
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
steps_taken
|
| 137 |
-
score
|
| 138 |
-
success
|
| 139 |
|
| 140 |
log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
|
| 141 |
|
| 142 |
try:
|
| 143 |
-
# Reset
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
obs = result.observation
|
| 146 |
|
|
|
|
| 147 |
for step in range(1, MAX_STEPS + 1):
|
| 148 |
if result.done:
|
| 149 |
break
|
| 150 |
|
| 151 |
-
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
-
# Submit to environment
|
| 155 |
result = await env.step(SQLDebugAction(query=sql_query))
|
| 156 |
-
obs
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
steps_taken = step
|
| 164 |
|
| 165 |
-
log_step(
|
| 166 |
-
step=step,
|
| 167 |
-
action=sql_query,
|
| 168 |
-
reward=reward,
|
| 169 |
-
done=done,
|
| 170 |
-
error=error,
|
| 171 |
-
)
|
| 172 |
|
| 173 |
if done:
|
| 174 |
break
|
| 175 |
|
| 176 |
-
#
|
| 177 |
-
score
|
| 178 |
-
score
|
| 179 |
success = score >= SUCCESS_THRESHOLD
|
| 180 |
|
| 181 |
except Exception as exc:
|
|
@@ -187,7 +252,7 @@ async def main() -> None:
|
|
| 187 |
except Exception as e:
|
| 188 |
print(f"[DEBUG] env.close() error: {e}", flush=True)
|
| 189 |
|
| 190 |
-
log_end(success=success, steps=steps_taken, score=score, rewards=
|
| 191 |
|
| 192 |
|
| 193 |
if __name__ == "__main__":
|
|
|
|
| 1 |
ο»Ώ# inference.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import asyncio
|
| 3 |
import os
|
| 4 |
import textwrap
|
| 5 |
+
from typing import List, Optional, Dict
|
| 6 |
|
| 7 |
from openai import OpenAI
|
| 8 |
from client import SQLDebugEnv, SQLDebugAction
|
| 9 |
|
| 10 |
+
# ββ Env vars ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 11 |
IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
|
| 12 |
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 13 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
|
| 14 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.3-70b-versatile")
|
| 15 |
+
SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000")
|
| 16 |
|
| 17 |
+
TASK_NAME = os.getenv("SQL_ENV_TASK", "syntax_fix_001")
|
| 18 |
+
BENCHMARK = "sql-debug-optimizer"
|
| 19 |
+
MAX_STEPS = 8
|
| 20 |
+
TEMPERATURE = 0.3
|
| 21 |
+
MAX_TOKENS = 400
|
| 22 |
+
SUCCESS_THRESHOLD = 0.5
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
+
# ββ Stdout loggers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 26 |
|
| 27 |
def log_start(task: str, env: str, model: str) -> None:
|
| 28 |
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 29 |
|
| 30 |
|
| 31 |
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
|
|
|
|
| 32 |
action_clean = action.replace("\n", " ").replace("\r", "").strip()
|
|
|
|
|
|
|
| 33 |
print(
|
| 34 |
f"[STEP] step={step} action={action_clean} reward={reward:.2f} "
|
| 35 |
+
f"done={str(done).lower()} error={error or 'null'}",
|
| 36 |
flush=True,
|
| 37 |
)
|
| 38 |
|
|
|
|
| 46 |
)
|
| 47 |
|
| 48 |
|
| 49 |
+
# ββ Prompts βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 50 |
+
|
| 51 |
+
SYSTEM_PROMPT = """You are an expert SQL engineer fixing and optimizing SQL queries.
|
| 52 |
+
|
| 53 |
+
STRICT OUTPUT RULES:
|
| 54 |
+
- Output ONLY raw SQL. No markdown. No backticks. No explanation. No comments.
|
| 55 |
+
- Your output is executed directly against a SQLite database.
|
| 56 |
+
- If your previous attempt got negative reward, you made things worse β try differently.
|
| 57 |
+
- If reward is stalled (same score 2+ steps), change strategy significantly."""
|
| 58 |
+
|
| 59 |
+
TASK_CONTEXT = {
|
| 60 |
+
"syntax_fix_001": "The query has typographical errors in SQL keywords.",
|
| 61 |
+
"logic_fix_001": "The query runs but returns incorrect rows due to a logic error.",
|
| 62 |
+
"optimize_001": "The query is correct but slow. Rewrite it to be faster.",
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
GRADUATED_HINTS = {
|
| 66 |
+
"syntax_fix_001": [
|
| 67 |
+
"",
|
| 68 |
+
"Check the spelling of SQL keywords like SELECT, FROM, WHERE.",
|
| 69 |
+
"Compare each word: SELECT FROM WHERE ORDER BY GROUP BY β fix any typos.",
|
| 70 |
+
"The typos are: SELEC β SELECT, FORM β FROM, WERE β WHERE.",
|
| 71 |
+
],
|
| 72 |
+
"logic_fix_001": [
|
| 73 |
+
"",
|
| 74 |
+
"The query returns more rows than expected. Check your JOIN type.",
|
| 75 |
+
"LEFT JOIN includes rows even when no match exists. Consider INNER JOIN.",
|
| 76 |
+
"Change LEFT JOIN to INNER JOIN to exclude employees with no matching department.",
|
| 77 |
+
],
|
| 78 |
+
"optimize_001": [
|
| 79 |
+
"",
|
| 80 |
+
"The query uses a subquery that runs once per row β this is slow.",
|
| 81 |
+
"Compute the per-user average once using GROUP BY, then JOIN the result.",
|
| 82 |
+
"Use: WITH user_avg AS (SELECT user_id, AVG(amount) AS avg FROM transactions GROUP BY user_id) SELECT t.* FROM transactions t JOIN user_avg u ON t.user_id = u.user_id WHERE t.amount > u.avg AND t.status = 'completed'",
|
| 83 |
+
],
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_hint_level(step: int, stall_count: int) -> int:
|
| 88 |
+
if step <= 2 and stall_count < 2:
|
| 89 |
+
return 0
|
| 90 |
+
if step <= 4 and stall_count < 4:
|
| 91 |
+
return 1
|
| 92 |
+
if step <= 6:
|
| 93 |
+
return 2
|
| 94 |
+
return 3
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def build_prompt(obs, step: int, stall_count: int, prev_delta: float) -> str:
|
| 98 |
+
context = TASK_CONTEXT.get(obs.task_id, "Fix the SQL query.")
|
| 99 |
+
hint_level = get_hint_level(step, stall_count)
|
| 100 |
+
hint = GRADUATED_HINTS.get(obs.task_id, [""] * 4)[hint_level]
|
| 101 |
+
result_preview = str(obs.query_result[:3]) if obs.query_result else "none"
|
| 102 |
+
|
| 103 |
+
# β read feedback from metadata dict, not obs.feedback
|
| 104 |
+
meta = obs.metadata or {}
|
| 105 |
+
feedback = meta.get("feedback", "analyse the result yourself")
|
| 106 |
+
|
| 107 |
+
reward_context = ""
|
| 108 |
+
if step > 1:
|
| 109 |
+
if prev_delta > 0.01:
|
| 110 |
+
reward_context = f"Last change IMPROVED score (+{prev_delta:.2f}). Keep going."
|
| 111 |
+
elif prev_delta < -0.01:
|
| 112 |
+
reward_context = f"Last change WORSENED score ({prev_delta:.2f}). Revert and try differently."
|
| 113 |
+
else:
|
| 114 |
+
reward_context = f"Last change had NO EFFECT (delta={prev_delta:.2f}). Try a completely different approach."
|
| 115 |
+
|
| 116 |
+
hint_block = f"\nHINT: {hint}" if hint else ""
|
| 117 |
|
|
|
|
|
|
|
|
|
|
| 118 |
return textwrap.dedent(f"""
|
| 119 |
+
TASK: {context}
|
| 120 |
+
{reward_context}{hint_block}
|
| 121 |
|
| 122 |
+
SCHEMA:
|
| 123 |
+
{obs.schema_sql.strip()[:600]}
|
| 124 |
|
| 125 |
+
CURRENT QUERY:
|
| 126 |
{obs.current_query.strip()}
|
| 127 |
|
| 128 |
ERROR: {obs.error_message or "none"}
|
| 129 |
+
RESULT (first 3 rows): {result_preview}
|
| 130 |
+
FEEDBACK: {feedback}
|
| 131 |
+
BEST SCORE SO FAR: {obs.reward_so_far:.3f}
|
| 132 |
+
STEP: {step} of {MAX_STEPS}
|
| 133 |
|
| 134 |
+
Write the corrected SQL:
|
| 135 |
""").strip()
|
| 136 |
|
| 137 |
|
| 138 |
+
def call_llm(
|
| 139 |
+
client: OpenAI,
|
| 140 |
+
obs,
|
| 141 |
+
history: List[Dict],
|
| 142 |
+
step: int,
|
| 143 |
+
stall_count: int,
|
| 144 |
+
prev_delta: float,
|
| 145 |
+
) -> str:
|
| 146 |
+
user_content = build_prompt(obs, step, stall_count, prev_delta)
|
| 147 |
+
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
| 148 |
+
messages.extend(history[-6:])
|
| 149 |
+
messages.append({"role": "user", "content": user_content})
|
| 150 |
+
|
| 151 |
try:
|
| 152 |
completion = client.chat.completions.create(
|
| 153 |
model=MODEL_NAME,
|
| 154 |
+
messages=messages,
|
|
|
|
|
|
|
|
|
|
| 155 |
temperature=TEMPERATURE,
|
| 156 |
max_tokens=MAX_TOKENS,
|
| 157 |
stream=False,
|
| 158 |
)
|
| 159 |
raw = (completion.choices[0].message.content or "").strip()
|
|
|
|
|
|
|
| 160 |
if "```" in raw:
|
|
|
|
| 161 |
raw = "\n".join(
|
| 162 |
+
l for l in raw.split("\n")
|
| 163 |
+
if not l.strip().startswith("```")
|
| 164 |
).strip()
|
| 165 |
|
| 166 |
+
result = raw if raw else "SELECT 1"
|
| 167 |
+
history.append({"role": "user", "content": user_content})
|
| 168 |
+
history.append({"role": "assistant", "content": result})
|
| 169 |
+
return result
|
| 170 |
|
| 171 |
except Exception as exc:
|
| 172 |
print(f"[DEBUG] LLM call failed: {exc}", flush=True)
|
| 173 |
return "SELECT 1"
|
| 174 |
|
| 175 |
|
| 176 |
+
# ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 177 |
|
| 178 |
async def main() -> None:
|
| 179 |
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 180 |
+
env = SQLDebugEnv(base_url=SERVER_URL)
|
| 181 |
|
| 182 |
+
delta_rewards: List[float] = [] # per-step delta β logged in [STEP]
|
| 183 |
+
abs_scores: List[float] = [] # per-step absolute β used for final score
|
| 184 |
+
history: List[Dict] = []
|
| 185 |
+
stall_count = 0
|
| 186 |
+
prev_delta = 0.0
|
| 187 |
+
steps_taken = 0
|
| 188 |
+
score = 0.0
|
| 189 |
+
success = False
|
| 190 |
|
| 191 |
log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
|
| 192 |
|
| 193 |
try:
|
| 194 |
+
# ββ Reset βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 195 |
+
try:
|
| 196 |
+
result = await env.reset(task_id=TASK_NAME)
|
| 197 |
+
except Exception as e:
|
| 198 |
+
print(f"[DEBUG] reset() failed: {e}", flush=True)
|
| 199 |
+
raise
|
| 200 |
+
|
| 201 |
obs = result.observation
|
| 202 |
|
| 203 |
+
# ββ Episode loop ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 204 |
for step in range(1, MAX_STEPS + 1):
|
| 205 |
if result.done:
|
| 206 |
break
|
| 207 |
|
| 208 |
+
sql_query = call_llm(
|
| 209 |
+
client, obs, history,
|
| 210 |
+
step=step,
|
| 211 |
+
stall_count=stall_count,
|
| 212 |
+
prev_delta=prev_delta,
|
| 213 |
+
)
|
| 214 |
|
|
|
|
| 215 |
result = await env.step(SQLDebugAction(query=sql_query))
|
| 216 |
+
obs = result.observation
|
| 217 |
+
|
| 218 |
+
# delta reward from grader (can be negative)
|
| 219 |
+
delta = result.reward or 0.0
|
| 220 |
+
# absolute score tracked via reward_so_far on observation
|
| 221 |
+
abs_s = obs.reward_so_far
|
| 222 |
+
done = result.done
|
| 223 |
+
error = obs.error_message if obs.error_message else None
|
| 224 |
+
|
| 225 |
+
# Stall detection β reset on any meaningful change
|
| 226 |
+
if abs(delta) < 0.01:
|
| 227 |
+
stall_count += 1
|
| 228 |
+
else:
|
| 229 |
+
stall_count = 0
|
| 230 |
+
|
| 231 |
+
prev_delta = delta
|
| 232 |
+
delta_rewards.append(delta)
|
| 233 |
+
abs_scores.append(abs_s)
|
| 234 |
steps_taken = step
|
| 235 |
|
| 236 |
+
log_step(step=step, action=sql_query, reward=delta, done=done, error=error)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
if done:
|
| 239 |
break
|
| 240 |
|
| 241 |
+
# Final score = best absolute score reached this episode
|
| 242 |
+
score = max(abs_scores) if abs_scores else 0.0
|
| 243 |
+
score = min(max(score, 0.0), 1.0)
|
| 244 |
success = score >= SUCCESS_THRESHOLD
|
| 245 |
|
| 246 |
except Exception as exc:
|
|
|
|
| 252 |
except Exception as e:
|
| 253 |
print(f"[DEBUG] env.close() error: {e}", flush=True)
|
| 254 |
|
| 255 |
+
log_end(success=success, steps=steps_taken, score=score, rewards=delta_rewards)
|
| 256 |
|
| 257 |
|
| 258 |
if __name__ == "__main__":
|
models.py
CHANGED
|
@@ -1,31 +1,14 @@
|
|
| 1 |
-
#
|
| 2 |
-
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the BSD-style license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
"""
|
| 8 |
-
Data models for the SQL Debug & Optimizer Environment.
|
| 9 |
-
"""
|
| 10 |
-
|
| 11 |
-
from typing import Any, Dict, List
|
| 12 |
from pydantic import Field
|
| 13 |
from openenv.core.env_server.types import Action, Observation
|
| 14 |
|
| 15 |
|
| 16 |
class SQLDebugAction(Action):
|
| 17 |
-
"""
|
| 18 |
-
What the agent submits each step β just a SQL query string.
|
| 19 |
-
The environment will run it, grade it, and return a new observation.
|
| 20 |
-
"""
|
| 21 |
query: str = Field(..., description="The SQL query the agent wants to try")
|
| 22 |
|
| 23 |
|
| 24 |
class SQLDebugObservation(Observation):
|
| 25 |
-
"""
|
| 26 |
-
What the agent sees after each step.
|
| 27 |
-
Contains everything it needs to improve its next query.
|
| 28 |
-
"""
|
| 29 |
task_id: str = Field(default="", description="Which task is active")
|
| 30 |
schema_sql: str = Field(default="", description="CREATE TABLE statements for this task")
|
| 31 |
current_query: str = Field(default="", description="Last query that was run")
|
|
|
|
| 1 |
+
# models.py
|
| 2 |
+
from typing import Any, Dict, List, Optional
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from pydantic import Field
|
| 4 |
from openenv.core.env_server.types import Action, Observation
|
| 5 |
|
| 6 |
|
| 7 |
class SQLDebugAction(Action):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
query: str = Field(..., description="The SQL query the agent wants to try")
|
| 9 |
|
| 10 |
|
| 11 |
class SQLDebugObservation(Observation):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
task_id: str = Field(default="", description="Which task is active")
|
| 13 |
schema_sql: str = Field(default="", description="CREATE TABLE statements for this task")
|
| 14 |
current_query: str = Field(default="", description="Last query that was run")
|
openenv_sql_debug.egg-info/PKG-INFO
CHANGED
|
@@ -6,6 +6,7 @@ Requires-Python: >=3.10
|
|
| 6 |
Requires-Dist: openenv-core[core]>=0.2.2
|
| 7 |
Requires-Dist: openai>=2.30.0
|
| 8 |
Requires-Dist: uvicorn>=0.43.0
|
|
|
|
| 9 |
Provides-Extra: dev
|
| 10 |
Requires-Dist: pytest>=8.0.0; extra == "dev"
|
| 11 |
Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
|
|
|
|
| 6 |
Requires-Dist: openenv-core[core]>=0.2.2
|
| 7 |
Requires-Dist: openai>=2.30.0
|
| 8 |
Requires-Dist: uvicorn>=0.43.0
|
| 9 |
+
Requires-Dist: httpx>=0.28.1
|
| 10 |
Provides-Extra: dev
|
| 11 |
Requires-Dist: pytest>=8.0.0; extra == "dev"
|
| 12 |
Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
|
openenv_sql_debug.egg-info/SOURCES.txt
CHANGED
|
@@ -1,8 +1,13 @@
|
|
| 1 |
README.md
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
pyproject.toml
|
|
|
|
|
|
|
| 3 |
./__init__.py
|
| 4 |
./client.py
|
| 5 |
-
./grader.py
|
| 6 |
./inference.py
|
| 7 |
./models.py
|
| 8 |
./runner.py
|
|
|
|
| 1 |
README.md
|
| 2 |
+
__init__.py
|
| 3 |
+
client.py
|
| 4 |
+
inference.py
|
| 5 |
+
models.py
|
| 6 |
pyproject.toml
|
| 7 |
+
runner.py
|
| 8 |
+
test.py
|
| 9 |
./__init__.py
|
| 10 |
./client.py
|
|
|
|
| 11 |
./inference.py
|
| 12 |
./models.py
|
| 13 |
./runner.py
|
openenv_sql_debug.egg-info/requires.txt
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
openenv-core[core]>=0.2.2
|
| 2 |
openai>=2.30.0
|
| 3 |
uvicorn>=0.43.0
|
|
|
|
| 4 |
|
| 5 |
[dev]
|
| 6 |
pytest>=8.0.0
|
|
|
|
| 1 |
openenv-core[core]>=0.2.2
|
| 2 |
openai>=2.30.0
|
| 3 |
uvicorn>=0.43.0
|
| 4 |
+
httpx>=0.28.1
|
| 5 |
|
| 6 |
[dev]
|
| 7 |
pytest>=8.0.0
|
pyproject.toml
CHANGED
|
@@ -20,6 +20,7 @@ dependencies = [
|
|
| 20 |
"openenv-core[core]>=0.2.2",
|
| 21 |
"openai>=2.30.0",
|
| 22 |
"uvicorn>=0.43.0",
|
|
|
|
| 23 |
]
|
| 24 |
|
| 25 |
[project.optional-dependencies]
|
|
|
|
| 20 |
"openenv-core[core]>=0.2.2",
|
| 21 |
"openai>=2.30.0",
|
| 22 |
"uvicorn>=0.43.0",
|
| 23 |
+
"httpx>=0.28.1",
|
| 24 |
]
|
| 25 |
|
| 26 |
[project.optional-dependencies]
|
server/app.py
CHANGED
|
@@ -37,7 +37,7 @@ try:
|
|
| 37 |
from .sql_debug_environment import SQLDebugEnvironment
|
| 38 |
except ModuleNotFoundError:
|
| 39 |
from models import SQLDebugAction, SQLDebugObservation
|
| 40 |
-
from
|
| 41 |
|
| 42 |
|
| 43 |
app = create_app(
|
|
|
|
| 37 |
from .sql_debug_environment import SQLDebugEnvironment
|
| 38 |
except ModuleNotFoundError:
|
| 39 |
from models import SQLDebugAction, SQLDebugObservation
|
| 40 |
+
from server.sql_debug_environment import SQLDebugEnvironment
|
| 41 |
|
| 42 |
|
| 43 |
app = create_app(
|
server/requirements.txt
CHANGED
|
@@ -1,3 +1,4 @@
|
|
| 1 |
openenv[core]>=0.2.0
|
| 2 |
fastapi>=0.115.0
|
| 3 |
-
uvicorn>=0.24.0
|
|
|
|
|
|
| 1 |
openenv[core]>=0.2.0
|
| 2 |
fastapi>=0.115.0
|
| 3 |
+
uvicorn>=0.24.0
|
| 4 |
+
httpx>=0.28.1
|
server/sql_debug_environment.py
CHANGED
|
@@ -10,7 +10,6 @@ SQL Debug & Optimizer Environment β server-side implementation.
|
|
| 10 |
The server runs this. The agent never touches this file directly.
|
| 11 |
It loads tasks, runs queries in SQLite, grades them, and returns observations.
|
| 12 |
"""
|
| 13 |
-
|
| 14 |
from uuid import uuid4
|
| 15 |
from openenv.core.env_server.interfaces import Environment
|
| 16 |
from openenv.core.env_server.types import State
|
|
@@ -21,35 +20,34 @@ except ImportError:
|
|
| 21 |
from models import SQLDebugAction, SQLDebugObservation
|
| 22 |
|
| 23 |
from runner import run_query
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
def _load_all_tasks() -> dict:
|
| 28 |
-
"""Load every task from the tasks/ folder into a dict keyed by task_id."""
|
| 29 |
from tasks.task_easy import TASK as EASY
|
| 30 |
from tasks.task_medium import TASK as MEDIUM
|
| 31 |
from tasks.task_hard import TASK as HARD
|
|
|
|
| 32 |
return {
|
| 33 |
-
EASY["task_id"]:
|
| 34 |
MEDIUM["task_id"]: MEDIUM,
|
| 35 |
-
HARD["task_id"]:
|
| 36 |
}
|
| 37 |
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
SQLite, grades it (0.0β1.0), and returns the result as an observation.
|
| 46 |
|
| 47 |
-
Three tasks:
|
| 48 |
-
syntax_fix_001 (easy) β fix typos in SQL keywords
|
| 49 |
-
logic_fix_001 (medium) β fix wrong JOIN type causing bad results
|
| 50 |
-
# optimize_001 (hard) β rewrite correlated subquery as a CTE
|
| 51 |
-
"""
|
| 52 |
|
|
|
|
| 53 |
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 54 |
|
| 55 |
def __init__(self):
|
|
@@ -57,32 +55,29 @@ class SQLDebugEnvironment(Environment):
|
|
| 57 |
self._current_task = None
|
| 58 |
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 59 |
self._best_reward = 0.0
|
|
|
|
| 60 |
self._current_query = ""
|
| 61 |
|
| 62 |
-
#
|
| 63 |
|
| 64 |
-
def reset(self, task_id: str = None) -> SQLDebugObservation:
|
| 65 |
-
"""
|
| 66 |
-
Start a new episode.
|
| 67 |
-
Pass task_id to pick a specific task, or leave None for the default (easy).
|
| 68 |
-
"""
|
| 69 |
if task_id is None:
|
| 70 |
-
task_id = list(self._all_tasks.keys())[0]
|
| 71 |
|
| 72 |
if task_id not in self._all_tasks:
|
| 73 |
-
# Unknown task β return error observation instead of crashing
|
| 74 |
return SQLDebugObservation(
|
| 75 |
task_id=task_id,
|
| 76 |
-
error_message=f"Unknown
|
| 77 |
available_tasks=list(self._all_tasks.keys()),
|
|
|
|
| 78 |
)
|
| 79 |
|
| 80 |
self._current_task = self._all_tasks[task_id]
|
| 81 |
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 82 |
self._best_reward = 0.0
|
|
|
|
| 83 |
self._current_query = self._current_task["broken_query"]
|
| 84 |
|
| 85 |
-
# Run the broken query so the agent sees the starting error
|
| 86 |
run_result = run_query(
|
| 87 |
self._current_task["schema_sql"],
|
| 88 |
self._current_query,
|
|
@@ -101,45 +96,44 @@ class SQLDebugEnvironment(Environment):
|
|
| 101 |
available_tasks=list(self._all_tasks.keys()),
|
| 102 |
done=False,
|
| 103 |
reward=0.0,
|
|
|
|
| 104 |
)
|
| 105 |
|
| 106 |
-
# ββ step βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 107 |
-
|
| 108 |
def step(self, action: SQLDebugAction) -> SQLDebugObservation:
|
| 109 |
-
|
| 110 |
-
Agent submits a query.
|
| 111 |
-
We run it, grade it, and return the new observation + reward.
|
| 112 |
-
"""
|
| 113 |
if self._current_task is None:
|
| 114 |
-
|
| 115 |
-
error_message="Call reset() before step()",
|
| 116 |
-
available_tasks=list(self._all_tasks.keys()),
|
| 117 |
-
done=True,
|
| 118 |
-
reward=0.0,
|
| 119 |
-
)
|
| 120 |
|
| 121 |
self._state.step_count += 1
|
| 122 |
self._current_query = action.query
|
| 123 |
|
| 124 |
-
# Run the query in SQLite
|
| 125 |
run_result = run_query(
|
| 126 |
self._current_task["schema_sql"],
|
| 127 |
action.query,
|
| 128 |
)
|
| 129 |
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
-
|
| 135 |
-
self._best_reward = max(self._best_reward,
|
| 136 |
|
| 137 |
-
# Episode ends on perfect score or max steps
|
| 138 |
max_steps = self._current_task.get("max_steps", 8)
|
| 139 |
-
done = (
|
|
|
|
|
|
|
| 140 |
|
| 141 |
return SQLDebugObservation(
|
| 142 |
-
task_id=
|
| 143 |
schema_sql=self._current_task["schema_sql"],
|
| 144 |
current_query=action.query,
|
| 145 |
error_message=run_result["error"] or "",
|
|
@@ -150,11 +144,17 @@ class SQLDebugEnvironment(Environment):
|
|
| 150 |
reward_so_far=self._best_reward,
|
| 151 |
available_tasks=list(self._all_tasks.keys()),
|
| 152 |
done=done,
|
| 153 |
-
reward=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
)
|
| 155 |
|
| 156 |
-
# ββ state βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 157 |
-
|
| 158 |
@property
|
| 159 |
def state(self) -> State:
|
| 160 |
-
return self._state
|
|
|
|
| 10 |
The server runs this. The agent never touches this file directly.
|
| 11 |
It loads tasks, runs queries in SQLite, grades them, and returns observations.
|
| 12 |
"""
|
|
|
|
| 13 |
from uuid import uuid4
|
| 14 |
from openenv.core.env_server.interfaces import Environment
|
| 15 |
from openenv.core.env_server.types import State
|
|
|
|
| 20 |
from models import SQLDebugAction, SQLDebugObservation
|
| 21 |
|
| 22 |
from runner import run_query
|
| 23 |
+
|
| 24 |
+
# Import each task's dedicated grader
|
| 25 |
+
from graders.grader_easy import grade as grade_easy
|
| 26 |
+
from graders.grader_medium import grade as grade_medium
|
| 27 |
+
from graders.grader_hard import grade as grade_hard
|
| 28 |
|
| 29 |
|
| 30 |
def _load_all_tasks() -> dict:
|
|
|
|
| 31 |
from tasks.task_easy import TASK as EASY
|
| 32 |
from tasks.task_medium import TASK as MEDIUM
|
| 33 |
from tasks.task_hard import TASK as HARD
|
| 34 |
+
|
| 35 |
return {
|
| 36 |
+
EASY["task_id"]: EASY,
|
| 37 |
MEDIUM["task_id"]: MEDIUM,
|
| 38 |
+
HARD["task_id"]: HARD,
|
| 39 |
}
|
| 40 |
|
| 41 |
|
| 42 |
+
# Maps each task_id to its dedicated grader function
|
| 43 |
+
TASK_GRADERS = {
|
| 44 |
+
"syntax_fix_001": grade_easy,
|
| 45 |
+
"logic_fix_001": grade_medium,
|
| 46 |
+
"optimize_001": grade_hard,
|
| 47 |
+
}
|
|
|
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
+
class SQLDebugEnvironment(Environment):
|
| 51 |
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 52 |
|
| 53 |
def __init__(self):
|
|
|
|
| 55 |
self._current_task = None
|
| 56 |
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 57 |
self._best_reward = 0.0
|
| 58 |
+
self._prev_absolute_score = 0.0 # used for delta computation
|
| 59 |
self._current_query = ""
|
| 60 |
|
| 61 |
+
# sql_debug_environment.py β replace reset() return and step() return only
|
| 62 |
|
| 63 |
+
def reset(self, task_id: str = None, **kwargs) -> SQLDebugObservation:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
if task_id is None:
|
| 65 |
+
task_id = list(self._all_tasks.keys())[0]
|
| 66 |
|
| 67 |
if task_id not in self._all_tasks:
|
|
|
|
| 68 |
return SQLDebugObservation(
|
| 69 |
task_id=task_id,
|
| 70 |
+
error_message=f"Unknown task '{task_id}'. Available: {list(self._all_tasks.keys())}",
|
| 71 |
available_tasks=list(self._all_tasks.keys()),
|
| 72 |
+
metadata={},
|
| 73 |
)
|
| 74 |
|
| 75 |
self._current_task = self._all_tasks[task_id]
|
| 76 |
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 77 |
self._best_reward = 0.0
|
| 78 |
+
self._prev_absolute_score = 0.0
|
| 79 |
self._current_query = self._current_task["broken_query"]
|
| 80 |
|
|
|
|
| 81 |
run_result = run_query(
|
| 82 |
self._current_task["schema_sql"],
|
| 83 |
self._current_query,
|
|
|
|
| 96 |
available_tasks=list(self._all_tasks.keys()),
|
| 97 |
done=False,
|
| 98 |
reward=0.0,
|
| 99 |
+
metadata={"feedback": "", "status": "ready"}, # β feedback in metadata
|
| 100 |
)
|
| 101 |
|
|
|
|
|
|
|
| 102 |
def step(self, action: SQLDebugAction) -> SQLDebugObservation:
|
| 103 |
+
# Auto-reset if not already initialized (handles session management issues)
|
|
|
|
|
|
|
|
|
|
| 104 |
if self._current_task is None:
|
| 105 |
+
self.reset()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
self._state.step_count += 1
|
| 108 |
self._current_query = action.query
|
| 109 |
|
|
|
|
| 110 |
run_result = run_query(
|
| 111 |
self._current_task["schema_sql"],
|
| 112 |
action.query,
|
| 113 |
)
|
| 114 |
|
| 115 |
+
task_id = self._current_task["task_id"]
|
| 116 |
+
grader_fn = TASK_GRADERS.get(task_id, grade_easy)
|
| 117 |
+
|
| 118 |
+
reward_dict = grader_fn(
|
| 119 |
+
task=self._current_task,
|
| 120 |
+
agent_query=action.query,
|
| 121 |
+
run_result=run_result,
|
| 122 |
+
prev_absolute_score=self._prev_absolute_score,
|
| 123 |
+
step_count=self._state.step_count,
|
| 124 |
+
max_steps=self._current_task.get("max_steps", 8),
|
| 125 |
+
)
|
| 126 |
|
| 127 |
+
self._prev_absolute_score = reward_dict["absolute_score"]
|
| 128 |
+
self._best_reward = max(self._best_reward, reward_dict["absolute_score"])
|
| 129 |
|
|
|
|
| 130 |
max_steps = self._current_task.get("max_steps", 8)
|
| 131 |
+
done = (
|
| 132 |
+
reward_dict["absolute_score"] >= 0.99 or self._state.step_count >= max_steps
|
| 133 |
+
)
|
| 134 |
|
| 135 |
return SQLDebugObservation(
|
| 136 |
+
task_id=task_id,
|
| 137 |
schema_sql=self._current_task["schema_sql"],
|
| 138 |
current_query=action.query,
|
| 139 |
error_message=run_result["error"] or "",
|
|
|
|
| 144 |
reward_so_far=self._best_reward,
|
| 145 |
available_tasks=list(self._all_tasks.keys()),
|
| 146 |
done=done,
|
| 147 |
+
reward=reward_dict["value"],
|
| 148 |
+
metadata={ # β all extra data here
|
| 149 |
+
"feedback": reward_dict["feedback"],
|
| 150 |
+
"status": reward_dict["status"],
|
| 151 |
+
"absolute_score": reward_dict["absolute_score"],
|
| 152 |
+
"delta": reward_dict["delta"],
|
| 153 |
+
"result_score": reward_dict["result_score"],
|
| 154 |
+
"plan_score": reward_dict["plan_score"],
|
| 155 |
+
},
|
| 156 |
)
|
| 157 |
|
|
|
|
|
|
|
| 158 |
@property
|
| 159 |
def state(self) -> State:
|
| 160 |
+
return self._state
|
tasks/task_hard.py
CHANGED
|
@@ -4,12 +4,12 @@ import random
|
|
| 4 |
def generate_schema(n_rows=5000, seed=42):
|
| 5 |
"""Generates schema + INSERT statements for n_rows transactions."""
|
| 6 |
rng = random.Random(seed)
|
| 7 |
-
statuses = ['completed', 'pending', 'failed']
|
| 8 |
inserts = []
|
| 9 |
for i in range(1, n_rows + 1):
|
| 10 |
user_id = rng.randint(1, 100)
|
| 11 |
amount = round(rng.uniform(10, 1000), 2)
|
| 12 |
-
status = rng.choice(statuses)
|
| 13 |
inserts.append(f"INSERT INTO transactions VALUES ({i}, {user_id}, {amount}, 'completed');")
|
| 14 |
return (
|
| 15 |
"CREATE TABLE transactions (id INTEGER, user_id INTEGER, amount REAL, ts TEXT, status TEXT);\n"
|
|
|
|
| 4 |
def generate_schema(n_rows=5000, seed=42):
|
| 5 |
"""Generates schema + INSERT statements for n_rows transactions."""
|
| 6 |
rng = random.Random(seed)
|
| 7 |
+
# statuses = ['completed', 'pending', 'failed']
|
| 8 |
inserts = []
|
| 9 |
for i in range(1, n_rows + 1):
|
| 10 |
user_id = rng.randint(1, 100)
|
| 11 |
amount = round(rng.uniform(10, 1000), 2)
|
| 12 |
+
# status = rng.choice(statuses)
|
| 13 |
inserts.append(f"INSERT INTO transactions VALUES ({i}, {user_id}, {amount}, 'completed');")
|
| 14 |
return (
|
| 15 |
"CREATE TABLE transactions (id INTEGER, user_id INTEGER, amount REAL, ts TEXT, status TEXT);\n"
|
uv.lock
CHANGED
|
@@ -1603,6 +1603,7 @@ name = "openenv-sql-debug"
|
|
| 1603 |
version = "0.1.0"
|
| 1604 |
source = { editable = "." }
|
| 1605 |
dependencies = [
|
|
|
|
| 1606 |
{ name = "openai" },
|
| 1607 |
{ name = "openenv-core", extra = ["core"] },
|
| 1608 |
{ name = "uvicorn" },
|
|
@@ -1616,6 +1617,7 @@ dev = [
|
|
| 1616 |
|
| 1617 |
[package.metadata]
|
| 1618 |
requires-dist = [
|
|
|
|
| 1619 |
{ name = "openai", specifier = ">=2.30.0" },
|
| 1620 |
{ name = "openenv-core", extras = ["core"], specifier = ">=0.2.2" },
|
| 1621 |
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },
|
|
|
|
| 1603 |
version = "0.1.0"
|
| 1604 |
source = { editable = "." }
|
| 1605 |
dependencies = [
|
| 1606 |
+
{ name = "httpx" },
|
| 1607 |
{ name = "openai" },
|
| 1608 |
{ name = "openenv-core", extra = ["core"] },
|
| 1609 |
{ name = "uvicorn" },
|
|
|
|
| 1617 |
|
| 1618 |
[package.metadata]
|
| 1619 |
requires-dist = [
|
| 1620 |
+
{ name = "httpx", specifier = ">=0.28.1" },
|
| 1621 |
{ name = "openai", specifier = ">=2.30.0" },
|
| 1622 |
{ name = "openenv-core", extras = ["core"], specifier = ">=0.2.2" },
|
| 1623 |
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },
|