File size: 7,604 Bytes
a8a3c90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3867c62
 
a8a3c90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3867c62
 
 
 
 
 
 
 
 
a8a3c90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3867c62
 
 
 
 
 
a8a3c90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3867c62
 
 
 
 
 
a8a3c90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
"""
QueryForge SQL Environment β€” server-side implementation.

The agent interacts with a SQL debugging and optimisation challenge:
  reset()              β†’ next task in round-robin rotation
  reset(task_id="x")  β†’ pin to a specific task by ID (built-in or custom)
  step()               β†’ grade the submitted query, return scored observation
  state                β†’ episode_id + step count

Reward scale:
  0.00        syntax error
  0.15        syntax valid, runtime error
  0.30        executes, wrong / empty results
  0.30–0.80   partial row correctness (deterministic, DuckDB)
  0.80–1.00   correct results + AI quality assessment (Anthropic)

Episode ends when:
  - score >= 0.90  (correct + high-quality solution)
  - best_score has not improved for 2 consecutive steps (early stopping)
  - max_steps for the task is exhausted
"""

import logging
import os
from typing import Optional
from uuid import uuid4

from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State

try:
    from ..models import SQLAction, SQLObservation
    from ..tasks import REGISTRY, SQLTask
    from ..judge import grade
except ImportError:
    from models import SQLAction, SQLObservation
    from tasks import REGISTRY, SQLTask
    from judge import grade

logger = logging.getLogger(__name__)
_AI_JUDGE_ACTIVE = bool(os.environ.get("ANTHROPIC_API_KEY"))

logger.info(
    "QueryForge environment loaded | AI judge: %s | done_threshold: %s",
    "ACTIVE (scores up to 1.0)" if _AI_JUDGE_ACTIVE else "OFFLINE β€” deterministic only (max score 0.80)",
    "0.90" if _AI_JUDGE_ACTIVE else "0.80",
)


class QueryforgeEnvironment(Environment):
    """
    SQL Query Debugger & Optimiser environment.

    Built-in tasks (cycled in order by default):
      1. easy   β€” fix three misspelled SQL keywords
      2. medium β€” fix a missing JOIN condition causing a cartesian product
      3. hard   β€” rewrite a correlated subquery as a CTE

    Custom tasks can be registered at runtime via POST /tasks and then
    requested by passing task_id to reset():
      env.reset(task_id="my_custom_task")

    Each episode ends when:
      - The agent achieves score β‰₯ 0.90 (correct + high-quality solution), or
      - best_score has not improved for 2 consecutive steps (early stopping), or
      - The maximum steps for the current task is exhausted.

    Supports concurrent WebSocket sessions (each client gets its own instance).
    """

    SUPPORTS_CONCURRENT_SESSIONS: bool = True

    # Episode ends when score >= this threshold.
    # Falls back to 0.80 when ANTHROPIC_API_KEY is unset (AI judge offline,
    # deterministic scoring caps at 0.80).
    DONE_THRESHOLD: float = 0.80 if not __import__("os").environ.get("ANTHROPIC_API_KEY") else 0.90
    # Episode ends when best_score has not improved for this many consecutive steps
    EARLY_STOP_STEPS: int = 2

    def __init__(self) -> None:
        self._state = State(episode_id=str(uuid4()), step_count=0)
        self._current_task: Optional[SQLTask] = None
        self._best_score: float = 0.0
        self._attempt: int = 0
        self._stale_steps: int = 0  # consecutive steps with no best_score improvement

    # ── OpenEnv interface ─────────────────────────────────────────────────────

    def reset(
        self,
        task_id: Optional[str] = None,
        seed: Optional[int] = None,
        episode_id: Optional[str] = None,
        **kwargs,
    ) -> SQLObservation:
        """
        Start a new episode.

        Args:
            task_id:    Pin to a specific task by ID.  If None, the registry
                        cycles round-robin through all registered tasks.
            seed:       Ignored (reserved for future use).
            episode_id: Optional custom episode identifier.
        """
        ep_id = episode_id or str(uuid4())
        self._state = State(episode_id=ep_id, step_count=0)
        self._best_score = 0.0
        self._attempt = 0
        self._stale_steps = 0

        logger.info(
            "reset() | task_id=%s | AI judge: %s",
            task_id or "round-robin",
            "ACTIVE" if _AI_JUDGE_ACTIVE else "OFFLINE",
        )

        if task_id is not None:
            try:
                self._current_task = REGISTRY.get(task_id)
            except KeyError as exc:
                # Unknown task_id β€” return an error observation so the caller
                # gets clear feedback instead of a silent 500.
                return SQLObservation(
                    feedback=str(exc),
                    hint=f"Available task IDs: {', '.join(REGISTRY.ids())}",
                    done=True,
                    reward=0.0,
                )
        else:
            self._current_task = REGISTRY.cycle_next()

        return SQLObservation(
            task_id=self._current_task.id,
            task_level=self._current_task.level,
            task_title=self._current_task.title,
            task_description=self._current_task.description,
            syntax_valid=False,
            execution_success=False,
            execution_error=None,
            rows_returned=0,
            feedback="New task loaded. Submit your fixed/optimised SQL query.",
            hint=self._current_task.hint,
            attempt=0,
            best_score=0.0,
            done=False,
            reward=0.0,
        )

    def step(self, action: SQLAction) -> SQLObservation:  # type: ignore[override]
        """Grade the submitted SQL query and return a scored observation."""
        self._state.step_count += 1
        self._attempt += 1

        if self._current_task is None:
            return SQLObservation(
                feedback="No task active. Call reset() first.",
                hint="Call reset() to start a new episode.",
                done=True,
                reward=0.0,
            )

        logger.info(
            "step() | task=%s | attempt=%d | AI judge: %s",
            self._current_task.id,
            self._attempt,
            "ACTIVE" if _AI_JUDGE_ACTIVE else "OFFLINE",
        )
        score, feedback, details = grade(self._current_task, action.sql)

        # Fix 1 β€” early stopping: track consecutive steps with no improvement
        if score > self._best_score:
            self._stale_steps = 0
        else:
            self._stale_steps += 1
        self._best_score = max(self._best_score, score)

        # Fix 3 β€” lower done threshold + early stopping condition
        done = (
            score >= self.DONE_THRESHOLD
            or self._stale_steps >= self.EARLY_STOP_STEPS
            or self._state.step_count >= self._current_task.max_steps
        )

        return SQLObservation(
            task_id=self._current_task.id,
            task_level=self._current_task.level,
            task_title=self._current_task.title,
            task_description=self._current_task.description,
            syntax_valid=bool(details.get("syntax_valid", False)),
            execution_success=bool(details.get("execution_success", False)),
            execution_error=details.get("execution_error"),
            rows_returned=int(details.get("rows_returned", 0)),
            feedback=feedback,
            hint="" if score >= 0.9 else self._current_task.hint,
            attempt=self._attempt,
            best_score=self._best_score,
            done=done,
            reward=score,
        )

    @property
    def state(self) -> State:
        return self._state