from __future__ import annotations import json from typing import Any, Dict, Optional from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from pydantic import BaseModel from models import TrustAction, TrustObservation, TrustState, ContentSignals from your_environment import TrustSafetyEnvironment # ── Force manual FastAPI (openenv_core create_app causes 422 on /step) ──────── print("[app] Using manual FastAPI ✅") _env = TrustSafetyEnvironment(seed=42) app = FastAPI( title="Trust & Safety RL Environment", description="Risk-aware content moderation environment for agent training.", version="1.0.0", ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ── Serializers ─────────────────────────────────────────────────────────────── def _obs_to_dict(obs: TrustObservation) -> Dict[str, Any]: return { "ticket_id": obs.ticket_id, "post_text": obs.post_text, "image_description": obs.image_description, "comments_found": obs.comments_found, "user_history_found": obs.user_history_found, "entity_status_found": obs.entity_status_found, "policy_found": obs.policy_found, "extracted_signals": obs.extracted_signals, "validation_result": obs.validation_result, "step_number": obs.step_number, "info": obs.info, "done": obs.done, "reward": obs.reward, } def _state_to_dict(s: TrustState) -> Dict[str, Any]: return { "episode_id": s.episode_id, "step_count": s.step_count, "current_task_id": s.current_task_id, "difficulty": s.difficulty, "ambiguity_level": s.ambiguity_level, "risk_level": s.risk_level, "tools_used": s.tools_used, "signals_extracted": s.signals_extracted, "is_done": s.is_done, } # ── Request bodies ───────────────────────────────────────────────────────────── class ResetRequest(BaseModel): seed: Any = None episode_id: Any = None model_config = {"extra": "ignore"} class ActionRequest(BaseModel): action_type: str = "" tool_name: Optional[str] = None signals: Optional[Dict[str, Any]] = None # raw dict — validated below final_decision: Optional[str] = None model_config = {"extra": "ignore"} # ← ignore unknown keys from LLM # ── Helpers ──────────────────────────────────────────────────────────────────── def _parse_signals(raw: Dict[str, Any]) -> ContentSignals: """Defensively normalise LLM signal output before Pydantic validation.""" # Clamp floats raw["toxicity_level"] = float(raw.get("toxicity_level", 0.5)) raw["confidence"] = float(raw.get("confidence", 0.5)) # content_flags must be a list of strings flags = raw.get("content_flags", []) if not isinstance(flags, list): flags = [flags] if isinstance(flags, str) else [] raw["content_flags"] = [str(f) for f in flags] # boolean coercion raw["is_protected_class"] = bool(raw.get("is_protected_class", False)) raw["is_direct_attack"] = bool(raw.get("is_direct_attack", False)) raw["abusive_language_present"] = bool(raw.get("abusive_language_present", False)) # string fields — fallback to sensible defaults raw.setdefault("target", "none") raw.setdefault("intent", "ambiguous") raw.setdefault("context_type", "statement") return ContentSignals(**raw) # ── Routes ───────────────────────────────────────────────────────────────────── @app.get("/health") async def health(): return {"status": "ok", "environment": "trust-safety-env", "version": "1.0.0"} @app.get("/") async def root(): return {"status": "ok", "docs": "/docs"} @app.post("/reset") async def reset(body: ResetRequest = ResetRequest()): obs = _env.reset(seed=body.seed, episode_id=body.episode_id) return JSONResponse(_obs_to_dict(obs)) @app.post("/step") async def step(body: ActionRequest): # Parse + validate signals defensively signals: Optional[ContentSignals] = None if body.signals: try: signals = _parse_signals(dict(body.signals)) # copy so we don't mutate except Exception as e: raise HTTPException(status_code=400, detail=f"Invalid signals payload: {e}") action = TrustAction( action_type = body.action_type, tool_name = body.tool_name, signals = signals, final_decision = body.final_decision, ) try: obs = _env.step(action) except (RuntimeError, ValueError) as e: raise HTTPException(status_code=400, detail=str(e)) return JSONResponse(_obs_to_dict(obs)) @app.get("/state") async def state(): return JSONResponse(_state_to_dict(_env.state))