""" OpenEnv typed models — Observation, Action, Reward. All models are Pydantic v2 compliant. """ from __future__ import annotations from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field # --------------------------------------------------------------------------- # Observation # --------------------------------------------------------------------------- class Observation(BaseModel): """What the agent sees at each step.""" task_id: int = Field(..., description="Which task (1=easy, 2=medium, 3=hard)") task_name: str = Field(..., description="Human-readable task name") task_description: str = Field(..., description="What the agent must accomplish") query: str = Field(..., description="The SQL query the agent must fix / optimise") schema_context: str = Field( ..., description="DDL / schema description relevant to the query" ) hint: Optional[str] = Field( None, description="Optional natural-language hint for the current step" ) step_number: int = Field(0, description="Current step within the episode (0-indexed)") max_steps: int = Field(5, description="Maximum steps allowed per episode") done: bool = Field(False, description="Whether the episode has ended") # --------------------------------------------------------------------------- # Action # --------------------------------------------------------------------------- class Action(BaseModel): """What the agent submits at each step.""" rewritten_query: str = Field( ..., description="The agent's rewritten / improved SQL query" ) explanation: str = Field( ..., description="Natural-language explanation of changes made" ) is_done: bool = Field( False, description="Set True when the agent believes the query is fully optimised", ) # --------------------------------------------------------------------------- # Reward # --------------------------------------------------------------------------- class RewardBreakdown(BaseModel): correctness: float = Field(0.0, ge=0.0, le=1.0) performance: float = Field(0.0, ge=0.0, le=1.0) style: float = Field(0.0, ge=0.0, le=1.0) step_penalty: float = Field(0.0, le=0.0) # always ≤ 0 class Reward(BaseModel): """Reward returned after each step.""" score: float = Field(..., ge=0.0, le=1.0, description="Aggregate step reward") grader_score: float = Field( ..., ge=0.0, le=1.0, description="Raw grader score for the submitted query" ) breakdown: RewardBreakdown = Field( default_factory=RewardBreakdown, description="Per-dimension partial scores", ) feedback: str = Field("", description="Human-readable feedback from the grader") cumulative_score: float = Field( 0.0, ge=0.0, le=1.0, description="Total score accumulated over episode so far" )