""" FastAPI server for SQL Data Analyst OpenEnv. Provides REST and WebSocket endpoints for HuggingFace Spaces deployment. """ from fastapi import FastAPI, WebSocket, WebSocketDisconnect from pydantic import BaseModel, Field from typing import Optional, Dict, Any import json import asyncio from env import SQLAnalystEnv, Action app = FastAPI(title="SQL Data Analyst Environment") envs: Dict[str, SQLAnalystEnv] = {} class ResetRequest(BaseModel): task_id: str = Field(default="monthly_signups") class StepRequest(BaseModel): session_id: str sql_query: Optional[str] = None submit_answer: Optional[str] = None class StateRequest(BaseModel): session_id: str @app.get("/") async def root(): return { "name": "sql-data-analyst", "version": "1.0.0", "description": "SQL Data Analyst OpenEnv - RL environment for SQL query generation", } @app.post("/reset") async def reset(task_id: str = "monthly_signups") -> Dict[str, Any]: session_id = task_id env = SQLAnalystEnv(task_id=task_id) result = env.reset() envs[session_id] = env return { "session_id": session_id, "observation": { "schema_summary": result.observation.schema_summary, "question": result.observation.question, "step": result.observation.step, "max_steps": result.observation.max_steps, "hints": result.observation.hints, "done": result.observation.done, }, "reward": result.reward, "done": result.done, } @app.post("/step") async def step( session_id: str, sql_query: Optional[str] = None, submit_answer: Optional[str] = None, ) -> Dict[str, Any]: if session_id not in envs: return {"error": "Session not found. Call /reset first."} env = envs[session_id] action = Action(sql_query=sql_query, submit_answer=submit_answer) result = env.step(action) return { "observation": { "schema_summary": result.observation.schema_summary, "question": result.observation.question, "last_query": result.observation.last_query, "last_result": { "columns": result.observation.last_result.columns if result.observation.last_result else None, "rows": result.observation.last_result.rows if result.observation.last_result else None, "error": result.observation.last_result.error if result.observation.last_result else None, } if result.observation.last_result else None, "last_error": result.observation.last_error, "step": result.observation.step, "max_steps": result.observation.max_steps, "hints": result.observation.hints, "done": result.observation.done, }, "reward": result.reward, "done": result.done, "info": result.info, } @app.post("/state") async def state(session_id: str) -> Dict[str, Any]: if session_id not in envs: return {"error": "Session not found. Call /reset first."} env = envs[session_id] state = env.state() return { "task_id": state.task_id, "difficulty": state.difficulty, "step": state.step, "max_steps": state.max_steps, "query_history": state.query_history, "total_reward": state.total_reward, "done": state.done, } @app.post("/delete") async def delete_session(session_id: str) -> Dict[str, str]: if session_id in envs: del envs[session_id] return {"status": "deleted", "session_id": session_id} return {"status": "not_found", "session_id": session_id} @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() session_id = None env = None try: while True: data = await websocket.receive_text() message = json.loads(data) action_type = message.get("type") if action_type == "reset": task_id = message.get("task_id", "monthly_signups") env = SQLAnalystEnv(task_id=task_id) result = env.reset() session_id = task_id envs[session_id] = env await websocket.send_json( { "type": "reset", "observation": { "schema_summary": result.observation.schema_summary, "question": result.observation.question, "step": result.observation.step, "max_steps": result.observation.max_steps, "hints": result.observation.hints, }, "reward": result.reward, "done": result.done, } ) elif action_type == "step": if not env: await websocket.send_json({"error": "Call reset first"}) continue action = Action( sql_query=message.get("sql_query"), submit_answer=message.get("submit_answer"), ) result = env.step(action) await websocket.send_json( { "type": "step", "observation": { "schema_summary": result.observation.schema_summary, "question": result.observation.question, "last_query": result.observation.last_query, "last_result": { "columns": result.observation.last_result.columns if result.observation.last_result else None, "rows": result.observation.last_result.rows if result.observation.last_result else None, "error": result.observation.last_result.error if result.observation.last_result else None, } if result.observation.last_result else None, "step": result.observation.step, "hints": result.observation.hints, "done": result.observation.done, }, "reward": result.reward, "done": result.done, "info": result.info, } ) elif action_type == "state": if not env: await websocket.send_json({"error": "Call reset first"}) continue state = env.state() await websocket.send_json( { "type": "state", "task_id": state.task_id, "difficulty": state.difficulty, "step": state.step, "max_steps": state.max_steps, "query_history": state.query_history, "total_reward": state.total_reward, "done": state.done, } ) elif action_type == "close": if session_id and session_id in envs: del envs[session_id] break except WebSocketDisconnect: pass except Exception as e: await websocket.send_json({"error": str(e)}) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)