code-review-env / inference.py
SimranShaikh's picture
commit
7d77fa5 verified
"""
inference.py β€” Baseline inference script for CodeReview-Env.
Uses required [START] / [STEP] / [END] plain-text log format.
Environment variables:
API_BASE_URL LLM API base URL
MODEL_NAME model identifier
HF_TOKEN API key
SPACE_URL deployed HF Space URL (default: http://localhost:7860)
"""
import json
import os
import sys
import time
from typing import Any, Dict, List, Optional
import httpx
from openai import OpenAI
# ── Config ────────────────────────────────────────────────────
API_BASE_URL: str = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME: str = os.environ.get("MODEL_NAME", "gpt-4o-mini")
API_KEY: str = os.environ.get("HF_TOKEN", os.environ.get("OPENAI_API_KEY", "dummy"))
SPACE_URL: str = os.environ.get("SPACE_URL", "http://localhost:7860").rstrip("/")
BENCHMARK = "CodeReview-Env"
MAX_TOKENS = 1024
SUCCESS_SCORE_THRESHOLD = 0.6
TASKS = ["easy_syntax", "medium_logic", "hard_security"]
# ── MANDATORY log format: plain text [START]/[STEP]/[END] ─────
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: Any, reward: float, done: bool, error: Optional[str] = None) -> None:
action_str = str(action)[:100].replace("\n", " ")
error_str = error if error else "null"
print(f"[STEP] step={step} action={action_str} reward={reward} done={done} error={error_str}", flush=True)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = str([round(r, 4) for r in rewards])
print(f"[END] success={success} steps={steps} score={score} rewards={rewards_str}", flush=True)
# ── HTTP client for OpenEnv server ───────────────────────────
class EnvClient:
def __init__(self, base_url: str) -> None:
self.base_url = base_url
self._http = httpx.Client(timeout=60.0)
def wait_until_ready(self, retries: int = 15, delay: float = 4.0) -> bool:
for i in range(retries):
try:
r = self._http.get(f"{self.base_url}/health")
if r.status_code == 200:
print(f"[DEBUG] Server ready", flush=True)
return True
except Exception as e:
print(f"[DEBUG] Waiting for server ({i+1}/{retries}): {e}", flush=True)
time.sleep(delay)
return False
def reset(self) -> Dict:
try:
r = self._http.post(f"{self.base_url}/reset")
r.raise_for_status()
return r.json()
except Exception as e:
print(f"[DEBUG] reset() failed: {e}", flush=True)
return {"observation": {}, "reward": 0.0, "done": False}
def step(self, action: Dict) -> Dict:
try:
# OpenEnv create_app requires: {"action": {...}}
r = self._http.post(f"{self.base_url}/step", json={"action": action})
r.raise_for_status()
return r.json()
except Exception as e:
print(f"[DEBUG] step() failed: {e}", flush=True)
return {"observation": {}, "reward": 0.0, "done": True}
def close(self) -> None:
try:
self._http.close()
except Exception:
pass
# ── LLM agent ────────────────────────────────────────────────
SYSTEM_PROMPT = """\
You are an expert software engineer specialising in code review, debugging, \
and security auditing. Analyse the code and return ONLY a JSON object \
(no prose, no markdown fences):
{
"identified_issues": [
{
"line_number": <int or null>,
"issue_type": "<syntax_error|logic_bug|security_vulnerability>",
"description": "<clear description>",
"severity": "<low|medium|high|critical>"
}
],
"suggested_fix": "<complete corrected code as string, or null>",
"explanation": "<brief summary of all findings>",
"submit": true
}
"""
def build_prompt(obs: Dict, step: int, prev_feedback: Optional[str]) -> str:
parts = [
f"Task: {obs.get('task_name', '')} ({obs.get('difficulty', '')})",
f"Language: {obs.get('language', 'python')}",
f"Context: {obs.get('context', '')}",
"",
"Code to review:",
"```",
obs.get("code_snippet", ""),
"```",
f"(Step {step}/{obs.get('max_steps', 5)})",
]
if prev_feedback:
parts += ["", "Previous grader feedback:", prev_feedback]
return "\n".join(parts)
def call_llm(llm_client: OpenAI, prompt: str) -> str:
try:
resp = llm_client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
max_tokens=MAX_TOKENS,
temperature=0.2,
)
return (resp.choices[0].message.content or "{}").strip()
except Exception as e:
print(f"[DEBUG] LLM call failed: {e}", flush=True)
return json.dumps({
"identified_issues": [],
"suggested_fix": None,
"explanation": f"LLM unavailable: {e}",
"submit": True,
})
def parse_llm_output(raw: str) -> Dict:
raw = raw.strip()
if raw.startswith("```"):
parts = raw.split("```")
raw = parts[1] if len(parts) > 1 else raw
if raw.startswith("json"):
raw = raw[4:].strip()
try:
return json.loads(raw)
except json.JSONDecodeError:
return {
"identified_issues": [],
"suggested_fix": None,
"explanation": raw[:200],
"submit": True,
}
# ── Task runner ───────────────────────────────────────────────
def run_task(task_id: str, env: EnvClient, llm: OpenAI) -> float:
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
rewards: List[float] = []
steps_taken = 0
score = 0.0
success = False
prev_feedback: Optional[str] = None
try:
result = env.reset()
obs = result.get("observation", {})
max_steps = int(obs.get("max_steps", 5))
for step in range(1, max_steps + 1):
if result.get("done", False):
break
prompt = build_prompt(obs, step, prev_feedback)
raw = call_llm(llm, prompt)
action = parse_llm_output(raw)
result = env.step(action)
reward = float(result.get("reward", 0.0))
done = bool(result.get("done", False))
obs = result.get("observation", {})
prev_feedback = obs.get("previous_feedback")
rewards.append(reward)
steps_taken = step
log_step(
step=step,
action=action.get("explanation", "")[:100],
reward=reward,
done=done,
error=None,
)
if done:
break
score = max(rewards) if rewards else 0.0
score = round(min(max(score, 0.0), 1.0), 4)
success = score >= SUCCESS_SCORE_THRESHOLD
except Exception as e:
print(f"[DEBUG] run_task error: {e}", flush=True)
log_step(step=steps_taken + 1, action="error", reward=0.0, done=True, error=str(e))
score = 0.0
success = False
finally:
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
return score
# ── Main ──────────────────────────────────────────────────────
def main() -> None:
print(f"[DEBUG] Starting. SPACE_URL={SPACE_URL} MODEL={MODEL_NAME}", flush=True)
llm = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
env = EnvClient(SPACE_URL)
if not env.wait_until_ready():
print("[ERROR] Server not reachable. Exiting.", flush=True)
for task_id in TASKS:
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
log_end(success=False, steps=0, score=0.0, rewards=[])
sys.exit(1)
task_scores: Dict[str, float] = {}
for task_id in TASKS:
print(f"\n[DEBUG] ===== Running task: {task_id} =====", flush=True)
try:
task_scores[task_id] = run_task(task_id, env, llm)
except Exception as e:
print(f"[DEBUG] Task {task_id} crashed: {e}", flush=True)
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
log_end(success=False, steps=0, score=0.0, rewards=[])
task_scores[task_id] = 0.0
time.sleep(1)
env.close()
overall = sum(task_scores.values()) / len(task_scores)
print(f"\n[DEBUG] Overall average: {overall:.4f}", flush=True)
if __name__ == "__main__":
main()