""" OpenEnv spec routes. POST /env/reset → Observation POST /env/step → {observation: Observation, reward: RewardInfo} GET /env/state → current episode state dict GET /env/tasks → list of task metadata GET /env/info → env metadata """ from __future__ import annotations import json import sys from fastapi import APIRouter, HTTPException from pydantic import BaseModel from typing import Optional from env.sql_env import get_env, Observation, Action, RewardInfo from env.tasks import get_all_tasks router = APIRouter() def _log(tag: str, payload: dict) -> None: """Emit a single structured log line to stdout: [TAG] """ print(f"[{tag}] {json.dumps(payload)}", flush=True) # ─── Request Models ─────────────────────────────────────────────── class ResetRequest(BaseModel): task_id: str = "simple_queries" question_id: Optional[str] = None class StepRequest(BaseModel): repair_action: str = "generate" custom_sql: Optional[str] = None # ─── Routes ─────────────────────────────────────────────────────── @router.post("/reset", response_model=Observation) async def env_reset(req: ResetRequest): """Reset the environment to start a new episode.""" env = get_env() if req.question_id: obs = env.reset_with_question(req.task_id, req.question_id) else: obs = env.reset(req.task_id) _log("START", { "task_id": obs.task_id, "task_difficulty": obs.task_difficulty, "question": obs.question, "max_attempts": obs.max_attempts, }) return obs @router.post("/step") async def env_step(req: StepRequest): """Execute one step in the current episode.""" env = get_env() try: action = Action( repair_action=req.repair_action, custom_sql=req.custom_sql, ) obs, reward = await env.step(action) _log("STEP", { "attempt": obs.attempt_number, "action": req.repair_action, "sql": obs.current_sql or "", "error": obs.error_message, "error_class": obs.error_class, "reward": round(reward.value, 4), "success": reward.success, "done": reward.done, }) if reward.done: ep = env._episode _log("END", { "success": reward.success, "attempts": obs.attempt_number, "total_reward": round( sum(s.reward for s in ep.steps) if ep and ep.steps else reward.value, 4 ), }) return { "observation": obs.model_dump(), "reward": reward.model_dump(), } except RuntimeError as e: raise HTTPException(status_code=400, detail=str(e)) @router.get("/state") async def env_state(): """Get the current episode state.""" env = get_env() return env.state() @router.get("/tasks") async def list_tasks(): """List all available tasks with metadata.""" tasks = get_all_tasks() return [ { "id": t.id, "name": t.name, "difficulty": t.difficulty, "description": t.description, "question_count": len(t.questions), "questions": [ { "id": q.id, "question": q.question, "hint_tables": q.hint_tables, } for q in t.questions ], } for t in tasks ] @router.get("/info") async def env_info(): """Return environment metadata (matches openenv.yaml spec).""" return { "name": "sql-agent-openenv", "version": "1.0.0", "description": "SQL generation and repair environment with RL-driven repair strategy selection.", "action_space": { "type": "discrete", "actions": [ "generate", "rewrite_full", "fix_column", "fix_table", "add_groupby", "rewrite_cte", "fix_syntax", "change_dialect", "relax_filter", ], }, "observation_space": { "type": "dict", "fields": [ "question", "schema_info", "current_sql", "error_message", "error_class", "attempt_number", "max_attempts", "task_id", "task_difficulty", ], }, "reward_range": [0.05, 0.95], "max_steps": 5, "tasks": ["simple_queries", "join_queries", "complex_queries"], "rl_algorithm": "LinUCB (contextual bandit)", "feature_dim": 20, "num_actions": 8, }