research-integrity-gym / inference.py
Bhavishya011
Fix structured output format with brackets
9898283
"""
OpenEnv baseline inference script.
Environment variables (set by judge's evaluation system):
API_BASE_URL - LLM API endpoint (default: HF router)
MODEL_NAME - Model identifier (default: Llama 3.3 70B)
HF_TOKEN - API token (optional, tries multiple env vars)
Usage:
python inference.py # human-readable scores
python inference.py --output-json # JSON output
"""
from __future__ import annotations
import argparse
import json
import os
import sys
import textwrap
from openai import OpenAI
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from env.environment import ResearchIntegrityEnv
from env.models import (
Action, ActionType,
FlawReport, SubmitAuditPayload,
SubmitResultsPayload, SubmitVerdictPayload, SubmitCitationReportPayload,
Verdict,
)
# ---------------------------------------------------------------------------
# Environment variables with defaults (per OpenEnv spec)
# ---------------------------------------------------------------------------
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.3-70B-Instruct")
# Try multiple possible API key environment variables
API_KEY = (
os.getenv("HF_TOKEN") or
os.getenv("API_KEY") or
os.getenv("GROQ_API_KEY") or
os.getenv("OPENAI_API_KEY")
)
MAX_STEPS = 15
SEED = 42
# ---------------------------------------------------------------------------
# System prompts per task
# ---------------------------------------------------------------------------
SYSTEM_PROMPTS = {
"task1_methodology_audit": textwrap.dedent("""
You are a scientific peer reviewer. Read the paper stub and identify
methodological flaws. There are exactly 4 planted flaws.
Available actions (respond with JSON only, no prose):
{"action_type": "read_section", "section": "<name>"}
{"action_type": "flag_flaw", "flaw_type": "<type>", "location": "<section>", "description": "<text>"}
{"action_type": "submit_audit", "audit_payload": {"flaws": [{"flaw_type":"...","location":"...","description":"..."}]}}
Flaw types: wrong_statistical_test, underpowered_sample,
undisclosed_exclusion, p_value_manipulation
Submit when you have all 4 flaws. Respond ONLY with valid JSON.
""").strip(),
"task2_replication": textwrap.dedent("""
You are a data scientist replicating an ML experiment.
Follow the methods section exactly and report AUC-ROC and F1-score.
Available actions (respond with JSON only, no prose):
{"action_type": "read_dataset"}
{"action_type": "execute_code", "code": "<python code>"}
{"action_type": "submit_results", "results_payload": {"auc": 0.0, "f1": 0.0, "interpretation": "..."}}
Use stratified train_test_split, StandardScaler, LogisticRegression
with class_weight='balanced', random_state=42. Dataset path: DATASET_PATH.
Respond ONLY with valid JSON.
""").strip(),
"task3_claim_verify": textwrap.dedent("""
You are a statistical auditor verifying a paper's claim using raw data.
Available actions (respond with JSON only, no prose):
{"action_type": "read_dataset"}
{"action_type": "execute_code", "code": "<python code>"}
{"action_type": "flag_concern", "concern_type": "<type>", "evidence": "<text>"}
{"action_type": "submit_verdict", "verdict_payload": {
"verdict": "valid|partially_valid|invalid",
"effect_size": 0.0,
"p_value": 0.0,
"justification": "<text with at least 50 words explaining your reasoning>"
}}
Run your own t-test. Check if claimed n matches dataset rows.
Look for undisclosed exclusions. Respond ONLY with valid JSON.
""").strip(),
"task4_citation_check": textwrap.dedent("""
You are a citation integrity checker. A paper cites 3 sources, but ONE citation
is fabricated - the claim in the paper doesn't match what the cited source says.
STRATEGY:
1. Read the paper carefully - note what each citation claims
2. Compare each claim to the citation excerpts provided
3. Find the ONE citation where the claim contradicts the excerpt
4. Submit your report identifying the fabricated citation
Available actions (respond with JSON only):
{"action_type": "read_section", "section": "citations"}
{"action_type": "submit_report", "report_payload": {
"fabricated_citation_id": 2,
"fabrication_type": "directional - paper claims X increased but source says decreased",
"verified_correct_citations": [1, 3],
"evidence": "Quote the specific text showing the mismatch"
}}
Fabrication types to look for:
- directional: paper says "increased" but source says "decreased" (or vice versa)
- magnitude: paper claims 25% but source shows 2.5% (wrong numbers)
- population: paper generalizes to children but source studied adults only
- significance: paper claims p<0.05 but source shows p>0.05
- absent: paper claims a finding the source never mentions
Respond ONLY with valid JSON. No prose.
""").strip(),
}
# ---------------------------------------------------------------------------
# Agent loop
# ---------------------------------------------------------------------------
def run_task(client: OpenAI, task_id: str, env: ResearchIntegrityEnv) -> dict:
"""Run a single task and return the result dict."""
obs = env.reset(task_id=task_id)
# Structured logging: [START]
print(f"[START] task={task_id}", flush=True)
messages = [
{"role": "system", "content": SYSTEM_PROMPTS[task_id]},
{"role": "user", "content": f"PAPER:\n{obs.paper_text}\n\nBegin your analysis."},
]
final_reward = 0.0
grader_score = 0.0
steps_taken = 0
for step_num in range(MAX_STEPS):
try:
response = client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
temperature=0.0,
max_tokens=800,
)
raw = response.choices[0].message.content.strip()
except Exception as e:
print(f"LLM call failed: {e}", file=sys.stderr)
# Return empty submission on API failure
raw = '{"action_type": "submit_audit", "audit_payload": {"flaws": []}}'
messages.append({"role": "assistant", "content": raw})
action = _parse_action(raw, task_id)
if action is None:
messages.append({
"role": "user",
"content": "Invalid JSON. Respond with a valid action JSON only.",
})
continue
obs, reward, done, info = env.step(action)
steps_taken += 1
final_reward = reward.total
if reward.grader_score is not None:
grader_score = reward.grader_score
# Structured logging: [STEP]
print(f"[STEP] task={task_id} step={steps_taken} reward={reward.step_reward:.4f}", flush=True)
if done:
break
# Feed observation back to agent
feedback_parts = [f"Step {step_num + 1} result:"]
if obs.code_result:
feedback_parts.append(f"Code output:\n{obs.code_result[:1000]}")
if obs.flags_raised:
feedback_parts.append(f"Flags raised so far: {obs.flags_raised}")
feedback_parts.append(f"Step reward: {reward.step_reward}")
feedback_parts.append(f"Steps remaining: {MAX_STEPS - step_num - 1}")
feedback_parts.append("Continue your analysis or submit when ready.")
messages.append({"role": "user", "content": "\n".join(feedback_parts)})
# Structured logging: [END]
print(f"[END] task={task_id} score={grader_score:.4f} steps={steps_taken}", flush=True)
return {
"task_id": task_id,
"grader_score": round(grader_score, 4),
"final_reward": round(final_reward, 4),
"steps_taken": steps_taken,
}
def _parse_action(raw: str, task_id: str) -> Action | None:
"""Parse agent JSON into an Action object. Returns None on failure."""
try:
text = raw.strip()
if text.startswith("```"):
text = text.split("```")[1]
if text.startswith("json"):
text = text[4:]
data = json.loads(text.strip())
atype = data.get("action_type")
if atype == "read_section":
return Action(action_type=ActionType.read_section, section=data.get("section"))
if atype == "read_dataset":
return Action(action_type=ActionType.read_dataset)
if atype == "execute_code":
return Action(action_type=ActionType.execute_code, code=data.get("code", ""))
if atype == "flag_flaw":
return Action(
action_type=ActionType.flag_flaw,
flaw_type=data.get("flaw_type", ""),
location=data.get("location", ""),
description=data.get("description", ""),
)
if atype == "flag_concern":
return Action(
action_type=ActionType.flag_concern,
concern_type=data.get("concern_type", ""),
evidence=data.get("evidence", ""),
)
if atype == "check_citation":
return Action(
action_type=ActionType.check_citation,
citation_id=int(data.get("citation_id", 0)),
)
if atype == "flag_fabrication":
return Action(
action_type=ActionType.flag_fabrication,
citation_id=int(data.get("citation_id", 0)),
flaw_type=data.get("fabrication_type", ""),
description=data.get("evidence", ""),
)
if atype == "submit_audit":
ap = data.get("audit_payload", {})
flaws = [FlawReport(**f) for f in ap.get("flaws", [])]
return Action(
action_type=ActionType.submit_audit,
audit_payload=SubmitAuditPayload(flaws=flaws),
)
if atype == "submit_results":
rp = data.get("results_payload", {})
return Action(
action_type=ActionType.submit_results,
results_payload=SubmitResultsPayload(
auc=float(rp.get("auc", 0)),
f1=float(rp.get("f1", 0)),
interpretation=rp.get("interpretation", ""),
),
)
if atype == "submit_verdict":
vp = data.get("verdict_payload", {})
return Action(
action_type=ActionType.submit_verdict,
verdict_payload=SubmitVerdictPayload(
verdict=Verdict(vp.get("verdict", "invalid")),
effect_size=float(vp.get("effect_size", 0)),
p_value=float(vp.get("p_value", 0.5)),
justification=vp.get("justification", ""),
),
)
if atype == "submit_report":
rp = data.get("report_payload", {})
return Action(
action_type=ActionType.submit_report,
report_payload=SubmitCitationReportPayload(
fabricated_citation_id=rp.get("fabricated_citation_id"),
fabrication_type=rp.get("fabrication_type", ""),
verified_correct_citations=rp.get("verified_correct_citations", []),
evidence=rp.get("evidence", ""),
),
)
except Exception:
pass
return None
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="Research Integrity Gym Inference Script")
parser.add_argument("--output-json", action="store_true",
help="Print JSON output (for automated evaluation)")
args = parser.parse_args()
if not args.output_json:
print(f"Using API: {API_BASE_URL}")
print(f"Model: {MODEL_NAME}")
print(f"API Key: {'Set' if API_KEY else 'Missing'}")
# Validate API key
if not API_KEY:
print("ERROR: No API key found. Set HF_TOKEN, API_KEY, GROQ_API_KEY, or OPENAI_API_KEY.", file=sys.stderr)
sys.exit(1)
# Create OpenAI client with error handling
# Use minimal parameters to avoid version/environment conflicts
try:
import httpx
# Create a basic HTTP client with short timeouts to fail fast
http_client = httpx.Client(
timeout=httpx.Timeout(10.0, connect=5.0), # Short timeouts
limits=httpx.Limits(max_connections=10, max_keepalive_connections=5),
)
client = OpenAI(
api_key=API_KEY,
base_url=API_BASE_URL,
http_client=http_client,
max_retries=0, # No retries for faster failure
)
except ImportError:
# Fallback if httpx import fails
try:
client = OpenAI(
api_key=API_KEY,
base_url=API_BASE_URL,
max_retries=0,
timeout=10.0,
)
except Exception as e:
print(f"ERROR: Failed to create OpenAI client: {e}", file=sys.stderr)
sys.exit(1)
except Exception as e:
# Final fallback - try absolute minimal client
try:
client = OpenAI(api_key=API_KEY, base_url=API_BASE_URL, timeout=10.0, max_retries=0)
except Exception as e2:
print(f"ERROR: Failed to create OpenAI client: {e2}", file=sys.stderr)
print(f"Original error: {e}", file=sys.stderr)
sys.exit(1)
env = ResearchIntegrityEnv(seed=SEED)
task_ids = [
"task1_methodology_audit",
"task2_replication",
"task3_claim_verify",
"task4_citation_check",
]
results = []
for task_id in task_ids:
result = run_task(client, task_id, env)
results.append(result)
avg = round(sum(r["grader_score"] for r in results) / len(results), 4)
output = {
"model": MODEL_NAME,
"seed": SEED,
"tasks": results,
"avg_grader_score": avg,
}
if args.output_json:
print(json.dumps(output))
else:
print(f"\nAverage grader score: {avg:.4f}")
if __name__ == "__main__":
main()