Spaces:
Sleeping
Sleeping
| """SecureCodeEnv - Routes v2 (production-complete)""" | |
| from fastapi import APIRouter, HTTPException, Query | |
| from typing import List, Optional | |
| from app.models import ( | |
| StepAction, StepObservation, ScoreDetails, | |
| ResetRequest, ResetObservation, | |
| StateResponse, TaskSummary, | |
| ) | |
| from app.state import EpisodeState | |
| from graders.reward_aggregator import grade_submission | |
| from tasks.task_registry import sample_task, get_task, TASK_REGISTRY, list_tasks | |
| from codegraph.graph import CodeGraph | |
| from codegraph.serializer import serialize_graph | |
| import uuid, threading | |
| router = APIRouter() | |
| _sessions: dict[str, EpisodeState] = {} | |
| _lock = threading.Lock() | |
| MAX_STEPS = 5 | |
| DONE_THRESHOLD = 0.90 | |
| def _cleanup(): | |
| with _lock: | |
| expired = [k for k, v in _sessions.items() if v.is_expired()] | |
| for k in expired: | |
| del _sessions[k] | |
| # ββ POST /reset ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def reset(body: ResetRequest = None): | |
| """Start a new episode. Returns task + initial CodeGraph.""" | |
| _cleanup() | |
| if body is None: | |
| body = ResetRequest() | |
| # Support specific task_id override | |
| if body.task_id: | |
| try: | |
| task = get_task(body.task_id) | |
| except KeyError: | |
| raise HTTPException(404, f"task_id {body.task_id!r} not found. " | |
| f"Available: {list(TASK_REGISTRY.keys())}") | |
| difficulty = task["difficulty"] | |
| else: | |
| difficulty = (body.difficulty or "medium").lower() | |
| if difficulty not in ("easy", "medium", "hard"): | |
| raise HTTPException(400, f"difficulty must be easy/medium/hard. Got: {difficulty!r}") | |
| task = sample_task(difficulty) | |
| sid = body.session_id or str(uuid.uuid4()) | |
| graph = CodeGraph(episode_seed=abs(hash(sid)) % 999_999) | |
| state = EpisodeState(task=task, graph=graph, step=0, done=False) | |
| with _lock: | |
| _sessions[sid] = state | |
| return ResetObservation( | |
| session_id=sid, | |
| task_id=task["id"], | |
| problem_statement=task["problem_statement"], | |
| difficulty=difficulty, | |
| cwe_targets=task["cwe_targets"], | |
| codegraph=serialize_graph(graph), | |
| starter_code=task.get("starter_code", ""), | |
| naive_baseline={"code": task.get("naive_code", "")}, | |
| ) | |
| # ββ POST /step βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def step(action: StepAction): | |
| """Submit code. Returns multi-dimensional reward + updated CodeGraph.""" | |
| with _lock: | |
| state = _sessions.get(action.session_id) | |
| if state is None: | |
| raise HTTPException(404, "Session not found β call POST /reset first.") | |
| if state.done: | |
| raise HTTPException(400, "Episode done β call POST /reset to start a new one.") | |
| if not action.code or not action.code.strip(): | |
| raise HTTPException(422, "code must be a non-empty Python string.") | |
| result = grade_submission( | |
| code=action.code, | |
| filename=action.filename or "solution.py", | |
| task=state.task, | |
| graph=state.graph, | |
| step=state.step, | |
| seed=state.graph.episode_seed + state.step, | |
| ) | |
| state.graph.update(action.filename or "solution.py", result["new_metadata"]) | |
| state.step += 1 | |
| state.scores_history.append(result["total_reward"]) | |
| state.done = result.get("done_eligible", False) or state.step >= MAX_STEPS | |
| # Build structured details object | |
| raw = result.get("details", {}) or {} | |
| details = ScoreDetails( | |
| correctness_passed=raw.get("correctness", {}).get("passed"), | |
| correctness_total=raw.get("correctness", {}).get("total"), | |
| attacks_blocked=raw.get("attacks", {}).get("blocked"), | |
| attacks_total=raw.get("attacks", {}).get("total"), | |
| attack_type=raw.get("attacks", {}).get("type"), | |
| bandit_score=raw.get("static", {}).get("bandit_score"), | |
| static_issues_count=len(raw.get("static", {}).get("issues", [])), | |
| agent_ms=result.get("agent_ms"), | |
| naive_ms=result.get("naive_ms"), | |
| optimal_ms=result.get("optimal_ms"), | |
| ) | |
| return StepObservation( | |
| scores=result["scores"], | |
| total_reward=result["total_reward"], | |
| feedback=result["feedback"], | |
| codegraph=serialize_graph(state.graph), | |
| done=state.done, | |
| step_count=state.step, | |
| details=details, | |
| ) | |
| # ββ GET /state βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_state(session_id: str): | |
| """Get current episode state without advancing it.""" | |
| with _lock: | |
| state = _sessions.get(session_id) | |
| if state is None: | |
| raise HTTPException(404, "Session not found.") | |
| return StateResponse( | |
| session_id=session_id, | |
| task_id=state.task["id"], | |
| step=state.step, | |
| done=state.done, | |
| codegraph=serialize_graph(state.graph), | |
| difficulty=state.task.get("difficulty", "medium"), | |
| scores_history=state.scores_history, | |
| ) | |
| # ββ GET /tasks βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_tasks(difficulty: Optional[str] = Query(None)): | |
| """List all available tasks, optionally filtered by difficulty.""" | |
| raw = list_tasks(difficulty) | |
| return [ | |
| TaskSummary( | |
| id=t["id"], | |
| difficulty=t["difficulty"], | |
| cwe_targets=t["cwe_targets"], | |
| description=TASK_REGISTRY[t["id"]].get("problem_statement", "")[:120] + "β¦", | |
| ) | |
| for t in raw | |
| ] | |
| # ββ GET /tasks/{task_id} βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_task_detail(task_id: str): | |
| """Get full detail for a specific task.""" | |
| try: | |
| task = get_task(task_id) | |
| except KeyError: | |
| raise HTTPException(404, f"Task {task_id!r} not found.") | |
| return { | |
| "id": task["id"], | |
| "difficulty": task["difficulty"], | |
| "cwe_targets": task["cwe_targets"], | |
| "problem_statement": task["problem_statement"], | |
| "starter_code": task.get("starter_code", ""), | |
| "attack_type": task.get("attack_type", "none"), | |
| "security_checks": task.get("security_checks", []), | |
| } | |