File size: 8,770 Bytes
62b6842
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b836ee9
62b6842
 
cc700fe
b836ee9
62b6842
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29a2e8e
 
 
62b6842
 
 
 
 
 
 
 
 
 
 
 
 
 
aedaafb
d2760e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62b6842
 
 
 
 
 
b836ee9
 
 
 
54e8b1b
 
b836ee9
 
62b6842
cc700fe
54e8b1b
b836ee9
54e8b1b
b836ee9
54e8b1b
 
 
b836ee9
54e8b1b
 
 
 
b836ee9
54e8b1b
b836ee9
 
 
 
 
 
cc700fe
62b6842
b836ee9
62b6842
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da8f87b
8122ba9
62b6842
 
 
 
 
da8f87b
8122ba9
62b6842
 
 
 
 
 
 
 
 
 
 
 
 
 
da8f87b
8122ba9
62b6842
 
8122ba9
62b6842
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da8f87b
 
 
 
8122ba9
 
 
 
 
 
 
 
 
 
 
62b6842
 
 
 
 
 
 
 
 
 
 
 
da8f87b
c0f7fc8
62b6842
 
 
 
c0f7fc8
62b6842
 
 
c0f7fc8
62b6842
 
 
 
 
c0f7fc8
62b6842
 
 
 
 
 
 
 
 
 
 
 
 
c32739a
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
"""
FastAPI application — OpenEnv-compliant HTTP interface.

Required endpoints (from spec):
  POST /reset      → Observation
  POST /step       → {observation, reward, done, info}
  GET  /state      → current episode state dict
  GET  /tasks      → list of tasks + action schemas
  POST /grader     → run grader on completed episode
  POST /baseline   → run baseline inference script, return scores
  GET  /health     → 200 OK (used by Docker HEALTHCHECK and judge ping)
"""
from __future__ import annotations

import os
import sys

from fastapi import FastAPI, HTTPException, Body, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional
import json

# Ensure project root is on path regardless of working directory
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from env.environment import ResearchIntegrityEnv
from env.models import Action, ActionType

app = FastAPI(
    title="Research Integrity Gym",
    description=(
        "OpenEnv environment for training and evaluating AI agents on "
        "scientific research integrity tasks. Agents must audit methodology, "
        "replicate experiments, and verify statistical claims."
    ),
    version="1.0.0",
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# One global environment instance per server process
_env = ResearchIntegrityEnv()


# ---------------------------------------------------------------------------
# Request / response models
# ---------------------------------------------------------------------------

class ResetRequest(BaseModel):
    task_id: str = "task1_methodology_audit"
    seed:    int | None = None

    class Config:
        extra = "ignore"  # Ignore extra fields


class StepRequest(BaseModel):
    action: Action


class GraderRequest(BaseModel):
    task_id: str
    episode_state: dict   # serialised state from a completed episode


# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------

@app.get("/api")
def root():
    """Root endpoint - redirects to API documentation."""
    return {
        "name": "Research Integrity Gym",
        "description": "OpenEnv environment for AI agents to evaluate scientific research integrity",
        "docs": "/docs",
        "endpoints": {
            "health": "GET /health",
            "tasks": "GET /tasks",
            "reset": "POST /reset",
            "step": "POST /step",
            "state": "GET /state",
        }
    }


@app.get("/health")
def health():
    return {"status": "ok", "environment": "research-integrity-gym"}


@app.post("/reset")
async def reset(request: Request):
    """Start a new episode. Returns initial Observation.
    
    Accepts:
      - Empty body
      - Body with just "null"
      - JSON body with task_id and/or seed
    """
    global _env
    
    # Parse body manually to handle empty/missing/null body
    body_bytes = await request.body()
    body_text = body_bytes.decode("utf-8").strip() if body_bytes else ""
    
    # Handle empty body, "null", "{}", or actual JSON
    body_data = {}
    if body_text and body_text != "null":
        try:
            parsed = json.loads(body_text)
            if isinstance(parsed, dict):
                body_data = parsed
            # If parsed is None or not a dict, keep body_data as empty dict
        except json.JSONDecodeError:
            pass  # Keep body_data as empty dict
    
    task_id = body_data.get("task_id", "task1_methodology_audit")
    seed = body_data.get("seed", None)
    
    if seed is not None:
        _env = ResearchIntegrityEnv(seed=seed)
    
    try:
        obs = _env.reset(task_id=task_id)
        return obs.model_dump()
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))


@app.post("/step")
def step(req: StepRequest):
    """Execute one action. Returns observation, reward, done, info."""
    try:
        obs, reward, done, info = _env.step(req.action)
        return {
            "observation": obs.model_dump(),
            "reward":      reward.model_dump(),
            "done":        done,
            "info":        info,
        }
    except RuntimeError as e:
        raise HTTPException(status_code=400, detail=str(e))


@app.get("/state")
def state():
    """Return current episode state (excludes ground truth)."""
    return _env.state()


@app.get("/tasks")
def tasks():
    """Return all available tasks with their action schemas."""
    from tasks.task1_methodology_audit import MethodologyAuditTask
    from tasks.task2_replication import ReplicationTask
    from tasks.task3_claim_verify import ClaimVerifyTask
    from tasks.task4_citation_check import CitationCheckTask
    from tasks.task5_fda_approval import FDAApprovalTask

    task_list = [
        MethodologyAuditTask().task_info(),
        ReplicationTask().task_info(),
        ClaimVerifyTask().task_info(),
        CitationCheckTask().task_info(),
        FDAApprovalTask().task_info(),
    ]
    return {"tasks": task_list}


@app.post("/grader")
def grader(req: GraderRequest):
    """
    Run the grader for a completed episode externally.
    Accepts a serialised terminal_action and ground_truth.
    Used by the judge's automated evaluation pipeline.
    """
    from graders.grader1 import grade_audit
    from graders.grader2 import grade_results
    from graders.grader3 import grade_verdict
    from graders.grader4 import grade_citation_report
    from graders.grader5 import grade_fda_verdict
    from env.models import (
        SubmitAuditPayload, SubmitResultsPayload, SubmitVerdictPayload,
        SubmitCitationReportPayload, SubmitFDAVerdictPayload, FlawReport,
    )

    task_id      = req.task_id
    state_dict   = req.episode_state
    gt           = state_dict.get("ground_truth", {})
    terminal_act = state_dict.get("terminal_action", {})

    try:
        if task_id == "task1_methodology_audit":
            flaws   = [FlawReport(**f) for f in terminal_act.get("flaws", [])]
            payload = SubmitAuditPayload(flaws=flaws)
            score   = grade_audit(payload, gt)

        elif task_id == "task2_replication":
            payload = SubmitResultsPayload(**terminal_act)
            score   = grade_results(payload, gt)

        elif task_id == "task3_claim_verify":
            payload = SubmitVerdictPayload(**terminal_act)
            score   = grade_verdict(payload, gt)

        elif task_id == "task4_citation_check":
            payload = SubmitCitationReportPayload(**terminal_act)
            score   = grade_citation_report(payload, gt)

        elif task_id == "task5_fda_approval":
            payload = SubmitFDAVerdictPayload(**terminal_act)
            # For external grader calls, we create a minimal EpisodeState
            from env.state import EpisodeState
            mock_state = EpisodeState(
                task_id=task_id,
                flags_raised=state_dict.get("flags_raised", []),
                code_calls=state_dict.get("code_calls", 0),
            )
            score = grade_fda_verdict(payload, gt, mock_state)

        else:
            raise HTTPException(status_code=400, detail=f"Unknown task_id: {task_id}")

        return {"task_id": task_id, "grader_score": score}

    except Exception as e:
        raise HTTPException(status_code=422, detail=str(e))


@app.post("/baseline")
def baseline():
    """
    Trigger the baseline inference script and return scores for all 4 tasks.
    Requires HF_TOKEN in environment.
    """
    import subprocess
    import json

    api_key = os.environ.get("HF_TOKEN", "")
    if not api_key:
        raise HTTPException(
            status_code=503,
            detail="HF_TOKEN not set. Add it to Space secrets.",
        )

    result = subprocess.run(
        [sys.executable, "baseline.py", "--output-json"],
        capture_output=True, text=True, timeout=300,
        env={**os.environ, "HF_TOKEN": api_key},
    )

    if result.returncode != 0:
        raise HTTPException(
            status_code=500,
            detail=f"Baseline script failed:\n{result.stderr[:2000]}",
        )

    try:
        scores = json.loads(result.stdout)
        return scores
    except json.JSONDecodeError:
        return {"raw_output": result.stdout[:3000]}


# ---------------------------------------------------------------------------
# Mount Gradio demo UI at root
# ---------------------------------------------------------------------------
import gradio as gr
from app import demo as gradio_demo

app = gr.mount_gradio_app(app, gradio_demo, path="/")