Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI server for SQL Data Analyst OpenEnv. | |
| Provides REST and WebSocket endpoints for HuggingFace Spaces deployment. | |
| """ | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from pydantic import BaseModel | |
| from typing import Optional, Dict, Any | |
| import json | |
| import asyncio | |
| from env import SQLAnalystEnv, Action | |
| app = FastAPI(title="SQL Data Analyst Environment") | |
| envs: Dict[str, SQLAnalystEnv] = {} | |
| class ResetRequest(BaseModel): | |
| task_id: str = "monthly_signups" | |
| class StepRequest(BaseModel): | |
| session_id: str | |
| sql_query: Optional[str] = None | |
| submit_answer: Optional[str] = None | |
| class StateRequest(BaseModel): | |
| session_id: str | |
| async def root(): | |
| return { | |
| "name": "sql-data-analyst", | |
| "version": "1.0.0", | |
| "description": "SQL Data Analyst OpenEnv - RL environment for SQL query generation", | |
| } | |
| async def reset(req: ResetRequest) -> Dict[str, Any]: | |
| session_id = req.task_id | |
| env = SQLAnalystEnv(task_id=req.task_id) | |
| result = env.reset() | |
| envs[session_id] = env | |
| return { | |
| "session_id": session_id, | |
| "observation": { | |
| "schema_summary": result.observation.schema_summary, | |
| "question": result.observation.question, | |
| "step": result.observation.step, | |
| "max_steps": result.observation.max_steps, | |
| "hints": result.observation.hints, | |
| "done": result.observation.done, | |
| }, | |
| "reward": result.reward, | |
| "done": result.done, | |
| } | |
| async def step(req: StepRequest) -> Dict[str, Any]: | |
| session_id = req.session_id | |
| if session_id not in envs: | |
| return {"error": "Session not found. Call /reset first."} | |
| env = envs[session_id] | |
| action = Action(sql_query=req.sql_query, submit_answer=req.submit_answer) | |
| result = env.step(action) | |
| return { | |
| "observation": { | |
| "schema_summary": result.observation.schema_summary, | |
| "question": result.observation.question, | |
| "last_query": result.observation.last_query, | |
| "last_result": { | |
| "columns": result.observation.last_result.columns | |
| if result.observation.last_result | |
| else None, | |
| "rows": result.observation.last_result.rows | |
| if result.observation.last_result | |
| else None, | |
| "error": result.observation.last_result.error | |
| if result.observation.last_result | |
| else None, | |
| } | |
| if result.observation.last_result | |
| else None, | |
| "last_error": result.observation.last_error, | |
| "step": result.observation.step, | |
| "max_steps": result.observation.max_steps, | |
| "hints": result.observation.hints, | |
| "done": result.observation.done, | |
| }, | |
| "reward": result.reward, | |
| "done": result.done, | |
| "info": result.info, | |
| } | |
| async def state(req: StateRequest) -> Dict[str, Any]: | |
| session_id = req.session_id | |
| if session_id not in envs: | |
| return {"error": "Session not found. Call /reset first."} | |
| env = envs[session_id] | |
| state = env.state() | |
| return { | |
| "task_id": state.task_id, | |
| "difficulty": state.difficulty, | |
| "step": state.step, | |
| "max_steps": state.max_steps, | |
| "query_history": state.query_history, | |
| "total_reward": state.total_reward, | |
| "done": state.done, | |
| } | |
| async def delete_session(req: StateRequest) -> Dict[str, str]: | |
| session_id = req.session_id | |
| if session_id in envs: | |
| del envs[session_id] | |
| return {"status": "deleted", "session_id": session_id} | |
| return {"status": "not_found", "session_id": session_id} | |
| async def websocket_endpoint(websocket: WebSocket): | |
| await websocket.accept() | |
| session_id = None | |
| env = None | |
| try: | |
| while True: | |
| data = await websocket.receive_text() | |
| message = json.loads(data) | |
| action_type = message.get("type") | |
| if action_type == "reset": | |
| task_id = message.get("task_id", "monthly_signups") | |
| env = SQLAnalystEnv(task_id=task_id) | |
| result = env.reset() | |
| session_id = task_id | |
| envs[session_id] = env | |
| await websocket.send_json( | |
| { | |
| "type": "reset", | |
| "observation": { | |
| "schema_summary": result.observation.schema_summary, | |
| "question": result.observation.question, | |
| "step": result.observation.step, | |
| "max_steps": result.observation.max_steps, | |
| "hints": result.observation.hints, | |
| }, | |
| "reward": result.reward, | |
| "done": result.done, | |
| } | |
| ) | |
| elif action_type == "step": | |
| if not env: | |
| await websocket.send_json({"error": "Call reset first"}) | |
| continue | |
| action = Action( | |
| sql_query=message.get("sql_query"), | |
| submit_answer=message.get("submit_answer"), | |
| ) | |
| result = env.step(action) | |
| await websocket.send_json( | |
| { | |
| "type": "step", | |
| "observation": { | |
| "schema_summary": result.observation.schema_summary, | |
| "question": result.observation.question, | |
| "last_query": result.observation.last_query, | |
| "last_result": { | |
| "columns": result.observation.last_result.columns | |
| if result.observation.last_result | |
| else None, | |
| "rows": result.observation.last_result.rows | |
| if result.observation.last_result | |
| else None, | |
| "error": result.observation.last_result.error | |
| if result.observation.last_result | |
| else None, | |
| } | |
| if result.observation.last_result | |
| else None, | |
| "step": result.observation.step, | |
| "hints": result.observation.hints, | |
| "done": result.observation.done, | |
| }, | |
| "reward": result.reward, | |
| "done": result.done, | |
| "info": result.info, | |
| } | |
| ) | |
| elif action_type == "state": | |
| if not env: | |
| await websocket.send_json({"error": "Call reset first"}) | |
| continue | |
| state = env.state() | |
| await websocket.send_json( | |
| { | |
| "type": "state", | |
| "task_id": state.task_id, | |
| "difficulty": state.difficulty, | |
| "step": state.step, | |
| "max_steps": state.max_steps, | |
| "query_history": state.query_history, | |
| "total_reward": state.total_reward, | |
| "done": state.done, | |
| } | |
| ) | |
| elif action_type == "close": | |
| if session_id and session_id in envs: | |
| del envs[session_id] | |
| break | |
| except WebSocketDisconnect: | |
| pass | |
| except Exception as e: | |
| await websocket.send_json({"error": str(e)}) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |