nl2sql-bench / models.py
ritvik360's picture
Upload folder using huggingface_hub
a39d8ef verified
"""
nl2sql-bench/models.py
======================
Typed contracts for the NL2SQL-Bench OpenEnv environment.
Action : The SQL query the agent submits.
Observation : What the agent sees after each step.
State : Episode-level metadata (for state() endpoint).
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from openenv.core.env_server import Action, Observation, State
# ---------------------------------------------------------------------------
# Action
# ---------------------------------------------------------------------------
class NL2SQLAction(Action):
"""A single SQL query submitted by the agent."""
query: str = ""
# ---------------------------------------------------------------------------
# Observation
# ---------------------------------------------------------------------------
class NL2SQLObservation(Observation):
"""
Everything the agent needs to reason about and iterate its SQL query.
Fields
------
question : The natural-language question to answer.
schema_context : Relevant table/column descriptions as a string block.
task_name : Identifier of the current task (easy / medium / hard).
last_query : The SQL the agent submitted on the last step (empty on reset).
last_result : Up to 10 rows returned by the last query (list of dicts).
last_error : SQLite error string if the query failed, else None.
result_columns : Column names of last_result rows.
step : Current step number (1-indexed).
max_steps : Maximum steps allowed per episode.
done : True when the episode is over (success or step exhausted).
reward : Reward for the most recent action (None on reset).
score : Normalised cumulative score so far [0.0, 1.0].
"""
question: str = ""
schema_context: str = ""
task_name: str = ""
last_query: str = ""
last_result: List[Dict[str, Any]] = field(default_factory=list)
last_error: Optional[str] = None
result_columns: List[str] = field(default_factory=list)
step: int = 0
max_steps: int = 5
done: bool = False
reward: Optional[float] = None
score: float = 0.0
# ---------------------------------------------------------------------------
# State
# ---------------------------------------------------------------------------
class NL2SQLState(State):
"""Episode-level state (returned by the /state endpoint)."""
episode_id: Optional[str] = None
step_count: int = 0
task_name: str = ""
task_difficulty: str = "" # easy | medium | hard
question: str = ""
best_reward: float = 0.0 # highest reward seen this episode
cumulative_reward: float = 0.0
solved: bool = False # True if exact match was achieved