Spaces:
Sleeping
Sleeping
File size: 1,772 Bytes
d103a0f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 | from pydantic import BaseModel, Field
from typing import Optional, List, Any
class Action(BaseModel):
"""What the agent can do each step."""
sql_query: Optional[str] = Field(
None,
description="A SQL SELECT query to execute against the database"
)
submit_answer: Optional[str] = Field(
None,
description="Final answer to submit. Ends the episode."
)
def is_valid(self) -> bool:
# Exactly one of the two must be set
return bool(self.sql_query) != bool(self.submit_answer)
class QueryResult(BaseModel):
"""Result of executing a SQL query."""
columns: List[str] = []
rows: List[List[Any]] = []
error: Optional[str] = None
truncated: bool = False
total_rows: int = 0
class Observation(BaseModel):
"""What the agent sees after each step."""
schema_summary: str = Field(..., description="Compact DB schema")
question: str = Field(..., description="Business question to answer")
last_query: Optional[str] = None
last_result: Optional[QueryResult] = None
last_error: Optional[str] = None
step: int = 0
max_steps: int = 20
hints: List[str] = []
done: bool = False
class StepResult(BaseModel):
"""Full result returned by step()."""
observation: Observation
reward: float = 0.0
done: bool = False
info: dict = {}
class EnvState(BaseModel):
"""Full environment state returned by state()."""
task_id: str
difficulty: str
step: int
max_steps: int
query_history: List[str] = []
total_reward: float = 0.0
done: bool = False
|