Spaces:
Sleeping
Sleeping
File size: 2,407 Bytes
f762b8d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 | # environment/models.py
# Typed Pydantic models for OpenEnv interface
# Implements Action, Observation, and Reward schemas
from typing import Optional
from pydantic import BaseModel, model_validator
class Action(BaseModel):
"""
Action model for the SQL Analyst environment.
The agent must provide EXACTLY ONE of:
- sql_query: Execute a SQL query against the database
- submit_answer: Submit a final answer for grading
Edge Case Shield: Pydantic model_validator enforces mutual exclusivity.
"""
sql_query: Optional[str] = None
submit_answer: Optional[str] = None
@model_validator(mode='after')
def validate_exactly_one_action(self) -> 'Action':
"""
Enforce that the agent provides exactly one of sql_query or submit_answer.
This prevents ambiguous actions and ensures clean state transitions.
"""
has_sql = self.sql_query is not None and self.sql_query.strip() != ""
has_answer = self.submit_answer is not None and self.submit_answer.strip() != ""
if has_sql and has_answer:
raise ValueError(
"Invalid action: Provide exactly ONE of 'sql_query' or 'submit_answer', not both."
)
if not has_sql and not has_answer:
raise ValueError(
"Invalid action: Must provide exactly ONE of 'sql_query' or 'submit_answer'."
)
return self
class Observation(BaseModel):
"""
Observation model representing the current state visible to the agent.
Fields:
- schema_info: Database schema information (tables, columns, types)
- current_question: The task question the agent must answer
- last_query_result: Result from the most recent SQL query execution
- error_message: Any error from the last action (empty string if none)
"""
schema_info: str
current_question: str
last_query_result: str
error_message: str
class Reward(BaseModel):
"""
Reward model containing a single float value.
Reward shaping follows the PRD specification:
- +0.1: Successful, error-free SQL query
- -0.1: SQLite syntax error
- -1.0: Destructive action detected (done=True)
- -0.5: Step count >= 15 (infinite loop shield, done=True)
"""
value: float
|