Param20h's picture
Upload folder using huggingface_hub
72a3502 verified
"""
FastAPI server exposing the OpenEnv SQL Optimizer environment.
Endpoints:
POST /reset β†’ Observation
POST /step β†’ {observation, reward, done, info}
GET /state β†’ state dict
GET /tasks β†’ list of tasks + action schema
GET /grader β†’ grader score for last completed episode
POST /baseline β†’ trigger baseline inference on all 3 tasks
"""
from __future__ import annotations
import os
import subprocess
import sys
from typing import Any, Dict, Optional
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
from env.environment import SQLOptimizerEnv
from env.models import Action, Observation, Reward
from env.tasks import TASKS
app = FastAPI(
title="SQL Query Optimizer β€” OpenEnv",
description=(
"An OpenEnv-compliant environment where AI agents learn to rewrite "
"and optimise SQL queries across three difficulty levels."
),
version="1.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Single shared environment instance (stateful, per-process)
_env = SQLOptimizerEnv()
# ──────────────────────────────────────────────────────────────────────────────
# Request / Response schemas
# ──────────────────────────────────────────────────────────────────────────────
class ResetRequest(BaseModel):
task_id: int = 1
class StepResponse(BaseModel):
observation: Observation
reward: Reward
done: bool
info: Dict[str, Any]
class GraderResponse(BaseModel):
task_id: Optional[int]
grader_score: float
cumulative_score: float
done: bool
class TaskInfo(BaseModel):
id: int
name: str
difficulty: str
description: str
action_schema: Dict[str, Any]
class BaselineResponse(BaseModel):
task_results: Dict[str, float]
message: str
def _parse_end_payload(stdout: str) -> Dict[str, Any]:
for line in reversed(stdout.splitlines()):
if not line.startswith("[END] "):
continue
payload_text = line[len("[END] ") :].strip()
import json
return json.loads(payload_text)
raise ValueError("Could not find [END] payload in inference output")
# ──────────────────────────────────────────────────────────────────────────────
# Endpoints
# ──────────────────────────────────────────────────────────────────────────────
def _health_payload() -> Dict[str, str]:
return {"status": "ok", "environment": "sql-query-optimizer", "version": "1.0.0"}
@app.get("/", summary="Health check")
def health() -> Dict[str, str]:
return _health_payload()
@app.get("/web", include_in_schema=False)
@app.get("/web/", include_in_schema=False)
def web_health() -> Dict[str, str]:
return _health_payload()
@app.post("/reset", response_model=Observation, summary="Start / restart an episode")
@app.post("/web/reset", response_model=Observation, include_in_schema=False)
def reset(req: Optional[ResetRequest] = None) -> Observation:
"""Reset the environment for a given task_id (1=easy, 2=medium, 3=hard)."""
try:
task_id = req.task_id if req is not None else 1
obs = _env.reset(task_id=task_id)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
return obs
@app.post("/step", response_model=StepResponse, summary="Submit an action")
@app.post("/web/step", response_model=StepResponse, include_in_schema=False)
def step(action: Action) -> StepResponse:
"""Advance the environment by submitting an Action."""
try:
obs, reward, done, info = _env.step(action)
except RuntimeError as exc:
raise HTTPException(status_code=400, detail=str(exc))
return StepResponse(observation=obs, reward=reward, done=done, info=info)
@app.get("/state", summary="Return current internal state")
@app.get("/web/state", include_in_schema=False)
def state() -> Dict[str, Any]:
"""Return the current internal state of the environment."""
return _env.state()
@app.get("/tasks", response_model=list[TaskInfo], summary="List tasks + action schema")
@app.get("/web/tasks", response_model=list[TaskInfo], include_in_schema=False)
def list_tasks() -> list[TaskInfo]:
"""Return all tasks with descriptions and the action schema."""
action_schema = Action.model_json_schema()
return [
TaskInfo(
id=t.id,
name=t.name,
difficulty=t.difficulty,
description=t.description,
action_schema=action_schema,
)
for t in TASKS.values()
]
@app.get("/grader", response_model=GraderResponse, summary="Grader score for last episode")
@app.get("/web/grader", response_model=GraderResponse, include_in_schema=False)
def grader() -> GraderResponse:
"""Return the grader score after the current/last episode."""
s = _env.state()
if s.get("status") == "not_started":
raise HTTPException(status_code=400, detail="No episode started. Call /reset first.")
return GraderResponse(
task_id=s.get("task_id"),
grader_score=s.get("last_grader_score", 0.0),
cumulative_score=s.get("cumulative_score", 0.0),
done=s.get("done", False),
)
@app.post("/baseline", response_model=BaselineResponse, summary="Run baseline inference on all tasks")
@app.post("/web/baseline", response_model=BaselineResponse, include_in_schema=False)
def baseline() -> BaselineResponse:
"""
Trigger the baseline inference script (inference.py) and return scores.
Requires API_BASE_URL, MODEL_NAME, and HF_TOKEN to be set in the environment.
"""
required_vars = ["API_BASE_URL", "MODEL_NAME", "HF_TOKEN"]
missing = [name for name in required_vars if not os.getenv(name)]
if missing:
raise HTTPException(
status_code=400,
detail=f"Missing required environment variables: {', '.join(missing)}",
)
try:
result = subprocess.run(
[sys.executable, "inference.py"],
capture_output=True,
text=True,
timeout=1200,
)
if result.returncode != 0:
raise HTTPException(
status_code=500,
detail=f"Inference script failed:\n{result.stderr}",
)
payload = _parse_end_payload(result.stdout)
return BaselineResponse(
task_results=payload.get("task_results", {}),
message="Baseline completed successfully.",
)
except subprocess.TimeoutExpired:
raise HTTPException(status_code=500, detail="Inference script timed out after 1200s.")
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
def main() -> None:
uvicorn.run("server.app:app", host="0.0.0.0", port=7860)
if __name__ == "__main__":
main()