openenv / routers /irt.py
sentinel-space-publisher
space: publish latest Sentinel app snapshot
c452421
# -*- coding: utf-8 -*-
"""IRT (Incident Response Triage) API endpoints.
Extracted from app.py - handles /reset, /step, /state, /tasks, /grader, /baseline.
"""
from __future__ import annotations
import secrets
import traceback
from typing import Any, Dict
from fastapi import APIRouter, Body, Header, HTTPException
from pydantic import BaseModel
from src.environment import IncidentResponseEnv
from src.models import Action, StepResult
from src.tasks import get_all_tasks
from routers.deps import (
_SESSION_REGISTRY,
_TELEMETRY,
_log,
get_or_create_session,
record_leaderboard,
)
router = APIRouter()
# ---------------------------------------------------------------------------
# Request / response schemas
# ---------------------------------------------------------------------------
class ResetRequest(BaseModel):
task_id: str = "severity_classification"
session_id: str | None = None
variant_seed: int | None = None
class BaselineResponse(BaseModel):
results: list
summary: Dict[str, Any]
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post("/reset")
async def reset(request: ResetRequest | None = Body(default=None)):
"""Reset the environment for a given task_id.
Returns the initial observation plus a `session_id` that must be
passed via the `X-Session-ID` header on all subsequent calls.
"""
if request is None:
request = ResetRequest()
try:
session_id, env = get_or_create_session(request.session_id)
# When no variant_seed is supplied randomise for anti-memorization;
# explicit 0 keeps the primary (deterministic) scenario.
seed = request.variant_seed if request.variant_seed is not None else secrets.randbelow(100)
obs = env.reset(request.task_id, variant_seed=seed)
_TELEMETRY["episodes_total"] += 1
_log.info("episode reset task=%s session=%s variant=%d", request.task_id, session_id[:8], seed)
data = obs.model_dump(mode="json")
data["session_id"] = session_id
return data
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
@router.post("/step")
async def step(
action: Action,
x_session_id: str | None = Header(default=None, alias="X-Session-ID"),
):
"""Execute one action and return observation, reward, done, info."""
if not x_session_id or x_session_id not in _SESSION_REGISTRY:
raise HTTPException(
status_code=400,
detail="Missing or unknown X-Session-ID header. Call /reset first.",
)
env = _SESSION_REGISTRY[x_session_id]
try:
result: StepResult = env.step(action)
_TELEMETRY["steps_total"] += 1
return result.model_dump()
except RuntimeError as exc:
_TELEMETRY["errors_total"] += 1
raise HTTPException(status_code=400, detail=str(exc))
except Exception as exc:
_TELEMETRY["errors_total"] += 1
raise HTTPException(status_code=500, detail=f"Internal error: {exc}")
@router.get("/state")
async def state(
x_session_id: str | None = Header(default=None, alias="X-Session-ID"),
):
"""Return full environment state."""
if not x_session_id or x_session_id not in _SESSION_REGISTRY:
raise HTTPException(
status_code=400,
detail="Missing or unknown X-Session-ID header. Call /reset first.",
)
env = _SESSION_REGISTRY[x_session_id]
try:
return env.state().model_dump()
except RuntimeError as exc:
raise HTTPException(status_code=400, detail=str(exc))
@router.get("/tasks")
async def tasks():
"""List all tasks with descriptions and action schema."""
return [t.model_dump() for t in get_all_tasks()]
@router.post("/grader")
async def grader(
x_session_id: str | None = Header(default=None, alias="X-Session-ID"),
):
"""Return grader score for the current or most recent episode."""
if not x_session_id or x_session_id not in _SESSION_REGISTRY:
raise HTTPException(
status_code=400,
detail="Missing or unknown X-Session-ID header. Call /reset first.",
)
env = _SESSION_REGISTRY[x_session_id]
try:
result = env.grade()
_TELEMETRY["grader_calls"] += 1
state = env.state()
record_leaderboard(state.task_id, result.score, state.total_steps_taken)
_log.info("graded task=%s score=%.4f steps=%d", state.task_id, result.score, state.total_steps_taken)
return result.model_dump()
except RuntimeError as exc:
_TELEMETRY["errors_total"] += 1
raise HTTPException(status_code=400, detail=str(exc))
@router.post("/baseline")
async def baseline():
"""Run the rule-based baseline inference against all tasks (in-process).
Creates a dedicated ephemeral env instance so it never interferes
with any active session.
"""
try:
from baseline.inference import run_all_tasks
dedicated_env = IncidentResponseEnv()
results = run_all_tasks(base_url=None, env_instance=dedicated_env)
_TELEMETRY["baseline_runs"] += 1
summary = {
"mean_score": round(
sum(r["score"] for r in results) / len(results), 4
),
"tasks_evaluated": len(results),
}
return BaselineResponse(results=results, summary=summary).model_dump()
except Exception as exc:
traceback.print_exc()
raise HTTPException(
status_code=500,
detail=f"Baseline execution failed: {exc}",
)