Bhavishya011
fix: change fastapi root to /api so gradio can use /
aedaafb
"""
FastAPI application — OpenEnv-compliant HTTP interface.
Required endpoints (from spec):
POST /reset → Observation
POST /step → {observation, reward, done, info}
GET /state → current episode state dict
GET /tasks → list of tasks + action schemas
POST /grader → run grader on completed episode
POST /baseline → run baseline inference script, return scores
GET /health → 200 OK (used by Docker HEALTHCHECK and judge ping)
"""
from __future__ import annotations
import os
import sys
from fastapi import FastAPI, HTTPException, Body, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional
import json
# Ensure project root is on path regardless of working directory
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from env.environment import ResearchIntegrityEnv
from env.models import Action, ActionType
app = FastAPI(
title="Research Integrity Gym",
description=(
"OpenEnv environment for training and evaluating AI agents on "
"scientific research integrity tasks. Agents must audit methodology, "
"replicate experiments, and verify statistical claims."
),
version="1.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# One global environment instance per server process
_env = ResearchIntegrityEnv()
# ---------------------------------------------------------------------------
# Request / response models
# ---------------------------------------------------------------------------
class ResetRequest(BaseModel):
task_id: str = "task1_methodology_audit"
seed: int | None = None
class Config:
extra = "ignore" # Ignore extra fields
class StepRequest(BaseModel):
action: Action
class GraderRequest(BaseModel):
task_id: str
episode_state: dict # serialised state from a completed episode
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.get("/api")
def root():
"""Root endpoint - redirects to API documentation."""
return {
"name": "Research Integrity Gym",
"description": "OpenEnv environment for AI agents to evaluate scientific research integrity",
"docs": "/docs",
"endpoints": {
"health": "GET /health",
"tasks": "GET /tasks",
"reset": "POST /reset",
"step": "POST /step",
"state": "GET /state",
}
}
@app.get("/health")
def health():
return {"status": "ok", "environment": "research-integrity-gym"}
@app.post("/reset")
async def reset(request: Request):
"""Start a new episode. Returns initial Observation.
Accepts:
- Empty body
- Body with just "null"
- JSON body with task_id and/or seed
"""
global _env
# Parse body manually to handle empty/missing/null body
body_bytes = await request.body()
body_text = body_bytes.decode("utf-8").strip() if body_bytes else ""
# Handle empty body, "null", "{}", or actual JSON
body_data = {}
if body_text and body_text != "null":
try:
parsed = json.loads(body_text)
if isinstance(parsed, dict):
body_data = parsed
# If parsed is None or not a dict, keep body_data as empty dict
except json.JSONDecodeError:
pass # Keep body_data as empty dict
task_id = body_data.get("task_id", "task1_methodology_audit")
seed = body_data.get("seed", None)
if seed is not None:
_env = ResearchIntegrityEnv(seed=seed)
try:
obs = _env.reset(task_id=task_id)
return obs.model_dump()
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.post("/step")
def step(req: StepRequest):
"""Execute one action. Returns observation, reward, done, info."""
try:
obs, reward, done, info = _env.step(req.action)
return {
"observation": obs.model_dump(),
"reward": reward.model_dump(),
"done": done,
"info": info,
}
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/state")
def state():
"""Return current episode state (excludes ground truth)."""
return _env.state()
@app.get("/tasks")
def tasks():
"""Return all available tasks with their action schemas."""
from tasks.task1_methodology_audit import MethodologyAuditTask
from tasks.task2_replication import ReplicationTask
from tasks.task3_claim_verify import ClaimVerifyTask
from tasks.task4_citation_check import CitationCheckTask
from tasks.task5_fda_approval import FDAApprovalTask
task_list = [
MethodologyAuditTask().task_info(),
ReplicationTask().task_info(),
ClaimVerifyTask().task_info(),
CitationCheckTask().task_info(),
FDAApprovalTask().task_info(),
]
return {"tasks": task_list}
@app.post("/grader")
def grader(req: GraderRequest):
"""
Run the grader for a completed episode externally.
Accepts a serialised terminal_action and ground_truth.
Used by the judge's automated evaluation pipeline.
"""
from graders.grader1 import grade_audit
from graders.grader2 import grade_results
from graders.grader3 import grade_verdict
from graders.grader4 import grade_citation_report
from graders.grader5 import grade_fda_verdict
from env.models import (
SubmitAuditPayload, SubmitResultsPayload, SubmitVerdictPayload,
SubmitCitationReportPayload, SubmitFDAVerdictPayload, FlawReport,
)
task_id = req.task_id
state_dict = req.episode_state
gt = state_dict.get("ground_truth", {})
terminal_act = state_dict.get("terminal_action", {})
try:
if task_id == "task1_methodology_audit":
flaws = [FlawReport(**f) for f in terminal_act.get("flaws", [])]
payload = SubmitAuditPayload(flaws=flaws)
score = grade_audit(payload, gt)
elif task_id == "task2_replication":
payload = SubmitResultsPayload(**terminal_act)
score = grade_results(payload, gt)
elif task_id == "task3_claim_verify":
payload = SubmitVerdictPayload(**terminal_act)
score = grade_verdict(payload, gt)
elif task_id == "task4_citation_check":
payload = SubmitCitationReportPayload(**terminal_act)
score = grade_citation_report(payload, gt)
elif task_id == "task5_fda_approval":
payload = SubmitFDAVerdictPayload(**terminal_act)
# For external grader calls, we create a minimal EpisodeState
from env.state import EpisodeState
mock_state = EpisodeState(
task_id=task_id,
flags_raised=state_dict.get("flags_raised", []),
code_calls=state_dict.get("code_calls", 0),
)
score = grade_fda_verdict(payload, gt, mock_state)
else:
raise HTTPException(status_code=400, detail=f"Unknown task_id: {task_id}")
return {"task_id": task_id, "grader_score": score}
except Exception as e:
raise HTTPException(status_code=422, detail=str(e))
@app.post("/baseline")
def baseline():
"""
Trigger the baseline inference script and return scores for all 4 tasks.
Requires HF_TOKEN in environment.
"""
import subprocess
import json
api_key = os.environ.get("HF_TOKEN", "")
if not api_key:
raise HTTPException(
status_code=503,
detail="HF_TOKEN not set. Add it to Space secrets.",
)
result = subprocess.run(
[sys.executable, "baseline.py", "--output-json"],
capture_output=True, text=True, timeout=300,
env={**os.environ, "HF_TOKEN": api_key},
)
if result.returncode != 0:
raise HTTPException(
status_code=500,
detail=f"Baseline script failed:\n{result.stderr[:2000]}",
)
try:
scores = json.loads(result.stdout)
return scores
except json.JSONDecodeError:
return {"raw_output": result.stdout[:3000]}
# ---------------------------------------------------------------------------
# Mount Gradio demo UI at root
# ---------------------------------------------------------------------------
import gradio as gr
from app import demo as gradio_demo
app = gr.mount_gradio_app(app, gradio_demo, path="/")