hannan2859r commited on
Commit
f9f5e0d
Β·
verified Β·
1 Parent(s): a5ae22e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -26
app.py CHANGED
@@ -10,12 +10,13 @@ Endpoints:
10
  GET /tasks β†’ list of all tasks
11
  GET /metrics β†’ episode-level training metrics (for reward curve UI)
12
  POST /reset_metrics β†’ clear metrics history
 
13
  """
14
 
15
  from fastapi import FastAPI, HTTPException
16
  from fastapi.middleware.cors import CORSMiddleware
17
  from models import FocusAction, FocusObservation, FocusState
18
- from environment import FocusFlowEnvironment, TASKS
19
  from typing import Optional, List, Dict
20
  from pydantic import BaseModel
21
  import uvicorn
@@ -38,20 +39,29 @@ app.add_middleware(
38
  )
39
 
40
  # ── Global state ──────────────────────────────────────────────────────────────
41
-
42
- # Map session IDs to their specific environment instances
43
  sessions: Dict[str, FocusFlowEnvironment] = {}
44
-
45
- # Track metrics and episode counts per session
46
  session_metrics: Dict[str, List[dict]] = {}
47
  session_episodes: Dict[str, int] = {}
48
 
49
 
50
- # ── Response model ────────────────────────────────────────────────────────────
51
  class StepResponse(FocusObservation):
52
- reward: float
53
- done: bool
54
- info: dict
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
 
57
  # ── Endpoints ─────────────────────────────────────────────────────────────────
@@ -95,14 +105,14 @@ def reset(task_id: str = "task_1", seed: int = 42, session_id: str = "default"):
95
  status_code=400,
96
  detail=f"Unknown task_id '{task_id}'. Valid: {valid_ids}"
97
  )
98
-
99
  if session_id not in session_episodes:
100
  session_episodes[session_id] = 0
101
- session_metrics[session_id] = []
102
-
103
  sessions[session_id] = FocusFlowEnvironment(task_id=task_id, seed=seed)
104
  session_episodes[session_id] += 1
105
-
106
  return sessions[session_id].reset()
107
 
108
 
@@ -110,7 +120,6 @@ def reset(task_id: str = "task_1", seed: int = 42, session_id: str = "default"):
110
  def step(action: FocusAction, session_id: str = "default"):
111
  """
112
  Submit one action. Returns next observation + reward + done flag.
113
-
114
  The `reasoning` field in FocusAction is REQUIRED and graded.
115
  Empty or low-quality reasoning incurs a reward penalty.
116
  """
@@ -120,17 +129,16 @@ def step(action: FocusAction, session_id: str = "default"):
120
  status_code=400,
121
  detail=f"Session '{session_id}' not initialised. Call POST /reset first."
122
  )
123
-
124
  obs, reward, done, info = env.step(action)
125
 
126
- # Log for metrics endpoint
127
  session_metrics[session_id].append({
128
- "episode": session_episodes[session_id],
129
- "step": info["step"],
130
- "reward": reward,
131
- "cumulative": info["cumulative"],
132
  "reasoning_q": obs.reasoning_quality_score,
133
- "success": info.get("success", False),
134
  })
135
 
136
  return StepResponse(
@@ -160,11 +168,10 @@ def metrics(session_id: str = "default"):
160
  Use this in your Colab notebook to visualise training progress.
161
  """
162
  metrics_log = session_metrics.get(session_id, [])
163
-
164
  if not metrics_log:
165
  return {"message": "No data yet. Run some episodes first.", "data": []}
166
 
167
- # Aggregate by episode
168
  from collections import defaultdict
169
  ep_rewards = defaultdict(float)
170
  ep_steps = defaultdict(int)
@@ -189,17 +196,48 @@ def metrics(session_id: str = "default"):
189
  "total_steps": len(metrics_log),
190
  "total_episodes": len(episodes_summary),
191
  "episodes": episodes_summary,
192
- "raw_steps": metrics_log[-100:], # last 100 steps
193
  }
194
 
195
 
196
  @app.post("/reset_metrics")
197
  def reset_metrics(session_id: str = "default"):
198
  """Clear the metrics log. Call this between training runs."""
199
- session_metrics[session_id] = []
200
  session_episodes[session_id] = 0
201
  return {"message": f"Metrics cleared for session '{session_id}'."}
202
 
203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  if __name__ == "__main__":
205
- uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
 
10
  GET /tasks β†’ list of all tasks
11
  GET /metrics β†’ episode-level training metrics (for reward curve UI)
12
  POST /reset_metrics β†’ clear metrics history
13
+ POST /grader β†’ direct reasoning quality grader (offline evaluation)
14
  """
15
 
16
  from fastapi import FastAPI, HTTPException
17
  from fastapi.middleware.cors import CORSMiddleware
18
  from models import FocusAction, FocusObservation, FocusState
19
+ from environment import FocusFlowEnvironment, TASKS, grade_reasoning
20
  from typing import Optional, List, Dict
21
  from pydantic import BaseModel
22
  import uvicorn
 
39
  )
40
 
41
  # ── Global state ──────────────────────────────────────────────────────────────
 
 
42
  sessions: Dict[str, FocusFlowEnvironment] = {}
 
 
43
  session_metrics: Dict[str, List[dict]] = {}
44
  session_episodes: Dict[str, int] = {}
45
 
46
 
47
+ # ── Response models ───────────────────────────────────────────────────────────
48
  class StepResponse(FocusObservation):
49
+ reward: float
50
+ done: bool
51
+ info: dict
52
+
53
+
54
+ class GraderRequest(BaseModel):
55
+ reasoning: str
56
+ action_type: str
57
+
58
+
59
+ class GraderResponse(BaseModel):
60
+ reasoning: str
61
+ action_type: str
62
+ reasoning_quality_score: float
63
+ verdict: str
64
+ explanation: str
65
 
66
 
67
  # ── Endpoints ─────────────────────────────────────────────────────────────────
 
105
  status_code=400,
106
  detail=f"Unknown task_id '{task_id}'. Valid: {valid_ids}"
107
  )
108
+
109
  if session_id not in session_episodes:
110
  session_episodes[session_id] = 0
111
+ session_metrics[session_id] = []
112
+
113
  sessions[session_id] = FocusFlowEnvironment(task_id=task_id, seed=seed)
114
  session_episodes[session_id] += 1
115
+
116
  return sessions[session_id].reset()
117
 
118
 
 
120
  def step(action: FocusAction, session_id: str = "default"):
121
  """
122
  Submit one action. Returns next observation + reward + done flag.
 
123
  The `reasoning` field in FocusAction is REQUIRED and graded.
124
  Empty or low-quality reasoning incurs a reward penalty.
125
  """
 
129
  status_code=400,
130
  detail=f"Session '{session_id}' not initialised. Call POST /reset first."
131
  )
132
+
133
  obs, reward, done, info = env.step(action)
134
 
 
135
  session_metrics[session_id].append({
136
+ "episode": session_episodes[session_id],
137
+ "step": info["step"],
138
+ "reward": reward,
139
+ "cumulative": info["cumulative"],
140
  "reasoning_q": obs.reasoning_quality_score,
141
+ "success": info.get("success", False),
142
  })
143
 
144
  return StepResponse(
 
168
  Use this in your Colab notebook to visualise training progress.
169
  """
170
  metrics_log = session_metrics.get(session_id, [])
171
+
172
  if not metrics_log:
173
  return {"message": "No data yet. Run some episodes first.", "data": []}
174
 
 
175
  from collections import defaultdict
176
  ep_rewards = defaultdict(float)
177
  ep_steps = defaultdict(int)
 
196
  "total_steps": len(metrics_log),
197
  "total_episodes": len(episodes_summary),
198
  "episodes": episodes_summary,
199
+ "raw_steps": metrics_log[-100:],
200
  }
201
 
202
 
203
  @app.post("/reset_metrics")
204
  def reset_metrics(session_id: str = "default"):
205
  """Clear the metrics log. Call this between training runs."""
206
+ session_metrics[session_id] = []
207
  session_episodes[session_id] = 0
208
  return {"message": f"Metrics cleared for session '{session_id}'."}
209
 
210
 
211
+ @app.post("/grader", response_model=GraderResponse)
212
+ def grader(request: GraderRequest):
213
+ """
214
+ Direct grader invocation for offline evaluation.
215
+ Use this to test reasoning quality without running a full episode.
216
+ Judges can use this to verify the grading pipeline works correctly.
217
+ """
218
+ score = grade_reasoning(request.reasoning, request.action_type, None)
219
+
220
+ if score >= 0.7:
221
+ verdict = "excellent"
222
+ explanation = "Reasoning is clear, relevant, and uses proper justification."
223
+ elif score >= 0.5:
224
+ verdict = "good"
225
+ explanation = "Reasoning is adequate but could mention more context signals."
226
+ elif score >= 0.3:
227
+ verdict = "weak"
228
+ explanation = "Reasoning is too short or lacks relevant keywords."
229
+ else:
230
+ verdict = "poor"
231
+ explanation = "Reasoning is empty, spammy, or below minimum quality threshold."
232
+
233
+ return GraderResponse(
234
+ reasoning = request.reasoning,
235
+ action_type = request.action_type,
236
+ reasoning_quality_score = score,
237
+ verdict = verdict,
238
+ explanation = explanation,
239
+ )
240
+
241
+
242
  if __name__ == "__main__":
243
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)