Spaces:
Sleeping
Sleeping
File size: 7,604 Bytes
a8a3c90 3867c62 a8a3c90 3867c62 a8a3c90 3867c62 a8a3c90 3867c62 a8a3c90 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | """
QueryForge SQL Environment β server-side implementation.
The agent interacts with a SQL debugging and optimisation challenge:
reset() β next task in round-robin rotation
reset(task_id="x") β pin to a specific task by ID (built-in or custom)
step() β grade the submitted query, return scored observation
state β episode_id + step count
Reward scale:
0.00 syntax error
0.15 syntax valid, runtime error
0.30 executes, wrong / empty results
0.30β0.80 partial row correctness (deterministic, DuckDB)
0.80β1.00 correct results + AI quality assessment (Anthropic)
Episode ends when:
- score >= 0.90 (correct + high-quality solution)
- best_score has not improved for 2 consecutive steps (early stopping)
- max_steps for the task is exhausted
"""
import logging
import os
from typing import Optional
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
try:
from ..models import SQLAction, SQLObservation
from ..tasks import REGISTRY, SQLTask
from ..judge import grade
except ImportError:
from models import SQLAction, SQLObservation
from tasks import REGISTRY, SQLTask
from judge import grade
logger = logging.getLogger(__name__)
_AI_JUDGE_ACTIVE = bool(os.environ.get("ANTHROPIC_API_KEY"))
logger.info(
"QueryForge environment loaded | AI judge: %s | done_threshold: %s",
"ACTIVE (scores up to 1.0)" if _AI_JUDGE_ACTIVE else "OFFLINE β deterministic only (max score 0.80)",
"0.90" if _AI_JUDGE_ACTIVE else "0.80",
)
class QueryforgeEnvironment(Environment):
"""
SQL Query Debugger & Optimiser environment.
Built-in tasks (cycled in order by default):
1. easy β fix three misspelled SQL keywords
2. medium β fix a missing JOIN condition causing a cartesian product
3. hard β rewrite a correlated subquery as a CTE
Custom tasks can be registered at runtime via POST /tasks and then
requested by passing task_id to reset():
env.reset(task_id="my_custom_task")
Each episode ends when:
- The agent achieves score β₯ 0.90 (correct + high-quality solution), or
- best_score has not improved for 2 consecutive steps (early stopping), or
- The maximum steps for the current task is exhausted.
Supports concurrent WebSocket sessions (each client gets its own instance).
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
# Episode ends when score >= this threshold.
# Falls back to 0.80 when ANTHROPIC_API_KEY is unset (AI judge offline,
# deterministic scoring caps at 0.80).
DONE_THRESHOLD: float = 0.80 if not __import__("os").environ.get("ANTHROPIC_API_KEY") else 0.90
# Episode ends when best_score has not improved for this many consecutive steps
EARLY_STOP_STEPS: int = 2
def __init__(self) -> None:
self._state = State(episode_id=str(uuid4()), step_count=0)
self._current_task: Optional[SQLTask] = None
self._best_score: float = 0.0
self._attempt: int = 0
self._stale_steps: int = 0 # consecutive steps with no best_score improvement
# ββ OpenEnv interface βββββββββββββββββββββββββββββββββββββββββββββββββββββ
def reset(
self,
task_id: Optional[str] = None,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs,
) -> SQLObservation:
"""
Start a new episode.
Args:
task_id: Pin to a specific task by ID. If None, the registry
cycles round-robin through all registered tasks.
seed: Ignored (reserved for future use).
episode_id: Optional custom episode identifier.
"""
ep_id = episode_id or str(uuid4())
self._state = State(episode_id=ep_id, step_count=0)
self._best_score = 0.0
self._attempt = 0
self._stale_steps = 0
logger.info(
"reset() | task_id=%s | AI judge: %s",
task_id or "round-robin",
"ACTIVE" if _AI_JUDGE_ACTIVE else "OFFLINE",
)
if task_id is not None:
try:
self._current_task = REGISTRY.get(task_id)
except KeyError as exc:
# Unknown task_id β return an error observation so the caller
# gets clear feedback instead of a silent 500.
return SQLObservation(
feedback=str(exc),
hint=f"Available task IDs: {', '.join(REGISTRY.ids())}",
done=True,
reward=0.0,
)
else:
self._current_task = REGISTRY.cycle_next()
return SQLObservation(
task_id=self._current_task.id,
task_level=self._current_task.level,
task_title=self._current_task.title,
task_description=self._current_task.description,
syntax_valid=False,
execution_success=False,
execution_error=None,
rows_returned=0,
feedback="New task loaded. Submit your fixed/optimised SQL query.",
hint=self._current_task.hint,
attempt=0,
best_score=0.0,
done=False,
reward=0.0,
)
def step(self, action: SQLAction) -> SQLObservation: # type: ignore[override]
"""Grade the submitted SQL query and return a scored observation."""
self._state.step_count += 1
self._attempt += 1
if self._current_task is None:
return SQLObservation(
feedback="No task active. Call reset() first.",
hint="Call reset() to start a new episode.",
done=True,
reward=0.0,
)
logger.info(
"step() | task=%s | attempt=%d | AI judge: %s",
self._current_task.id,
self._attempt,
"ACTIVE" if _AI_JUDGE_ACTIVE else "OFFLINE",
)
score, feedback, details = grade(self._current_task, action.sql)
# Fix 1 β early stopping: track consecutive steps with no improvement
if score > self._best_score:
self._stale_steps = 0
else:
self._stale_steps += 1
self._best_score = max(self._best_score, score)
# Fix 3 β lower done threshold + early stopping condition
done = (
score >= self.DONE_THRESHOLD
or self._stale_steps >= self.EARLY_STOP_STEPS
or self._state.step_count >= self._current_task.max_steps
)
return SQLObservation(
task_id=self._current_task.id,
task_level=self._current_task.level,
task_title=self._current_task.title,
task_description=self._current_task.description,
syntax_valid=bool(details.get("syntax_valid", False)),
execution_success=bool(details.get("execution_success", False)),
execution_error=details.get("execution_error"),
rows_returned=int(details.get("rows_returned", 0)),
feedback=feedback,
hint="" if score >= 0.9 else self._current_task.hint,
attempt=self._attempt,
best_score=self._best_score,
done=done,
reward=score,
)
@property
def state(self) -> State:
return self._state
|