# Interface Sketch: F001 - Core Environment Loop ## Types ```python # --- models.py changes --- class SQLAction(Action): """Structured action from agent to environment.""" action_type: str = Field( ..., description="One of: DESCRIBE, SAMPLE, QUERY, ANSWER" ) argument: str = Field( ..., description="Table name (DESCRIBE/SAMPLE), SQL string (QUERY), or answer value (ANSWER)" ) # Remove: action_description, tokens (tokens stay if OpenEnv requires them) class SQLObservation(Observation): """Rich observation from environment to agent.""" # Inherited: done (bool), reward (float | None) question: str = Field(..., description="The NL question to answer") schema_info: str = Field(..., description="Known schema info (table names initially)") result: str = Field(default="", description="Result of last action (truncated)") error: str = Field(default="", description="Error message if action failed") step_count: int = Field(default=0, description="Current step number") budget_remaining: int = Field(default=0, description="Steps left") action_history: list[str] = Field( default_factory=list, description="Summary of previous actions" ) @dataclass class EpisodeContext: """Per-episode server-side state (never sent to agent).""" episode_id: str db_connection: sqlite3.Connection question_record: QuestionRecord step_count: int = 0 budget: int = 15 described_tables: set[str] = field(default_factory=set) action_log: list[str] = field(default_factory=list) done: bool = False gold_answer: str | None = None # Computed at reset by running gold_sql @dataclass class QuestionRecord: """One question from the dataset.""" question_id: str question_text: str database_name: str gold_sql: str gold_answer: str answer_type: str # "integer" | "float" | "string" | "list" difficulty: str # "easy" | "medium" | "hard" tables_involved: list[str] ``` ## Functions ```python # --- server/sql_environment.py --- class SQLEnvironment(Environment[SQLAction, SQLObservation, SQLState]): def __init__(self, questions_path: str, db_dir: str, tokenizer, step_budget: int = 15): """Initialize with path to questions JSON and database directory.""" ... def reset(self, *, seed: int | None = None, episode_id: str | None = None, **kwargs) -> SQLObservation: """Pick random question, open read-only SQLite, return initial observation.""" ... def step(self, action: SQLAction, *, timeout_s: float = 30, **kwargs) -> SQLObservation: """Dispatch to handler, update episode context, return observation.""" ... # --- Action handlers (private) --- def _handle_describe(self, table_name: str) -> str: """Return column names, types, row count for table. Error if table not found.""" ... def _handle_sample(self, table_name: str, limit: int = 5) -> str: """Execute SELECT * FROM table LIMIT N, return formatted rows.""" ... def _handle_query(self, sql: str) -> str: """Validate SELECT-only, execute with timeout, truncate to 20 rows.""" ... def _handle_answer(self, value: str) -> tuple[bool, float]: """Compare to gold answer, return (correct, reward).""" ... # --- Infrastructure (private) --- def _execute_sql(self, sql: str, timeout_s: float = 5.0) -> list[tuple]: """Sandboxed execution: read-only, timeout, SELECT-only.""" ... def _open_db(self, db_name: str) -> sqlite3.Connection: """Open read-only SQLite connection for a Spider database.""" ... def _load_questions(self, path: str) -> list[QuestionRecord]: """Load and parse question JSON into QuestionRecord list.""" ... def _build_observation(self) -> SQLObservation: """Construct observation from current episode context.""" ... ``` ## Data Flow ``` ┌──────────────────────────────────────────────────────────────────────┐ │ RESET FLOW │ │ │ │ Client.reset() │ │ │ │ │ ▼ │ │ SQLEnvironment.reset() │ │ │ │ │ ├── Pick random QuestionRecord │ │ ├── _open_db(question.database_name) ──→ sqlite3.Connection │ │ ├── Execute gold_sql to compute gold_answer │ │ ├── Create EpisodeContext │ │ └── _build_observation() ──→ SQLObservation │ │ (question, table names only, budget=15) │ └──────────────────────────────────────────────────────────────────────┘ ┌──────────────────────────────────────────────────────────────────────┐ │ STEP FLOW │ │ │ │ Client.step(SQLAction) │ │ │ │ │ ▼ │ │ SQLEnvironment.step(action) │ │ │ │ │ ├── Validate action_type ∈ {DESCRIBE, SAMPLE, QUERY, ANSWER} │ │ │ │ │ ├─→ DESCRIBE ──→ _handle_describe(table_name) │ │ │ └── _get_table_schema() via sqlite3 │ │ │ │ │ ├─→ SAMPLE ──→ _handle_sample(table_name) │ │ │ └── _execute_sql("SELECT * ... LIMIT 5") │ │ │ │ │ ├─→ QUERY ──→ _handle_query(sql) │ │ │ ├── SELECT-only check │ │ │ └── _execute_sql(sql, timeout=5s) │ │ │ └── Truncate to 20 rows │ │ │ │ │ └─→ ANSWER ──→ _handle_answer(value) │ │ ├── Compare to gold_answer │ │ └── done=True, reward=1.0|0.0 │ │ │ │ ├── Update EpisodeContext (step_count++, budget--) │ │ ├── Check budget exhaustion → done=True if budget==0 │ │ └── _build_observation() ──→ SQLObservation │ └──────────────────────────────────────────────────────────────────────┘ ┌──────────────────────────────────────────────────────────────────────┐ │ SQLITE SANDBOXING │ │ │ │ _execute_sql(sql, timeout_s) │ │ │ │ │ ├── Check: sql.strip().upper().startswith("SELECT") │ │ │ └── Reject non-SELECT → error message │ │ │ │ │ ├── Execute via read-only sqlite3.Connection │ │ │ └── URI: "file:{path}?mode=ro" │ │ │ │ │ ├── Timeout: sqlite3 progress_handler or thread timeout │ │ │ └── Kill query after timeout_s → timeout error │ │ │ │ │ └── Truncate results to max_rows (20) │ │ └── Append "... (N more rows)" if truncated │ └──────────────────────────────────────────────────────────────────────┘ ``` ## Open Questions - Should `_execute_sql` use `sqlite3.connect` progress handler (callback-based interrupt) or a thread with timeout? Progress handler is simpler but SQLite-specific. - Should we keep the `tokens` field in SQLAction/SQLObservation for backward compat, or do a clean break? Rich observations may make tokens redundant. - How to handle `message_to_action()` — is it required by OpenEnv's client protocol, or can we remove it?