Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI application — OpenEnv-compliant HTTP interface. | |
| Required endpoints (from spec): | |
| POST /reset → Observation | |
| POST /step → {observation, reward, done, info} | |
| GET /state → current episode state dict | |
| GET /tasks → list of tasks + action schemas | |
| POST /grader → run grader on completed episode | |
| POST /baseline → run baseline inference script, return scores | |
| GET /health → 200 OK (used by Docker HEALTHCHECK and judge ping) | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import sys | |
| from fastapi import FastAPI, HTTPException, Body, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| import json | |
| # Ensure project root is on path regardless of working directory | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from env.environment import ResearchIntegrityEnv | |
| from env.models import Action, ActionType | |
| app = FastAPI( | |
| title="Research Integrity Gym", | |
| description=( | |
| "OpenEnv environment for training and evaluating AI agents on " | |
| "scientific research integrity tasks. Agents must audit methodology, " | |
| "replicate experiments, and verify statistical claims." | |
| ), | |
| version="1.0.0", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # One global environment instance per server process | |
| _env = ResearchIntegrityEnv() | |
| # --------------------------------------------------------------------------- | |
| # Request / response models | |
| # --------------------------------------------------------------------------- | |
| class ResetRequest(BaseModel): | |
| task_id: str = "task1_methodology_audit" | |
| seed: int | None = None | |
| class Config: | |
| extra = "ignore" # Ignore extra fields | |
| class StepRequest(BaseModel): | |
| action: Action | |
| class GraderRequest(BaseModel): | |
| task_id: str | |
| episode_state: dict # serialised state from a completed episode | |
| # --------------------------------------------------------------------------- | |
| # Endpoints | |
| # --------------------------------------------------------------------------- | |
| def root(): | |
| """Root endpoint - redirects to API documentation.""" | |
| return { | |
| "name": "Research Integrity Gym", | |
| "description": "OpenEnv environment for AI agents to evaluate scientific research integrity", | |
| "docs": "/docs", | |
| "endpoints": { | |
| "health": "GET /health", | |
| "tasks": "GET /tasks", | |
| "reset": "POST /reset", | |
| "step": "POST /step", | |
| "state": "GET /state", | |
| } | |
| } | |
| def health(): | |
| return {"status": "ok", "environment": "research-integrity-gym"} | |
| async def reset(request: Request): | |
| """Start a new episode. Returns initial Observation. | |
| Accepts: | |
| - Empty body | |
| - Body with just "null" | |
| - JSON body with task_id and/or seed | |
| """ | |
| global _env | |
| # Parse body manually to handle empty/missing/null body | |
| body_bytes = await request.body() | |
| body_text = body_bytes.decode("utf-8").strip() if body_bytes else "" | |
| # Handle empty body, "null", "{}", or actual JSON | |
| body_data = {} | |
| if body_text and body_text != "null": | |
| try: | |
| parsed = json.loads(body_text) | |
| if isinstance(parsed, dict): | |
| body_data = parsed | |
| # If parsed is None or not a dict, keep body_data as empty dict | |
| except json.JSONDecodeError: | |
| pass # Keep body_data as empty dict | |
| task_id = body_data.get("task_id", "task1_methodology_audit") | |
| seed = body_data.get("seed", None) | |
| if seed is not None: | |
| _env = ResearchIntegrityEnv(seed=seed) | |
| try: | |
| obs = _env.reset(task_id=task_id) | |
| return obs.model_dump() | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| def step(req: StepRequest): | |
| """Execute one action. Returns observation, reward, done, info.""" | |
| try: | |
| obs, reward, done, info = _env.step(req.action) | |
| return { | |
| "observation": obs.model_dump(), | |
| "reward": reward.model_dump(), | |
| "done": done, | |
| "info": info, | |
| } | |
| except RuntimeError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| def state(): | |
| """Return current episode state (excludes ground truth).""" | |
| return _env.state() | |
| def tasks(): | |
| """Return all available tasks with their action schemas.""" | |
| from tasks.task1_methodology_audit import MethodologyAuditTask | |
| from tasks.task2_replication import ReplicationTask | |
| from tasks.task3_claim_verify import ClaimVerifyTask | |
| from tasks.task4_citation_check import CitationCheckTask | |
| from tasks.task5_fda_approval import FDAApprovalTask | |
| task_list = [ | |
| MethodologyAuditTask().task_info(), | |
| ReplicationTask().task_info(), | |
| ClaimVerifyTask().task_info(), | |
| CitationCheckTask().task_info(), | |
| FDAApprovalTask().task_info(), | |
| ] | |
| return {"tasks": task_list} | |
| def grader(req: GraderRequest): | |
| """ | |
| Run the grader for a completed episode externally. | |
| Accepts a serialised terminal_action and ground_truth. | |
| Used by the judge's automated evaluation pipeline. | |
| """ | |
| from graders.grader1 import grade_audit | |
| from graders.grader2 import grade_results | |
| from graders.grader3 import grade_verdict | |
| from graders.grader4 import grade_citation_report | |
| from graders.grader5 import grade_fda_verdict | |
| from env.models import ( | |
| SubmitAuditPayload, SubmitResultsPayload, SubmitVerdictPayload, | |
| SubmitCitationReportPayload, SubmitFDAVerdictPayload, FlawReport, | |
| ) | |
| task_id = req.task_id | |
| state_dict = req.episode_state | |
| gt = state_dict.get("ground_truth", {}) | |
| terminal_act = state_dict.get("terminal_action", {}) | |
| try: | |
| if task_id == "task1_methodology_audit": | |
| flaws = [FlawReport(**f) for f in terminal_act.get("flaws", [])] | |
| payload = SubmitAuditPayload(flaws=flaws) | |
| score = grade_audit(payload, gt) | |
| elif task_id == "task2_replication": | |
| payload = SubmitResultsPayload(**terminal_act) | |
| score = grade_results(payload, gt) | |
| elif task_id == "task3_claim_verify": | |
| payload = SubmitVerdictPayload(**terminal_act) | |
| score = grade_verdict(payload, gt) | |
| elif task_id == "task4_citation_check": | |
| payload = SubmitCitationReportPayload(**terminal_act) | |
| score = grade_citation_report(payload, gt) | |
| elif task_id == "task5_fda_approval": | |
| payload = SubmitFDAVerdictPayload(**terminal_act) | |
| # For external grader calls, we create a minimal EpisodeState | |
| from env.state import EpisodeState | |
| mock_state = EpisodeState( | |
| task_id=task_id, | |
| flags_raised=state_dict.get("flags_raised", []), | |
| code_calls=state_dict.get("code_calls", 0), | |
| ) | |
| score = grade_fda_verdict(payload, gt, mock_state) | |
| else: | |
| raise HTTPException(status_code=400, detail=f"Unknown task_id: {task_id}") | |
| return {"task_id": task_id, "grader_score": score} | |
| except Exception as e: | |
| raise HTTPException(status_code=422, detail=str(e)) | |
| def baseline(): | |
| """ | |
| Trigger the baseline inference script and return scores for all 4 tasks. | |
| Requires HF_TOKEN in environment. | |
| """ | |
| import subprocess | |
| import json | |
| api_key = os.environ.get("HF_TOKEN", "") | |
| if not api_key: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="HF_TOKEN not set. Add it to Space secrets.", | |
| ) | |
| result = subprocess.run( | |
| [sys.executable, "baseline.py", "--output-json"], | |
| capture_output=True, text=True, timeout=300, | |
| env={**os.environ, "HF_TOKEN": api_key}, | |
| ) | |
| if result.returncode != 0: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Baseline script failed:\n{result.stderr[:2000]}", | |
| ) | |
| try: | |
| scores = json.loads(result.stdout) | |
| return scores | |
| except json.JSONDecodeError: | |
| return {"raw_output": result.stdout[:3000]} | |
| # --------------------------------------------------------------------------- | |
| # Mount Gradio demo UI at root | |
| # --------------------------------------------------------------------------- | |
| import gradio as gr | |
| from app import demo as gradio_demo | |
| app = gr.mount_gradio_app(app, gradio_demo, path="/") | |