Spaces:
Sleeping
Sleeping
File size: 8,770 Bytes
62b6842 b836ee9 62b6842 cc700fe b836ee9 62b6842 29a2e8e 62b6842 aedaafb d2760e5 62b6842 b836ee9 54e8b1b b836ee9 62b6842 cc700fe 54e8b1b b836ee9 54e8b1b b836ee9 54e8b1b b836ee9 54e8b1b b836ee9 54e8b1b b836ee9 cc700fe 62b6842 b836ee9 62b6842 da8f87b 8122ba9 62b6842 da8f87b 8122ba9 62b6842 da8f87b 8122ba9 62b6842 8122ba9 62b6842 da8f87b 8122ba9 62b6842 da8f87b c0f7fc8 62b6842 c0f7fc8 62b6842 c0f7fc8 62b6842 c0f7fc8 62b6842 c32739a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 | """
FastAPI application — OpenEnv-compliant HTTP interface.
Required endpoints (from spec):
POST /reset → Observation
POST /step → {observation, reward, done, info}
GET /state → current episode state dict
GET /tasks → list of tasks + action schemas
POST /grader → run grader on completed episode
POST /baseline → run baseline inference script, return scores
GET /health → 200 OK (used by Docker HEALTHCHECK and judge ping)
"""
from __future__ import annotations
import os
import sys
from fastapi import FastAPI, HTTPException, Body, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional
import json
# Ensure project root is on path regardless of working directory
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from env.environment import ResearchIntegrityEnv
from env.models import Action, ActionType
app = FastAPI(
title="Research Integrity Gym",
description=(
"OpenEnv environment for training and evaluating AI agents on "
"scientific research integrity tasks. Agents must audit methodology, "
"replicate experiments, and verify statistical claims."
),
version="1.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# One global environment instance per server process
_env = ResearchIntegrityEnv()
# ---------------------------------------------------------------------------
# Request / response models
# ---------------------------------------------------------------------------
class ResetRequest(BaseModel):
task_id: str = "task1_methodology_audit"
seed: int | None = None
class Config:
extra = "ignore" # Ignore extra fields
class StepRequest(BaseModel):
action: Action
class GraderRequest(BaseModel):
task_id: str
episode_state: dict # serialised state from a completed episode
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.get("/api")
def root():
"""Root endpoint - redirects to API documentation."""
return {
"name": "Research Integrity Gym",
"description": "OpenEnv environment for AI agents to evaluate scientific research integrity",
"docs": "/docs",
"endpoints": {
"health": "GET /health",
"tasks": "GET /tasks",
"reset": "POST /reset",
"step": "POST /step",
"state": "GET /state",
}
}
@app.get("/health")
def health():
return {"status": "ok", "environment": "research-integrity-gym"}
@app.post("/reset")
async def reset(request: Request):
"""Start a new episode. Returns initial Observation.
Accepts:
- Empty body
- Body with just "null"
- JSON body with task_id and/or seed
"""
global _env
# Parse body manually to handle empty/missing/null body
body_bytes = await request.body()
body_text = body_bytes.decode("utf-8").strip() if body_bytes else ""
# Handle empty body, "null", "{}", or actual JSON
body_data = {}
if body_text and body_text != "null":
try:
parsed = json.loads(body_text)
if isinstance(parsed, dict):
body_data = parsed
# If parsed is None or not a dict, keep body_data as empty dict
except json.JSONDecodeError:
pass # Keep body_data as empty dict
task_id = body_data.get("task_id", "task1_methodology_audit")
seed = body_data.get("seed", None)
if seed is not None:
_env = ResearchIntegrityEnv(seed=seed)
try:
obs = _env.reset(task_id=task_id)
return obs.model_dump()
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.post("/step")
def step(req: StepRequest):
"""Execute one action. Returns observation, reward, done, info."""
try:
obs, reward, done, info = _env.step(req.action)
return {
"observation": obs.model_dump(),
"reward": reward.model_dump(),
"done": done,
"info": info,
}
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/state")
def state():
"""Return current episode state (excludes ground truth)."""
return _env.state()
@app.get("/tasks")
def tasks():
"""Return all available tasks with their action schemas."""
from tasks.task1_methodology_audit import MethodologyAuditTask
from tasks.task2_replication import ReplicationTask
from tasks.task3_claim_verify import ClaimVerifyTask
from tasks.task4_citation_check import CitationCheckTask
from tasks.task5_fda_approval import FDAApprovalTask
task_list = [
MethodologyAuditTask().task_info(),
ReplicationTask().task_info(),
ClaimVerifyTask().task_info(),
CitationCheckTask().task_info(),
FDAApprovalTask().task_info(),
]
return {"tasks": task_list}
@app.post("/grader")
def grader(req: GraderRequest):
"""
Run the grader for a completed episode externally.
Accepts a serialised terminal_action and ground_truth.
Used by the judge's automated evaluation pipeline.
"""
from graders.grader1 import grade_audit
from graders.grader2 import grade_results
from graders.grader3 import grade_verdict
from graders.grader4 import grade_citation_report
from graders.grader5 import grade_fda_verdict
from env.models import (
SubmitAuditPayload, SubmitResultsPayload, SubmitVerdictPayload,
SubmitCitationReportPayload, SubmitFDAVerdictPayload, FlawReport,
)
task_id = req.task_id
state_dict = req.episode_state
gt = state_dict.get("ground_truth", {})
terminal_act = state_dict.get("terminal_action", {})
try:
if task_id == "task1_methodology_audit":
flaws = [FlawReport(**f) for f in terminal_act.get("flaws", [])]
payload = SubmitAuditPayload(flaws=flaws)
score = grade_audit(payload, gt)
elif task_id == "task2_replication":
payload = SubmitResultsPayload(**terminal_act)
score = grade_results(payload, gt)
elif task_id == "task3_claim_verify":
payload = SubmitVerdictPayload(**terminal_act)
score = grade_verdict(payload, gt)
elif task_id == "task4_citation_check":
payload = SubmitCitationReportPayload(**terminal_act)
score = grade_citation_report(payload, gt)
elif task_id == "task5_fda_approval":
payload = SubmitFDAVerdictPayload(**terminal_act)
# For external grader calls, we create a minimal EpisodeState
from env.state import EpisodeState
mock_state = EpisodeState(
task_id=task_id,
flags_raised=state_dict.get("flags_raised", []),
code_calls=state_dict.get("code_calls", 0),
)
score = grade_fda_verdict(payload, gt, mock_state)
else:
raise HTTPException(status_code=400, detail=f"Unknown task_id: {task_id}")
return {"task_id": task_id, "grader_score": score}
except Exception as e:
raise HTTPException(status_code=422, detail=str(e))
@app.post("/baseline")
def baseline():
"""
Trigger the baseline inference script and return scores for all 4 tasks.
Requires HF_TOKEN in environment.
"""
import subprocess
import json
api_key = os.environ.get("HF_TOKEN", "")
if not api_key:
raise HTTPException(
status_code=503,
detail="HF_TOKEN not set. Add it to Space secrets.",
)
result = subprocess.run(
[sys.executable, "baseline.py", "--output-json"],
capture_output=True, text=True, timeout=300,
env={**os.environ, "HF_TOKEN": api_key},
)
if result.returncode != 0:
raise HTTPException(
status_code=500,
detail=f"Baseline script failed:\n{result.stderr[:2000]}",
)
try:
scores = json.loads(result.stdout)
return scores
except json.JSONDecodeError:
return {"raw_output": result.stdout[:3000]}
# ---------------------------------------------------------------------------
# Mount Gradio demo UI at root
# ---------------------------------------------------------------------------
import gradio as gr
from app import demo as gradio_demo
app = gr.mount_gradio_app(app, gradio_demo, path="/")
|