sql_env / specs /F001-INTERFACE_SKETCH.md
hjerpe's picture
Upload folder using huggingface_hub
5dd1bb4 verified

Interface Sketch: F001 - Core Environment Loop

Types

# --- models.py changes ---

class SQLAction(Action):
    """Structured action from agent to environment."""
    action_type: str = Field(
        ..., description="One of: DESCRIBE, SAMPLE, QUERY, ANSWER"
    )
    argument: str = Field(
        ..., description="Table name (DESCRIBE/SAMPLE), SQL string (QUERY), or answer value (ANSWER)"
    )
    # Remove: action_description, tokens (tokens stay if OpenEnv requires them)


class SQLObservation(Observation):
    """Rich observation from environment to agent."""
    # Inherited: done (bool), reward (float | None)
    question: str = Field(..., description="The NL question to answer")
    schema_info: str = Field(..., description="Known schema info (table names initially)")
    result: str = Field(default="", description="Result of last action (truncated)")
    error: str = Field(default="", description="Error message if action failed")
    step_count: int = Field(default=0, description="Current step number")
    budget_remaining: int = Field(default=0, description="Steps left")
    action_history: list[str] = Field(
        default_factory=list, description="Summary of previous actions"
    )


@dataclass
class EpisodeContext:
    """Per-episode server-side state (never sent to agent)."""
    episode_id: str
    db_connection: sqlite3.Connection
    question_record: QuestionRecord
    step_count: int = 0
    budget: int = 15
    described_tables: set[str] = field(default_factory=set)
    action_log: list[str] = field(default_factory=list)
    done: bool = False
    gold_answer: str | None = None  # Computed at reset by running gold_sql


@dataclass
class QuestionRecord:
    """One question from the dataset."""
    question_id: str
    question_text: str
    database_name: str
    gold_sql: str
    gold_answer: str
    answer_type: str  # "integer" | "float" | "string" | "list"
    difficulty: str   # "easy" | "medium" | "hard"
    tables_involved: list[str]

Functions

# --- server/sql_environment.py ---

class SQLEnvironment(Environment[SQLAction, SQLObservation, SQLState]):

    def __init__(self, questions_path: str, db_dir: str, tokenizer, step_budget: int = 15):
        """Initialize with path to questions JSON and database directory."""
        ...

    def reset(self, *, seed: int | None = None, episode_id: str | None = None, **kwargs) -> SQLObservation:
        """Pick random question, open read-only SQLite, return initial observation."""
        ...

    def step(self, action: SQLAction, *, timeout_s: float = 30, **kwargs) -> SQLObservation:
        """Dispatch to handler, update episode context, return observation."""
        ...

    # --- Action handlers (private) ---

    def _handle_describe(self, table_name: str) -> str:
        """Return column names, types, row count for table. Error if table not found."""
        ...

    def _handle_sample(self, table_name: str, limit: int = 5) -> str:
        """Execute SELECT * FROM table LIMIT N, return formatted rows."""
        ...

    def _handle_query(self, sql: str) -> str:
        """Validate SELECT-only, execute with timeout, truncate to 20 rows."""
        ...

    def _handle_answer(self, value: str) -> tuple[bool, float]:
        """Compare to gold answer, return (correct, reward)."""
        ...

    # --- Infrastructure (private) ---

    def _execute_sql(self, sql: str, timeout_s: float = 5.0) -> list[tuple]:
        """Sandboxed execution: read-only, timeout, SELECT-only."""
        ...

    def _open_db(self, db_name: str) -> sqlite3.Connection:
        """Open read-only SQLite connection for a Spider database."""
        ...

    def _load_questions(self, path: str) -> list[QuestionRecord]:
        """Load and parse question JSON into QuestionRecord list."""
        ...

    def _build_observation(self) -> SQLObservation:
        """Construct observation from current episode context."""
        ...

Data Flow

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                         RESET FLOW                                    β”‚
β”‚                                                                       β”‚
β”‚  Client.reset()                                                       β”‚
β”‚       β”‚                                                               β”‚
β”‚       β–Ό                                                               β”‚
β”‚  SQLEnvironment.reset()                                               β”‚
β”‚       β”‚                                                               β”‚
β”‚       β”œβ”€β”€ Pick random QuestionRecord                                  β”‚
β”‚       β”œβ”€β”€ _open_db(question.database_name) ──→ sqlite3.Connection    β”‚
β”‚       β”œβ”€β”€ Execute gold_sql to compute gold_answer                    β”‚
β”‚       β”œβ”€β”€ Create EpisodeContext                                       β”‚
β”‚       └── _build_observation() ──→ SQLObservation                    β”‚
β”‚            (question, table names only, budget=15)                    β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                         STEP FLOW                                     β”‚
β”‚                                                                       β”‚
β”‚  Client.step(SQLAction)                                               β”‚
β”‚       β”‚                                                               β”‚
β”‚       β–Ό                                                               β”‚
β”‚  SQLEnvironment.step(action)                                          β”‚
β”‚       β”‚                                                               β”‚
β”‚       β”œβ”€β”€ Validate action_type ∈ {DESCRIBE, SAMPLE, QUERY, ANSWER}   β”‚
β”‚       β”‚                                                               β”‚
β”‚       β”œβ”€β†’ DESCRIBE ──→ _handle_describe(table_name)                  β”‚
β”‚       β”‚                    └── _get_table_schema() via sqlite3       β”‚
β”‚       β”‚                                                               β”‚
β”‚       β”œβ”€β†’ SAMPLE ──→ _handle_sample(table_name)                      β”‚
β”‚       β”‚                  └── _execute_sql("SELECT * ... LIMIT 5")    β”‚
β”‚       β”‚                                                               β”‚
β”‚       β”œβ”€β†’ QUERY ──→ _handle_query(sql)                               β”‚
β”‚       β”‚                 β”œβ”€β”€ SELECT-only check                        β”‚
β”‚       β”‚                 └── _execute_sql(sql, timeout=5s)            β”‚
β”‚       β”‚                       └── Truncate to 20 rows                β”‚
β”‚       β”‚                                                               β”‚
β”‚       └─→ ANSWER ──→ _handle_answer(value)                           β”‚
β”‚                          β”œβ”€β”€ Compare to gold_answer                  β”‚
β”‚                          └── done=True, reward=1.0|0.0               β”‚
β”‚                                                               β”‚
β”‚       β”œβ”€β”€ Update EpisodeContext (step_count++, budget--)              β”‚
β”‚       β”œβ”€β”€ Check budget exhaustion β†’ done=True if budget==0           β”‚
β”‚       └── _build_observation() ──→ SQLObservation                    β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                    SQLITE SANDBOXING                                   β”‚
β”‚                                                                       β”‚
β”‚  _execute_sql(sql, timeout_s)                                        β”‚
β”‚       β”‚                                                               β”‚
β”‚       β”œβ”€β”€ Check: sql.strip().upper().startswith("SELECT")            β”‚
β”‚       β”‚     └── Reject non-SELECT β†’ error message                    β”‚
β”‚       β”‚                                                               β”‚
β”‚       β”œβ”€β”€ Execute via read-only sqlite3.Connection                   β”‚
β”‚       β”‚     └── URI: "file:{path}?mode=ro"                           β”‚
β”‚       β”‚                                                               β”‚
β”‚       β”œβ”€β”€ Timeout: sqlite3 progress_handler or thread timeout        β”‚
β”‚       β”‚     └── Kill query after timeout_s β†’ timeout error           β”‚
β”‚       β”‚                                                               β”‚
β”‚       └── Truncate results to max_rows (20)                          β”‚
β”‚             └── Append "... (N more rows)" if truncated              β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Open Questions

  • Should _execute_sql use sqlite3.connect progress handler (callback-based interrupt) or a thread with timeout? Progress handler is simpler but SQLite-specific.
  • Should we keep the tokens field in SQLAction/SQLObservation for backward compat, or do a clean break? Rich observations may make tokens redundant.
  • How to handle message_to_action() β€” is it required by OpenEnv's client protocol, or can we remove it?