AniketAsla commited on
Commit
c815569
·
verified ·
1 Parent(s): be81cf6

deploy: update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +185 -185
app/main.py CHANGED
@@ -1,185 +1,185 @@
1
- from __future__ import annotations
2
-
3
- import time
4
- from threading import Lock
5
- from typing import Any, Dict, Optional
6
- from uuid import uuid4
7
-
8
- from fastapi import FastAPI, HTTPException, Query
9
- from fastapi.background import BackgroundTasks
10
- from fastapi.responses import FileResponse
11
- from fastapi.staticfiles import StaticFiles
12
- from pydantic import BaseModel, Field, ValidationError
13
-
14
- from .environment import InsuranceClaimEnvironment
15
- from .models import InsuranceClaimAction, InsuranceClaimObservation
16
- from .tasks import list_tasks_summary
17
- from .session_store import get_confidence_distribution
18
-
19
- SESSION_TTL_SECONDS = 1800 # 30 minutes
20
-
21
-
22
- class SessionEntry:
23
- def __init__(self, env: InsuranceClaimEnvironment):
24
- self.env = env
25
- self.last_used = time.time()
26
-
27
-
28
- _sessions: Dict[str, SessionEntry] = {}
29
- _sessions_lock = Lock()
30
-
31
-
32
- def _get_or_create_session(session_id: str) -> InsuranceClaimEnvironment:
33
- with _sessions_lock:
34
- if session_id not in _sessions:
35
- _sessions[session_id] = SessionEntry(InsuranceClaimEnvironment())
36
- entry = _sessions[session_id]
37
- entry.last_used = time.time()
38
- return entry.env
39
-
40
-
41
- def _cleanup_sessions() -> None:
42
- now = time.time()
43
- with _sessions_lock:
44
- expired = [k for k, v in _sessions.items() if now - v.last_used > SESSION_TTL_SECONDS]
45
- for k in expired:
46
- del _sessions[k]
47
-
48
-
49
- class ResetBody(BaseModel):
50
- task_id: str | None = None
51
- seed: int | None = None
52
- session_id: str | None = None
53
- episode_id: str | None = None
54
-
55
-
56
- class StepBody(BaseModel):
57
- action: Dict[str, Any] = Field(default_factory=dict)
58
- session_id: str | None = None
59
-
60
-
61
- app = FastAPI(title="DebateFloor — Insurance Calibration RL Environment")
62
-
63
- import os
64
- _frontend_dist = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "frontend", "dist")
65
- _react_mounted = False
66
-
67
- if os.path.isdir(_frontend_dist):
68
- app.mount("/assets", StaticFiles(directory=os.path.join(_frontend_dist, "assets")), name="assets")
69
- _react_mounted = True
70
- print("React UI mounted from frontend/dist")
71
- else:
72
- print(f"WARNING: React UI not mounted. Missing directory: {_frontend_dist}")
73
-
74
- @app.get("/")
75
- def index():
76
- if _react_mounted:
77
- return FileResponse(os.path.join(_frontend_dist, "index.html"))
78
- return {
79
- "name": "DebateFloor — Insurance Calibration RL Environment",
80
- "status": "running",
81
- "endpoints": ["/health", "/tasks", "/schema", "/reset", "/step", "/state"],
82
- "docs": "/docs",
83
- }
84
-
85
-
86
- @app.post("/reset")
87
- def reset(body: ResetBody = ResetBody(), background_tasks: BackgroundTasks = BackgroundTasks()) -> dict:
88
- background_tasks.add_task(_cleanup_sessions)
89
- session_id = body.session_id or body.episode_id or str(uuid4())
90
- env = _get_or_create_session(session_id)
91
- obs = env.reset(task_id=body.task_id, seed=body.seed, episode_id=session_id)
92
- return {
93
- "observation": obs.model_dump(),
94
- "reward": float(obs.reward or 0.0),
95
- "done": bool(obs.done),
96
- "session_id": session_id,
97
- }
98
-
99
-
100
- @app.post("/step")
101
- def step(body: StepBody) -> dict:
102
- session_id = body.session_id or "default"
103
- env = _get_or_create_session(session_id)
104
- try:
105
- action = InsuranceClaimAction(**body.action)
106
- except (ValidationError, ValueError) as exc:
107
- errors = exc.errors() if hasattr(exc, "errors") else [{"msg": str(exc)}]
108
- # Ensure errors are JSON-serialisable (strip non-serialisable ctx values)
109
- safe = [
110
- {k: str(v) if not isinstance(v, (str, int, float, bool, list)) else v
111
- for k, v in e.items() if k != "ctx"}
112
- for e in errors
113
- ]
114
- raise HTTPException(status_code=422, detail=safe)
115
- obs = env.step(action)
116
- return {
117
- "observation": obs.model_dump(),
118
- "reward": float(obs.reward or 0.0),
119
- "done": bool(obs.done),
120
- "session_id": session_id,
121
- }
122
-
123
-
124
- @app.get("/state")
125
- def state(session_id: str = Query(default="default")) -> dict:
126
- env = _get_or_create_session(session_id)
127
- return env.state.model_dump()
128
-
129
-
130
- @app.get("/schema")
131
- def schema() -> dict:
132
- env = InsuranceClaimEnvironment()
133
- return {
134
- "action": InsuranceClaimAction.model_json_schema(),
135
- "observation": InsuranceClaimObservation.model_json_schema(),
136
- "state": env.state.model_json_schema(),
137
- }
138
-
139
-
140
- @app.get("/tasks")
141
- def tasks() -> dict:
142
- return {"tasks": list_tasks_summary()}
143
-
144
-
145
- @app.get("/health")
146
- def health() -> dict:
147
- return {
148
- "status": "healthy",
149
- "environment": "debatefloor_insurance_calibration_env",
150
- "active_sessions": len(_sessions),
151
- }
152
-
153
-
154
- @app.get("/stats")
155
- def stats() -> dict:
156
- """Confidence distribution across all sessions — proves anti-gaming is active."""
157
- return get_confidence_distribution()
158
-
159
-
160
- @app.post("/rollout")
161
- def rollout(task_id: str = "contradictory_claim", seed: int = 42) -> dict:
162
- """Run a scripted demo episode and return the full step-by-step trace for judges."""
163
- import requests as _req
164
- session_id = f"rollout-{seed}-{task_id}"
165
- base = "http://localhost:7860"
166
- trace = []
167
-
168
- reset_r = _req.post(f"{base}/reset", json={"task_id": task_id, "seed": seed, "session_id": session_id})
169
- trace.append({"action": "reset", "response": reset_r.json()})
170
-
171
- scripted_steps = [
172
- {"action_type": "validate_document", "parameters": {"doc_id": "DOC-001"}, "reasoning": "Checking primary document for fraud signals."},
173
- {"action_type": "flag_fraud_signal", "parameters": {"flag_id": "date_mismatch", "evidence": "Incident date on claim form contradicts hospital admission date."}, "reasoning": "Date inconsistency is a strong fraud indicator."},
174
- {"action_type": "convene_debate_panel", "parameters": {}, "reasoning": "Evidence is contradictory — convening adversarial debate before terminal decision."},
175
- {"action_type": "deny_claim", "confidence": "MED", "reason": "Date mismatch confirmed by debate panel.", "reasoning": "MED confidence — debate panel supports denial but evidence is not conclusive."},
176
- ]
177
-
178
- for action in scripted_steps:
179
- step_r = _req.post(f"{base}/step", json={"action": action, "session_id": session_id})
180
- step_data = step_r.json()
181
- trace.append({"action": action["action_type"], "reward": step_data.get("reward"), "done": step_data.get("done"), "response": step_data})
182
- if step_data.get("done"):
183
- break
184
-
185
- return {"task_id": task_id, "seed": seed, "session_id": session_id, "trace": trace}
 
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from threading import Lock
5
+ from typing import Any, Dict, Optional
6
+ from uuid import uuid4
7
+
8
+ from fastapi import FastAPI, HTTPException, Query
9
+ from fastapi.background import BackgroundTasks
10
+ from fastapi.responses import FileResponse
11
+ from fastapi.staticfiles import StaticFiles
12
+ from pydantic import BaseModel, Field, ValidationError
13
+
14
+ from .environment import InsuranceClaimEnvironment
15
+ from .models import InsuranceClaimAction, InsuranceClaimObservation
16
+ from .tasks import list_tasks_summary
17
+ from .session_store import get_confidence_distribution
18
+
19
+ SESSION_TTL_SECONDS = 1800 # 30 minutes
20
+
21
+
22
+ class SessionEntry:
23
+ def __init__(self, env: InsuranceClaimEnvironment):
24
+ self.env = env
25
+ self.last_used = time.time()
26
+
27
+
28
+ _sessions: Dict[str, SessionEntry] = {}
29
+ _sessions_lock = Lock()
30
+
31
+
32
+ def _get_or_create_session(session_id: str) -> InsuranceClaimEnvironment:
33
+ with _sessions_lock:
34
+ if session_id not in _sessions:
35
+ _sessions[session_id] = SessionEntry(InsuranceClaimEnvironment())
36
+ entry = _sessions[session_id]
37
+ entry.last_used = time.time()
38
+ return entry.env
39
+
40
+
41
+ def _cleanup_sessions() -> None:
42
+ now = time.time()
43
+ with _sessions_lock:
44
+ expired = [k for k, v in _sessions.items() if now - v.last_used > SESSION_TTL_SECONDS]
45
+ for k in expired:
46
+ del _sessions[k]
47
+
48
+
49
+ class ResetBody(BaseModel):
50
+ task_id: str | None = None
51
+ seed: int | None = None
52
+ session_id: str | None = None
53
+ episode_id: str | None = None
54
+
55
+
56
+ class StepBody(BaseModel):
57
+ action: Dict[str, Any] = Field(default_factory=dict)
58
+ session_id: str | None = None
59
+
60
+
61
+ app = FastAPI(title="DebateFloor — Insurance Calibration RL Environment")
62
+
63
+ import os
64
+ _frontend_dist = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "frontend", "dist")
65
+ _react_mounted = False
66
+
67
+ if os.path.isdir(_frontend_dist):
68
+ app.mount("/assets", StaticFiles(directory=os.path.join(_frontend_dist, "assets")), name="assets")
69
+ _react_mounted = True
70
+ print("React UI mounted from frontend/dist")
71
+ else:
72
+ print(f"WARNING: React UI not mounted. Missing directory: {_frontend_dist}")
73
+
74
+ @app.get("/")
75
+ def index():
76
+ if _react_mounted:
77
+ return FileResponse(os.path.join(_frontend_dist, "index.html"))
78
+ return {
79
+ "name": "DebateFloor — Insurance Calibration RL Environment",
80
+ "status": "running",
81
+ "endpoints": ["/health", "/tasks", "/schema", "/reset", "/step", "/state"],
82
+ "docs": "/docs",
83
+ }
84
+
85
+
86
+ @app.post("/reset")
87
+ def reset(body: ResetBody = ResetBody(), background_tasks: BackgroundTasks = BackgroundTasks()) -> dict:
88
+ background_tasks.add_task(_cleanup_sessions)
89
+ session_id = body.session_id or body.episode_id or str(uuid4())
90
+ env = _get_or_create_session(session_id)
91
+ obs = env.reset(task_id=body.task_id, seed=body.seed, episode_id=session_id)
92
+ return {
93
+ "observation": obs.model_dump(),
94
+ "reward": float(obs.reward or 0.0),
95
+ "done": bool(obs.done),
96
+ "session_id": session_id,
97
+ }
98
+
99
+
100
+ @app.post("/step")
101
+ def step(body: StepBody) -> dict:
102
+ session_id = body.session_id or "default"
103
+ env = _get_or_create_session(session_id)
104
+ try:
105
+ action = InsuranceClaimAction(**body.action)
106
+ except (ValidationError, ValueError) as exc:
107
+ errors = exc.errors() if hasattr(exc, "errors") else [{"msg": str(exc)}]
108
+ # Ensure errors are JSON-serialisable (strip non-serialisable ctx values)
109
+ safe = [
110
+ {k: str(v) if not isinstance(v, (str, int, float, bool, list)) else v
111
+ for k, v in e.items() if k != "ctx"}
112
+ for e in errors
113
+ ]
114
+ raise HTTPException(status_code=422, detail=safe)
115
+ obs = env.step(action)
116
+ return {
117
+ "observation": obs.model_dump(),
118
+ "reward": float(obs.reward or 0.0),
119
+ "done": bool(obs.done),
120
+ "session_id": session_id,
121
+ }
122
+
123
+
124
+ @app.get("/state")
125
+ def state(session_id: str = Query(default="default")) -> dict:
126
+ env = _get_or_create_session(session_id)
127
+ return env.state.model_dump()
128
+
129
+
130
+ @app.get("/schema")
131
+ def schema() -> dict:
132
+ env = InsuranceClaimEnvironment()
133
+ return {
134
+ "action": InsuranceClaimAction.model_json_schema(),
135
+ "observation": InsuranceClaimObservation.model_json_schema(),
136
+ "state": env.state.model_json_schema(),
137
+ }
138
+
139
+
140
+ @app.get("/tasks")
141
+ def tasks() -> dict:
142
+ return {"tasks": list_tasks_summary()}
143
+
144
+
145
+ @app.get("/health")
146
+ def health() -> dict:
147
+ return {
148
+ "status": "healthy",
149
+ "environment": "debatefloor_insurance_calibration_env",
150
+ "active_sessions": len(_sessions),
151
+ }
152
+
153
+
154
+ @app.get("/stats")
155
+ def stats() -> dict:
156
+ """Confidence distribution across all sessions — proves anti-gaming is active."""
157
+ return get_confidence_distribution()
158
+
159
+
160
+ @app.post("/rollout")
161
+ def rollout(task_id: str = "contradictory_claim", seed: int = 42) -> dict:
162
+ """Run a scripted demo episode and return the full step-by-step trace for judges."""
163
+ import requests as _req
164
+ session_id = f"rollout-{seed}-{task_id}"
165
+ base = "http://localhost:7860"
166
+ trace = []
167
+
168
+ reset_r = _req.post(f"{base}/reset", json={"task_id": task_id, "seed": seed, "session_id": session_id})
169
+ trace.append({"action": "reset", "response": reset_r.json()})
170
+
171
+ scripted_steps = [
172
+ {"action_type": "validate_document", "parameters": {"doc_id": "DOC-001"}, "reasoning": "Checking primary document for fraud signals."},
173
+ {"action_type": "flag_fraud_signal", "parameters": {"flag_id": "date_mismatch", "evidence": "Incident date on claim form contradicts hospital admission date."}, "reasoning": "Date inconsistency is a strong fraud indicator."},
174
+ {"action_type": "convene_debate_panel", "parameters": {}, "reasoning": "Evidence is contradictory — convening adversarial debate before terminal decision."},
175
+ {"action_type": "deny_claim", "confidence": "MED", "reason": "Date mismatch confirmed by debate panel.", "reasoning": "MED confidence — debate panel supports denial but evidence is not conclusive."},
176
+ ]
177
+
178
+ for action in scripted_steps:
179
+ step_r = _req.post(f"{base}/step", json={"action": action, "session_id": session_id})
180
+ step_data = step_r.json()
181
+ trace.append({"action": action["action_type"], "reward": step_data.get("reward"), "done": step_data.get("done"), "response": step_data})
182
+ if step_data.get("done"):
183
+ break
184
+
185
+ return {"task_id": task_id, "seed": seed, "session_id": session_id, "trace": trace}