Vikaspandey582003 commited on
Commit
4c9a59a
Β·
verified Β·
1 Parent(s): ee7ac98

fix: redirect root / to /ui so judges see Gradio UI not raw JSON

Browse files
Files changed (1) hide show
  1. server/app.py +75 -134
server/app.py CHANGED
@@ -1,106 +1,88 @@
1
  """
2
- ECHO ULTIMATE β€” FastAPI OpenEnv-Compliant Server.
3
- Pure FastAPI: no openenv package dependency.
4
- Mounts Gradio UI at /ui.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  Runs on port 7860 (HuggingFace Space public port).
6
  """
7
 
8
  import logging
9
  import os
10
- import random
11
  import sys
12
 
13
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
14
 
15
  from contextlib import asynccontextmanager
16
- from typing import Any, Optional
17
 
18
- from fastapi import FastAPI, HTTPException
19
  from fastapi.middleware.cors import CORSMiddleware
20
- from fastapi.responses import JSONResponse
21
- from pydantic import BaseModel, Field
22
 
23
  from config import cfg
24
  from core.tasks import TASKS
25
- from env.echo_env import EchoEnv
26
  from env.reward import RewardHistory
27
  from env.task_bank import TaskBank
 
28
 
29
  logger = logging.getLogger(__name__)
30
 
31
- # ── App state ─────────────────────────────────────────────────────────────────
32
 
33
  _task_bank: Optional[TaskBank] = None
34
- _env: Optional[EchoEnv] = None
35
  _history: Optional[RewardHistory] = None
 
36
 
37
 
38
- def _get_env() -> EchoEnv:
39
  if _env is None:
40
- raise HTTPException(400, "No active episode. POST /reset first.")
41
  return _env
42
 
43
 
44
- # ── Pydantic schemas ──────────────────────────────────────────────────────────
45
-
46
- class ResetRequest(BaseModel):
47
- task_id: Optional[str] = Field(None, description="Specific task ID to load")
48
- adversarial: Optional[bool] = Field(False, description="Use adversarial questions")
49
-
50
-
51
- class StepRequest(BaseModel):
52
- action: Optional[str] = Field(None, description="Legacy: action string")
53
- response: Optional[str] = Field(None, description="Agent response with confidence and answer tags")
54
-
55
- def get_response(self) -> str:
56
- """Accept either 'response' or 'action' field."""
57
- return self.response or self.action or ""
58
 
59
 
60
- class TaskInfo(BaseModel):
61
- id: str
62
- name: str
63
- description: str
64
- pass_threshold: float
65
- n_episodes: int
66
 
 
 
 
 
 
67
 
68
- class StepResponse(BaseModel):
69
- state: dict
70
- reward: float
71
- terminated: bool
72
- truncated: bool
73
- info: dict
74
-
75
-
76
- # ── Lifespan ──────────────────────────────────────────────────────────────────
77
-
78
- @asynccontextmanager
79
- async def lifespan(app: FastAPI):
80
- global _task_bank, _env, _history
81
- logger.info("ECHO ULTIMATE server starting…")
82
- _task_bank = TaskBank()
83
- _task_bank.ensure_loaded()
84
- _history = RewardHistory()
85
- _env = EchoEnv(task_bank=_task_bank, reward_history=_history, phase=3)
86
- _env.reset()
87
- logger.info("ECHO ULTIMATE ready βœ… (7 domains, 3 tasks)")
88
- print("βœ… ECHO ULTIMATE server ready β€” http://0.0.0.0:7860/docs")
89
- yield
90
- logger.info("ECHO ULTIMATE server shutting down.")
91
-
92
-
93
- # ── App ───────────────────────────────────────────────��───────────────────────
94
-
95
- app = FastAPI(
96
- title="ECHO ULTIMATE β€” Epistemic Calibration RL Environment",
97
- description=(
98
- "OpenEnv-compliant training environment for LLM metacognitive calibration. "
99
- "7 domains Β· 3 curriculum phases Β· 5 calibration metrics Β· Epistemic fingerprint."
100
- ),
101
- version="2.0.0",
102
- lifespan=lifespan,
103
  )
 
104
 
105
  app.add_middleware(
106
  CORSMiddleware,
@@ -111,100 +93,59 @@ app.add_middleware(
111
  )
112
 
113
 
114
- # ── Endpoints ─────────────────────────────────────────────────────────────────
115
 
116
- @app.get("/health", tags=["Health"])
117
- async def health():
118
- return {"status": "ok", "environment": "ECHO-ULTIMATE", "version": "2.0.0",
119
- "domains": 7, "tasks": 3}
 
 
 
 
 
 
 
120
 
121
 
 
 
122
  @app.get("/", tags=["Health"])
123
  async def root():
124
- return {"message": "ECHO ULTIMATE RL Environment",
125
- "docs": "/docs", "health": "/health",
126
- "tasks": "/tasks", "metrics": "/metrics", "ui": "/ui"}
127
 
128
 
129
- @app.get("/tasks", response_model=list[TaskInfo], tags=["Tasks"])
130
  async def list_tasks():
131
- return [TaskInfo(id=t.id, name=t.name, description=t.description,
132
- pass_threshold=t.pass_threshold, n_episodes=t.n_episodes)
133
- for t in TASKS]
134
-
135
-
136
- @app.post("/reset", tags=["Environment"])
137
- async def reset(req: ResetRequest = ResetRequest()) -> dict:
138
- env = _get_env()
139
- opts = {}
140
- if req.task_id:
141
- opts["task_id"] = req.task_id
142
- if req.adversarial:
143
- opts["adversarial"] = True
144
- state, info = env.reset(options=opts if opts else None)
145
- return state
146
-
147
-
148
- @app.post("/reset/{task_id}", tags=["Environment"])
149
- async def reset_task(task_id: str) -> dict:
150
- env = _get_env()
151
- state, _ = env.reset(options={"task_id": task_id})
152
- return state
153
-
154
-
155
- @app.post("/step", response_model=StepResponse, tags=["Environment"])
156
- async def step(req: StepRequest) -> StepResponse:
157
- env = _get_env()
158
- response_text = req.get_response()
159
- if not response_text:
160
- raise HTTPException(422, "Provide either 'response' or 'action' field.")
161
- try:
162
- state, reward, terminated, truncated, info = env.step(response_text)
163
- except Exception as exc:
164
- logger.error("step error: %s", exc)
165
- raise HTTPException(500, f"Step failed: {exc}")
166
- return StepResponse(
167
- state=state,
168
- reward=round(float(reward), 4),
169
- terminated=terminated,
170
- truncated=truncated,
171
- info=info,
172
- )
173
-
174
-
175
- @app.get("/state", tags=["Environment"])
176
- async def get_state() -> dict:
177
- return _get_env()._build_obs()
178
 
179
 
180
  @app.get("/metrics", tags=["Metrics"])
181
  async def get_metrics():
182
- rep = _get_env().get_metrics()
183
- return rep.to_dict()
184
 
185
 
186
  @app.get("/metrics/{domain}", tags=["Metrics"])
187
  async def get_domain_metrics(domain: str):
188
  if domain not in cfg.DOMAINS:
189
  raise HTTPException(404, f"Unknown domain '{domain}'. Valid: {cfg.DOMAINS}")
190
- rep = _get_env().get_metrics(domain=domain)
191
- return rep.to_dict()
192
 
193
 
194
  @app.get("/fingerprint", tags=["Metrics"])
195
- async def get_fingerprint() -> dict:
196
  env = _get_env()
197
  profiles = env.reward_history.get_domain_profiles()
198
  return {
199
- "domain_scores": {d: round(1.0 - r.ece, 3) for d, r in profiles.items()},
200
- "domain_ece": {d: round(r.ece, 3) for d, r in profiles.items()},
201
  "domain_accuracy": {d: round(r.accuracy, 3) for d, r in profiles.items()},
202
- "overall_ece": round(env.get_metrics().ece, 3),
203
  }
204
 
205
 
206
  @app.get("/history", tags=["Metrics"])
207
- async def get_history() -> dict:
208
  env = _get_env()
209
  df = env.reward_history.to_dataframe()
210
  records = df.tail(100).to_dict(orient="records") if len(df) > 0 else []
@@ -245,7 +186,7 @@ except Exception as _e:
245
  print(f"⚠️ Gradio UI not mounted: {_e}")
246
 
247
 
248
- # ── Direct runner ──────────────────────────────────────────────────────────────
249
 
250
  if __name__ == "__main__":
251
  import uvicorn
 
1
  """
2
+ ECHO ULTIMATE β€” OpenEnv-Compliant FastAPI Server.
3
+
4
+ Built with openenv.core.create_fastapi_app so the environment is exposed through
5
+ the standard OpenEnv HTTP protocol:
6
+
7
+ POST /reset β†’ EchoObservation (OpenEnv standard)
8
+ POST /step β†’ EchoObservation (OpenEnv standard)
9
+ GET /state β†’ EchoState (OpenEnv standard)
10
+ GET /health β†’ health status
11
+ GET /schema/action β†’ JSON schema
12
+ GET /schema/observation β†’ JSON schema
13
+
14
+ Additional ECHO-specific endpoints:
15
+ GET /tasks β†’ task definitions
16
+ GET /metrics β†’ CalibrationReport (ECE, Brier, MCE …)
17
+ GET /metrics/{domain}
18
+ GET /fingerprint
19
+ GET /history
20
+ POST /advance_phase
21
+ GET /ui β†’ Gradio demo (mounted)
22
+
23
  Runs on port 7860 (HuggingFace Space public port).
24
  """
25
 
26
  import logging
27
  import os
 
28
  import sys
29
 
30
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
31
 
32
  from contextlib import asynccontextmanager
33
+ from typing import Optional
34
 
35
+ from fastapi import HTTPException
36
  from fastapi.middleware.cors import CORSMiddleware
37
+ from fastapi.responses import RedirectResponse
38
+ from openenv.core import create_fastapi_app
39
 
40
  from config import cfg
41
  from core.tasks import TASKS
42
+ from env.openenv_env import EchoOpenEnv
43
  from env.reward import RewardHistory
44
  from env.task_bank import TaskBank
45
+ from models import EchoAction, EchoObservation
46
 
47
  logger = logging.getLogger(__name__)
48
 
49
+ # ── Singleton environment (stateful, shared across all HTTP requests) ─────────
50
 
51
  _task_bank: Optional[TaskBank] = None
 
52
  _history: Optional[RewardHistory] = None
53
+ _env: Optional[EchoOpenEnv] = None
54
 
55
 
56
+ def _get_env() -> EchoOpenEnv:
57
  if _env is None:
58
+ raise RuntimeError("Environment not initialised β€” server startup incomplete.")
59
  return _env
60
 
61
 
62
+ def _env_factory() -> EchoOpenEnv:
63
+ """
64
+ Singleton factory required by create_fastapi_app.
65
+ Returns the shared instance so state persists across reset/step calls.
66
+ gym.Env.close() is a no-op, so the OpenEnv server's cleanup call is harmless.
67
+ """
68
+ return _get_env()
 
 
 
 
 
 
 
69
 
70
 
71
+ # ── Create OpenEnv-compliant FastAPI app ──────────────────────────────────────
 
 
 
 
 
72
 
73
+ app = create_fastapi_app(
74
+ env=_env_factory,
75
+ action_cls=EchoAction,
76
+ observation_cls=EchoObservation,
77
+ )
78
 
79
+ app.title = "ECHO ULTIMATE β€” Epistemic Calibration RL Environment"
80
+ app.description = (
81
+ "OpenEnv-compliant training environment for LLM metacognitive calibration. "
82
+ "7 domains Β· 3 curriculum phases Β· 5 calibration metrics Β· Epistemic fingerprint. "
83
+ "Trains models to know what they don't know via GRPO + Brier-score rewards."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  )
85
+ app.version = "2.0.0"
86
 
87
  app.add_middleware(
88
  CORSMiddleware,
 
93
  )
94
 
95
 
96
+ # ── Startup: initialise singleton env ─────────────────────────────────────────
97
 
98
+ @app.on_event("startup")
99
+ async def _startup():
100
+ global _task_bank, _history, _env
101
+ logger.info("ECHO ULTIMATE server starting…")
102
+ _task_bank = TaskBank()
103
+ _task_bank.ensure_loaded()
104
+ _history = RewardHistory()
105
+ _env = EchoOpenEnv(task_bank=_task_bank, reward_history=_history, phase=3)
106
+ _env._gym_reset()
107
+ logger.info("ECHO ULTIMATE ready βœ… (7 domains, 3 tasks)")
108
+ print("βœ… ECHO ULTIMATE server ready β€” http://0.0.0.0:7860/docs")
109
 
110
 
111
+ # ── ECHO-specific extra endpoints ─────────────────────────────────────────────
112
+
113
  @app.get("/", tags=["Health"])
114
  async def root():
115
+ return RedirectResponse(url="/ui")
 
 
116
 
117
 
118
+ @app.get("/tasks", tags=["Tasks"])
119
  async def list_tasks():
120
+ return _get_env().list_tasks()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
 
123
  @app.get("/metrics", tags=["Metrics"])
124
  async def get_metrics():
125
+ return _get_env().get_metrics().to_dict()
 
126
 
127
 
128
  @app.get("/metrics/{domain}", tags=["Metrics"])
129
  async def get_domain_metrics(domain: str):
130
  if domain not in cfg.DOMAINS:
131
  raise HTTPException(404, f"Unknown domain '{domain}'. Valid: {cfg.DOMAINS}")
132
+ return _get_env().get_metrics(domain=domain).to_dict()
 
133
 
134
 
135
  @app.get("/fingerprint", tags=["Metrics"])
136
+ async def get_fingerprint():
137
  env = _get_env()
138
  profiles = env.reward_history.get_domain_profiles()
139
  return {
140
+ "domain_scores": {d: round(1.0 - r.ece, 3) for d, r in profiles.items()},
141
+ "domain_ece": {d: round(r.ece, 3) for d, r in profiles.items()},
142
  "domain_accuracy": {d: round(r.accuracy, 3) for d, r in profiles.items()},
143
+ "overall_ece": round(env.get_metrics().ece, 3),
144
  }
145
 
146
 
147
  @app.get("/history", tags=["Metrics"])
148
+ async def get_history():
149
  env = _get_env()
150
  df = env.reward_history.to_dataframe()
151
  records = df.tail(100).to_dict(orient="records") if len(df) > 0 else []
 
186
  print(f"⚠️ Gradio UI not mounted: {_e}")
187
 
188
 
189
+ # ── Direct runner ─────────────────────────────────────────────────────────────
190
 
191
  if __name__ == "__main__":
192
  import uvicorn