SecureCodeEnv / app /routes.py
vishaldhakad's picture
changing logics for the result
9aed967
"""SecureCodeEnv - Routes v2 (production-complete)"""
from fastapi import APIRouter, HTTPException, Query
from typing import List, Optional
from app.models import (
StepAction, StepObservation, ScoreDetails,
ResetRequest, ResetObservation,
StateResponse, TaskSummary,
)
from app.state import EpisodeState
from graders.reward_aggregator import grade_submission
from tasks.task_registry import sample_task, get_task, TASK_REGISTRY, list_tasks
from codegraph.graph import CodeGraph
from codegraph.serializer import serialize_graph
import uuid, threading
router = APIRouter()
_sessions: dict[str, EpisodeState] = {}
_lock = threading.Lock()
MAX_STEPS = 5
DONE_THRESHOLD = 0.90
def _cleanup():
with _lock:
expired = [k for k, v in _sessions.items() if v.is_expired()]
for k in expired:
del _sessions[k]
# ── POST /reset ──────────────────────────────────────────────────────────────
@router.post("/reset", response_model=ResetObservation, tags=["OpenEnv"])
def reset(body: ResetRequest = None):
"""Start a new episode. Returns task + initial CodeGraph."""
_cleanup()
if body is None:
body = ResetRequest()
# Support specific task_id override
if body.task_id:
try:
task = get_task(body.task_id)
except KeyError:
raise HTTPException(404, f"task_id {body.task_id!r} not found. "
f"Available: {list(TASK_REGISTRY.keys())}")
difficulty = task["difficulty"]
else:
difficulty = (body.difficulty or "medium").lower()
if difficulty not in ("easy", "medium", "hard"):
raise HTTPException(400, f"difficulty must be easy/medium/hard. Got: {difficulty!r}")
task = sample_task(difficulty)
sid = body.session_id or str(uuid.uuid4())
graph = CodeGraph(episode_seed=abs(hash(sid)) % 999_999)
state = EpisodeState(task=task, graph=graph, step=0, done=False)
with _lock:
_sessions[sid] = state
return ResetObservation(
session_id=sid,
task_id=task["id"],
problem_statement=task["problem_statement"],
difficulty=difficulty,
cwe_targets=task["cwe_targets"],
codegraph=serialize_graph(graph),
starter_code=task.get("starter_code", ""),
naive_baseline={"code": task.get("naive_code", "")},
)
# ── POST /step ───────────────────────────────────────────────────────────────
@router.post("/step", response_model=StepObservation, tags=["OpenEnv"])
def step(action: StepAction):
"""Submit code. Returns multi-dimensional reward + updated CodeGraph."""
with _lock:
state = _sessions.get(action.session_id)
if state is None:
raise HTTPException(404, "Session not found β€” call POST /reset first.")
if state.done:
raise HTTPException(400, "Episode done β€” call POST /reset to start a new one.")
if not action.code or not action.code.strip():
raise HTTPException(422, "code must be a non-empty Python string.")
result = grade_submission(
code=action.code,
filename=action.filename or "solution.py",
task=state.task,
graph=state.graph,
step=state.step,
seed=state.graph.episode_seed + state.step,
)
state.graph.update(action.filename or "solution.py", result["new_metadata"])
state.step += 1
state.scores_history.append(result["total_reward"])
state.done = result.get("done_eligible", False) or state.step >= MAX_STEPS
# Build structured details object
raw = result.get("details", {}) or {}
details = ScoreDetails(
correctness_passed=raw.get("correctness", {}).get("passed"),
correctness_total=raw.get("correctness", {}).get("total"),
attacks_blocked=raw.get("attacks", {}).get("blocked"),
attacks_total=raw.get("attacks", {}).get("total"),
attack_type=raw.get("attacks", {}).get("type"),
bandit_score=raw.get("static", {}).get("bandit_score"),
static_issues_count=len(raw.get("static", {}).get("issues", [])),
agent_ms=result.get("agent_ms"),
naive_ms=result.get("naive_ms"),
optimal_ms=result.get("optimal_ms"),
)
return StepObservation(
scores=result["scores"],
total_reward=result["total_reward"],
feedback=result["feedback"],
codegraph=serialize_graph(state.graph),
done=state.done,
step_count=state.step,
details=details,
)
# ── GET /state ───────────────────────────────────────────────────────────────
@router.get("/state", response_model=StateResponse, tags=["OpenEnv"])
def get_state(session_id: str):
"""Get current episode state without advancing it."""
with _lock:
state = _sessions.get(session_id)
if state is None:
raise HTTPException(404, "Session not found.")
return StateResponse(
session_id=session_id,
task_id=state.task["id"],
step=state.step,
done=state.done,
codegraph=serialize_graph(state.graph),
difficulty=state.task.get("difficulty", "medium"),
scores_history=state.scores_history,
)
# ── GET /tasks ───────────────────────────────────────────────────────────────
@router.get("/tasks", response_model=List[TaskSummary], tags=["Discovery"])
def get_tasks(difficulty: Optional[str] = Query(None)):
"""List all available tasks, optionally filtered by difficulty."""
raw = list_tasks(difficulty)
return [
TaskSummary(
id=t["id"],
difficulty=t["difficulty"],
cwe_targets=t["cwe_targets"],
description=TASK_REGISTRY[t["id"]].get("problem_statement", "")[:120] + "…",
)
for t in raw
]
# ── GET /tasks/{task_id} ─────────────────────────────────────────────────────
@router.get("/tasks/{task_id}", tags=["Discovery"])
def get_task_detail(task_id: str):
"""Get full detail for a specific task."""
try:
task = get_task(task_id)
except KeyError:
raise HTTPException(404, f"Task {task_id!r} not found.")
return {
"id": task["id"],
"difficulty": task["difficulty"],
"cwe_targets": task["cwe_targets"],
"problem_statement": task["problem_statement"],
"starter_code": task.get("starter_code", ""),
"attack_type": task.get("attack_type", "none"),
"security_checks": task.get("security_checks", []),
}