sql-query-optimizer / env /models.py
Param20h's picture
Upload folder using huggingface_hub
210535c verified
"""
OpenEnv typed models — Observation, Action, Reward.
All models are Pydantic v2 compliant.
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
# ---------------------------------------------------------------------------
# Observation
# ---------------------------------------------------------------------------
class Observation(BaseModel):
"""What the agent sees at each step."""
task_id: int = Field(..., description="Which task (1=easy, 2=medium, 3=hard)")
task_name: str = Field(..., description="Human-readable task name")
task_description: str = Field(..., description="What the agent must accomplish")
query: str = Field(..., description="The SQL query the agent must fix / optimise")
schema_context: str = Field(
..., description="DDL / schema description relevant to the query"
)
hint: Optional[str] = Field(
None, description="Optional natural-language hint for the current step"
)
step_number: int = Field(0, description="Current step within the episode (0-indexed)")
max_steps: int = Field(5, description="Maximum steps allowed per episode")
done: bool = Field(False, description="Whether the episode has ended")
# ---------------------------------------------------------------------------
# Action
# ---------------------------------------------------------------------------
class Action(BaseModel):
"""What the agent submits at each step."""
rewritten_query: str = Field(
..., description="The agent's rewritten / improved SQL query"
)
explanation: str = Field(
..., description="Natural-language explanation of changes made"
)
is_done: bool = Field(
False,
description="Set True when the agent believes the query is fully optimised",
)
# ---------------------------------------------------------------------------
# Reward
# ---------------------------------------------------------------------------
class RewardBreakdown(BaseModel):
correctness: float = Field(0.0, ge=0.0, le=1.0)
performance: float = Field(0.0, ge=0.0, le=1.0)
style: float = Field(0.0, ge=0.0, le=1.0)
step_penalty: float = Field(0.0, le=0.0) # always ≤ 0
class Reward(BaseModel):
"""Reward returned after each step."""
score: float = Field(..., ge=0.0, le=1.0, description="Aggregate step reward")
grader_score: float = Field(
..., ge=0.0, le=1.0, description="Raw grader score for the submitted query"
)
breakdown: RewardBreakdown = Field(
default_factory=RewardBreakdown,
description="Per-dimension partial scores",
)
feedback: str = Field("", description="Human-readable feedback from the grader")
cumulative_score: float = Field(
0.0, ge=0.0, le=1.0, description="Total score accumulated over episode so far"
)