Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI backend β exposes the LangGraph pipeline as a REST API. | |
| Endpoints: | |
| POST /chat β single-turn inference (non-streaming) | |
| POST /chat/stream β streaming token delivery via SSE | |
| GET /users β list available personas | |
| POST /session/reset β reset session state for a user | |
| GET /health β liveness check | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import time | |
| from typing import AsyncGenerator | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| from config.settings import settings | |
| from guardrails.checks import check_input | |
| from pipeline.graph import aac_graph | |
| from pipeline.state import PipelineState | |
| from retrieval.bucket_priors import uniform_priors | |
| app = FastAPI( | |
| title="Multimodal AAC Chatbot API", | |
| description="Agentic RAG pipeline for AAC persona communication", | |
| version="2.0.0", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ββ In-memory session store (replace with Redis for multi-worker deployments) ββ | |
| _sessions: dict[str, dict] = {} | |
| # ββ Request / response schemas βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ChatRequest(BaseModel): | |
| user_id: str | |
| query: str | |
| affect_override: str | None = None # "HAPPY"|"FRUSTRATED"|"NEUTRAL"|"SURPRISED" | |
| gesture_tag: str | None = None | |
| gaze_bucket: str | None = None | |
| class ChatResponse(BaseModel): | |
| user_id: str | |
| query: str | |
| response: str | |
| affect: str | |
| llm_tier: str | |
| retrieval_mode: str | |
| latency: dict | |
| guardrail_passed: bool | |
| # ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _get_or_init_session(user_id: str) -> dict: | |
| if user_id not in _sessions: | |
| with open(settings.users_json) as f: | |
| users = {u["id"]: u for u in json.load(f)["users"]} | |
| if user_id not in users: | |
| raise HTTPException(status_code=404, detail=f"User '{user_id}' not found") | |
| _sessions[user_id] = { | |
| "persona_profile": users[user_id], | |
| "session_history": [], | |
| "bucket_priors": uniform_priors(), | |
| "turn_id": 0, | |
| } | |
| return _sessions[user_id] | |
| def _build_initial_state(req: ChatRequest, session: dict) -> PipelineState: | |
| affect_state = None | |
| if req.affect_override: | |
| affect_state = {"emotion": req.affect_override, "vector": {}, "smoothed": {}} | |
| session["turn_id"] += 1 | |
| return PipelineState( | |
| user_id=req.user_id, | |
| persona_profile=session["persona_profile"], | |
| session_history=session["session_history"], | |
| turn_id=session["turn_id"], | |
| affect=affect_state, | |
| gesture_tag=req.gesture_tag, | |
| gaze_bucket=req.gaze_bucket, | |
| air_written_text=None, | |
| raw_query=req.query, | |
| intent_route=None, | |
| generation_config=None, | |
| retrieved_chunks=[], | |
| bucket_priors=session["bucket_priors"], | |
| retrieval_mode_used="", | |
| augmented_prompt=None, | |
| candidates=[], | |
| selected_response=None, | |
| llm_tier_used="", | |
| latency_log={"t_sensing": 0.0, "t_intent": 0.0, "t_retrieval": 0.0, "t_generation": 0.0, "t_total": 0.0}, | |
| mlflow_run_id=None, | |
| guardrail_passed=True, | |
| ) | |
| # ββ Routes βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def health(): | |
| return {"status": "ok"} | |
| def list_users(): | |
| with open(settings.users_json) as f: | |
| return json.load(f) | |
| def reset_session(user_id: str): | |
| _sessions.pop(user_id, None) | |
| return {"status": "reset", "user_id": user_id} | |
| def chat(req: ChatRequest): | |
| guard = check_input(req.query) | |
| if not guard["allowed"]: | |
| return ChatResponse( | |
| user_id=req.user_id, | |
| query=req.query, | |
| response=guard["fallback"], | |
| affect="NEUTRAL", | |
| llm_tier="none", | |
| retrieval_mode="none", | |
| latency={}, | |
| guardrail_passed=False, | |
| ) | |
| session = _get_or_init_session(req.user_id) | |
| initial_state = _build_initial_state(req, session) | |
| result: PipelineState = aac_graph.invoke(initial_state) | |
| # Persist updated session state | |
| session["session_history"] = result["session_history"] | |
| session["bucket_priors"] = result["bucket_priors"] | |
| return ChatResponse( | |
| user_id=req.user_id, | |
| query=req.query, | |
| response=result["selected_response"] or "", | |
| affect=(result.get("affect") or {}).get("emotion", "NEUTRAL"), | |
| llm_tier=result.get("llm_tier_used", "unknown"), | |
| retrieval_mode=result.get("retrieval_mode_used", "unknown"), | |
| latency=result.get("latency_log") or {}, | |
| guardrail_passed=result.get("guardrail_passed", True), | |
| ) | |