Spaces:
Configuration error
Configuration error
| from __future__ import annotations | |
| import json | |
| from typing import Any, Dict, Optional | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| from models import TrustAction, TrustObservation, TrustState, ContentSignals | |
| from your_environment import TrustSafetyEnvironment | |
| # ββ Force manual FastAPI (openenv_core create_app causes 422 on /step) ββββββββ | |
| print("[app] Using manual FastAPI β ") | |
| _env = TrustSafetyEnvironment(seed=42) | |
| app = FastAPI( | |
| title="Trust & Safety RL Environment", | |
| description="Risk-aware content moderation environment for agent training.", | |
| version="1.0.0", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ββ Serializers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _obs_to_dict(obs: TrustObservation) -> Dict[str, Any]: | |
| return { | |
| "ticket_id": obs.ticket_id, | |
| "post_text": obs.post_text, | |
| "image_description": obs.image_description, | |
| "comments_found": obs.comments_found, | |
| "user_history_found": obs.user_history_found, | |
| "entity_status_found": obs.entity_status_found, | |
| "policy_found": obs.policy_found, | |
| "extracted_signals": obs.extracted_signals, | |
| "validation_result": obs.validation_result, | |
| "step_number": obs.step_number, | |
| "info": obs.info, | |
| "done": obs.done, | |
| "reward": obs.reward, | |
| } | |
| def _state_to_dict(s: TrustState) -> Dict[str, Any]: | |
| return { | |
| "episode_id": s.episode_id, | |
| "step_count": s.step_count, | |
| "current_task_id": s.current_task_id, | |
| "difficulty": s.difficulty, | |
| "ambiguity_level": s.ambiguity_level, | |
| "risk_level": s.risk_level, | |
| "tools_used": s.tools_used, | |
| "signals_extracted": s.signals_extracted, | |
| "is_done": s.is_done, | |
| } | |
| # ββ Request bodies βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ResetRequest(BaseModel): | |
| seed: Any = None | |
| episode_id: Any = None | |
| model_config = {"extra": "ignore"} | |
| class ActionRequest(BaseModel): | |
| action_type: str = "" | |
| tool_name: Optional[str] = None | |
| signals: Optional[Dict[str, Any]] = None # raw dict β validated below | |
| final_decision: Optional[str] = None | |
| model_config = {"extra": "ignore"} # β ignore unknown keys from LLM | |
| # ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _parse_signals(raw: Dict[str, Any]) -> ContentSignals: | |
| """Defensively normalise LLM signal output before Pydantic validation.""" | |
| # Clamp floats | |
| raw["toxicity_level"] = float(raw.get("toxicity_level", 0.5)) | |
| raw["confidence"] = float(raw.get("confidence", 0.5)) | |
| # content_flags must be a list of strings | |
| flags = raw.get("content_flags", []) | |
| if not isinstance(flags, list): | |
| flags = [flags] if isinstance(flags, str) else [] | |
| raw["content_flags"] = [str(f) for f in flags] | |
| # boolean coercion | |
| raw["is_protected_class"] = bool(raw.get("is_protected_class", False)) | |
| raw["is_direct_attack"] = bool(raw.get("is_direct_attack", False)) | |
| raw["abusive_language_present"] = bool(raw.get("abusive_language_present", False)) | |
| # string fields β fallback to sensible defaults | |
| raw.setdefault("target", "none") | |
| raw.setdefault("intent", "ambiguous") | |
| raw.setdefault("context_type", "statement") | |
| return ContentSignals(**raw) | |
| # ββ Routes βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def health(): | |
| return {"status": "ok", "environment": "trust-safety-env", "version": "1.0.0"} | |
| async def root(): | |
| return {"status": "ok", "docs": "/docs"} | |
| async def reset(body: ResetRequest = ResetRequest()): | |
| obs = _env.reset(seed=body.seed, episode_id=body.episode_id) | |
| return JSONResponse(_obs_to_dict(obs)) | |
| async def step(body: ActionRequest): | |
| # Parse + validate signals defensively | |
| signals: Optional[ContentSignals] = None | |
| if body.signals: | |
| try: | |
| signals = _parse_signals(dict(body.signals)) # copy so we don't mutate | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Invalid signals payload: {e}") | |
| action = TrustAction( | |
| action_type = body.action_type, | |
| tool_name = body.tool_name, | |
| signals = signals, | |
| final_decision = body.final_decision, | |
| ) | |
| try: | |
| obs = _env.step(action) | |
| except (RuntimeError, ValueError) as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| return JSONResponse(_obs_to_dict(obs)) | |
| async def state(): | |
| return JSONResponse(_state_to_dict(_env.state)) |