File size: 7,552 Bytes
210535c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2541228
210535c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2541228
 
 
 
 
 
 
 
 
 
 
210535c
 
 
 
35c8316
 
 
 
210535c
 
35c8316
 
 
 
 
 
 
210535c
 
 
733e095
72a3502
210535c
 
72a3502
 
210535c
 
 
 
 
 
733e095
210535c
 
 
 
 
 
 
 
 
 
733e095
210535c
 
 
 
 
 
733e095
210535c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
733e095
210535c
 
 
 
 
 
 
 
 
 
 
 
 
 
733e095
210535c
 
2541228
 
210535c
2541228
 
 
210535c
 
2541228
210535c
 
 
2541228
210535c
 
2541228
210535c
 
 
 
2541228
210535c
2541228
 
 
 
 
210535c
2541228
210535c
 
2541228
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
"""
FastAPI server exposing the OpenEnv SQL Optimizer environment.

Endpoints:
  POST /reset        β†’ Observation
  POST /step         β†’ {observation, reward, done, info}
  GET  /state        β†’ state dict
  GET  /tasks        β†’ list of tasks + action schema
  GET  /grader       β†’ grader score for last completed episode
  POST /baseline     β†’ trigger baseline inference on all 3 tasks
"""
from __future__ import annotations

import os
import subprocess
import sys
from typing import Any, Dict, Optional

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn

from env.environment import SQLOptimizerEnv
from env.models import Action, Observation, Reward
from env.tasks import TASKS

app = FastAPI(
    title="SQL Query Optimizer β€” OpenEnv",
    description=(
        "An OpenEnv-compliant environment where AI agents learn to rewrite "
        "and optimise SQL queries across three difficulty levels."
    ),
    version="1.0.0",
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# Single shared environment instance (stateful, per-process)
_env = SQLOptimizerEnv()


# ──────────────────────────────────────────────────────────────────────────────
# Request / Response schemas
# ──────────────────────────────────────────────────────────────────────────────

class ResetRequest(BaseModel):
    task_id: int = 1


class StepResponse(BaseModel):
    observation: Observation
    reward: Reward
    done: bool
    info: Dict[str, Any]


class GraderResponse(BaseModel):
    task_id: Optional[int]
    grader_score: float
    cumulative_score: float
    done: bool


class TaskInfo(BaseModel):
    id: int
    name: str
    difficulty: str
    description: str
    action_schema: Dict[str, Any]


class BaselineResponse(BaseModel):
    task_results: Dict[str, float]
    message: str


def _parse_end_payload(stdout: str) -> Dict[str, Any]:
    for line in reversed(stdout.splitlines()):
        if not line.startswith("[END] "):
            continue
        payload_text = line[len("[END] ") :].strip()
        import json

        return json.loads(payload_text)
    raise ValueError("Could not find [END] payload in inference output")


# ──────────────────────────────────────────────────────────────────────────────
# Endpoints
# ──────────────────────────────────────────────────────────────────────────────

def _health_payload() -> Dict[str, str]:
    return {"status": "ok", "environment": "sql-query-optimizer", "version": "1.0.0"}


@app.get("/", summary="Health check")
def health() -> Dict[str, str]:
    return _health_payload()


@app.get("/web", include_in_schema=False)
@app.get("/web/", include_in_schema=False)
def web_health() -> Dict[str, str]:
    return _health_payload()


@app.post("/reset", response_model=Observation, summary="Start / restart an episode")
@app.post("/web/reset", response_model=Observation, include_in_schema=False)
def reset(req: Optional[ResetRequest] = None) -> Observation:
    """Reset the environment for a given task_id (1=easy, 2=medium, 3=hard)."""
    try:
        task_id = req.task_id if req is not None else 1
        obs = _env.reset(task_id=task_id)
    except ValueError as exc:
        raise HTTPException(status_code=400, detail=str(exc))
    return obs


@app.post("/step", response_model=StepResponse, summary="Submit an action")
@app.post("/web/step", response_model=StepResponse, include_in_schema=False)
def step(action: Action) -> StepResponse:
    """Advance the environment by submitting an Action."""
    try:
        obs, reward, done, info = _env.step(action)
    except RuntimeError as exc:
        raise HTTPException(status_code=400, detail=str(exc))
    return StepResponse(observation=obs, reward=reward, done=done, info=info)


@app.get("/state", summary="Return current internal state")
@app.get("/web/state", include_in_schema=False)
def state() -> Dict[str, Any]:
    """Return the current internal state of the environment."""
    return _env.state()


@app.get("/tasks", response_model=list[TaskInfo], summary="List tasks + action schema")
@app.get("/web/tasks", response_model=list[TaskInfo], include_in_schema=False)
def list_tasks() -> list[TaskInfo]:
    """Return all tasks with descriptions and the action schema."""
    action_schema = Action.model_json_schema()
    return [
        TaskInfo(
            id=t.id,
            name=t.name,
            difficulty=t.difficulty,
            description=t.description,
            action_schema=action_schema,
        )
        for t in TASKS.values()
    ]


@app.get("/grader", response_model=GraderResponse, summary="Grader score for last episode")
@app.get("/web/grader", response_model=GraderResponse, include_in_schema=False)
def grader() -> GraderResponse:
    """Return the grader score after the current/last episode."""
    s = _env.state()
    if s.get("status") == "not_started":
        raise HTTPException(status_code=400, detail="No episode started. Call /reset first.")
    return GraderResponse(
        task_id=s.get("task_id"),
        grader_score=s.get("last_grader_score", 0.0),
        cumulative_score=s.get("cumulative_score", 0.0),
        done=s.get("done", False),
    )


@app.post("/baseline", response_model=BaselineResponse, summary="Run baseline inference on all tasks")
@app.post("/web/baseline", response_model=BaselineResponse, include_in_schema=False)
def baseline() -> BaselineResponse:
    """
    Trigger the baseline inference script (inference.py) and return scores.
    Requires API_BASE_URL, MODEL_NAME, and HF_TOKEN to be set in the environment.
    """
    required_vars = ["API_BASE_URL", "MODEL_NAME", "HF_TOKEN"]
    missing = [name for name in required_vars if not os.getenv(name)]
    if missing:
        raise HTTPException(
            status_code=400,
            detail=f"Missing required environment variables: {', '.join(missing)}",
        )
    try:
        result = subprocess.run(
            [sys.executable, "inference.py"],
            capture_output=True,
            text=True,
            timeout=1200,
        )
        if result.returncode != 0:
            raise HTTPException(
                status_code=500,
                detail=f"Inference script failed:\n{result.stderr}",
            )
        payload = _parse_end_payload(result.stdout)
        return BaselineResponse(
            task_results=payload.get("task_results", {}),
            message="Baseline completed successfully.",
        )
    except subprocess.TimeoutExpired:
        raise HTTPException(status_code=500, detail="Inference script timed out after 1200s.")
    except Exception as exc:
        raise HTTPException(status_code=500, detail=str(exc))


def main() -> None:
    uvicorn.run("server.app:app", host="0.0.0.0", port=7860)


if __name__ == "__main__":
    main()