openenv / server /app.py
jeromerichard's picture
Fix: add server/app.py, uv.lock, project.scripts entry point
7cf2ffd
from __future__ import annotations
import json
from typing import Any, Dict, Optional
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from models import TrustAction, TrustObservation, TrustState, ContentSignals
from your_environment import TrustSafetyEnvironment
# ── Force manual FastAPI (openenv_core create_app causes 422 on /step) ────────
print("[app] Using manual FastAPI βœ…")
_env = TrustSafetyEnvironment(seed=42)
app = FastAPI(
title="Trust & Safety RL Environment",
description="Risk-aware content moderation environment for agent training.",
version="1.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ── Serializers ───────────────────────────────────────────────────────────────
def _obs_to_dict(obs: TrustObservation) -> Dict[str, Any]:
return {
"ticket_id": obs.ticket_id,
"post_text": obs.post_text,
"image_description": obs.image_description,
"comments_found": obs.comments_found,
"user_history_found": obs.user_history_found,
"entity_status_found": obs.entity_status_found,
"policy_found": obs.policy_found,
"extracted_signals": obs.extracted_signals,
"validation_result": obs.validation_result,
"step_number": obs.step_number,
"info": obs.info,
"done": obs.done,
"reward": obs.reward,
}
def _state_to_dict(s: TrustState) -> Dict[str, Any]:
return {
"episode_id": s.episode_id,
"step_count": s.step_count,
"current_task_id": s.current_task_id,
"difficulty": s.difficulty,
"ambiguity_level": s.ambiguity_level,
"risk_level": s.risk_level,
"tools_used": s.tools_used,
"signals_extracted": s.signals_extracted,
"is_done": s.is_done,
}
# ── Request bodies ─────────────────────────────────────────────────────────────
class ResetRequest(BaseModel):
seed: Any = None
episode_id: Any = None
model_config = {"extra": "ignore"}
class ActionRequest(BaseModel):
action_type: str = ""
tool_name: Optional[str] = None
signals: Optional[Dict[str, Any]] = None # raw dict β€” validated below
final_decision: Optional[str] = None
model_config = {"extra": "ignore"} # ← ignore unknown keys from LLM
# ── Helpers ────────────────────────────────────────────────────────────────────
def _parse_signals(raw: Dict[str, Any]) -> ContentSignals:
"""Defensively normalise LLM signal output before Pydantic validation."""
# Clamp floats
raw["toxicity_level"] = float(raw.get("toxicity_level", 0.5))
raw["confidence"] = float(raw.get("confidence", 0.5))
# content_flags must be a list of strings
flags = raw.get("content_flags", [])
if not isinstance(flags, list):
flags = [flags] if isinstance(flags, str) else []
raw["content_flags"] = [str(f) for f in flags]
# boolean coercion
raw["is_protected_class"] = bool(raw.get("is_protected_class", False))
raw["is_direct_attack"] = bool(raw.get("is_direct_attack", False))
raw["abusive_language_present"] = bool(raw.get("abusive_language_present", False))
# string fields β€” fallback to sensible defaults
raw.setdefault("target", "none")
raw.setdefault("intent", "ambiguous")
raw.setdefault("context_type", "statement")
return ContentSignals(**raw)
# ── Routes ─────────────────────────────────────────────────────────────────────
@app.get("/health")
async def health():
return {"status": "ok", "environment": "trust-safety-env", "version": "1.0.0"}
@app.get("/")
async def root():
return {"status": "ok", "docs": "/docs"}
@app.post("/reset")
async def reset(body: ResetRequest = ResetRequest()):
obs = _env.reset(seed=body.seed, episode_id=body.episode_id)
return JSONResponse(_obs_to_dict(obs))
@app.post("/step")
async def step(body: ActionRequest):
# Parse + validate signals defensively
signals: Optional[ContentSignals] = None
if body.signals:
try:
signals = _parse_signals(dict(body.signals)) # copy so we don't mutate
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid signals payload: {e}")
action = TrustAction(
action_type = body.action_type,
tool_name = body.tool_name,
signals = signals,
final_decision = body.final_decision,
)
try:
obs = _env.step(action)
except (RuntimeError, ValueError) as e:
raise HTTPException(status_code=400, detail=str(e))
return JSONResponse(_obs_to_dict(obs))
@app.get("/state")
async def state():
return JSONResponse(_state_to_dict(_env.state))