Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| from typing import Any, Dict, Optional | |
| from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Request | |
| from openenv_state import OPENENV_STATE, OpenEnvState | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from models import MLOpsAction, MLOpsObservation, MLOpsState | |
| from mlops_environment import MLOpsEnvironment | |
| app = FastAPI( | |
| title="MLOps Pipeline Debugger", | |
| description="OpenEnv environment: AI agent diagnoses broken ML training runs.", | |
| version="1.0.0", | |
| ) | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) | |
| _http_env: Optional[MLOpsEnvironment] = None | |
| class ResetRequest(BaseModel): | |
| task_id: Optional[str] = "easy" | |
| seed: Optional[int] = None | |
| task: Optional[str] = None # Support both task_id and task | |
| class StepResponse(BaseModel): | |
| observation: MLOpsObservation | |
| reward: float | |
| done: bool | |
| info: Dict[str, Any] | |
| async def reset(request: Request): | |
| try: | |
| body = await request.json() | |
| except Exception: | |
| body = {} | |
| task_id = body.get("task_id") or body.get("task") or "easy" | |
| seed = body.get("seed") | |
| global _http_env | |
| _http_env = MLOpsEnvironment(task_id=task_id) | |
| return _http_env.reset(seed=seed) | |
| async def root(): | |
| return { | |
| "message": "MLOps Pipeline Debugger API", | |
| "version": "1.0.0", | |
| "docs": "This is an OpenEnv-compatible RL environment", | |
| "endpoints": { | |
| "GET /": "This message", | |
| "GET /health": "Health check", | |
| "GET /tasks": "List available tasks", | |
| "GET /openenv/state": "OpenEnv state", | |
| "POST /reset": "Start a new episode", | |
| "POST /step": "Take an action", | |
| "GET /state": "Get current state", | |
| }, | |
| } | |
| async def health(): | |
| return {"status": "ok", "environment": "mlops_debug_env", "version": "1.0.0"} | |
| def openenv_state(): | |
| return OPENENV_STATE | |
| async def list_tasks(): | |
| return { | |
| "tasks": [ | |
| { | |
| "task_id": "easy", | |
| "name": "Config Error Diagnosis", | |
| "difficulty": "easy", | |
| "max_steps": 20, | |
| }, | |
| { | |
| "task_id": "medium", | |
| "name": "Data Leakage Detection", | |
| "difficulty": "medium", | |
| "max_steps": 30, | |
| }, | |
| { | |
| "task_id": "hard", | |
| "name": "Silent Evaluation Bug", | |
| "difficulty": "hard", | |
| "max_steps": 40, | |
| }, | |
| ] | |
| } | |
| async def step(request: Request): | |
| if _http_env is None: | |
| raise HTTPException(400, "Call /reset first.") | |
| # Get raw body as dict | |
| try: | |
| body = await request.json() | |
| except Exception: | |
| body = {} | |
| # Handle various input formats | |
| action = None | |
| try: | |
| if "action_type" in body: | |
| action = MLOpsAction(**body) | |
| elif "action" in body: | |
| action = MLOpsAction(**body["action"]) | |
| elif "message" in body: | |
| action = MLOpsAction(action_type=body["message"]) | |
| except Exception as e: | |
| raise HTTPException(422, f"Invalid action: {str(e)}") | |
| if action is None or action.action_type is None: | |
| raise HTTPException(422, "Field required: action_type") | |
| try: | |
| obs, reward, done, info = _http_env.step(action) | |
| return StepResponse(observation=obs, reward=reward, done=done, info=info) | |
| except Exception as e: | |
| raise HTTPException(500, f"Step error: {str(e)}") | |
| async def state(): | |
| if _http_env is None: | |
| raise HTTPException(400, "Call /reset first.") | |
| return _http_env.state | |
| async def ws_endpoint(websocket: WebSocket): | |
| await websocket.accept() | |
| env: Optional[MLOpsEnvironment] = None | |
| try: | |
| while True: | |
| msg = json.loads(await websocket.receive_text()) | |
| method = msg.get("method") | |
| if method == "reset": | |
| env = MLOpsEnvironment(task_id=msg.get("task_id", "easy")) | |
| obs = env.reset(seed=msg.get("seed")) | |
| await websocket.send_text( | |
| json.dumps({"method": "reset", "observation": obs.model_dump()}) | |
| ) | |
| elif method == "step": | |
| if env is None: | |
| await websocket.send_text(json.dumps({"error": "Call reset first"})) | |
| continue | |
| action = MLOpsAction(**msg.get("action", {})) | |
| obs, reward, done, info = env.step(action) | |
| await websocket.send_text( | |
| json.dumps( | |
| { | |
| "method": "step", | |
| "observation": obs.model_dump(), | |
| "reward": reward, | |
| "done": done, | |
| "info": info, | |
| } | |
| ) | |
| ) | |
| elif method == "state": | |
| if env is None: | |
| await websocket.send_text(json.dumps({"error": "Call reset first"})) | |
| continue | |
| await websocket.send_text( | |
| json.dumps({"method": "state", "state": env.state.model_dump()}) | |
| ) | |
| except WebSocketDisconnect: | |
| pass | |
| def main(): | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |
| if __name__ == "__main__": | |
| main() | |