""" nl2sql-bench/models.py ====================== Typed contracts for the NL2SQL-Bench OpenEnv environment. Action : The SQL query the agent submits. Observation : What the agent sees after each step. State : Episode-level metadata (for state() endpoint). """ from __future__ import annotations from dataclasses import dataclass, field from typing import Any, Dict, List, Optional from openenv.core.env_server import Action, Observation, State # --------------------------------------------------------------------------- # Action # --------------------------------------------------------------------------- class NL2SQLAction(Action): """A single SQL query submitted by the agent.""" query: str = "" # --------------------------------------------------------------------------- # Observation # --------------------------------------------------------------------------- class NL2SQLObservation(Observation): """ Everything the agent needs to reason about and iterate its SQL query. Fields ------ question : The natural-language question to answer. schema_context : Relevant table/column descriptions as a string block. task_name : Identifier of the current task (easy / medium / hard). last_query : The SQL the agent submitted on the last step (empty on reset). last_result : Up to 10 rows returned by the last query (list of dicts). last_error : SQLite error string if the query failed, else None. result_columns : Column names of last_result rows. step : Current step number (1-indexed). max_steps : Maximum steps allowed per episode. done : True when the episode is over (success or step exhausted). reward : Reward for the most recent action (None on reset). score : Normalised cumulative score so far [0.0, 1.0]. """ question: str = "" schema_context: str = "" task_name: str = "" last_query: str = "" last_result: List[Dict[str, Any]] = field(default_factory=list) last_error: Optional[str] = None result_columns: List[str] = field(default_factory=list) step: int = 0 max_steps: int = 5 done: bool = False reward: Optional[float] = None score: float = 0.0 # --------------------------------------------------------------------------- # State # --------------------------------------------------------------------------- class NL2SQLState(State): """Episode-level state (returned by the /state endpoint).""" episode_id: Optional[str] = None step_count: int = 0 task_name: str = "" task_difficulty: str = "" # easy | medium | hard question: str = "" best_reward: float = 0.0 # highest reward seen this episode cumulative_reward: float = 0.0 solved: bool = False # True if exact match was achieved