aac-chatbot / api /main.py
akashkolte's picture
added basic pipeline
e06dc15
raw
history blame
5.39 kB
"""
FastAPI backend β€” exposes the LangGraph pipeline as a REST API.
Endpoints:
POST /chat β€” single-turn inference (non-streaming)
POST /chat/stream β€” streaming token delivery via SSE
GET /users β€” list available personas
POST /session/reset β€” reset session state for a user
GET /health β€” liveness check
"""
from __future__ import annotations
import json
import time
from typing import AsyncGenerator
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from config.settings import settings
from guardrails.checks import check_input
from pipeline.graph import aac_graph
from pipeline.state import PipelineState
from retrieval.bucket_priors import uniform_priors
app = FastAPI(
title="Multimodal AAC Chatbot API",
description="Agentic RAG pipeline for AAC persona communication",
version="2.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ── In-memory session store (replace with Redis for multi-worker deployments) ──
_sessions: dict[str, dict] = {}
# ── Request / response schemas ─────────────────────────────────────────────────
class ChatRequest(BaseModel):
user_id: str
query: str
affect_override: str | None = None # "HAPPY"|"FRUSTRATED"|"NEUTRAL"|"SURPRISED"
gesture_tag: str | None = None
gaze_bucket: str | None = None
class ChatResponse(BaseModel):
user_id: str
query: str
response: str
affect: str
llm_tier: str
retrieval_mode: str
latency: dict
guardrail_passed: bool
# ── Helpers ────────────────────────────────────────────────────────────────────
def _get_or_init_session(user_id: str) -> dict:
if user_id not in _sessions:
with open(settings.users_json) as f:
users = {u["id"]: u for u in json.load(f)["users"]}
if user_id not in users:
raise HTTPException(status_code=404, detail=f"User '{user_id}' not found")
_sessions[user_id] = {
"persona_profile": users[user_id],
"session_history": [],
"bucket_priors": uniform_priors(),
"turn_id": 0,
}
return _sessions[user_id]
def _build_initial_state(req: ChatRequest, session: dict) -> PipelineState:
affect_state = None
if req.affect_override:
affect_state = {"emotion": req.affect_override, "vector": {}, "smoothed": {}}
session["turn_id"] += 1
return PipelineState(
user_id=req.user_id,
persona_profile=session["persona_profile"],
session_history=session["session_history"],
turn_id=session["turn_id"],
affect=affect_state,
gesture_tag=req.gesture_tag,
gaze_bucket=req.gaze_bucket,
air_written_text=None,
raw_query=req.query,
intent_route=None,
generation_config=None,
retrieved_chunks=[],
bucket_priors=session["bucket_priors"],
retrieval_mode_used="",
augmented_prompt=None,
candidates=[],
selected_response=None,
llm_tier_used="",
latency_log={"t_sensing": 0.0, "t_intent": 0.0, "t_retrieval": 0.0, "t_generation": 0.0, "t_total": 0.0},
mlflow_run_id=None,
guardrail_passed=True,
)
# ── Routes ─────────────────────────────────────────────────────────────────────
@app.get("/health")
def health():
return {"status": "ok"}
@app.get("/users")
def list_users():
with open(settings.users_json) as f:
return json.load(f)
@app.post("/session/reset")
def reset_session(user_id: str):
_sessions.pop(user_id, None)
return {"status": "reset", "user_id": user_id}
@app.post("/chat", response_model=ChatResponse)
def chat(req: ChatRequest):
guard = check_input(req.query)
if not guard["allowed"]:
return ChatResponse(
user_id=req.user_id,
query=req.query,
response=guard["fallback"],
affect="NEUTRAL",
llm_tier="none",
retrieval_mode="none",
latency={},
guardrail_passed=False,
)
session = _get_or_init_session(req.user_id)
initial_state = _build_initial_state(req, session)
result: PipelineState = aac_graph.invoke(initial_state)
# Persist updated session state
session["session_history"] = result["session_history"]
session["bucket_priors"] = result["bucket_priors"]
return ChatResponse(
user_id=req.user_id,
query=req.query,
response=result["selected_response"] or "",
affect=(result.get("affect") or {}).get("emotion", "NEUTRAL"),
llm_tier=result.get("llm_tier_used", "unknown"),
retrieval_mode=result.get("retrieval_mode_used", "unknown"),
latency=result.get("latency_log") or {},
guardrail_passed=result.get("guardrail_passed", True),
)