Rockerleo commited on
Commit
3b074e8
·
verified ·
1 Parent(s): cf91c05

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +181 -0
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import json
3
+ from typing import Any, Dict, Optional
4
+ from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Request
5
+ from openenv_state import OPENENV_STATE, OpenEnvState
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from pydantic import BaseModel
8
+
9
+ from models import MLOpsAction, MLOpsObservation, MLOpsState
10
+ from mlops_environment import MLOpsEnvironment
11
+
12
+ app = FastAPI(
13
+ title="MLOps Pipeline Debugger",
14
+ description="OpenEnv environment: AI agent diagnoses broken ML training runs.",
15
+ version="1.0.0",
16
+ )
17
+ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
18
+
19
+ _http_env: Optional[MLOpsEnvironment] = None
20
+
21
+
22
+ class ResetRequest(BaseModel):
23
+ task_id: Optional[str] = "easy"
24
+ seed: Optional[int] = None
25
+ task: Optional[str] = None # Support both task_id and task
26
+
27
+
28
+ class StepResponse(BaseModel):
29
+ observation: MLOpsObservation
30
+ reward: float
31
+ done: bool
32
+ info: Dict[str, Any]
33
+
34
+
35
+ @app.post("/reset", response_model=MLOpsObservation)
36
+ async def reset(request: Request):
37
+ try:
38
+ body = await request.json()
39
+ except Exception:
40
+ body = {}
41
+ task_id = body.get("task_id") or body.get("task") or "easy"
42
+ seed = body.get("seed")
43
+ global _http_env
44
+ _http_env = MLOpsEnvironment(task_id=task_id)
45
+ return _http_env.reset(seed=seed)
46
+
47
+
48
+ @app.get("/")
49
+ async def root():
50
+ return {
51
+ "message": "MLOps Pipeline Debugger API",
52
+ "version": "1.0.0",
53
+ "docs": "This is an OpenEnv-compatible RL environment",
54
+ "endpoints": {
55
+ "GET /": "This message",
56
+ "GET /health": "Health check",
57
+ "GET /tasks": "List available tasks",
58
+ "GET /openenv/state": "OpenEnv state",
59
+ "POST /reset": "Start a new episode",
60
+ "POST /step": "Take an action",
61
+ "GET /state": "Get current state",
62
+ },
63
+ }
64
+
65
+
66
+ @app.get("/health")
67
+ async def health():
68
+ return {"status": "ok", "environment": "mlops_debug_env", "version": "1.0.0"}
69
+
70
+
71
+ @app.get("/openenv/state", response_model=OpenEnvState)
72
+ def openenv_state():
73
+ return OPENENV_STATE
74
+
75
+
76
+ @app.get("/tasks")
77
+ async def list_tasks():
78
+ return {
79
+ "tasks": [
80
+ {
81
+ "task_id": "easy",
82
+ "name": "Config Error Diagnosis",
83
+ "difficulty": "easy",
84
+ "max_steps": 20,
85
+ },
86
+ {
87
+ "task_id": "medium",
88
+ "name": "Data Leakage Detection",
89
+ "difficulty": "medium",
90
+ "max_steps": 30,
91
+ },
92
+ {
93
+ "task_id": "hard",
94
+ "name": "Silent Evaluation Bug",
95
+ "difficulty": "hard",
96
+ "max_steps": 40,
97
+ },
98
+ ]
99
+ }
100
+
101
+
102
+ @app.post("/step", response_model=StepResponse)
103
+ async def step(request: Request):
104
+ if _http_env is None:
105
+ raise HTTPException(400, "Call /reset first.")
106
+
107
+ # Get raw body as dict
108
+ try:
109
+ body = await request.json()
110
+ except Exception:
111
+ body = {}
112
+
113
+ # Handle various input formats
114
+ action = None
115
+ try:
116
+ if "action_type" in body:
117
+ action = MLOpsAction(**body)
118
+ elif "action" in body:
119
+ action = MLOpsAction(**body["action"])
120
+ elif "message" in body:
121
+ action = MLOpsAction(action_type=body["message"])
122
+ except Exception as e:
123
+ raise HTTPException(422, f"Invalid action: {str(e)}")
124
+
125
+ if action is None or action.action_type is None:
126
+ raise HTTPException(422, "Field required: action_type")
127
+
128
+ try:
129
+ obs, reward, done, info = _http_env.step(action)
130
+ return StepResponse(observation=obs, reward=reward, done=done, info=info)
131
+ except Exception as e:
132
+ raise HTTPException(500, f"Step error: {str(e)}")
133
+
134
+
135
+ @app.get("/state", response_model=MLOpsState)
136
+ async def state():
137
+ if _http_env is None:
138
+ raise HTTPException(400, "Call /reset first.")
139
+ return _http_env.state
140
+
141
+
142
+ @app.websocket("/ws")
143
+ async def ws_endpoint(websocket: WebSocket):
144
+ await websocket.accept()
145
+ env: Optional[MLOpsEnvironment] = None
146
+ try:
147
+ while True:
148
+ msg = json.loads(await websocket.receive_text())
149
+ method = msg.get("method")
150
+ if method == "reset":
151
+ env = MLOpsEnvironment(task_id=msg.get("task_id", "easy"))
152
+ obs = env.reset(seed=msg.get("seed"))
153
+ await websocket.send_text(
154
+ json.dumps({"method": "reset", "observation": obs.model_dump()})
155
+ )
156
+ elif method == "step":
157
+ if env is None:
158
+ await websocket.send_text(json.dumps({"error": "Call reset first"}))
159
+ continue
160
+ action = MLOpsAction(**msg.get("action", {}))
161
+ obs, reward, done, info = env.step(action)
162
+ await websocket.send_text(
163
+ json.dumps(
164
+ {
165
+ "method": "step",
166
+ "observation": obs.model_dump(),
167
+ "reward": reward,
168
+ "done": done,
169
+ "info": info,
170
+ }
171
+ )
172
+ )
173
+ elif method == "state":
174
+ if env is None:
175
+ await websocket.send_text(json.dumps({"error": "Call reset first"}))
176
+ continue
177
+ await websocket.send_text(
178
+ json.dumps({"method": "state", "state": env.state.model_dump()})
179
+ )
180
+ except WebSocketDisconnect:
181
+ pass