Spaces:
Sleeping
Sleeping
| """ | |
| QueryForge SQL Environment β server-side implementation. | |
| The agent interacts with a SQL debugging and optimisation challenge: | |
| reset() β next task in round-robin rotation | |
| reset(task_id="x") β pin to a specific task by ID (built-in or custom) | |
| step() β grade the submitted query, return scored observation | |
| state β episode_id + step count | |
| Reward scale: | |
| 0.00 syntax error | |
| 0.15 syntax valid, runtime error | |
| 0.30 executes, wrong / empty results | |
| 0.30β0.80 partial row correctness (deterministic, DuckDB) | |
| 0.80β1.00 correct results + AI quality assessment (Anthropic) | |
| Episode ends when: | |
| - score >= 0.90 (correct + high-quality solution) | |
| - best_score has not improved for 2 consecutive steps (early stopping) | |
| - max_steps for the task is exhausted | |
| """ | |
| import logging | |
| import os | |
| from typing import Optional | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| try: | |
| from ..models import SQLAction, SQLObservation | |
| from ..tasks import REGISTRY, SQLTask | |
| from ..judge import grade | |
| except ImportError: | |
| from models import SQLAction, SQLObservation | |
| from tasks import REGISTRY, SQLTask | |
| from judge import grade | |
| logger = logging.getLogger(__name__) | |
| _AI_JUDGE_ACTIVE = bool(os.environ.get("ANTHROPIC_API_KEY")) | |
| logger.info( | |
| "QueryForge environment loaded | AI judge: %s | done_threshold: %s", | |
| "ACTIVE (scores up to 1.0)" if _AI_JUDGE_ACTIVE else "OFFLINE β deterministic only (max score 0.80)", | |
| "0.90" if _AI_JUDGE_ACTIVE else "0.80", | |
| ) | |
| class QueryforgeEnvironment(Environment): | |
| """ | |
| SQL Query Debugger & Optimiser environment. | |
| Built-in tasks (cycled in order by default): | |
| 1. easy β fix three misspelled SQL keywords | |
| 2. medium β fix a missing JOIN condition causing a cartesian product | |
| 3. hard β rewrite a correlated subquery as a CTE | |
| Custom tasks can be registered at runtime via POST /tasks and then | |
| requested by passing task_id to reset(): | |
| env.reset(task_id="my_custom_task") | |
| Each episode ends when: | |
| - The agent achieves score β₯ 0.90 (correct + high-quality solution), or | |
| - best_score has not improved for 2 consecutive steps (early stopping), or | |
| - The maximum steps for the current task is exhausted. | |
| Supports concurrent WebSocket sessions (each client gets its own instance). | |
| """ | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| # Episode ends when score >= this threshold. | |
| # Falls back to 0.80 when ANTHROPIC_API_KEY is unset (AI judge offline, | |
| # deterministic scoring caps at 0.80). | |
| DONE_THRESHOLD: float = 0.80 if not __import__("os").environ.get("ANTHROPIC_API_KEY") else 0.90 | |
| # Episode ends when best_score has not improved for this many consecutive steps | |
| EARLY_STOP_STEPS: int = 2 | |
| def __init__(self) -> None: | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._current_task: Optional[SQLTask] = None | |
| self._best_score: float = 0.0 | |
| self._attempt: int = 0 | |
| self._stale_steps: int = 0 # consecutive steps with no best_score improvement | |
| # ββ OpenEnv interface βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def reset( | |
| self, | |
| task_id: Optional[str] = None, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| **kwargs, | |
| ) -> SQLObservation: | |
| """ | |
| Start a new episode. | |
| Args: | |
| task_id: Pin to a specific task by ID. If None, the registry | |
| cycles round-robin through all registered tasks. | |
| seed: Ignored (reserved for future use). | |
| episode_id: Optional custom episode identifier. | |
| """ | |
| ep_id = episode_id or str(uuid4()) | |
| self._state = State(episode_id=ep_id, step_count=0) | |
| self._best_score = 0.0 | |
| self._attempt = 0 | |
| self._stale_steps = 0 | |
| logger.info( | |
| "reset() | task_id=%s | AI judge: %s", | |
| task_id or "round-robin", | |
| "ACTIVE" if _AI_JUDGE_ACTIVE else "OFFLINE", | |
| ) | |
| if task_id is not None: | |
| try: | |
| self._current_task = REGISTRY.get(task_id) | |
| except KeyError as exc: | |
| # Unknown task_id β return an error observation so the caller | |
| # gets clear feedback instead of a silent 500. | |
| return SQLObservation( | |
| feedback=str(exc), | |
| hint=f"Available task IDs: {', '.join(REGISTRY.ids())}", | |
| done=True, | |
| reward=0.0, | |
| ) | |
| else: | |
| self._current_task = REGISTRY.cycle_next() | |
| return SQLObservation( | |
| task_id=self._current_task.id, | |
| task_level=self._current_task.level, | |
| task_title=self._current_task.title, | |
| task_description=self._current_task.description, | |
| syntax_valid=False, | |
| execution_success=False, | |
| execution_error=None, | |
| rows_returned=0, | |
| feedback="New task loaded. Submit your fixed/optimised SQL query.", | |
| hint=self._current_task.hint, | |
| attempt=0, | |
| best_score=0.0, | |
| done=False, | |
| reward=0.0, | |
| ) | |
| def step(self, action: SQLAction) -> SQLObservation: # type: ignore[override] | |
| """Grade the submitted SQL query and return a scored observation.""" | |
| self._state.step_count += 1 | |
| self._attempt += 1 | |
| if self._current_task is None: | |
| return SQLObservation( | |
| feedback="No task active. Call reset() first.", | |
| hint="Call reset() to start a new episode.", | |
| done=True, | |
| reward=0.0, | |
| ) | |
| logger.info( | |
| "step() | task=%s | attempt=%d | AI judge: %s", | |
| self._current_task.id, | |
| self._attempt, | |
| "ACTIVE" if _AI_JUDGE_ACTIVE else "OFFLINE", | |
| ) | |
| score, feedback, details = grade(self._current_task, action.sql) | |
| # Fix 1 β early stopping: track consecutive steps with no improvement | |
| if score > self._best_score: | |
| self._stale_steps = 0 | |
| else: | |
| self._stale_steps += 1 | |
| self._best_score = max(self._best_score, score) | |
| # Fix 3 β lower done threshold + early stopping condition | |
| done = ( | |
| score >= self.DONE_THRESHOLD | |
| or self._stale_steps >= self.EARLY_STOP_STEPS | |
| or self._state.step_count >= self._current_task.max_steps | |
| ) | |
| return SQLObservation( | |
| task_id=self._current_task.id, | |
| task_level=self._current_task.level, | |
| task_title=self._current_task.title, | |
| task_description=self._current_task.description, | |
| syntax_valid=bool(details.get("syntax_valid", False)), | |
| execution_success=bool(details.get("execution_success", False)), | |
| execution_error=details.get("execution_error"), | |
| rows_returned=int(details.get("rows_returned", 0)), | |
| feedback=feedback, | |
| hint="" if score >= 0.9 else self._current_task.hint, | |
| attempt=self._attempt, | |
| best_score=self._best_score, | |
| done=done, | |
| reward=score, | |
| ) | |
| def state(self) -> State: | |
| return self._state | |