Sayed223 commited on
Commit
2f3df7f
Β·
verified Β·
1 Parent(s): c8ae1d4

Create app.py

Browse files
Files changed (1) hide show
  1. server/app.py +479 -0
server/app.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ server.py β€” FastAPI/OpenEnv server wrapper for CustomerSupportEnv.
3
+
4
+ Exposes the environment as REST endpoints compatible with OpenEnv specification.
5
+ Handles session management and action validation.
6
+
7
+ Endpoints:
8
+ POST /reset β†’ Initialize new episode, return initial observation
9
+ POST /step β†’ Apply action, return (obs, reward, done)
10
+ GET /state β†’ Get current environment state
11
+ GET /tasks β†’ List all tasks
12
+ POST /grade β†’ Grade current episode
13
+ GET /health β†’ Health check
14
+ GET /openenv.yaml β†’ Spec file
15
+ """
16
+ from __future__ import annotations
17
+
18
+ import json
19
+ import os
20
+ import sys
21
+ import traceback
22
+ from typing import Any, Dict, Optional
23
+ from pathlib import Path
24
+
25
+ # FastAPI imports
26
+ try:
27
+ from fastapi import FastAPI, HTTPException, Request, Body
28
+ from fastapi.responses import FileResponse, JSONResponse
29
+ from pydantic import BaseModel, ConfigDict
30
+ import uvicorn
31
+ except ImportError as e:
32
+ print(f"[ERROR] Missing FastAPI dependency: {e}", flush=True)
33
+ print("Run: pip install fastapi uvicorn pydantic", flush=True)
34
+ sys.exit(1)
35
+
36
+ # Local env imports
37
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
38
+ try:
39
+ from env.environment import CustomerSupportEnv, TASKS
40
+ from env.models import Action, ActionType, Observation, Reward
41
+ from graders.graders import grade
42
+ except ImportError as e:
43
+ print(f"[ERROR] Missing local env module: {e}", flush=True)
44
+ traceback.print_exc()
45
+ sys.exit(1)
46
+
47
+ # ── FastAPI App ──────────────────────────────────────────────────────────────
48
+ app = FastAPI(
49
+ title="CustomerSupportEnv",
50
+ description="OpenEnv-compatible customer support RL environment",
51
+ version="1.0.0"
52
+ )
53
+
54
+ # ── Session Storage (in-memory for single deployment) ───────────────────────
55
+ _sessions: Dict[str, Dict[str, Any]] = {}
56
+ _session_counter = 0
57
+
58
+
59
+ def new_session_id() -> str:
60
+ """Generate a unique session ID."""
61
+ global _session_counter
62
+ _session_counter += 1
63
+ return f"session_{_session_counter:06d}"
64
+
65
+
66
+ # ── Pydantic Models ──────────────────────────────────────────────────────────
67
+ class ResetRequest(BaseModel):
68
+ model_config = ConfigDict(extra="allow")
69
+ task_id: Optional[str] = None
70
+ seed: Optional[int] = None
71
+
72
+
73
+ class StepRequest(BaseModel):
74
+ session_id: str
75
+ action_type: str
76
+ payload: Optional[str] = None
77
+
78
+
79
+ class GradeRequest(BaseModel):
80
+ session_id: str
81
+
82
+
83
+ # ── Helper: Make JSON serializable ──────────────────────────────────────────
84
+ def to_json_serializable(obj: Any) -> Any:
85
+ """Convert any object to JSON-serializable format."""
86
+ if obj is None:
87
+ return None
88
+ elif isinstance(obj, (str, int, float, bool)):
89
+ return obj
90
+ elif isinstance(obj, dict):
91
+ return {k: to_json_serializable(v) for k, v in obj.items()}
92
+ elif isinstance(obj, (list, tuple)):
93
+ return [to_json_serializable(item) for item in obj]
94
+ elif hasattr(obj, 'dict') and callable(obj.dict):
95
+ # Pydantic model
96
+ return to_json_serializable(obj.dict())
97
+ elif hasattr(obj, '__dict__'):
98
+ # Regular object with attributes
99
+ return to_json_serializable(obj.__dict__)
100
+ else:
101
+ # Fallback to string representation
102
+ return str(obj)
103
+
104
+
105
+ def serialize_obs(obs: Observation) -> Dict[str, Any]:
106
+ """Convert Observation dataclass to JSON-serializable dict."""
107
+ # Convert all fields to JSON-serializable format
108
+ return {
109
+ "ticket_id": to_json_serializable(obs.ticket_id),
110
+ "task_id": to_json_serializable(obs.task_id),
111
+ "status": to_json_serializable(obs.status),
112
+ "sentiment": to_json_serializable(obs.sentiment),
113
+ "priority": to_json_serializable(obs.priority),
114
+ "category": to_json_serializable(obs.category),
115
+ "turn": to_json_serializable(obs.turn),
116
+ "max_turns": to_json_serializable(obs.max_turns),
117
+ "history": to_json_serializable(obs.history),
118
+ "kb_results": to_json_serializable(obs.kb_results),
119
+ "kb_searched": to_json_serializable(obs.kb_searched),
120
+ "empathized": to_json_serializable(obs.empathized),
121
+ "clarified": to_json_serializable(obs.clarified),
122
+ "solution_offered": to_json_serializable(obs.solution_offered),
123
+ "escalated": to_json_serializable(obs.escalated),
124
+ "cumulative_reward": to_json_serializable(obs.cumulative_reward),
125
+ "done": to_json_serializable(obs.done),
126
+ }
127
+
128
+
129
+ def serialize_reward(reward: Reward) -> Dict[str, Any]:
130
+ """Convert Reward dataclass to JSON-serializable dict."""
131
+ return {
132
+ "total": to_json_serializable(reward.total),
133
+ "breakdown": to_json_serializable(reward.breakdown),
134
+ "reason": to_json_serializable(reward.reason),
135
+ }
136
+
137
+
138
+ # ── OpenEnv Endpoints ────────────────────────────────────────────────────────
139
+
140
+ @app.post("/reset")
141
+ async def reset(request: Optional[Dict[str, Any]] = Body(default=None)) -> JSONResponse:
142
+ """
143
+ Reset environment and start a new episode.
144
+
145
+ Accepts both empty POST and JSON body with optional parameters.
146
+
147
+ Args:
148
+ task_id: One of task_1, task_2, task_3 (optional, defaults to task_1)
149
+ seed: Optional random seed (defaults to 42)
150
+
151
+ Returns:
152
+ {
153
+ "session_id": str,
154
+ "observation": {...},
155
+ "info": {...}
156
+ }
157
+ """
158
+ try:
159
+ # Default values
160
+ task_id = "task_1"
161
+ seed = 42
162
+
163
+ # Override with request values if provided
164
+ if request is not None and isinstance(request, dict):
165
+ if "task_id" in request and request["task_id"]:
166
+ task_id = request["task_id"]
167
+ if "seed" in request and request["seed"] is not None:
168
+ seed = request["seed"]
169
+
170
+ print(f"[RESET] task_id={task_id}, seed={seed}", flush=True)
171
+
172
+ # Validate task_id
173
+ if task_id not in TASKS:
174
+ raise ValueError(f"Invalid task_id '{task_id}'. Must be one of: {list(TASKS.keys())}")
175
+
176
+ # Create and reset environment
177
+ env = CustomerSupportEnv(task_id=task_id, seed=seed)
178
+ obs = env.reset()
179
+
180
+ # Store session
181
+ session_id = new_session_id()
182
+ _sessions[session_id] = {
183
+ "env": env,
184
+ "task_id": task_id,
185
+ "observation": obs,
186
+ "steps": 0,
187
+ "done": False,
188
+ }
189
+
190
+ print(f"[RESET] Created session {session_id}", flush=True)
191
+
192
+ # Serialize observation to ensure JSON compatibility
193
+ obs_json = serialize_obs(obs)
194
+
195
+ return JSONResponse(
196
+ status_code=200,
197
+ content={
198
+ "session_id": session_id,
199
+ "observation": obs_json,
200
+ "info": {
201
+ "task_id": task_id,
202
+ "difficulty": TASKS[task_id].difficulty,
203
+ "description": TASKS[task_id].description,
204
+ }
205
+ }
206
+ )
207
+
208
+ except ValueError as e:
209
+ print(f"[RESET ERROR] Validation error: {e}", flush=True)
210
+ raise HTTPException(status_code=400, detail=str(e))
211
+ except Exception as e:
212
+ print(f"[RESET ERROR] {type(e).__name__}: {e}", flush=True)
213
+ traceback.print_exc()
214
+ raise HTTPException(status_code=500, detail=f"Reset failed: {str(e)}")
215
+
216
+
217
+ @app.post("/step")
218
+ async def step(request: StepRequest) -> JSONResponse:
219
+ """
220
+ Apply an action and step the environment.
221
+
222
+ Args:
223
+ session_id: Session ID from /reset
224
+ action_type: One of [search_kb, empathize, ask_clarify, offer_solution, escalate, resolve, send_message]
225
+ payload: Optional action payload (required for some action types)
226
+
227
+ Returns:
228
+ {
229
+ "observation": {...},
230
+ "reward": {...},
231
+ "done": bool,
232
+ "info": {...}
233
+ }
234
+ """
235
+ try:
236
+ session_id = request.session_id
237
+ action_type = request.action_type
238
+ payload = request.payload
239
+
240
+ if session_id not in _sessions:
241
+ raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
242
+
243
+ session = _sessions[session_id]
244
+ env = session["env"]
245
+
246
+ if session["done"]:
247
+ raise HTTPException(status_code=400, detail="Episode already done. Call /reset to start new episode.")
248
+
249
+ # Create action
250
+ action = Action(action_type=action_type, payload=payload)
251
+
252
+ # Step environment
253
+ result = env.step(action)
254
+
255
+ # Update session
256
+ session["observation"] = result.observation
257
+ session["steps"] += 1
258
+ session["done"] = result.observation.done
259
+
260
+ # Serialize for JSON compatibility
261
+ obs_json = serialize_obs(result.observation)
262
+ reward_json = serialize_reward(result.reward)
263
+
264
+ return JSONResponse(
265
+ status_code=200,
266
+ content={
267
+ "observation": obs_json,
268
+ "reward": reward_json,
269
+ "done": result.observation.done,
270
+ "info": {
271
+ "step": session["steps"],
272
+ "action": action_type,
273
+ }
274
+ }
275
+ )
276
+
277
+ except HTTPException:
278
+ raise
279
+ except Exception as e:
280
+ traceback.print_exc()
281
+ raise HTTPException(status_code=500, detail=f"Step failed: {str(e)}")
282
+
283
+
284
+ @app.get("/state")
285
+ async def state_endpoint(session_id: str) -> JSONResponse:
286
+ """
287
+ Get current environment state without stepping.
288
+
289
+ Args:
290
+ session_id: Session ID from /reset
291
+
292
+ Returns:
293
+ {
294
+ "observation": {...},
295
+ "info": {...}
296
+ }
297
+ """
298
+ try:
299
+ if session_id not in _sessions:
300
+ raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
301
+
302
+ session = _sessions[session_id]
303
+ obs = session["observation"]
304
+
305
+ obs_json = serialize_obs(obs)
306
+
307
+ return JSONResponse(
308
+ status_code=200,
309
+ content={
310
+ "observation": obs_json,
311
+ "info": {
312
+ "task_id": session["task_id"],
313
+ "steps": session["steps"],
314
+ "done": session["done"],
315
+ }
316
+ }
317
+ )
318
+
319
+ except HTTPException:
320
+ raise
321
+ except Exception as e:
322
+ traceback.print_exc()
323
+ raise HTTPException(status_code=500, detail=f"State query failed: {str(e)}")
324
+
325
+
326
+ @app.get("/tasks")
327
+ async def tasks_endpoint() -> JSONResponse:
328
+ """
329
+ List all available tasks.
330
+
331
+ Returns:
332
+ {
333
+ "tasks": [
334
+ {
335
+ "id": "task_1",
336
+ "name": "...",
337
+ "difficulty": "easy|medium|hard",
338
+ "description": "...",
339
+ "max_turns": int
340
+ },
341
+ ...
342
+ ]
343
+ }
344
+ """
345
+ try:
346
+ task_list = []
347
+ for task_id, task_obj in TASKS.items():
348
+ task_list.append({
349
+ "id": task_id,
350
+ "name": task_obj.name,
351
+ "difficulty": task_obj.difficulty,
352
+ "description": task_obj.description,
353
+ "max_turns": task_obj.max_turns,
354
+ })
355
+
356
+ return JSONResponse(
357
+ status_code=200,
358
+ content={"tasks": task_list}
359
+ )
360
+
361
+ except Exception as e:
362
+ traceback.print_exc()
363
+ raise HTTPException(status_code=500, detail=f"Tasks query failed: {str(e)}")
364
+
365
+
366
+ @app.post("/grade")
367
+ async def grade_endpoint(request: GradeRequest) -> JSONResponse:
368
+ """
369
+ Grade the current episode.
370
+
371
+ Args:
372
+ session_id: Session ID from /reset
373
+
374
+ Returns:
375
+ {
376
+ "score": float (0.0 to 1.0),
377
+ "passed": bool,
378
+ "breakdown": {...},
379
+ "reason": str
380
+ }
381
+ """
382
+ try:
383
+ session_id = request.session_id
384
+
385
+ if session_id not in _sessions:
386
+ raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
387
+
388
+ session = _sessions[session_id]
389
+ env = session["env"]
390
+ task_id = session["task_id"]
391
+
392
+ # Get final state
393
+ final_obs = env.state()
394
+
395
+ # Grade
396
+ grader_result = grade(task_id, final_obs)
397
+
398
+ return JSONResponse(
399
+ status_code=200,
400
+ content={
401
+ "score": grader_result.score,
402
+ "passed": grader_result.passed,
403
+ "breakdown": to_json_serializable(grader_result.breakdown),
404
+ "reason": grader_result.reason,
405
+ }
406
+ )
407
+
408
+ except HTTPException:
409
+ raise
410
+ except Exception as e:
411
+ traceback.print_exc()
412
+ raise HTTPException(status_code=500, detail=f"Grading failed: {str(e)}")
413
+
414
+
415
+ @app.get("/health")
416
+ async def health() -> JSONResponse:
417
+ """Health check endpoint."""
418
+ return JSONResponse(
419
+ status_code=200,
420
+ content={
421
+ "status": "healthy",
422
+ "service": "CustomerSupportEnv",
423
+ "version": "1.0.0",
424
+ "sessions_active": len(_sessions),
425
+ }
426
+ )
427
+
428
+
429
+ @app.get("/openenv.yaml")
430
+ async def openenv_spec() -> FileResponse:
431
+ """Serve OpenEnv specification."""
432
+ spec_path = Path(__file__).parent / "openenv.yaml"
433
+ if not spec_path.exists():
434
+ raise HTTPException(status_code=404, detail="openenv.yaml not found")
435
+ return FileResponse(spec_path, media_type="text/yaml")
436
+
437
+
438
+ # ── Root endpoint ────────────────────────────────────────────────────────────
439
+ @app.get("/")
440
+ async def root() -> JSONResponse:
441
+ """Root endpoint."""
442
+ return JSONResponse(
443
+ status_code=200,
444
+ content={
445
+ "service": "CustomerSupportEnv OpenEnv Server",
446
+ "version": "1.0.0",
447
+ "endpoints": {
448
+ "POST /reset": "Initialize new episode",
449
+ "POST /step": "Apply action",
450
+ "GET /state": "Get current state",
451
+ "GET /tasks": "List tasks",
452
+ "POST /grade": "Grade episode",
453
+ "GET /health": "Health check",
454
+ "GET /openenv.yaml": "Specification",
455
+ }
456
+ }
457
+ )
458
+
459
+
460
+ # ── Startup/Shutdown ─────────────────────────────────────────────────────────
461
+ @app.on_event("startup")
462
+ async def startup_event():
463
+ """Log startup."""
464
+ print("[INFO] CustomerSupportEnv server started", flush=True)
465
+
466
+
467
+ @app.on_event("shutdown")
468
+ async def shutdown_event():
469
+ """Log shutdown."""
470
+ print("[INFO] CustomerSupportEnv server shutdown", flush=True)
471
+
472
+
473
+ # ── Main ─────────────────────────────────────────────────────────────────────
474
+ if __name__ == "__main__":
475
+ port = int(os.environ.get("PORT", 7860))
476
+ host = os.environ.get("HOST", "0.0.0.0")
477
+
478
+ print(f"[INFO] Starting server on {host}:{port}", flush=True)
479
+ uvicorn.run(app, host=host, port=port, log_level="info")