File size: 1,772 Bytes
d103a0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from pydantic import BaseModel, Field
from typing import Optional, List, Any

class Action(BaseModel):
    """What the agent can do each step."""
    sql_query: Optional[str] = Field(
        None,
        description="A SQL SELECT query to execute against the database"
    )
    submit_answer: Optional[str] = Field(
        None,
        description="Final answer to submit. Ends the episode."
    )

    def is_valid(self) -> bool:
        # Exactly one of the two must be set
        return bool(self.sql_query) != bool(self.submit_answer)


class QueryResult(BaseModel):
    """Result of executing a SQL query."""
    columns:    List[str]      = []
    rows:       List[List[Any]] = []
    error:      Optional[str]  = None
    truncated:  bool           = False
    total_rows: int            = 0


class Observation(BaseModel):
    """What the agent sees after each step."""
    schema_summary: str   = Field(..., description="Compact DB schema")
    question:       str   = Field(..., description="Business question to answer")
    last_query:     Optional[str]         = None
    last_result:    Optional[QueryResult] = None
    last_error:     Optional[str]         = None
    step:           int   = 0
    max_steps:      int   = 20
    hints:          List[str] = []
    done:           bool  = False


class StepResult(BaseModel):
    """Full result returned by step()."""
    observation: Observation
    reward:      float = 0.0
    done:        bool  = False
    info:        dict  = {}


class EnvState(BaseModel):
    """Full environment state returned by state()."""
    task_id:      str
    difficulty:   str
    step:         int
    max_steps:    int
    query_history: List[str] = []
    total_reward: float = 0.0
    done:         bool  = False