abhinavthedev commited on
Commit
5db060f
Β·
verified Β·
1 Parent(s): aa3a171

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -195,7 +195,6 @@ The inference script defaults to `syntax_fix_001`, logs each step, and stops whe
195
  ```text
196
  sql_exp/
197
  β”œβ”€β”€ client.py # OpenEnv client wrapper
198
- β”œβ”€β”€ grader.py # Reward computation
199
  β”œβ”€β”€ inference.py # LLM-driven inference loop
200
  β”œβ”€β”€ models.py # Action and observation models
201
  β”œβ”€β”€ openenv.yaml # OpenEnv manifest
@@ -209,7 +208,10 @@ sql_exp/
209
  β”‚ β”œβ”€β”€ task_easy.py # Syntax-fix task
210
  β”‚ β”œβ”€β”€ task_medium.py # Join logic task
211
  β”‚ └── task_hard.py # Query optimization task
212
- β”œβ”€β”€ test.py # Manual websocket smoke test
 
 
 
213
  └── README.md # Project overview
214
  ```
215
 
 
195
  ```text
196
  sql_exp/
197
  β”œβ”€β”€ client.py # OpenEnv client wrapper
 
198
  β”œβ”€β”€ inference.py # LLM-driven inference loop
199
  β”œβ”€β”€ models.py # Action and observation models
200
  β”œβ”€β”€ openenv.yaml # OpenEnv manifest
 
208
  β”‚ β”œβ”€β”€ task_easy.py # Syntax-fix task
209
  β”‚ β”œβ”€β”€ task_medium.py # Join logic task
210
  β”‚ └── task_hard.py # Query optimization task
211
+ β”œβ”€β”€ graders/
212
+ β”‚ β”œβ”€β”€ grader_easy.py # Syntax-fix task
213
+ β”‚ β”œβ”€β”€ grader_medium.py # Join logic task
214
+ β”‚ └── grader_hard.py # Query optimization task
215
  └── README.md # Project overview
216
  ```
217
 
client.py CHANGED
@@ -1,16 +1,6 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the BSD-style license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
  # client.py
8
- """
9
- SQL Debug Environment client.
10
- This is what inference.py uses to talk to the running server.
11
- """
12
-
13
- from typing import Dict
14
 
15
  from openenv.core import EnvClient
16
  from openenv.core.client_types import StepResult
@@ -20,35 +10,33 @@ from models import SQLDebugAction, SQLDebugObservation
20
 
21
 
22
  class SQLDebugEnv(EnvClient[SQLDebugAction, SQLDebugObservation, State]):
23
- """
24
- Client for the SQL Debug & Optimizer environment.
25
-
26
- Maintains a persistent WebSocket connection to the server.
27
- Each instance gets its own dedicated environment session.
28
 
29
- Usage (direct server):
30
- with SQLDebugEnv(base_url="http://localhost:8000") as env:
31
- result = env.reset()
32
- print(result.observation.target_description)
33
- result = env.step(SQLDebugAction(query="SELECT * FROM orders"))
34
- print(result.reward)
35
 
36
- Usage (Docker):
37
- env = SQLDebugEnv.from_docker_image("sql-debug-env:latest")
38
- try:
39
- result = env.reset()
40
- result = env.step(SQLDebugAction(query="SELECT * FROM orders WHERE amount > 500"))
41
- finally:
42
- env.close()
43
- """
44
 
 
45
  def _step_payload(self, action: SQLDebugAction) -> Dict:
46
- """Convert SQLDebugAction to JSON payload."""
47
  return {"query": action.query}
48
 
 
 
49
  def _parse_result(self, payload: Dict) -> StepResult[SQLDebugObservation]:
50
- """Parse server JSON response into a typed StepResult."""
51
  obs_data = payload.get("observation", {})
 
52
 
53
  observation = SQLDebugObservation(
54
  task_id=obs_data.get("task_id", ""),
@@ -63,6 +51,7 @@ class SQLDebugEnv(EnvClient[SQLDebugAction, SQLDebugObservation, State]):
63
  available_tasks=obs_data.get("available_tasks", []),
64
  done=payload.get("done", False),
65
  reward=payload.get("reward", 0.0),
 
66
  )
67
 
68
  return StepResult(
@@ -72,8 +61,7 @@ class SQLDebugEnv(EnvClient[SQLDebugAction, SQLDebugObservation, State]):
72
  )
73
 
74
  def _parse_state(self, payload: Dict) -> State:
75
- """Parse server JSON response into a State object."""
76
  return State(
77
  episode_id=payload.get("episode_id"),
78
  step_count=payload.get("step_count", 0),
79
- )
 
 
 
 
 
 
 
1
  # client.py
2
+ from typing import Dict, Optional
3
+ import httpx
 
 
 
 
4
 
5
  from openenv.core import EnvClient
6
  from openenv.core.client_types import StepResult
 
10
 
11
 
12
  class SQLDebugEnv(EnvClient[SQLDebugAction, SQLDebugObservation, State]):
13
+ def __init__(self, base_url: str = "http://localhost:8000", **kwargs):
14
+ super().__init__(base_url=base_url, **kwargs)
15
+ self._base_url = base_url.rstrip("/")
 
 
16
 
17
+ # ── Override reset to send task_id in body ────────────────────────────────
18
+ async def reset(self, task_id: Optional[str] = None, **kwargs) -> StepResult:
19
+ payload = {}
20
+ if task_id:
21
+ payload["task_id"] = task_id
 
22
 
23
+ async with httpx.AsyncClient(timeout=30) as http:
24
+ response = await http.post(
25
+ f"{self._base_url}/reset",
26
+ json=payload,
27
+ )
28
+ response.raise_for_status()
29
+ return self._parse_result(response.json())
 
30
 
31
+ # ── step payload ──────────────────────────────────────────────────────────
32
  def _step_payload(self, action: SQLDebugAction) -> Dict:
 
33
  return {"query": action.query}
34
 
35
+ # β€” update _parse_result only
36
+
37
  def _parse_result(self, payload: Dict) -> StepResult[SQLDebugObservation]:
 
38
  obs_data = payload.get("observation", {})
39
+ meta = obs_data.get("metadata", {}) # ← feedback lives here now
40
 
41
  observation = SQLDebugObservation(
42
  task_id=obs_data.get("task_id", ""),
 
51
  available_tasks=obs_data.get("available_tasks", []),
52
  done=payload.get("done", False),
53
  reward=payload.get("reward", 0.0),
54
+ metadata=meta,
55
  )
56
 
57
  return StepResult(
 
61
  )
62
 
63
  def _parse_state(self, payload: Dict) -> State:
 
64
  return State(
65
  episode_id=payload.get("episode_id"),
66
  step_count=payload.get("step_count", 0),
67
+ )
graders/__init__.py ADDED
File without changes
graders/grader_easy.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tasks/grader_easy.py
2
+ """
3
+ Grader for syntax_fix_001 β€” fix typos in SQL keywords.
4
+ Reward is shaped on: syntax correctness + F1 row match + step efficiency.
5
+ """
6
+
7
+
8
+ def grade(
9
+ task: dict,
10
+ agent_query: str,
11
+ run_result: dict,
12
+ prev_absolute_score: float = 0.0,
13
+ step_count: int = 1,
14
+ max_steps: int = 5,
15
+ ) -> dict:
16
+ """
17
+ Easy task grader. Pure row-match scoring β€” no plan check needed.
18
+
19
+ Reward components:
20
+ syntax_score : 0.0 or 1.0 β€” did the query run at all?
21
+ result_score : 0.0–1.0 β€” F1 of returned vs expected rows
22
+ efficiency_bonus: 0.0–0.05 β€” small bonus for solving early
23
+ delta : absolute_score - prev_absolute_score
24
+ """
25
+
26
+ syntax_ok = run_result["error"] is None
27
+
28
+ # ── Syntax ────────────────────────────────────────────────────────────────
29
+ if not syntax_ok:
30
+ absolute_score = 0.05 # tiny gradient so agent knows to fix syntax first
31
+ delta = absolute_score - prev_absolute_score
32
+ delta = max(-0.3, min(0.5, delta))
33
+ return {
34
+ "value": delta,
35
+ "absolute_score": absolute_score,
36
+ "syntax_ok": False,
37
+ "result_score": 0.0,
38
+ "plan_score": 0.0,
39
+ "delta": delta,
40
+ "status": "syntax_error",
41
+ "feedback": f"syntax_error: {run_result['error'][:100]}",
42
+ "message": f"syntax_error | abs=0.050 | delta={delta:+.3f}",
43
+ }
44
+
45
+ # ── Row matching (F1) ─────────────────────────────────────────────────────
46
+ expected = task["expected_rows"]
47
+ got = run_result["rows"]
48
+
49
+ if not got:
50
+ result_score = 0.0
51
+ else:
52
+ correct_returned = sum(1 for row in got if row in expected)
53
+ correct_expected = sum(1 for row in expected if row in got)
54
+
55
+ precision = correct_returned / max(len(got), 1)
56
+ recall = correct_expected / max(len(expected), 1)
57
+
58
+ if precision + recall > 0:
59
+ result_score = 2 * precision * recall / (precision + recall)
60
+ else:
61
+ result_score = 0.0
62
+
63
+ # ── Efficiency bonus ──────────────────────────────────────────────────────
64
+ steps_remaining = max_steps - step_count
65
+ efficiency_bonus = 0.0
66
+ if result_score >= 0.99:
67
+ efficiency_bonus = round(0.05 * (steps_remaining / max_steps), 4)
68
+
69
+ # ── Absolute score β€” easy: syntax 15% + correctness 80% + bonus 5% ───────
70
+ absolute_score = round(
71
+ min(0.99, 0.15 * 1.0 + 0.80 * result_score + efficiency_bonus), 4
72
+ )
73
+
74
+ # ── Delta reward β€” the RL signal ──────────────────────────────────────────
75
+ delta = absolute_score - prev_absolute_score
76
+ if abs(delta) < 0.001 and step_count > 1:
77
+ delta -= 0.02 # stall penalty β€” discourages repeating same query
78
+ delta = round(max(-0.3, min(0.5, delta)), 4)
79
+
80
+ # ── Feedback for agent ────────────────────────────────────────────────────
81
+ issues = []
82
+ if result_score < 0.5:
83
+ issues.append("result_rows: returned rows do not match expected β€” check your WHERE clause")
84
+ elif result_score < 0.99:
85
+ issues.append(f"result_rows: partial match ({result_score:.0%}) β€” some rows still wrong")
86
+ if len(got) > len(expected):
87
+ issues.append(f"extra_rows: returned {len(got)} rows but expected {len(expected)}")
88
+ feedback = "; ".join(issues) if issues else "rows match β€” looking good"
89
+
90
+ status = (
91
+ "solved" if absolute_score >= 0.99
92
+ else "improving" if delta > 0.01
93
+ else "regression" if delta < -0.01
94
+ else "stalled"
95
+ )
96
+
97
+ return {
98
+ "value": delta,
99
+ "absolute_score": absolute_score,
100
+ "syntax_ok": True,
101
+ "result_score": result_score,
102
+ "plan_score": 0.0,
103
+ "delta": delta,
104
+ "status": status,
105
+ "feedback": feedback,
106
+ "message": (
107
+ f"{status} | abs={absolute_score:.3f} | delta={delta:+.3f} | "
108
+ f"result={result_score:.0%}"
109
+ ),
110
+ }
graders/grader_hard.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tasks/grader_hard.py
2
+ """
3
+ Grader for optimize_001 β€” replace correlated subquery with CTE.
4
+
5
+ Unlike easy/medium, there are no fixed expected_rows.
6
+ Score is entirely driven by query plan quality:
7
+ - uses WITH (CTE)
8
+ - uses GROUP BY
9
+ - uses AVG(
10
+ - does NOT use correlated subquery pattern
11
+ - executes without error
12
+ """
13
+
14
+
15
+ def grade(
16
+ task: dict,
17
+ agent_query: str,
18
+ run_result: dict,
19
+ prev_absolute_score: float = 0.0,
20
+ step_count: int = 1,
21
+ max_steps: int = 10,
22
+ ) -> dict:
23
+
24
+ syntax_ok = run_result["error"] is None
25
+
26
+ # ── Syntax ────────────────────────────────────────────────────────────────
27
+ if not syntax_ok:
28
+ absolute_score = 0.05
29
+ delta = round(
30
+ max(-0.3, min(0.5, absolute_score - prev_absolute_score)), 4
31
+ )
32
+ return {
33
+ "value": delta,
34
+ "absolute_score": absolute_score,
35
+ "syntax_ok": False,
36
+ "result_score": 0.0,
37
+ "plan_score": 0.0,
38
+ "delta": delta,
39
+ "status": "syntax_error",
40
+ "feedback": f"syntax_error: {run_result['error'][:100]}",
41
+ "message": f"syntax_error | abs=0.050 | delta={delta:+.3f}",
42
+ }
43
+
44
+ query_upper = agent_query.upper()
45
+ # good_patterns = task.get("good_patterns", ["WITH", "GROUP BY", "AVG("])
46
+
47
+ # ── Plan component scores ─────────────────────────────────────────────────
48
+
49
+ # 1. Uses CTE (WITH keyword) β€” most important signal
50
+ has_cte = "WITH" in query_upper
51
+ cte_score = 1.0 if has_cte else 0.0
52
+
53
+ # 2. Uses GROUP BY β€” required for computing per-user average
54
+ has_group_by = "GROUP BY" in query_upper
55
+ group_score = 1.0 if has_group_by else 0.0
56
+
57
+ # 3. Uses AVG β€” must be aggregating correctly
58
+ has_avg = "AVG(" in query_upper
59
+ avg_score = 1.0 if has_avg else 0.0
60
+
61
+ # 4. Correlated subquery penalty β€” still using the slow pattern
62
+ still_correlated = (
63
+ "SELECT AVG" in query_upper
64
+ and "WHERE" in query_upper
65
+ and not has_cte # WITH overrides this penalty
66
+ )
67
+ correlation_penalty = 0.4 if still_correlated else 0.0
68
+
69
+ # 5. Execution quality β€” did the query actually return rows?
70
+ rows_returned = len(run_result["rows"])
71
+ execution_score = 1.0 if rows_returned > 0 else 0.3
72
+ # 0.3 credit for running without error even if empty result
73
+
74
+ # ── Plan score weighted combination ───────────────────────────────────────
75
+ # CTE 40% + GROUP BY 25% + AVG 20% + execution 15%
76
+ plan_score = round(
77
+ max(
78
+ 0.0,
79
+ 0.40 * cte_score
80
+ + 0.25 * group_score
81
+ + 0.20 * avg_score
82
+ + 0.15 * execution_score
83
+ - correlation_penalty,
84
+ ),
85
+ 4,
86
+ )
87
+
88
+ # ── Efficiency bonus ──────────────────────────────────────────────────────
89
+ steps_remaining = max_steps - step_count
90
+ efficiency_bonus = 0.0
91
+ if plan_score >= 0.85:
92
+ efficiency_bonus = round(0.05 * (steps_remaining / max_steps), 4)
93
+
94
+ # ── Absolute score β€” hard: syntax 10% + plan 85% + bonus 5% ─────────────
95
+ absolute_score = round(
96
+ min(0.99, 0.10 * 1.0 + 0.85 * plan_score + efficiency_bonus), 4
97
+ )
98
+ absolute_score = max(0.05, absolute_score)
99
+
100
+ # ── Delta ─────────────────────────────────────────────────────────────────
101
+ delta = absolute_score - prev_absolute_score
102
+ if abs(delta) < 0.001 and step_count > 1:
103
+ delta -= 0.02
104
+ delta = round(max(-0.3, min(0.5, delta)), 4)
105
+
106
+ # ── Feedback ─────────────────────────────────────────────────────────────
107
+ issues = []
108
+ if not has_cte:
109
+ issues.append("missing_cte: query needs WITH clause to precompute averages")
110
+ if not has_group_by:
111
+ issues.append("missing_group_by: need GROUP BY user_id to compute per-user avg")
112
+ if not has_avg:
113
+ issues.append("missing_avg: need AVG(amount) in the CTE")
114
+ if still_correlated:
115
+ issues.append("still_correlated: subquery in WHERE runs per-row β€” move to CTE")
116
+ if rows_returned == 0 and syntax_ok:
117
+ issues.append("empty_result: query runs but returns no rows β€” check JOIN and WHERE")
118
+ feedback = "; ".join(issues) if issues else "plan looks optimized"
119
+
120
+ status = (
121
+ "solved" if absolute_score >= 0.99
122
+ else "improving" if delta > 0.01
123
+ else "regression" if delta < -0.01
124
+ else "stalled"
125
+ )
126
+
127
+ return {
128
+ "value": delta,
129
+ "absolute_score": absolute_score,
130
+ "syntax_ok": True,
131
+ "result_score": execution_score,
132
+ "plan_score": plan_score,
133
+ "delta": delta,
134
+ "status": status,
135
+ "feedback": feedback,
136
+ "message": (
137
+ f"{status} | abs={absolute_score:.3f} | delta={delta:+.3f} | "
138
+ f"plan={plan_score:.0%} | cte={has_cte} | group={has_group_by}"
139
+ ),
140
+ }
graders/grader_medium.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tasks/grader_medium.py
2
+ """
3
+ Grader for logic_fix_001 β€” fix wrong JOIN type / WHERE logic.
4
+
5
+ Harder than easy: agent must get BOTH precision and recall right.
6
+ Extra penalty for wrong row count (catches SELECT * with no WHERE).
7
+ """
8
+
9
+
10
+ def grade(
11
+ task: dict,
12
+ agent_query: str,
13
+ run_result: dict,
14
+ prev_absolute_score: float = 0.0,
15
+ step_count: int = 1,
16
+ max_steps: int = 8,
17
+ ) -> dict:
18
+
19
+ syntax_ok = run_result["error"] is None
20
+
21
+ # ── Syntax ────────────────────────────────────────────────────────────────
22
+ if not syntax_ok:
23
+ absolute_score = 0.05
24
+ delta = round(
25
+ max(-0.3, min(0.5, absolute_score - prev_absolute_score)), 4
26
+ )
27
+ return {
28
+ "value": delta,
29
+ "absolute_score": absolute_score,
30
+ "syntax_ok": False,
31
+ "result_score": 0.0,
32
+ "plan_score": 0.0,
33
+ "delta": delta,
34
+ "status": "syntax_error",
35
+ "feedback": f"syntax_error: {run_result['error'][:100]}",
36
+ "message": f"syntax_error | abs=0.050 | delta={delta:+.3f}",
37
+ }
38
+
39
+ expected = task["expected_rows"]
40
+ got = run_result["rows"]
41
+
42
+ # ── F1 row score ──────────────────────────────────────────────────────────
43
+ if not got:
44
+ result_score = 0.0
45
+ else:
46
+ correct_returned = sum(1 for row in got if row in expected)
47
+ correct_expected = sum(1 for row in expected if row in got)
48
+
49
+ precision = correct_returned / max(len(got), 1)
50
+ recall = correct_expected / max(len(expected), 1)
51
+
52
+ if precision + recall > 0:
53
+ result_score = 2 * precision * recall / (precision + recall)
54
+ else:
55
+ result_score = 0.0
56
+
57
+ # ── Extra penalty for wrong row count ─────────────────────────────────────
58
+ # Logic bugs typically show up as too many rows (LEFT JOIN returns NULLs)
59
+ # Penalize harder than easy task to encourage precise reasoning
60
+ row_count_penalty = 0.0
61
+ if len(got) > len(expected):
62
+ extra = len(got) - len(expected)
63
+ row_count_penalty = min(0.25, extra * 0.08)
64
+
65
+ # ── JOIN type hint score ──────────────────────────────────────────────────
66
+ # Gives partial credit for using the right JOIN type even if rows are off
67
+ # Avoids zero-reward cliff for agents that fix JOIN but have minor issues
68
+ query_upper = agent_query.upper()
69
+ join_score = 0.0
70
+ if "INNER JOIN" in query_upper:
71
+ join_score = 0.15 # using INNER JOIN is the right direction
72
+ elif "LEFT JOIN" in query_upper:
73
+ join_score = 0.0 # LEFT JOIN is the bug β€” no credit
74
+ elif "JOIN" in query_upper:
75
+ join_score = 0.05 # some join exists β€” small credit
76
+
77
+ # ── Efficiency bonus ──────────────────────────────────────────────────────
78
+ steps_remaining = max_steps - step_count
79
+ efficiency_bonus = 0.0
80
+ if result_score >= 0.99:
81
+ efficiency_bonus = round(0.05 * (steps_remaining / max_steps), 4)
82
+
83
+ # ── Absolute score β€” medium: syntax 10% + correctness 70% + join 15% + bonus 5% ──
84
+ absolute_score = round(
85
+ min(
86
+ 0.99,
87
+ 0.10 * 1.0
88
+ + 0.70 * result_score
89
+ + 0.15 * join_score
90
+ + efficiency_bonus
91
+ - row_count_penalty,
92
+ ),
93
+ 4,
94
+ )
95
+ absolute_score = max(0.05, absolute_score) # floor at 0.05
96
+
97
+ # ── Delta ─────────────────────────────────────────────────────────────────
98
+ delta = absolute_score - prev_absolute_score
99
+ if abs(delta) < 0.001 and step_count > 1:
100
+ delta -= 0.02
101
+ delta = round(max(-0.3, min(0.5, delta)), 4)
102
+
103
+ # ── Feedback ─────────────────────────────────────────────────────────────
104
+ issues = []
105
+ if "LEFT JOIN" in query_upper:
106
+ issues.append("join_type: using LEFT JOIN includes rows with no matching department")
107
+ if len(got) > len(expected):
108
+ issues.append(f"extra_rows: got {len(got)} rows, expected {len(expected)} β€” filter too loose")
109
+ if len(got) < len(expected) and len(got) > 0:
110
+ issues.append(f"missing_rows: got {len(got)} rows, expected {len(expected)} β€” filter too strict")
111
+ if result_score < 0.5:
112
+ issues.append("result_rows: output does not match expected β€” check JOIN and WHERE")
113
+ feedback = "; ".join(issues) if issues else "rows and join look correct"
114
+
115
+ status = (
116
+ "solved" if absolute_score >= 0.99
117
+ else "improving" if delta > 0.01
118
+ else "regression" if delta < -0.01
119
+ else "stalled"
120
+ )
121
+
122
+ return {
123
+ "value": delta,
124
+ "absolute_score": absolute_score,
125
+ "syntax_ok": True,
126
+ "result_score": result_score,
127
+ "plan_score": join_score,
128
+ "delta": delta,
129
+ "status": status,
130
+ "feedback": feedback,
131
+ "message": (
132
+ f"{status} | abs={absolute_score:.3f} | delta={delta:+.3f} | "
133
+ f"result={result_score:.0%} | join={join_score:.2f}"
134
+ ),
135
+ }
inference.py CHANGED
@@ -1,50 +1,38 @@
1
  ο»Ώ# inference.py
2
- """
3
- SQL Debug & Optimizer β€” OpenEnv Inference Script
4
-
5
- Mandatory stdout format:
6
- [START] task=<task_name> env=<benchmark> model=<model_name>
7
- [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
8
- [END] success=<true|false> steps=<n> score=<0.00> rewards=<r1,r2,...>
9
- """
10
-
11
  import asyncio
12
  import os
13
  import textwrap
14
- from typing import List, Optional
15
 
16
  from openai import OpenAI
17
  from client import SQLDebugEnv, SQLDebugAction
18
 
19
- # ── Mandatory env vars (injected by evaluator on submission) ──────────────────
20
  IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
21
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
22
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
23
- MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.3-70b-versatile")
 
24
 
25
- # ── Task + run config ─────────────────────────────────────────────────────────
26
- TASK_NAME = os.getenv("SQL_ENV_TASK", "syntax_fix_001")
27
- BENCHMARK = "sql-debug-optimizer"
28
- MAX_STEPS = 8 # well under 20 min limit; each step is ~2s
29
- TEMPERATURE = 0.0 # deterministic = reproducible scores
30
- MAX_TOKENS = 400
31
- SUCCESS_THRESHOLD = 0.5 # reward >= 0.5 = success
32
 
33
 
34
- # ── Mandatory stdout loggers β€” DO NOT change field names or order ─────────────
35
 
36
  def log_start(task: str, env: str, model: str) -> None:
37
  print(f"[START] task={task} env={env} model={model}", flush=True)
38
 
39
 
40
  def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
41
- # action must be single-line β€” newlines break log parsing
42
  action_clean = action.replace("\n", " ").replace("\r", "").strip()
43
- error_val = error if error else "null"
44
- done_val = str(done).lower()
45
  print(
46
  f"[STEP] step={step} action={action_clean} reward={reward:.2f} "
47
- f"done={done_val} error={error_val}",
48
  flush=True,
49
  )
50
 
@@ -58,124 +46,201 @@ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> No
58
  )
59
 
60
 
61
- # ── Prompt design ─────────────────────────────────────────────────────────────
62
-
63
- SYSTEM_PROMPT = textwrap.dedent("""
64
- You are an expert SQL engineer helping debug and optimize SQL queries.
65
-
66
- Rules (follow exactly):
67
- - Respond with ONLY the corrected SQL query.
68
- - No markdown, no code fences (no ```sql), no explanation.
69
- - No comments inside the SQL.
70
- - If the query has a syntax error, fix it first.
71
- - If the query has a logic bug (wrong JOIN, wrong WHERE), fix the logic.
72
- - If asked to optimize, replace correlated subqueries with CTEs using WITH.
73
- - Output raw SQL only β€” it will be executed directly.
74
- """).strip()
75
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- def build_prompt(obs) -> str:
78
- """Build the user prompt from the current observation."""
79
- result_preview = str(obs.query_result[:3]) if obs.query_result else "empty / error"
80
  return textwrap.dedent(f"""
81
- TASK: {obs.target_description}
 
82
 
83
- DATABASE SCHEMA:
84
- {obs.schema_sql.strip()[:800]}
85
 
86
- CURRENT QUERY (this is broken or slow β€” fix it):
87
  {obs.current_query.strip()}
88
 
89
  ERROR: {obs.error_message or "none"}
90
- CURRENT RESULT (first 3 rows): {result_preview}
91
- STEP: {obs.step_count + 1} of {MAX_STEPS}
 
 
92
 
93
- Write the corrected SQL query:
94
  """).strip()
95
 
96
 
97
- def call_llm(client: OpenAI, obs) -> str:
98
- """Ask the LLM for a better SQL query. Returns clean SQL string."""
 
 
 
 
 
 
 
 
 
 
 
99
  try:
100
  completion = client.chat.completions.create(
101
  model=MODEL_NAME,
102
- messages=[
103
- {"role": "system", "content": SYSTEM_PROMPT},
104
- {"role": "user", "content": build_prompt(obs)},
105
- ],
106
  temperature=TEMPERATURE,
107
  max_tokens=MAX_TOKENS,
108
  stream=False,
109
  )
110
  raw = (completion.choices[0].message.content or "").strip()
111
-
112
- # Strip markdown code fences if model adds them despite instructions
113
  if "```" in raw:
114
- lines = raw.split("\n")
115
  raw = "\n".join(
116
- line for line in lines if not line.strip().startswith("```")
 
117
  ).strip()
118
 
119
- return raw if raw else "SELECT 1"
 
 
 
120
 
121
  except Exception as exc:
122
  print(f"[DEBUG] LLM call failed: {exc}", flush=True)
123
  return "SELECT 1"
124
 
125
 
126
- # ── Main loop ─────────────────────────────────────────────────────────────���───
127
 
128
  async def main() -> None:
129
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
 
130
 
131
- # Connect to the environment (Docker or local server)
132
- SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000")
133
- env = SQLDebugEnv(base_url=SERVER_URL)
134
-
135
- rewards: List[float] = []
136
- steps_taken = 0
137
- score = 0.0
138
- success = False
139
 
140
  log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
141
 
142
  try:
143
- # Reset β€” get the broken query and task info
144
- result = await env.reset(task_id=TASK_NAME)
 
 
 
 
 
145
  obs = result.observation
146
 
 
147
  for step in range(1, MAX_STEPS + 1):
148
  if result.done:
149
  break
150
 
151
- # Ask LLM for a better query
152
- sql_query = call_llm(client, obs)
 
 
 
 
153
 
154
- # Submit to environment
155
  result = await env.step(SQLDebugAction(query=sql_query))
156
- obs = result.observation
157
-
158
- reward = result.reward or 0.0
159
- done = result.done
160
- error = obs.error_message if obs.error_message else None
161
-
162
- rewards.append(reward)
 
 
 
 
 
 
 
 
 
 
 
163
  steps_taken = step
164
 
165
- log_step(
166
- step=step,
167
- action=sql_query,
168
- reward=reward,
169
- done=done,
170
- error=error,
171
- )
172
 
173
  if done:
174
  break
175
 
176
- # Score = best reward achieved (already 0.0–1.0 from grader)
177
- score = max(rewards) if rewards else 0.0
178
- score = min(max(score, 0.0), 1.0)
179
  success = score >= SUCCESS_THRESHOLD
180
 
181
  except Exception as exc:
@@ -187,7 +252,7 @@ async def main() -> None:
187
  except Exception as e:
188
  print(f"[DEBUG] env.close() error: {e}", flush=True)
189
 
190
- log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
191
 
192
 
193
  if __name__ == "__main__":
 
1
  ο»Ώ# inference.py
 
 
 
 
 
 
 
 
 
2
  import asyncio
3
  import os
4
  import textwrap
5
+ from typing import List, Optional, Dict
6
 
7
  from openai import OpenAI
8
  from client import SQLDebugEnv, SQLDebugAction
9
 
10
+ # ── Env vars ──────────────────────────────────────────────────────────────────
11
  IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
12
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
13
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
14
+ MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.3-70b-versatile")
15
+ SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000")
16
 
17
+ TASK_NAME = os.getenv("SQL_ENV_TASK", "syntax_fix_001")
18
+ BENCHMARK = "sql-debug-optimizer"
19
+ MAX_STEPS = 8
20
+ TEMPERATURE = 0.3
21
+ MAX_TOKENS = 400
22
+ SUCCESS_THRESHOLD = 0.5
 
23
 
24
 
25
+ # ── Stdout loggers ────────────────────────────────────────────────────────────
26
 
27
  def log_start(task: str, env: str, model: str) -> None:
28
  print(f"[START] task={task} env={env} model={model}", flush=True)
29
 
30
 
31
  def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
 
32
  action_clean = action.replace("\n", " ").replace("\r", "").strip()
 
 
33
  print(
34
  f"[STEP] step={step} action={action_clean} reward={reward:.2f} "
35
+ f"done={str(done).lower()} error={error or 'null'}",
36
  flush=True,
37
  )
38
 
 
46
  )
47
 
48
 
49
+ # ── Prompts ───────────────────────────────────────────────────────────────────
50
+
51
+ SYSTEM_PROMPT = """You are an expert SQL engineer fixing and optimizing SQL queries.
52
+
53
+ STRICT OUTPUT RULES:
54
+ - Output ONLY raw SQL. No markdown. No backticks. No explanation. No comments.
55
+ - Your output is executed directly against a SQLite database.
56
+ - If your previous attempt got negative reward, you made things worse β€” try differently.
57
+ - If reward is stalled (same score 2+ steps), change strategy significantly."""
58
+
59
+ TASK_CONTEXT = {
60
+ "syntax_fix_001": "The query has typographical errors in SQL keywords.",
61
+ "logic_fix_001": "The query runs but returns incorrect rows due to a logic error.",
62
+ "optimize_001": "The query is correct but slow. Rewrite it to be faster.",
63
+ }
64
+
65
+ GRADUATED_HINTS = {
66
+ "syntax_fix_001": [
67
+ "",
68
+ "Check the spelling of SQL keywords like SELECT, FROM, WHERE.",
69
+ "Compare each word: SELECT FROM WHERE ORDER BY GROUP BY β€” fix any typos.",
70
+ "The typos are: SELEC β†’ SELECT, FORM β†’ FROM, WERE β†’ WHERE.",
71
+ ],
72
+ "logic_fix_001": [
73
+ "",
74
+ "The query returns more rows than expected. Check your JOIN type.",
75
+ "LEFT JOIN includes rows even when no match exists. Consider INNER JOIN.",
76
+ "Change LEFT JOIN to INNER JOIN to exclude employees with no matching department.",
77
+ ],
78
+ "optimize_001": [
79
+ "",
80
+ "The query uses a subquery that runs once per row β€” this is slow.",
81
+ "Compute the per-user average once using GROUP BY, then JOIN the result.",
82
+ "Use: WITH user_avg AS (SELECT user_id, AVG(amount) AS avg FROM transactions GROUP BY user_id) SELECT t.* FROM transactions t JOIN user_avg u ON t.user_id = u.user_id WHERE t.amount > u.avg AND t.status = 'completed'",
83
+ ],
84
+ }
85
+
86
+
87
+ def get_hint_level(step: int, stall_count: int) -> int:
88
+ if step <= 2 and stall_count < 2:
89
+ return 0
90
+ if step <= 4 and stall_count < 4:
91
+ return 1
92
+ if step <= 6:
93
+ return 2
94
+ return 3
95
+
96
+
97
+ def build_prompt(obs, step: int, stall_count: int, prev_delta: float) -> str:
98
+ context = TASK_CONTEXT.get(obs.task_id, "Fix the SQL query.")
99
+ hint_level = get_hint_level(step, stall_count)
100
+ hint = GRADUATED_HINTS.get(obs.task_id, [""] * 4)[hint_level]
101
+ result_preview = str(obs.query_result[:3]) if obs.query_result else "none"
102
+
103
+ # ← read feedback from metadata dict, not obs.feedback
104
+ meta = obs.metadata or {}
105
+ feedback = meta.get("feedback", "analyse the result yourself")
106
+
107
+ reward_context = ""
108
+ if step > 1:
109
+ if prev_delta > 0.01:
110
+ reward_context = f"Last change IMPROVED score (+{prev_delta:.2f}). Keep going."
111
+ elif prev_delta < -0.01:
112
+ reward_context = f"Last change WORSENED score ({prev_delta:.2f}). Revert and try differently."
113
+ else:
114
+ reward_context = f"Last change had NO EFFECT (delta={prev_delta:.2f}). Try a completely different approach."
115
+
116
+ hint_block = f"\nHINT: {hint}" if hint else ""
117
 
 
 
 
118
  return textwrap.dedent(f"""
119
+ TASK: {context}
120
+ {reward_context}{hint_block}
121
 
122
+ SCHEMA:
123
+ {obs.schema_sql.strip()[:600]}
124
 
125
+ CURRENT QUERY:
126
  {obs.current_query.strip()}
127
 
128
  ERROR: {obs.error_message or "none"}
129
+ RESULT (first 3 rows): {result_preview}
130
+ FEEDBACK: {feedback}
131
+ BEST SCORE SO FAR: {obs.reward_so_far:.3f}
132
+ STEP: {step} of {MAX_STEPS}
133
 
134
+ Write the corrected SQL:
135
  """).strip()
136
 
137
 
138
+ def call_llm(
139
+ client: OpenAI,
140
+ obs,
141
+ history: List[Dict],
142
+ step: int,
143
+ stall_count: int,
144
+ prev_delta: float,
145
+ ) -> str:
146
+ user_content = build_prompt(obs, step, stall_count, prev_delta)
147
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
148
+ messages.extend(history[-6:])
149
+ messages.append({"role": "user", "content": user_content})
150
+
151
  try:
152
  completion = client.chat.completions.create(
153
  model=MODEL_NAME,
154
+ messages=messages,
 
 
 
155
  temperature=TEMPERATURE,
156
  max_tokens=MAX_TOKENS,
157
  stream=False,
158
  )
159
  raw = (completion.choices[0].message.content or "").strip()
 
 
160
  if "```" in raw:
 
161
  raw = "\n".join(
162
+ l for l in raw.split("\n")
163
+ if not l.strip().startswith("```")
164
  ).strip()
165
 
166
+ result = raw if raw else "SELECT 1"
167
+ history.append({"role": "user", "content": user_content})
168
+ history.append({"role": "assistant", "content": result})
169
+ return result
170
 
171
  except Exception as exc:
172
  print(f"[DEBUG] LLM call failed: {exc}", flush=True)
173
  return "SELECT 1"
174
 
175
 
176
+ # ── Main ──────────────────────────────────────────────────────────────────────
177
 
178
  async def main() -> None:
179
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
180
+ env = SQLDebugEnv(base_url=SERVER_URL)
181
 
182
+ delta_rewards: List[float] = [] # per-step delta β€” logged in [STEP]
183
+ abs_scores: List[float] = [] # per-step absolute β€” used for final score
184
+ history: List[Dict] = []
185
+ stall_count = 0
186
+ prev_delta = 0.0
187
+ steps_taken = 0
188
+ score = 0.0
189
+ success = False
190
 
191
  log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
192
 
193
  try:
194
+ # ── Reset ─────────────────────────────────────────────────────────────
195
+ try:
196
+ result = await env.reset(task_id=TASK_NAME)
197
+ except Exception as e:
198
+ print(f"[DEBUG] reset() failed: {e}", flush=True)
199
+ raise
200
+
201
  obs = result.observation
202
 
203
+ # ── Episode loop ──────────────────────────────────────────────────────
204
  for step in range(1, MAX_STEPS + 1):
205
  if result.done:
206
  break
207
 
208
+ sql_query = call_llm(
209
+ client, obs, history,
210
+ step=step,
211
+ stall_count=stall_count,
212
+ prev_delta=prev_delta,
213
+ )
214
 
 
215
  result = await env.step(SQLDebugAction(query=sql_query))
216
+ obs = result.observation
217
+
218
+ # delta reward from grader (can be negative)
219
+ delta = result.reward or 0.0
220
+ # absolute score tracked via reward_so_far on observation
221
+ abs_s = obs.reward_so_far
222
+ done = result.done
223
+ error = obs.error_message if obs.error_message else None
224
+
225
+ # Stall detection β€” reset on any meaningful change
226
+ if abs(delta) < 0.01:
227
+ stall_count += 1
228
+ else:
229
+ stall_count = 0
230
+
231
+ prev_delta = delta
232
+ delta_rewards.append(delta)
233
+ abs_scores.append(abs_s)
234
  steps_taken = step
235
 
236
+ log_step(step=step, action=sql_query, reward=delta, done=done, error=error)
 
 
 
 
 
 
237
 
238
  if done:
239
  break
240
 
241
+ # Final score = best absolute score reached this episode
242
+ score = max(abs_scores) if abs_scores else 0.0
243
+ score = min(max(score, 0.0), 1.0)
244
  success = score >= SUCCESS_THRESHOLD
245
 
246
  except Exception as exc:
 
252
  except Exception as e:
253
  print(f"[DEBUG] env.close() error: {e}", flush=True)
254
 
255
+ log_end(success=success, steps=steps_taken, score=score, rewards=delta_rewards)
256
 
257
 
258
  if __name__ == "__main__":
models.py CHANGED
@@ -1,31 +1,14 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the BSD-style license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """
8
- Data models for the SQL Debug & Optimizer Environment.
9
- """
10
-
11
- from typing import Any, Dict, List
12
  from pydantic import Field
13
  from openenv.core.env_server.types import Action, Observation
14
 
15
 
16
  class SQLDebugAction(Action):
17
- """
18
- What the agent submits each step β€” just a SQL query string.
19
- The environment will run it, grade it, and return a new observation.
20
- """
21
  query: str = Field(..., description="The SQL query the agent wants to try")
22
 
23
 
24
  class SQLDebugObservation(Observation):
25
- """
26
- What the agent sees after each step.
27
- Contains everything it needs to improve its next query.
28
- """
29
  task_id: str = Field(default="", description="Which task is active")
30
  schema_sql: str = Field(default="", description="CREATE TABLE statements for this task")
31
  current_query: str = Field(default="", description="Last query that was run")
 
1
+ # models.py
2
+ from typing import Any, Dict, List, Optional
 
 
 
 
 
 
 
 
 
3
  from pydantic import Field
4
  from openenv.core.env_server.types import Action, Observation
5
 
6
 
7
  class SQLDebugAction(Action):
 
 
 
 
8
  query: str = Field(..., description="The SQL query the agent wants to try")
9
 
10
 
11
  class SQLDebugObservation(Observation):
 
 
 
 
12
  task_id: str = Field(default="", description="Which task is active")
13
  schema_sql: str = Field(default="", description="CREATE TABLE statements for this task")
14
  current_query: str = Field(default="", description="Last query that was run")
openenv_sql_debug.egg-info/PKG-INFO CHANGED
@@ -6,6 +6,7 @@ Requires-Python: >=3.10
6
  Requires-Dist: openenv-core[core]>=0.2.2
7
  Requires-Dist: openai>=2.30.0
8
  Requires-Dist: uvicorn>=0.43.0
 
9
  Provides-Extra: dev
10
  Requires-Dist: pytest>=8.0.0; extra == "dev"
11
  Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
 
6
  Requires-Dist: openenv-core[core]>=0.2.2
7
  Requires-Dist: openai>=2.30.0
8
  Requires-Dist: uvicorn>=0.43.0
9
+ Requires-Dist: httpx>=0.28.1
10
  Provides-Extra: dev
11
  Requires-Dist: pytest>=8.0.0; extra == "dev"
12
  Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
openenv_sql_debug.egg-info/SOURCES.txt CHANGED
@@ -1,8 +1,13 @@
1
  README.md
 
 
 
 
2
  pyproject.toml
 
 
3
  ./__init__.py
4
  ./client.py
5
- ./grader.py
6
  ./inference.py
7
  ./models.py
8
  ./runner.py
 
1
  README.md
2
+ __init__.py
3
+ client.py
4
+ inference.py
5
+ models.py
6
  pyproject.toml
7
+ runner.py
8
+ test.py
9
  ./__init__.py
10
  ./client.py
 
11
  ./inference.py
12
  ./models.py
13
  ./runner.py
openenv_sql_debug.egg-info/requires.txt CHANGED
@@ -1,6 +1,7 @@
1
  openenv-core[core]>=0.2.2
2
  openai>=2.30.0
3
  uvicorn>=0.43.0
 
4
 
5
  [dev]
6
  pytest>=8.0.0
 
1
  openenv-core[core]>=0.2.2
2
  openai>=2.30.0
3
  uvicorn>=0.43.0
4
+ httpx>=0.28.1
5
 
6
  [dev]
7
  pytest>=8.0.0
pyproject.toml CHANGED
@@ -20,6 +20,7 @@ dependencies = [
20
  "openenv-core[core]>=0.2.2",
21
  "openai>=2.30.0",
22
  "uvicorn>=0.43.0",
 
23
  ]
24
 
25
  [project.optional-dependencies]
 
20
  "openenv-core[core]>=0.2.2",
21
  "openai>=2.30.0",
22
  "uvicorn>=0.43.0",
23
+ "httpx>=0.28.1",
24
  ]
25
 
26
  [project.optional-dependencies]
server/app.py CHANGED
@@ -37,7 +37,7 @@ try:
37
  from .sql_debug_environment import SQLDebugEnvironment
38
  except ModuleNotFoundError:
39
  from models import SQLDebugAction, SQLDebugObservation
40
- from sql_exp.server.sql_debug_environment import SQLDebugEnvironment
41
 
42
 
43
  app = create_app(
 
37
  from .sql_debug_environment import SQLDebugEnvironment
38
  except ModuleNotFoundError:
39
  from models import SQLDebugAction, SQLDebugObservation
40
+ from server.sql_debug_environment import SQLDebugEnvironment
41
 
42
 
43
  app = create_app(
server/requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  openenv[core]>=0.2.0
2
  fastapi>=0.115.0
3
- uvicorn>=0.24.0
 
 
1
  openenv[core]>=0.2.0
2
  fastapi>=0.115.0
3
+ uvicorn>=0.24.0
4
+ httpx>=0.28.1
server/sql_debug_environment.py CHANGED
@@ -10,7 +10,6 @@ SQL Debug & Optimizer Environment β€” server-side implementation.
10
  The server runs this. The agent never touches this file directly.
11
  It loads tasks, runs queries in SQLite, grades them, and returns observations.
12
  """
13
-
14
  from uuid import uuid4
15
  from openenv.core.env_server.interfaces import Environment
16
  from openenv.core.env_server.types import State
@@ -21,35 +20,34 @@ except ImportError:
21
  from models import SQLDebugAction, SQLDebugObservation
22
 
23
  from runner import run_query
24
- from grader import compute_reward
 
 
 
 
25
 
26
 
27
  def _load_all_tasks() -> dict:
28
- """Load every task from the tasks/ folder into a dict keyed by task_id."""
29
  from tasks.task_easy import TASK as EASY
30
  from tasks.task_medium import TASK as MEDIUM
31
  from tasks.task_hard import TASK as HARD
 
32
  return {
33
- EASY["task_id"]: EASY,
34
  MEDIUM["task_id"]: MEDIUM,
35
- HARD["task_id"]: HARD,
36
  }
37
 
38
 
39
- class SQLDebugEnvironment(Environment):
40
- """
41
- SQL Debug & Optimizer environment.
42
-
43
- The agent receives a broken or slow SQL query and must fix/optimize it.
44
- Each step the agent submits a new query β€” the environment runs it in
45
- SQLite, grades it (0.0–1.0), and returns the result as an observation.
46
 
47
- Three tasks:
48
- syntax_fix_001 (easy) β€” fix typos in SQL keywords
49
- logic_fix_001 (medium) β€” fix wrong JOIN type causing bad results
50
- # optimize_001 (hard) β€” rewrite correlated subquery as a CTE
51
- """
52
 
 
53
  SUPPORTS_CONCURRENT_SESSIONS: bool = True
54
 
55
  def __init__(self):
@@ -57,32 +55,29 @@ class SQLDebugEnvironment(Environment):
57
  self._current_task = None
58
  self._state = State(episode_id=str(uuid4()), step_count=0)
59
  self._best_reward = 0.0
 
60
  self._current_query = ""
61
 
62
- # ── reset ────────────────────────────────────────────────────────────────
63
 
64
- def reset(self, task_id: str = None) -> SQLDebugObservation:
65
- """
66
- Start a new episode.
67
- Pass task_id to pick a specific task, or leave None for the default (easy).
68
- """
69
  if task_id is None:
70
- task_id = list(self._all_tasks.keys())[0] # default: easy
71
 
72
  if task_id not in self._all_tasks:
73
- # Unknown task β€” return error observation instead of crashing
74
  return SQLDebugObservation(
75
  task_id=task_id,
76
- error_message=f"Unknown task_id '{task_id}'. Available: {list(self._all_tasks.keys())}",
77
  available_tasks=list(self._all_tasks.keys()),
 
78
  )
79
 
80
  self._current_task = self._all_tasks[task_id]
81
  self._state = State(episode_id=str(uuid4()), step_count=0)
82
  self._best_reward = 0.0
 
83
  self._current_query = self._current_task["broken_query"]
84
 
85
- # Run the broken query so the agent sees the starting error
86
  run_result = run_query(
87
  self._current_task["schema_sql"],
88
  self._current_query,
@@ -101,45 +96,44 @@ class SQLDebugEnvironment(Environment):
101
  available_tasks=list(self._all_tasks.keys()),
102
  done=False,
103
  reward=0.0,
 
104
  )
105
 
106
- # ── step ─────────────────────────────────────────────────────────────────
107
-
108
  def step(self, action: SQLDebugAction) -> SQLDebugObservation:
109
- """
110
- Agent submits a query.
111
- We run it, grade it, and return the new observation + reward.
112
- """
113
  if self._current_task is None:
114
- return SQLDebugObservation(
115
- error_message="Call reset() before step()",
116
- available_tasks=list(self._all_tasks.keys()),
117
- done=True,
118
- reward=0.0,
119
- )
120
 
121
  self._state.step_count += 1
122
  self._current_query = action.query
123
 
124
- # Run the query in SQLite
125
  run_result = run_query(
126
  self._current_task["schema_sql"],
127
  action.query,
128
  )
129
 
130
- # Grade it (returns dict with value, syntax_ok, result_match_pct, etc.)
131
- reward_dict = compute_reward(self._current_task, action.query, run_result)
132
- reward_value = reward_dict["value"]
 
 
 
 
 
 
 
 
133
 
134
- # Track the best reward this episode
135
- self._best_reward = max(self._best_reward, reward_value)
136
 
137
- # Episode ends on perfect score or max steps
138
  max_steps = self._current_task.get("max_steps", 8)
139
- done = (reward_value >= 0.99) or (self._state.step_count >= max_steps)
 
 
140
 
141
  return SQLDebugObservation(
142
- task_id=self._current_task["task_id"],
143
  schema_sql=self._current_task["schema_sql"],
144
  current_query=action.query,
145
  error_message=run_result["error"] or "",
@@ -150,11 +144,17 @@ class SQLDebugEnvironment(Environment):
150
  reward_so_far=self._best_reward,
151
  available_tasks=list(self._all_tasks.keys()),
152
  done=done,
153
- reward=reward_value,
 
 
 
 
 
 
 
 
154
  )
155
 
156
- # ── state ─────────────────────────────────────────────────────────────────
157
-
158
  @property
159
  def state(self) -> State:
160
- return self._state
 
10
  The server runs this. The agent never touches this file directly.
11
  It loads tasks, runs queries in SQLite, grades them, and returns observations.
12
  """
 
13
  from uuid import uuid4
14
  from openenv.core.env_server.interfaces import Environment
15
  from openenv.core.env_server.types import State
 
20
  from models import SQLDebugAction, SQLDebugObservation
21
 
22
  from runner import run_query
23
+
24
+ # Import each task's dedicated grader
25
+ from graders.grader_easy import grade as grade_easy
26
+ from graders.grader_medium import grade as grade_medium
27
+ from graders.grader_hard import grade as grade_hard
28
 
29
 
30
  def _load_all_tasks() -> dict:
 
31
  from tasks.task_easy import TASK as EASY
32
  from tasks.task_medium import TASK as MEDIUM
33
  from tasks.task_hard import TASK as HARD
34
+
35
  return {
36
+ EASY["task_id"]: EASY,
37
  MEDIUM["task_id"]: MEDIUM,
38
+ HARD["task_id"]: HARD,
39
  }
40
 
41
 
42
+ # Maps each task_id to its dedicated grader function
43
+ TASK_GRADERS = {
44
+ "syntax_fix_001": grade_easy,
45
+ "logic_fix_001": grade_medium,
46
+ "optimize_001": grade_hard,
47
+ }
 
48
 
 
 
 
 
 
49
 
50
+ class SQLDebugEnvironment(Environment):
51
  SUPPORTS_CONCURRENT_SESSIONS: bool = True
52
 
53
  def __init__(self):
 
55
  self._current_task = None
56
  self._state = State(episode_id=str(uuid4()), step_count=0)
57
  self._best_reward = 0.0
58
+ self._prev_absolute_score = 0.0 # used for delta computation
59
  self._current_query = ""
60
 
61
+ # sql_debug_environment.py β€” replace reset() return and step() return only
62
 
63
+ def reset(self, task_id: str = None, **kwargs) -> SQLDebugObservation:
 
 
 
 
64
  if task_id is None:
65
+ task_id = list(self._all_tasks.keys())[0]
66
 
67
  if task_id not in self._all_tasks:
 
68
  return SQLDebugObservation(
69
  task_id=task_id,
70
+ error_message=f"Unknown task '{task_id}'. Available: {list(self._all_tasks.keys())}",
71
  available_tasks=list(self._all_tasks.keys()),
72
+ metadata={},
73
  )
74
 
75
  self._current_task = self._all_tasks[task_id]
76
  self._state = State(episode_id=str(uuid4()), step_count=0)
77
  self._best_reward = 0.0
78
+ self._prev_absolute_score = 0.0
79
  self._current_query = self._current_task["broken_query"]
80
 
 
81
  run_result = run_query(
82
  self._current_task["schema_sql"],
83
  self._current_query,
 
96
  available_tasks=list(self._all_tasks.keys()),
97
  done=False,
98
  reward=0.0,
99
+ metadata={"feedback": "", "status": "ready"}, # ← feedback in metadata
100
  )
101
 
 
 
102
  def step(self, action: SQLDebugAction) -> SQLDebugObservation:
103
+ # Auto-reset if not already initialized (handles session management issues)
 
 
 
104
  if self._current_task is None:
105
+ self.reset()
 
 
 
 
 
106
 
107
  self._state.step_count += 1
108
  self._current_query = action.query
109
 
 
110
  run_result = run_query(
111
  self._current_task["schema_sql"],
112
  action.query,
113
  )
114
 
115
+ task_id = self._current_task["task_id"]
116
+ grader_fn = TASK_GRADERS.get(task_id, grade_easy)
117
+
118
+ reward_dict = grader_fn(
119
+ task=self._current_task,
120
+ agent_query=action.query,
121
+ run_result=run_result,
122
+ prev_absolute_score=self._prev_absolute_score,
123
+ step_count=self._state.step_count,
124
+ max_steps=self._current_task.get("max_steps", 8),
125
+ )
126
 
127
+ self._prev_absolute_score = reward_dict["absolute_score"]
128
+ self._best_reward = max(self._best_reward, reward_dict["absolute_score"])
129
 
 
130
  max_steps = self._current_task.get("max_steps", 8)
131
+ done = (
132
+ reward_dict["absolute_score"] >= 0.99 or self._state.step_count >= max_steps
133
+ )
134
 
135
  return SQLDebugObservation(
136
+ task_id=task_id,
137
  schema_sql=self._current_task["schema_sql"],
138
  current_query=action.query,
139
  error_message=run_result["error"] or "",
 
144
  reward_so_far=self._best_reward,
145
  available_tasks=list(self._all_tasks.keys()),
146
  done=done,
147
+ reward=reward_dict["value"],
148
+ metadata={ # ← all extra data here
149
+ "feedback": reward_dict["feedback"],
150
+ "status": reward_dict["status"],
151
+ "absolute_score": reward_dict["absolute_score"],
152
+ "delta": reward_dict["delta"],
153
+ "result_score": reward_dict["result_score"],
154
+ "plan_score": reward_dict["plan_score"],
155
+ },
156
  )
157
 
 
 
158
  @property
159
  def state(self) -> State:
160
+ return self._state
tasks/task_hard.py CHANGED
@@ -4,12 +4,12 @@ import random
4
  def generate_schema(n_rows=5000, seed=42):
5
  """Generates schema + INSERT statements for n_rows transactions."""
6
  rng = random.Random(seed)
7
- statuses = ['completed', 'pending', 'failed']
8
  inserts = []
9
  for i in range(1, n_rows + 1):
10
  user_id = rng.randint(1, 100)
11
  amount = round(rng.uniform(10, 1000), 2)
12
- status = rng.choice(statuses)
13
  inserts.append(f"INSERT INTO transactions VALUES ({i}, {user_id}, {amount}, 'completed');")
14
  return (
15
  "CREATE TABLE transactions (id INTEGER, user_id INTEGER, amount REAL, ts TEXT, status TEXT);\n"
 
4
  def generate_schema(n_rows=5000, seed=42):
5
  """Generates schema + INSERT statements for n_rows transactions."""
6
  rng = random.Random(seed)
7
+ # statuses = ['completed', 'pending', 'failed']
8
  inserts = []
9
  for i in range(1, n_rows + 1):
10
  user_id = rng.randint(1, 100)
11
  amount = round(rng.uniform(10, 1000), 2)
12
+ # status = rng.choice(statuses)
13
  inserts.append(f"INSERT INTO transactions VALUES ({i}, {user_id}, {amount}, 'completed');")
14
  return (
15
  "CREATE TABLE transactions (id INTEGER, user_id INTEGER, amount REAL, ts TEXT, status TEXT);\n"
uv.lock CHANGED
@@ -1603,6 +1603,7 @@ name = "openenv-sql-debug"
1603
  version = "0.1.0"
1604
  source = { editable = "." }
1605
  dependencies = [
 
1606
  { name = "openai" },
1607
  { name = "openenv-core", extra = ["core"] },
1608
  { name = "uvicorn" },
@@ -1616,6 +1617,7 @@ dev = [
1616
 
1617
  [package.metadata]
1618
  requires-dist = [
 
1619
  { name = "openai", specifier = ">=2.30.0" },
1620
  { name = "openenv-core", extras = ["core"], specifier = ">=0.2.2" },
1621
  { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },
 
1603
  version = "0.1.0"
1604
  source = { editable = "." }
1605
  dependencies = [
1606
+ { name = "httpx" },
1607
  { name = "openai" },
1608
  { name = "openenv-core", extra = ["core"] },
1609
  { name = "uvicorn" },
 
1617
 
1618
  [package.metadata]
1619
  requires-dist = [
1620
+ { name = "httpx", specifier = ">=0.28.1" },
1621
  { name = "openai", specifier = ">=2.30.0" },
1622
  { name = "openenv-core", extras = ["core"], specifier = ">=0.2.2" },
1623
  { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },