Vikaspandey582003 commited on
Commit
09024bf
Β·
verified Β·
1 Parent(s): 26ba066

revert: restore server/app.py to last working state

Browse files
Files changed (1) hide show
  1. server/app.py +120 -92
server/app.py CHANGED
@@ -1,82 +1,105 @@
1
  """
2
- ECHO ULTIMATE β€” OpenEnv-Compliant FastAPI Server.
3
-
4
- Built with openenv.core.create_fastapi_app so the environment is exposed through
5
- the standard OpenEnv HTTP protocol:
6
-
7
- POST /reset β†’ EchoObservation (OpenEnv standard)
8
- POST /step β†’ EchoObservation (OpenEnv standard)
9
- GET /state β†’ EchoState (OpenEnv standard)
10
- GET /health β†’ health status
11
- GET /schema/action β†’ JSON schema
12
- GET /schema/observation β†’ JSON schema
13
-
14
- Additional ECHO-specific endpoints:
15
- GET /tasks β†’ task definitions
16
- GET /metrics β†’ CalibrationReport (ECE, Brier, MCE …)
17
- GET /metrics/{domain}
18
- GET /fingerprint
19
- GET /history
20
- POST /advance_phase
21
- GET /ui β†’ Gradio demo (mounted)
22
-
23
  Runs on port 7860 (HuggingFace Space public port).
24
  """
25
 
26
  import logging
27
  import os
 
28
  import sys
29
 
30
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
31
 
32
  from contextlib import asynccontextmanager
33
- from typing import Optional
34
 
35
  from fastapi import FastAPI, HTTPException
36
  from fastapi.middleware.cors import CORSMiddleware
37
- from fastapi.responses import RedirectResponse
 
38
 
39
  from config import cfg
40
  from core.tasks import TASKS
41
- from env.openenv_env import EchoOpenEnv
42
  from env.reward import RewardHistory
43
  from env.task_bank import TaskBank
44
- from models import EchoAction, EchoObservation
45
 
46
  logger = logging.getLogger(__name__)
47
 
48
- # ── Singleton environment (stateful, shared across all HTTP requests) ─────────
49
 
50
  _task_bank: Optional[TaskBank] = None
 
51
  _history: Optional[RewardHistory] = None
52
- _env: Optional[EchoOpenEnv] = None
53
 
54
 
55
- def _get_env() -> EchoOpenEnv:
56
  if _env is None:
57
- raise RuntimeError("Environment not initialised β€” server startup incomplete.")
58
  return _env
59
 
60
 
61
- def _env_factory() -> EchoOpenEnv:
62
- """
63
- Singleton factory required by create_fastapi_app.
64
- Returns the shared instance so state persists across reset/step calls.
65
- gym.Env.close() is a no-op, so the OpenEnv server's cleanup call is harmless.
66
- """
67
- return _get_env()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- # ── Create FastAPI app ────────────────────────────────────────────────────────
 
71
 
72
  app = FastAPI(
73
  title="ECHO ULTIMATE β€” Epistemic Calibration RL Environment",
74
  description=(
75
  "OpenEnv-compliant training environment for LLM metacognitive calibration. "
76
- "7 domains Β· 3 curriculum phases Β· 5 calibration metrics Β· Epistemic fingerprint. "
77
- "Trains models to know what they don't know via GRPO + Brier-score rewards."
78
  ),
79
  version="2.0.0",
 
80
  )
81
 
82
  app.add_middleware(
@@ -88,95 +111,100 @@ app.add_middleware(
88
  )
89
 
90
 
91
- # ── Startup: initialise singleton env ─────────────────────────────────────────
92
 
93
- @app.on_event("startup")
94
- async def _startup():
95
- global _task_bank, _history, _env
96
- logger.info("ECHO ULTIMATE server starting…")
97
- _task_bank = TaskBank()
98
- _task_bank.ensure_loaded()
99
- _history = RewardHistory()
100
- _env = EchoOpenEnv(task_bank=_task_bank, reward_history=_history, phase=3)
101
- _env._gym_reset()
102
- logger.info("ECHO ULTIMATE ready βœ… (7 domains, 3 tasks)")
103
- print("βœ… ECHO ULTIMATE server ready β€” http://0.0.0.0:7860/docs")
104
 
105
 
106
- # ── OpenEnv standard endpoints ────────────────────────────────────────────────
 
 
 
 
107
 
108
- @app.get("/health", tags=["Health"])
109
- async def health():
110
- return {"status": "ok", "environment": "ECHO-ULTIMATE", "version": "2.0.0", "domains": 7, "tasks": 3}
 
 
 
111
 
112
 
113
  @app.post("/reset", tags=["Environment"])
114
- async def reset(body: dict = {}):
115
  env = _get_env()
116
- task_id = body.get("task_id") if body else None
117
- obs_dict, info = env._gym_reset(options={"task_id": task_id} if task_id else None)
118
- task = env._current_task or {}
119
- return {**obs_dict, "question": task.get("question", obs_dict.get("question", "")), "info": info}
 
 
 
120
 
121
 
122
- @app.post("/step", tags=["Environment"])
123
- async def step(body: dict):
124
  env = _get_env()
125
- action = body.get("action") or body.get("response", "")
126
- obs_dict, reward, terminated, truncated, info = env._gym_step(action)
127
- return {"reward": round(reward, 4), "terminated": terminated or truncated, "info": info, **obs_dict}
128
 
129
 
130
- @app.get("/state", tags=["Environment"])
131
- async def get_state():
132
  env = _get_env()
133
- task = env._current_task or {}
134
- return {
135
- "current_question": task.get("question", ""),
136
- "domain": task.get("domain", ""),
137
- "difficulty": task.get("difficulty", ""),
138
- "phase": env.phase,
139
- }
140
-
141
-
142
- # ── ECHO-specific extra endpoints ─────────────────────────────────────────────
143
-
144
- @app.get("/", tags=["Health"])
145
- async def root():
146
- return RedirectResponse(url="/ui")
 
147
 
148
 
149
- @app.get("/tasks", tags=["Tasks"])
150
- async def list_tasks():
151
- return _get_env().list_tasks()
152
 
153
 
154
  @app.get("/metrics", tags=["Metrics"])
155
  async def get_metrics():
156
- return _get_env().get_metrics().to_dict()
 
157
 
158
 
159
  @app.get("/metrics/{domain}", tags=["Metrics"])
160
  async def get_domain_metrics(domain: str):
161
  if domain not in cfg.DOMAINS:
162
  raise HTTPException(404, f"Unknown domain '{domain}'. Valid: {cfg.DOMAINS}")
163
- return _get_env().get_metrics(domain=domain).to_dict()
 
164
 
165
 
166
  @app.get("/fingerprint", tags=["Metrics"])
167
- async def get_fingerprint():
168
  env = _get_env()
169
  profiles = env.reward_history.get_domain_profiles()
170
  return {
171
- "domain_scores": {d: round(1.0 - r.ece, 3) for d, r in profiles.items()},
172
- "domain_ece": {d: round(r.ece, 3) for d, r in profiles.items()},
173
  "domain_accuracy": {d: round(r.accuracy, 3) for d, r in profiles.items()},
174
- "overall_ece": round(env.get_metrics().ece, 3),
175
  }
176
 
177
 
178
  @app.get("/history", tags=["Metrics"])
179
- async def get_history():
180
  env = _get_env()
181
  df = env.reward_history.to_dataframe()
182
  records = df.tail(100).to_dict(orient="records") if len(df) > 0 else []
@@ -217,7 +245,7 @@ except Exception as _e:
217
  print(f"⚠️ Gradio UI not mounted: {_e}")
218
 
219
 
220
- # ── Direct runner ─────────────────────────────────────────────────────────────
221
 
222
  if __name__ == "__main__":
223
  import uvicorn
 
1
  """
2
+ ECHO ULTIMATE β€” FastAPI OpenEnv-Compliant Server.
3
+ Pure FastAPI: no openenv package dependency.
4
+ Mounts Gradio UI at /ui.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  Runs on port 7860 (HuggingFace Space public port).
6
  """
7
 
8
  import logging
9
  import os
10
+ import random
11
  import sys
12
 
13
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
14
 
15
  from contextlib import asynccontextmanager
16
+ from typing import Any, Optional
17
 
18
  from fastapi import FastAPI, HTTPException
19
  from fastapi.middleware.cors import CORSMiddleware
20
+ from fastapi.responses import JSONResponse
21
+ from pydantic import BaseModel, Field
22
 
23
  from config import cfg
24
  from core.tasks import TASKS
25
+ from env.echo_env import EchoEnv
26
  from env.reward import RewardHistory
27
  from env.task_bank import TaskBank
 
28
 
29
  logger = logging.getLogger(__name__)
30
 
31
+ # ── App state ─────────────────────────────────────────────────────────────────
32
 
33
  _task_bank: Optional[TaskBank] = None
34
+ _env: Optional[EchoEnv] = None
35
  _history: Optional[RewardHistory] = None
 
36
 
37
 
38
+ def _get_env() -> EchoEnv:
39
  if _env is None:
40
+ raise HTTPException(400, "No active episode. POST /reset first.")
41
  return _env
42
 
43
 
44
+ # ── Pydantic schemas ──────────────────────────────────────────────────────────
45
+
46
+ class ResetRequest(BaseModel):
47
+ task_id: Optional[str] = Field(None, description="Specific task ID to load")
48
+ adversarial: Optional[bool] = Field(False, description="Use adversarial questions")
49
+
50
+
51
+ class StepRequest(BaseModel):
52
+ action: Optional[str] = Field(None, description="Legacy: action string")
53
+ response: Optional[str] = Field(None, description="Agent response with confidence and answer tags")
54
+
55
+ def get_response(self) -> str:
56
+ """Accept either 'response' or 'action' field."""
57
+ return self.response or self.action or ""
58
+
59
+
60
+ class TaskInfo(BaseModel):
61
+ id: str
62
+ name: str
63
+ description: str
64
+ pass_threshold: float
65
+ n_episodes: int
66
+
67
+
68
+ class StepResponse(BaseModel):
69
+ state: dict
70
+ reward: float
71
+ terminated: bool
72
+ truncated: bool
73
+ info: dict
74
+
75
+
76
+ # ── Lifespan ──────────────────────────────────────────────────────────────────
77
 
78
+ @asynccontextmanager
79
+ async def lifespan(app: FastAPI):
80
+ global _task_bank, _env, _history
81
+ logger.info("ECHO ULTIMATE server starting…")
82
+ _task_bank = TaskBank()
83
+ _task_bank.ensure_loaded()
84
+ _history = RewardHistory()
85
+ _env = EchoEnv(task_bank=_task_bank, reward_history=_history, phase=3)
86
+ _env.reset()
87
+ logger.info("ECHO ULTIMATE ready βœ… (7 domains, 3 tasks)")
88
+ print("βœ… ECHO ULTIMATE server ready β€” http://0.0.0.0:7860/docs")
89
+ yield
90
+ logger.info("ECHO ULTIMATE server shutting down.")
91
 
92
+
93
+ # ── App ───────────────────────────────────────────────────────────────────────
94
 
95
  app = FastAPI(
96
  title="ECHO ULTIMATE β€” Epistemic Calibration RL Environment",
97
  description=(
98
  "OpenEnv-compliant training environment for LLM metacognitive calibration. "
99
+ "7 domains Β· 3 curriculum phases Β· 5 calibration metrics Β· Epistemic fingerprint."
 
100
  ),
101
  version="2.0.0",
102
+ lifespan=lifespan,
103
  )
104
 
105
  app.add_middleware(
 
111
  )
112
 
113
 
114
+ # ── Endpoints ─────────────────────────────────────────────────────────────────
115
 
116
+ @app.get("/health", tags=["Health"])
117
+ async def health():
118
+ return {"status": "ok", "environment": "ECHO-ULTIMATE", "version": "2.0.0",
119
+ "domains": 7, "tasks": 3}
 
 
 
 
 
 
 
120
 
121
 
122
+ @app.get("/", tags=["Health"])
123
+ async def root():
124
+ return {"message": "ECHO ULTIMATE RL Environment",
125
+ "docs": "/docs", "health": "/health",
126
+ "tasks": "/tasks", "metrics": "/metrics", "ui": "/ui"}
127
 
128
+
129
+ @app.get("/tasks", response_model=list[TaskInfo], tags=["Tasks"])
130
+ async def list_tasks():
131
+ return [TaskInfo(id=t.id, name=t.name, description=t.description,
132
+ pass_threshold=t.pass_threshold, n_episodes=t.n_episodes)
133
+ for t in TASKS]
134
 
135
 
136
  @app.post("/reset", tags=["Environment"])
137
+ async def reset(req: ResetRequest = ResetRequest()) -> dict:
138
  env = _get_env()
139
+ opts = {}
140
+ if req.task_id:
141
+ opts["task_id"] = req.task_id
142
+ if req.adversarial:
143
+ opts["adversarial"] = True
144
+ state, info = env.reset(options=opts if opts else None)
145
+ return state
146
 
147
 
148
+ @app.post("/reset/{task_id}", tags=["Environment"])
149
+ async def reset_task(task_id: str) -> dict:
150
  env = _get_env()
151
+ state, _ = env.reset(options={"task_id": task_id})
152
+ return state
 
153
 
154
 
155
+ @app.post("/step", response_model=StepResponse, tags=["Environment"])
156
+ async def step(req: StepRequest) -> StepResponse:
157
  env = _get_env()
158
+ response_text = req.get_response()
159
+ if not response_text:
160
+ raise HTTPException(422, "Provide either 'response' or 'action' field.")
161
+ try:
162
+ state, reward, terminated, truncated, info = env.step(response_text)
163
+ except Exception as exc:
164
+ logger.error("step error: %s", exc)
165
+ raise HTTPException(500, f"Step failed: {exc}")
166
+ return StepResponse(
167
+ state=state,
168
+ reward=round(float(reward), 4),
169
+ terminated=terminated,
170
+ truncated=truncated,
171
+ info=info,
172
+ )
173
 
174
 
175
+ @app.get("/state", tags=["Environment"])
176
+ async def get_state() -> dict:
177
+ return _get_env()._build_obs()
178
 
179
 
180
  @app.get("/metrics", tags=["Metrics"])
181
  async def get_metrics():
182
+ rep = _get_env().get_metrics()
183
+ return rep.to_dict()
184
 
185
 
186
  @app.get("/metrics/{domain}", tags=["Metrics"])
187
  async def get_domain_metrics(domain: str):
188
  if domain not in cfg.DOMAINS:
189
  raise HTTPException(404, f"Unknown domain '{domain}'. Valid: {cfg.DOMAINS}")
190
+ rep = _get_env().get_metrics(domain=domain)
191
+ return rep.to_dict()
192
 
193
 
194
  @app.get("/fingerprint", tags=["Metrics"])
195
+ async def get_fingerprint() -> dict:
196
  env = _get_env()
197
  profiles = env.reward_history.get_domain_profiles()
198
  return {
199
+ "domain_scores": {d: round(1.0 - r.ece, 3) for d, r in profiles.items()},
200
+ "domain_ece": {d: round(r.ece, 3) for d, r in profiles.items()},
201
  "domain_accuracy": {d: round(r.accuracy, 3) for d, r in profiles.items()},
202
+ "overall_ece": round(env.get_metrics().ece, 3),
203
  }
204
 
205
 
206
  @app.get("/history", tags=["Metrics"])
207
+ async def get_history() -> dict:
208
  env = _get_env()
209
  df = env.reward_history.to_dataframe()
210
  records = df.tail(100).to_dict(orient="records") if len(df) > 0 else []
 
245
  print(f"⚠️ Gradio UI not mounted: {_e}")
246
 
247
 
248
+ # ── Direct runner ──────────────────────────────────────────────────────────────
249
 
250
  if __name__ == "__main__":
251
  import uvicorn