Spaces:
Sleeping
Sleeping
| """ | |
| nl2sql-bench/models.py | |
| ====================== | |
| Typed contracts for the NL2SQL-Bench OpenEnv environment. | |
| Action : The SQL query the agent submits. | |
| Observation : What the agent sees after each step. | |
| State : Episode-level metadata (for state() endpoint). | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| from typing import Any, Dict, List, Optional | |
| from openenv.core.env_server import Action, Observation, State | |
| # --------------------------------------------------------------------------- | |
| # Action | |
| # --------------------------------------------------------------------------- | |
| class NL2SQLAction(Action): | |
| """A single SQL query submitted by the agent.""" | |
| query: str = "" | |
| # --------------------------------------------------------------------------- | |
| # Observation | |
| # --------------------------------------------------------------------------- | |
| class NL2SQLObservation(Observation): | |
| """ | |
| Everything the agent needs to reason about and iterate its SQL query. | |
| Fields | |
| ------ | |
| question : The natural-language question to answer. | |
| schema_context : Relevant table/column descriptions as a string block. | |
| task_name : Identifier of the current task (easy / medium / hard). | |
| last_query : The SQL the agent submitted on the last step (empty on reset). | |
| last_result : Up to 10 rows returned by the last query (list of dicts). | |
| last_error : SQLite error string if the query failed, else None. | |
| result_columns : Column names of last_result rows. | |
| step : Current step number (1-indexed). | |
| max_steps : Maximum steps allowed per episode. | |
| done : True when the episode is over (success or step exhausted). | |
| reward : Reward for the most recent action (None on reset). | |
| score : Normalised cumulative score so far [0.0, 1.0]. | |
| """ | |
| question: str = "" | |
| schema_context: str = "" | |
| task_name: str = "" | |
| last_query: str = "" | |
| last_result: List[Dict[str, Any]] = field(default_factory=list) | |
| last_error: Optional[str] = None | |
| result_columns: List[str] = field(default_factory=list) | |
| step: int = 0 | |
| max_steps: int = 5 | |
| done: bool = False | |
| reward: Optional[float] = None | |
| score: float = 0.0 | |
| # --------------------------------------------------------------------------- | |
| # State | |
| # --------------------------------------------------------------------------- | |
| class NL2SQLState(State): | |
| """Episode-level state (returned by the /state endpoint).""" | |
| episode_id: Optional[str] = None | |
| step_count: int = 0 | |
| task_name: str = "" | |
| task_difficulty: str = "" # easy | medium | hard | |
| question: str = "" | |
| best_reward: float = 0.0 # highest reward seen this episode | |
| cumulative_reward: float = 0.0 | |
| solved: bool = False # True if exact match was achieved | |