Spaces:
Sleeping
Sleeping
File size: 8,773 Bytes
fdd45f1 c534ac0 f9f5e0d fdd45f1 f9f5e0d c534ac0 fdd45f1 c534ac0 fdd45f1 c534ac0 f9f5e0d c534ac0 f9f5e0d c534ac0 f22e3ad fdd45f1 c534ac0 fdd45f1 c534ac0 fdd45f1 c534ac0 fdd45f1 c534ac0 fdd45f1 c534ac0 f9f5e0d c534ac0 f9f5e0d c534ac0 f9f5e0d c534ac0 fdd45f1 c534ac0 fdd45f1 c534ac0 fdd45f1 c534ac0 fdd45f1 c534ac0 f9f5e0d fdd45f1 c534ac0 f9f5e0d c534ac0 f9f5e0d c534ac0 fdd45f1 c534ac0 fdd45f1 c534ac0 fdd45f1 c534ac0 f9f5e0d c534ac0 f9f5e0d c534ac0 f9f5e0d c534ac0 f9f5e0d fdd45f1 f9f5e0d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 | """
FocusFlow RL Environment β app.py
FastAPI server exposing the OpenEnv HTTP API.
Endpoints:
POST /reset β FocusObservation
POST /step β FocusObservation + reward + done
GET /state β FocusState (full internal debug state)
GET /health β {"status": "ok"}
GET /tasks β list of all tasks
GET /metrics β episode-level training metrics (for reward curve UI)
POST /reset_metrics β clear metrics history
POST /grader β direct reasoning quality grader (offline evaluation)
"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from models import FocusAction, FocusObservation, FocusState
from environment import FocusFlowEnvironment, TASKS, grade_reasoning
from typing import Optional, List, Dict
from pydantic import BaseModel
import uvicorn
app = FastAPI(
title = "FocusFlow RL Environment",
description = (
"OpenEnv-compatible RL environment for student focus & distraction management. "
"LLM-hard: requires natural language reasoning, multi-day planning, "
"and urgency-aware event handling."
),
version="2.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ββ Global state ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
sessions: Dict[str, FocusFlowEnvironment] = {}
session_metrics: Dict[str, List[dict]] = {}
session_episodes: Dict[str, int] = {}
# ββ Response models βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class StepResponse(FocusObservation):
reward: float
done: bool
info: dict
class GraderRequest(BaseModel):
reasoning: str
action_type: str
class GraderResponse(BaseModel):
reasoning: str
action_type: str
reasoning_quality_score: float
verdict: str
explanation: str
# ββ Endpoints βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@app.get("/")
def root():
return {
"name": "FocusFlow RL Environment",
"version": "2.0.0",
"author": "Abdul Hannan",
"hackathon": "Meta x Scaler OpenEnv Hackathon 2026",
"description": "LLM-hard RL environment for student focus and distraction management",
"theme": "Theme 3.2 - Personalized Tasks",
"endpoints": {
"health": "/health",
"docs": "/docs",
"tasks": "/tasks",
"reset": "POST /reset",
"step": "POST /step",
"grader": "POST /grader",
"metrics": "/metrics"
},
"live_demo": "https://hannan2859r-focusflow-env.hf.space/docs",
"github": "https://github.com/abdulhannan-18/Focus_Flow_env"
}
@app.get("/health")
def health():
return {
"status": "ok",
"environment": "FocusFlow",
"version": "2.0.0",
"sessions_active": len(sessions),
}
@app.get("/tasks")
def list_tasks():
"""List all available tasks with descriptions."""
return {
"tasks": [
{
"id": t["id"],
"description": t["description"],
"max_steps": t["max_steps"],
"bonus_desc": t.get("bonus_desc", ""),
"days": t.get("days", 1),
}
for t in TASKS
]
}
@app.post("/reset", response_model=FocusObservation)
def reset(task_id: str = "task_1", seed: int = 42, session_id: str = "default"):
"""
Reset the environment for a new episode.
Call this before the first /step and at the start of each new episode.
"""
valid_ids = [t["id"] for t in TASKS]
if task_id not in valid_ids:
raise HTTPException(
status_code=400,
detail=f"Unknown task_id '{task_id}'. Valid: {valid_ids}"
)
if session_id not in session_episodes:
session_episodes[session_id] = 0
session_metrics[session_id] = []
sessions[session_id] = FocusFlowEnvironment(task_id=task_id, seed=seed)
session_episodes[session_id] += 1
return sessions[session_id].reset()
@app.post("/step", response_model=StepResponse)
def step(action: FocusAction, session_id: str = "default"):
"""
Submit one action. Returns next observation + reward + done flag.
The `reasoning` field in FocusAction is REQUIRED and graded.
Empty or low-quality reasoning incurs a reward penalty.
"""
env = sessions.get(session_id)
if env is None:
raise HTTPException(
status_code=400,
detail=f"Session '{session_id}' not initialised. Call POST /reset first."
)
obs, reward, done, info = env.step(action)
session_metrics[session_id].append({
"episode": session_episodes[session_id],
"step": info["step"],
"reward": reward,
"cumulative": info["cumulative"],
"reasoning_q": obs.reasoning_quality_score,
"success": info.get("success", False),
})
return StepResponse(
**obs.model_dump(),
reward=reward,
done=done,
info=info,
)
@app.get("/state", response_model=FocusState)
def state(session_id: str = "default"):
"""Return full internal environment state (for debugging and logging)."""
env = sessions.get(session_id)
if env is None:
raise HTTPException(
status_code=400,
detail=f"Session '{session_id}' not initialised. Call POST /reset first."
)
return env.state()
@app.get("/metrics")
def metrics(session_id: str = "default"):
"""
Returns per-step training metrics for reward curve plotting.
Use this in your Colab notebook to visualise training progress.
"""
metrics_log = session_metrics.get(session_id, [])
if not metrics_log:
return {"message": "No data yet. Run some episodes first.", "data": []}
from collections import defaultdict
ep_rewards = defaultdict(float)
ep_steps = defaultdict(int)
ep_success = defaultdict(bool)
for m in metrics_log:
ep = m["episode"]
ep_rewards[ep] += m["reward"]
ep_steps[ep] += 1
ep_success[ep] = ep_success[ep] or m["success"]
episodes_summary = [
{
"episode": ep,
"total_reward": round(ep_rewards[ep], 4),
"steps": ep_steps[ep],
"success": ep_success[ep],
}
for ep in sorted(ep_rewards.keys())
]
return {
"total_steps": len(metrics_log),
"total_episodes": len(episodes_summary),
"episodes": episodes_summary,
"raw_steps": metrics_log[-100:],
}
@app.post("/reset_metrics")
def reset_metrics(session_id: str = "default"):
"""Clear the metrics log. Call this between training runs."""
session_metrics[session_id] = []
session_episodes[session_id] = 0
return {"message": f"Metrics cleared for session '{session_id}'."}
@app.post("/grader", response_model=GraderResponse)
def grader(request: GraderRequest):
"""
Direct grader invocation for offline evaluation.
Use this to test reasoning quality without running a full episode.
Judges can use this to verify the grading pipeline works correctly.
"""
score = grade_reasoning(request.reasoning, request.action_type, None)
if score >= 0.7:
verdict = "excellent"
explanation = "Reasoning is clear, relevant, and uses proper justification."
elif score >= 0.5:
verdict = "good"
explanation = "Reasoning is adequate but could mention more context signals."
elif score >= 0.3:
verdict = "weak"
explanation = "Reasoning is too short or lacks relevant keywords."
else:
verdict = "poor"
explanation = "Reasoning is empty, spammy, or below minimum quality threshold."
return GraderResponse(
reasoning = request.reasoning,
action_type = request.action_type,
reasoning_quality_score = score,
verdict = verdict,
explanation = explanation,
)
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) |