Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| """IRT (Incident Response Triage) API endpoints. | |
| Extracted from app.py - handles /reset, /step, /state, /tasks, /grader, /baseline. | |
| """ | |
| from __future__ import annotations | |
| import secrets | |
| import traceback | |
| from typing import Any, Dict | |
| from fastapi import APIRouter, Body, Header, HTTPException | |
| from pydantic import BaseModel | |
| from src.environment import IncidentResponseEnv | |
| from src.models import Action, StepResult | |
| from src.tasks import get_all_tasks | |
| from routers.deps import ( | |
| _SESSION_REGISTRY, | |
| _TELEMETRY, | |
| _log, | |
| get_or_create_session, | |
| record_leaderboard, | |
| ) | |
| router = APIRouter() | |
| # --------------------------------------------------------------------------- | |
| # Request / response schemas | |
| # --------------------------------------------------------------------------- | |
| class ResetRequest(BaseModel): | |
| task_id: str = "severity_classification" | |
| session_id: str | None = None | |
| variant_seed: int | None = None | |
| class BaselineResponse(BaseModel): | |
| results: list | |
| summary: Dict[str, Any] | |
| # --------------------------------------------------------------------------- | |
| # Endpoints | |
| # --------------------------------------------------------------------------- | |
| async def reset(request: ResetRequest | None = Body(default=None)): | |
| """Reset the environment for a given task_id. | |
| Returns the initial observation plus a `session_id` that must be | |
| passed via the `X-Session-ID` header on all subsequent calls. | |
| """ | |
| if request is None: | |
| request = ResetRequest() | |
| try: | |
| session_id, env = get_or_create_session(request.session_id) | |
| # When no variant_seed is supplied randomise for anti-memorization; | |
| # explicit 0 keeps the primary (deterministic) scenario. | |
| seed = request.variant_seed if request.variant_seed is not None else secrets.randbelow(100) | |
| obs = env.reset(request.task_id, variant_seed=seed) | |
| _TELEMETRY["episodes_total"] += 1 | |
| _log.info("episode reset task=%s session=%s variant=%d", request.task_id, session_id[:8], seed) | |
| data = obs.model_dump(mode="json") | |
| data["session_id"] = session_id | |
| return data | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) | |
| async def step( | |
| action: Action, | |
| x_session_id: str | None = Header(default=None, alias="X-Session-ID"), | |
| ): | |
| """Execute one action and return observation, reward, done, info.""" | |
| if not x_session_id or x_session_id not in _SESSION_REGISTRY: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Missing or unknown X-Session-ID header. Call /reset first.", | |
| ) | |
| env = _SESSION_REGISTRY[x_session_id] | |
| try: | |
| result: StepResult = env.step(action) | |
| _TELEMETRY["steps_total"] += 1 | |
| return result.model_dump() | |
| except RuntimeError as exc: | |
| _TELEMETRY["errors_total"] += 1 | |
| raise HTTPException(status_code=400, detail=str(exc)) | |
| except Exception as exc: | |
| _TELEMETRY["errors_total"] += 1 | |
| raise HTTPException(status_code=500, detail=f"Internal error: {exc}") | |
| async def state( | |
| x_session_id: str | None = Header(default=None, alias="X-Session-ID"), | |
| ): | |
| """Return full environment state.""" | |
| if not x_session_id or x_session_id not in _SESSION_REGISTRY: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Missing or unknown X-Session-ID header. Call /reset first.", | |
| ) | |
| env = _SESSION_REGISTRY[x_session_id] | |
| try: | |
| return env.state().model_dump() | |
| except RuntimeError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) | |
| async def tasks(): | |
| """List all tasks with descriptions and action schema.""" | |
| return [t.model_dump() for t in get_all_tasks()] | |
| async def grader( | |
| x_session_id: str | None = Header(default=None, alias="X-Session-ID"), | |
| ): | |
| """Return grader score for the current or most recent episode.""" | |
| if not x_session_id or x_session_id not in _SESSION_REGISTRY: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Missing or unknown X-Session-ID header. Call /reset first.", | |
| ) | |
| env = _SESSION_REGISTRY[x_session_id] | |
| try: | |
| result = env.grade() | |
| _TELEMETRY["grader_calls"] += 1 | |
| state = env.state() | |
| record_leaderboard(state.task_id, result.score, state.total_steps_taken) | |
| _log.info("graded task=%s score=%.4f steps=%d", state.task_id, result.score, state.total_steps_taken) | |
| return result.model_dump() | |
| except RuntimeError as exc: | |
| _TELEMETRY["errors_total"] += 1 | |
| raise HTTPException(status_code=400, detail=str(exc)) | |
| async def baseline(): | |
| """Run the rule-based baseline inference against all tasks (in-process). | |
| Creates a dedicated ephemeral env instance so it never interferes | |
| with any active session. | |
| """ | |
| try: | |
| from baseline.inference import run_all_tasks | |
| dedicated_env = IncidentResponseEnv() | |
| results = run_all_tasks(base_url=None, env_instance=dedicated_env) | |
| _TELEMETRY["baseline_runs"] += 1 | |
| summary = { | |
| "mean_score": round( | |
| sum(r["score"] for r in results) / len(results), 4 | |
| ), | |
| "tasks_evaluated": len(results), | |
| } | |
| return BaselineResponse(results=results, summary=summary).model_dump() | |
| except Exception as exc: | |
| traceback.print_exc() | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Baseline execution failed: {exc}", | |
| ) | |