""" FastAPI backend — exposes the LangGraph pipeline as a REST API. Endpoints: POST /chat — single-turn inference (non-streaming) POST /chat/stream — streaming token delivery via SSE GET /users — list available personas POST /session/reset — reset session state for a user GET /health — liveness check """ from __future__ import annotations import json import time from typing import AsyncGenerator from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel from config.settings import settings from guardrails.checks import check_input from pipeline.graph import aac_graph from pipeline.state import PipelineState from retrieval.bucket_priors import uniform_priors app = FastAPI( title="Multimodal AAC Chatbot API", description="Agentic RAG pipeline for AAC persona communication", version="2.0.0", ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ── In-memory session store (replace with Redis for multi-worker deployments) ── _sessions: dict[str, dict] = {} # ── Request / response schemas ───────────────────────────────────────────────── class ChatRequest(BaseModel): user_id: str query: str affect_override: str | None = None # "HAPPY"|"FRUSTRATED"|"NEUTRAL"|"SURPRISED" gesture_tag: str | None = None gaze_bucket: str | None = None class ChatResponse(BaseModel): user_id: str query: str response: str affect: str llm_tier: str retrieval_mode: str latency: dict guardrail_passed: bool # ── Helpers ──────────────────────────────────────────────────────────────────── def _get_or_init_session(user_id: str) -> dict: if user_id not in _sessions: with open(settings.users_json) as f: users = {u["id"]: u for u in json.load(f)["users"]} if user_id not in users: raise HTTPException(status_code=404, detail=f"User '{user_id}' not found") _sessions[user_id] = { "persona_profile": users[user_id], "session_history": [], "bucket_priors": uniform_priors(), "turn_id": 0, } return _sessions[user_id] def _build_initial_state(req: ChatRequest, session: dict) -> PipelineState: affect_state = None if req.affect_override: affect_state = {"emotion": req.affect_override, "vector": {}, "smoothed": {}} session["turn_id"] += 1 return PipelineState( user_id=req.user_id, persona_profile=session["persona_profile"], session_history=session["session_history"], turn_id=session["turn_id"], affect=affect_state, gesture_tag=req.gesture_tag, gaze_bucket=req.gaze_bucket, air_written_text=None, raw_query=req.query, intent_route=None, generation_config=None, retrieved_chunks=[], bucket_priors=session["bucket_priors"], retrieval_mode_used="", augmented_prompt=None, candidates=[], selected_response=None, llm_tier_used="", latency_log={"t_sensing": 0.0, "t_intent": 0.0, "t_retrieval": 0.0, "t_generation": 0.0, "t_total": 0.0}, mlflow_run_id=None, guardrail_passed=True, ) # ── Routes ───────────────────────────────────────────────────────────────────── @app.get("/health") def health(): return {"status": "ok"} @app.get("/users") def list_users(): with open(settings.users_json) as f: return json.load(f) @app.post("/session/reset") def reset_session(user_id: str): _sessions.pop(user_id, None) return {"status": "reset", "user_id": user_id} @app.post("/chat", response_model=ChatResponse) def chat(req: ChatRequest): guard = check_input(req.query) if not guard["allowed"]: return ChatResponse( user_id=req.user_id, query=req.query, response=guard["fallback"], affect="NEUTRAL", llm_tier="none", retrieval_mode="none", latency={}, guardrail_passed=False, ) session = _get_or_init_session(req.user_id) initial_state = _build_initial_state(req, session) result: PipelineState = aac_graph.invoke(initial_state) # Persist updated session state session["session_history"] = result["session_history"] session["bucket_priors"] = result["bucket_priors"] return ChatResponse( user_id=req.user_id, query=req.query, response=result["selected_response"] or "", affect=(result.get("affect") or {}).get("emotion", "NEUTRAL"), llm_tier=result.get("llm_tier_used", "unknown"), retrieval_mode=result.get("retrieval_mode_used", "unknown"), latency=result.get("latency_log") or {}, guardrail_passed=result.get("guardrail_passed", True), )