# FastAPI backend — REST API for the AAC pipeline. import json import logging import re import threading import time from collections import OrderedDict from functools import lru_cache from pathlib import Path from fastapi import BackgroundTasks, FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel, Field from backend.config.settings import settings from backend.evals import compute_evals from backend.generation.llm_client import ( # active_model used by /debug/config active_model, get_client, ) from backend.guardrails.checks import check_input from backend.pipeline.graph import choose_planner_tier, run_pipeline, run_until_planner from backend.pipeline.intent_kind import classify_intent_kind from backend.pipeline.nodes import feedback as feedback_node from backend.pipeline.nodes import planner as planner_node from backend.pipeline.state import PipelineState from backend.retrieval import pick_index from backend.retrieval.priors import BUCKETS, CHUNK_TYPES, uniform from backend.retrieval.vector_store import _get_embedder, retrieve app = FastAPI( title="Multimodal AAC Chatbot API", description="Agentic RAG pipeline for AAC persona communication", version="2.0.0", ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) _log = logging.getLogger(__name__) _models_ready = False _RUN_ID_RE = re.compile(r"^[0-9a-f]{32}$") _ID_PATTERN = r"^[a-zA-Z0-9_-]+$" @app.on_event("startup") def _warmup(): global _models_ready import logging import os os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1") logging.getLogger("sentence_transformers").setLevel(logging.WARNING) print("Loading models...", end=" ", flush=True) _get_embedder() get_client() _models_ready = True print("ready.") # ── In-memory session store (replace with Redis for multi-worker deployments) ── _sessions: dict[str, dict] = {} # Eval scores keyed by run_id, filled by a BackgroundTask after /chat returns # so the UI can render the response immediately and poll GET /evals/{run_id}. # Multi-worker deploys should swap this (and _sessions) for Redis. _EVAL_FAILED: dict = {"_failed": True} _eval_results: OrderedDict[str, dict] = OrderedDict() _eval_lock = threading.Lock() _EVAL_RESULTS_MAX = 200 def _remember_eval(run_id: str, scores: dict | None) -> None: value = scores if scores else _EVAL_FAILED with _eval_lock: _eval_results[run_id] = value _eval_results.move_to_end(run_id) while len(_eval_results) > _EVAL_RESULTS_MAX: _eval_results.popitem(last=False) def _reserve_eval_slot(run_id: str) -> None: """Mark a run_id as in-flight so /evals can report 'pending' vs 'unknown'.""" with _eval_lock: if run_id not in _eval_results: _eval_results[run_id] = {} # empty dict = pending _eval_results.move_to_end(run_id) while len(_eval_results) > _EVAL_RESULTS_MAX: _eval_results.popitem(last=False) # ── Request / response schemas ───────────────────────────────────────────────── class ResolvedIntent(BaseModel): text: str source: str # voice_only | air_only | agree | conflict_air | conflict_voice | none voice_text: str | None = None air_text: str | None = None class ChatRequest(BaseModel): user_id: str query: str affect_override: str | None = None # "HAPPY"|"FRUSTRATED"|"NEUTRAL"|"SURPRISED" gesture_tag: str | None = None gaze_bucket: str | None = None air_written_text: str | None = None head_signal: str | None = None # "HEAD_SHAKE"|"HEAD_NOD_DISSATISFIED" voice_text: str | None = None resolved_intent: ResolvedIntent | None = None class TurnaroundRequest(BaseModel): user_id: str turn_id: int | None = None # optional guard against stale turnaround calls head_signal: str | None = None class CandidateOut(BaseModel): text: str strategy: str grounded_buckets: list[str] = [] class ChatResponse(BaseModel): user_id: str query: str response: str candidates: list[CandidateOut] = [] affect: str llm_tier: str llm_model: str retrieval_mode: str latency: dict guardrail_passed: bool run_id: str | None = None turn_id: int eval_scores: dict | None = None class PickRequest(BaseModel): run_id: str = Field(min_length=1, max_length=64, pattern=_ID_PATTERN) user_id: str = Field(min_length=1, max_length=64, pattern=_ID_PATTERN) picked_idx: int = Field(ge=0, le=10) class RegenerateRequest(BaseModel): user_id: str turn_id: int | None = None rejected_texts: list[str] = Field(default_factory=list, max_length=20) class RatingRequest(BaseModel): run_id: str = Field(min_length=1, max_length=64, pattern=_ID_PATTERN) user_id: str = Field(min_length=1, max_length=64, pattern=_ID_PATTERN) authenticity: int = Field(ge=1, le=5) rater_id: str = Field(default="anonymous", max_length=64, pattern=_ID_PATTERN) notes: str | None = Field(default=None, max_length=500) # ── Helpers ──────────────────────────────────────────────────────────────────── def _candidate_dicts(state) -> list[dict]: return [dict(c) for c in (state.get("candidates") or [])] def _load_persona_profile(user_id: str) -> dict: memories_path = settings.memories_dir / f"{user_id}.json" try: with open(memories_path) as f: persona = json.load(f) except FileNotFoundError as e: raise HTTPException( status_code=404, detail=f"Persona file not found: {memories_path}", ) from e return persona["profile"] def _get_or_init_session(user_id: str) -> dict: if user_id not in _sessions: try: with open(settings.users_json) as f: users = {u["id"]: u for u in json.load(f)["users"]} except FileNotFoundError as e: raise HTTPException( status_code=503, detail="users.json not found — run setup.sh" ) from e if user_id not in users: raise HTTPException(status_code=404, detail=f"User '{user_id}' not found") _sessions[user_id] = { "persona_profile": _load_persona_profile(user_id), "session_history": [], "bucket_priors": uniform(BUCKETS), "type_priors": uniform(CHUNK_TYPES), "turn_id": 0, } return _sessions[user_id] def _build_initial_state(req: ChatRequest, session: dict) -> PipelineState: affect_state = None if req.affect_override: affect_state = {"emotion": req.affect_override, "vector": {}, "smoothed": {}} session["turn_id"] += 1 return PipelineState( user_id=req.user_id, persona_profile=session["persona_profile"], session_history=session["session_history"], turn_id=session["turn_id"], affect=affect_state, gesture_tag=req.gesture_tag, gaze_bucket=req.gaze_bucket, air_written_text=req.air_written_text, head_signal=req.head_signal, voice_text=req.voice_text, resolved_intent=( req.resolved_intent.model_dump() if req.resolved_intent else None ), turnaround_triggered=False, raw_query=req.query, intent_route=None, generation_config=None, retrieved_chunks=[], bucket_priors=session["bucket_priors"], type_priors=session["type_priors"], retrieval_mode_used="", augmented_prompt=None, candidates=[], rejected_candidates=[], selected_response=None, llm_tier_used="", llm_model_used="", latency_log={ "t_sensing": 0.0, "t_intent": 0.0, "t_retrieval": 0.0, "t_generation": 0.0, "t_total": 0.0, }, run_id=None, guardrail_passed=True, ) def _re_retrieve_excluding( query: str, user_id: str, rejected_chunks: list[dict], ) -> list[dict] | None: """Pull fresh chunks for a turnaround, excluding the bucket and exact texts of the rejected chunks. Returns: - list of chunks (passing min-score floor) when re-retrieval improved on the rejected set - None when re-retrieval should not be used (no signal, all dropped by dedupe, or all below score floor) — caller should keep original chunks """ if not rejected_chunks: return None rejected_bucket = rejected_chunks[0].get("bucket") rejected_texts = {c.get("text") for c in rejected_chunks if c.get("text")} if not rejected_bucket: return None try: # Pull a wider net (top_k * 2) so dedupe + bucket-exclusion still leaves # enough candidates to fill rerank_k. fresh = retrieve( query=query, user_id=user_id, top_k=settings.retrieval_top_k * 2, rerank_k=settings.retrieval_top_k * 2, bucket_filter=None, ) except Exception as exc: _log.warning("turnaround re-retrieval failed: %r", exc) return None filtered = [ c for c in fresh if c.get("bucket") != rejected_bucket and c.get("text") not in rejected_texts and float(c.get("score", 0.0)) >= settings.turnaround_min_score ] if not filtered: _log.info( "turnaround re-retrieval found no chunks above score floor %.2f — " "keeping original chunks", settings.turnaround_min_score, ) return None return filtered[: settings.retrieval_rerank_k] # ── Routes ───────────────────────────────────────────────────────────────────── @app.get("/health") def health(): return {"status": "ok", "models_ready": _models_ready} @app.get("/debug/config") def debug_config(): """Return active model + key settings for the debug panel.""" return { "active_llm_tier": settings.active_llm_tier, "active_model": active_model(), "thinking_mode": settings.thinking_mode, "embed_model": settings.embed_model, "retrieval_top_k": settings.retrieval_top_k, "retrieval_rerank_k": settings.retrieval_rerank_k, "fallback_latency_threshold": settings.fallback_latency_threshold, "slo_target_s": settings.slo_target_s, } @app.get("/users") def list_users(): try: with open(settings.users_json) as f: return json.load(f) except FileNotFoundError as e: raise HTTPException( status_code=503, detail="users.json not found — run setup.sh" ) from e @app.post("/session/reset") def reset_session(user_id: str): _sessions.pop(user_id, None) return {"status": "reset", "user_id": user_id} def _compute_and_persist_evals( run_id: str | None, user_id: str, turn_id: int, response: str, chunks: list[dict], latency_log: dict, affect: str, gesture_tag: str | None, gaze_bucket: str | None, query: str = "", candidates: list[dict] | None = None, ) -> dict | None: if not settings.evals_enabled or not run_id: return None try: scores = compute_evals( response=response, chunks=chunks, latency_log=latency_log, affect=affect, gesture_tag=gesture_tag, gaze_bucket=gaze_bucket, slo_target=settings.slo_target_s, query=query, candidates=candidates, ) except Exception: _log.exception("evals scoring failed for run %s", run_id) _remember_eval(run_id, None) return None _remember_eval(run_id, scores) try: entry = { "run_id": run_id, "ts": time.time(), "user_id": user_id, "turn_id": turn_id, **scores, } logs_dir = Path(settings.logs_dir) logs_dir.mkdir(parents=True, exist_ok=True) with open(logs_dir / "evals.jsonl", "a", encoding="utf-8") as f: f.write(json.dumps(entry, ensure_ascii=False) + "\n") except Exception: _log.exception("evals JSONL persist failed for run %s", run_id) return scores @app.post("/chat", response_model=ChatResponse) def chat(req: ChatRequest, background_tasks: BackgroundTasks): guard = check_input(req.query) if not guard["allowed"]: return ChatResponse( user_id=req.user_id, query=req.query, response=guard["fallback"], candidates=[], affect="NEUTRAL", llm_tier="none", llm_model="none", retrieval_mode="none", latency={}, guardrail_passed=False, turn_id=0, ) session = _get_or_init_session(req.user_id) initial_state = _build_initial_state(req, session) result: PipelineState = run_pipeline(initial_state) session["session_history"] = result["session_history"] session["bucket_priors"] = result["bucket_priors"] session["type_priors"] = result["type_priors"] session["last_state"] = result affect_emotion = (result.get("affect") or {}).get("emotion", "NEUTRAL") run_id = result.get("run_id") # Evals (NLI cross-encoder) run off the response path; UI polls /evals. if run_id and settings.evals_enabled: _reserve_eval_slot(run_id) background_tasks.add_task( _compute_and_persist_evals, run_id=run_id, user_id=req.user_id, turn_id=result["turn_id"], response=result["selected_response"] or "", chunks=list(result.get("retrieved_chunks") or []), latency_log=dict(result.get("latency_log") or {}), affect=affect_emotion, gesture_tag=req.gesture_tag, gaze_bucket=req.gaze_bucket, query=req.query, candidates=_candidate_dicts(result), ) return ChatResponse( user_id=req.user_id, query=req.query, response=result["selected_response"] or "", candidates=[CandidateOut(**c) for c in result.get("candidates") or []], affect=affect_emotion, llm_tier=result.get("llm_tier_used", "unknown"), llm_model=result.get("llm_model_used", "unknown"), retrieval_mode=result.get("retrieval_mode_used", "unknown"), latency=result.get("latency_log") or {}, guardrail_passed=result.get("guardrail_passed", True), run_id=run_id, turn_id=result["turn_id"], eval_scores=None, ) @app.post("/chat/stream") def chat_stream(req: ChatRequest): """Server-Sent Events version of /chat. Runs intent + retrieval synchronously, then streams planner candidate tokens as they arrive. Final event carries the full ChatResponse-shaped payload. """ guard = check_input(req.query) if not guard["allowed"]: payload = { "user_id": req.user_id, "query": req.query, "response": guard["fallback"], "candidates": [], "affect": "NEUTRAL", "llm_tier": "none", "llm_model": "none", "retrieval_mode": "none", "latency": {}, "guardrail_passed": False, "turn_id": 0, "run_id": None, "eval_scores": None, } def _one_event(): yield _sse({"type": "complete", "response": payload}) return StreamingResponse(_one_event(), media_type="text/event-stream") session = _get_or_init_session(req.user_id) initial_state = _build_initial_state(req, session) def _gen(): state = run_until_planner(initial_state) tier = choose_planner_tier(state) completion: dict | None = None for evt in planner_node._run_stream(state, tier=tier): if evt["type"] == "complete": completion = evt["planner_update"] break yield _sse(evt) if completion is None: yield _sse({"type": "error", "message": "planner produced no completion"}) return state.update(completion) # type: ignore[typeddict-item] state.update(feedback_node.run(state)) # type: ignore[typeddict-item] session["session_history"] = state["session_history"] session["bucket_priors"] = state["bucket_priors"] session["type_priors"] = state["type_priors"] session["last_state"] = state affect_emotion = (state.get("affect") or {}).get("emotion", "NEUTRAL") run_id = state.get("run_id") # Evals run off the response path; UI polls GET /evals/{run_id}. cand_dicts = _candidate_dicts(state) if run_id and settings.evals_enabled: _reserve_eval_slot(run_id) threading.Thread( target=_compute_and_persist_evals, kwargs=dict( run_id=run_id, user_id=req.user_id, turn_id=state["turn_id"], response=state["selected_response"] or "", chunks=list(state.get("retrieved_chunks") or []), latency_log=dict(state.get("latency_log") or {}), affect=affect_emotion, gesture_tag=req.gesture_tag, gaze_bucket=req.gaze_bucket, query=req.query, candidates=cand_dicts, ), daemon=True, ).start() final = { "user_id": req.user_id, "query": req.query, "response": state["selected_response"] or "", "candidates": cand_dicts, "affect": affect_emotion, "llm_tier": state.get("llm_tier_used", "unknown"), "llm_model": state.get("llm_model_used", "unknown"), "retrieval_mode": state.get("retrieval_mode_used", "unknown"), "latency": state.get("latency_log") or {}, "guardrail_passed": state.get("guardrail_passed", True), "run_id": run_id, "turn_id": state["turn_id"], "eval_scores": None, } yield _sse({"type": "complete", "response": final}) return StreamingResponse( _gen(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) def _sse(data: dict) -> str: return f"data: {json.dumps(data, ensure_ascii=False)}\n\n" @app.get("/evals/{run_id}") def get_evals(run_id: str): if not _RUN_ID_RE.match(run_id): raise HTTPException(status_code=400, detail="invalid run_id") with _eval_lock: entry = _eval_results.get(run_id) if entry is None: return {"status": "unknown", "run_id": run_id, "eval_scores": None} if entry is _EVAL_FAILED: return {"status": "failed", "run_id": run_id, "eval_scores": None} if not entry: return {"status": "pending", "run_id": run_id, "eval_scores": None} return {"status": "ready", "run_id": run_id, "eval_scores": entry} @app.post("/chat/turnaround", response_model=ChatResponse) def chat_turnaround(req: TurnaroundRequest, background_tasks: BackgroundTasks): if req.user_id not in _sessions: raise HTTPException(status_code=404, detail="no active session") session = _sessions[req.user_id] last: PipelineState | None = session.get("last_state") if last is None: raise HTTPException(status_code=409, detail="no prior turn to rephrase") if req.turn_id is not None and req.turn_id != last["turn_id"]: raise HTTPException(status_code=409, detail="stale turn_id") # feedback.run will re-append (partner, aac_user) for this turn, so strip # both of those tail entries to avoid duplicating the partner line. The # rejected aac_user text is also excluded from the re-plan context this way. trimmed_history = list(last.get("session_history") or []) if trimmed_history and trimmed_history[-1].get("role") == "aac_user": trimmed_history.pop() if trimmed_history and trimmed_history[-1].get("role") == "partner": trimmed_history.pop() intent_kind = classify_intent_kind(last.get("intent_route")) gen_cfg = dict(last.get("generation_config") or {}) if intent_kind == "present_state": gen_cfg["persona_mod"] = "present_state_retry" gen_cfg["tone_tag"] = "[TONE:HONEST_UNCERTAIN]" else: gen_cfg["persona_mod"] = "reverse_stance" gen_cfg.setdefault("tone_tag", "[TONE:CLARIFYING_REPHRASE]") replan_state: PipelineState = dict(last) # type: ignore[assignment] replan_state["session_history"] = trimmed_history replan_state["generation_config"] = gen_cfg replan_state["head_signal"] = req.head_signal or last.get("head_signal") replan_state["turnaround_triggered"] = True replan_state["latency_log"] = { "t_sensing": 0.0, "t_intent": 0.0, "t_retrieval": 0.0, "t_generation": 0.0, "t_total": 0.0, } # For PERSONAL turnarounds, pull fresh chunks excluding the bucket and # exact texts of the rejected response — same chunks would just produce # the same wrong answer. _re_retrieve_excluding returns None when the # fresh batch is no better than what we already had, in which case we # keep the original chunks rather than degrade to lower-relevance ones. if intent_kind == "memory": fresh_chunks = _re_retrieve_excluding( query=last["raw_query"], user_id=last["user_id"], rejected_chunks=last.get("retrieved_chunks") or [], ) if fresh_chunks is not None: replan_state["retrieved_chunks"] = fresh_chunks replan_state["retrieval_mode_used"] = "turnaround_rebucket" planner_update = planner_node.run_primary(replan_state) replan_state.update(planner_update) # type: ignore[typeddict-item] feedback_update = feedback_node.run(replan_state) replan_state.update(feedback_update) # type: ignore[typeddict-item] session["session_history"] = replan_state["session_history"] session["bucket_priors"] = replan_state["bucket_priors"] session["type_priors"] = replan_state["type_priors"] session["last_state"] = replan_state affect_emotion = (replan_state.get("affect") or {}).get("emotion", "NEUTRAL") run_id = replan_state.get("run_id") if run_id and settings.evals_enabled: _reserve_eval_slot(run_id) background_tasks.add_task( _compute_and_persist_evals, run_id=run_id, user_id=req.user_id, turn_id=replan_state["turn_id"], response=replan_state["selected_response"] or "", chunks=list(replan_state.get("retrieved_chunks") or []), latency_log=dict(replan_state.get("latency_log") or {}), affect=affect_emotion, gesture_tag=replan_state.get("gesture_tag"), gaze_bucket=replan_state.get("gaze_bucket"), query=replan_state.get("raw_query") or "", candidates=_candidate_dicts(replan_state), ) return ChatResponse( user_id=req.user_id, query=replan_state["raw_query"], response=replan_state["selected_response"] or "", candidates=[CandidateOut(**c) for c in replan_state.get("candidates") or []], affect=affect_emotion, llm_tier=replan_state.get("llm_tier_used", "unknown"), llm_model=replan_state.get("llm_model_used", "unknown"), retrieval_mode=replan_state.get("retrieval_mode_used", "unknown"), latency=replan_state.get("latency_log") or {}, guardrail_passed=replan_state.get("guardrail_passed", True), run_id=run_id, turn_id=replan_state["turn_id"], eval_scores=None, ) def _find_turn_from_jsonl(run_id: str) -> dict | None: """Scan turns.jsonl from the end for a matching run_id. Used as fallback when the session's last_state has already moved on.""" path = Path(settings.logs_dir) / "turns.jsonl" if not path.exists(): return None try: with open(path, encoding="utf-8") as f: lines = f.readlines() except OSError: return None for line in reversed(lines[-500:]): # bounded tail scan try: row = json.loads(line) except json.JSONDecodeError: continue if row.get("run_id") == run_id: return row return None @app.post("/chat/pick") def pick_candidate(req: PickRequest): if not _RUN_ID_RE.match(req.run_id): raise HTTPException(status_code=400, detail="invalid run_id") session = _sessions.get(req.user_id) or {} last = session.get("last_state") or {} candidates = last.get("candidates") or [] query_text = last.get("raw_query") or "" # Fallback: last_state already advanced past this run_id — read from JSONL if last.get("run_id") != req.run_id or not candidates: row = _find_turn_from_jsonl(req.run_id) if not row: raise HTTPException(status_code=404, detail="turn not found") candidates = row.get("candidates") or [] query_text = row.get("query") or query_text if req.picked_idx >= len(candidates): raise HTTPException(status_code=400, detail="picked_idx out of range") picked = candidates[req.picked_idx] picked_text = picked.get("text", "") strategy = picked.get("strategy", "unknown") picked_buckets = [ b for b in (picked.get("grounded_buckets") or []) if b and b != "open_domain" ] if query_text and picked_text: try: pick_index.add( query=query_text, user_id=req.user_id, strategy=strategy, picked_text=picked_text, picked_buckets=picked_buckets, ) except Exception as exc: _log.warning("pick_index add failed: %r", exc) logs_dir = Path(settings.logs_dir) logs_dir.mkdir(parents=True, exist_ok=True) entry = { "ts": time.time(), "run_id": req.run_id, "user_id": req.user_id, "picked_idx": req.picked_idx, "strategy": strategy, "picked_text": picked_text, "query": query_text, } with open(logs_dir / "picks.jsonl", "a", encoding="utf-8") as f: f.write(json.dumps(entry, ensure_ascii=False) + "\n") return {"status": "ok", "strategy": strategy} @app.post("/chat/regenerate/stream") def chat_regenerate_stream(req: RegenerateRequest): """Streaming regenerate — same as /chat/stream but reuses last_state and marks all prior candidates as rejected.""" if req.user_id not in _sessions: raise HTTPException(status_code=404, detail="no active session") session = _sessions[req.user_id] last: PipelineState | None = session.get("last_state") if last is None: raise HTTPException(status_code=409, detail="no prior turn to regenerate") if req.turn_id is not None and req.turn_id != last["turn_id"]: raise HTTPException(status_code=409, detail="stale turn_id") gen_cfg = dict(last.get("generation_config") or {}) gen_cfg["persona_mod"] = "all_rejected" gen_cfg.setdefault("tone_tag", "[TONE:TRY_DIFFERENT_ANGLE]") prior_rejected = [c.get("text", "") for c in (last.get("candidates") or [])] merged = ( list(last.get("rejected_candidates") or []) + [t for t in prior_rejected if t] + [t for t in req.rejected_texts if t] ) seen: set[str] = set() rejected: list[str] = [] for t in merged: key = t.strip().lower() if key and key not in seen: seen.add(key) rejected.append(t) trimmed_history = list(last.get("session_history") or []) if trimmed_history and trimmed_history[-1].get("role") == "aac_user": trimmed_history.pop() if trimmed_history and trimmed_history[-1].get("role") == "partner": trimmed_history.pop() replan_state: PipelineState = dict(last) # type: ignore[assignment] replan_state["session_history"] = trimmed_history replan_state["generation_config"] = gen_cfg replan_state["rejected_candidates"] = rejected replan_state["turnaround_triggered"] = False replan_state["latency_log"] = { "t_sensing": 0.0, "t_intent": 0.0, "t_retrieval": 0.0, "t_generation": 0.0, "t_total": 0.0, } def _gen(): completion: dict | None = None for evt in planner_node._run_stream(replan_state, tier="primary"): if evt["type"] == "complete": completion = evt["planner_update"] break yield _sse(evt) if completion is None: yield _sse({"type": "error", "message": "planner produced no completion"}) return replan_state.update(completion) # type: ignore[typeddict-item] replan_state.update(feedback_node.run(replan_state)) # type: ignore[typeddict-item] session["session_history"] = replan_state["session_history"] session["bucket_priors"] = replan_state["bucket_priors"] session["type_priors"] = replan_state["type_priors"] session["last_state"] = replan_state affect_emotion = (replan_state.get("affect") or {}).get("emotion", "NEUTRAL") run_id = replan_state.get("run_id") cand_dicts = _candidate_dicts(replan_state) eval_scores = _compute_and_persist_evals( run_id=run_id, user_id=req.user_id, turn_id=replan_state["turn_id"], response=replan_state["selected_response"] or "", chunks=list(replan_state.get("retrieved_chunks") or []), latency_log=dict(replan_state.get("latency_log") or {}), affect=affect_emotion, gesture_tag=replan_state.get("gesture_tag"), gaze_bucket=replan_state.get("gaze_bucket"), query=replan_state.get("raw_query") or "", candidates=cand_dicts, ) final = { "user_id": req.user_id, "query": replan_state["raw_query"], "response": replan_state["selected_response"] or "", "candidates": cand_dicts, "affect": affect_emotion, "llm_tier": replan_state.get("llm_tier_used", "unknown"), "llm_model": replan_state.get("llm_model_used", "unknown"), "retrieval_mode": replan_state.get("retrieval_mode_used", "unknown"), "latency": replan_state.get("latency_log") or {}, "guardrail_passed": replan_state.get("guardrail_passed", True), "run_id": run_id, "turn_id": replan_state["turn_id"], "eval_scores": eval_scores, } yield _sse({"type": "complete", "response": final}) return StreamingResponse( _gen(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) @app.post("/chat/regenerate", response_model=ChatResponse) def chat_regenerate(req: RegenerateRequest): """Re-run the planner for the same turn with all prior candidates marked rejected. Does NOT advance turn_id — same partner query, fresh fan-out of candidates. """ if req.user_id not in _sessions: raise HTTPException(status_code=404, detail="no active session") session = _sessions[req.user_id] last: PipelineState | None = session.get("last_state") if last is None: raise HTTPException(status_code=409, detail="no prior turn to regenerate") if req.turn_id is not None and req.turn_id != last["turn_id"]: raise HTTPException(status_code=409, detail="stale turn_id") gen_cfg = dict(last.get("generation_config") or {}) gen_cfg["persona_mod"] = "all_rejected" gen_cfg.setdefault("tone_tag", "[TONE:TRY_DIFFERENT_ANGLE]") prior_rejected = [c.get("text", "") for c in (last.get("candidates") or [])] merged_rejected = ( list(last.get("rejected_candidates") or []) + [t for t in prior_rejected if t] + [t for t in req.rejected_texts if t] ) # Dedupe while preserving order. seen: set[str] = set() rejected: list[str] = [] for t in merged_rejected: key = t.strip().lower() if key and key not in seen: seen.add(key) rejected.append(t) # Strip the tail (partner, aac_user) so feedback doesn't stack duplicate # history entries on every regenerate — the user hasn't committed yet. trimmed_history = list(last.get("session_history") or []) if trimmed_history and trimmed_history[-1].get("role") == "aac_user": trimmed_history.pop() if trimmed_history and trimmed_history[-1].get("role") == "partner": trimmed_history.pop() replan_state: PipelineState = dict(last) # type: ignore[assignment] replan_state["session_history"] = trimmed_history replan_state["generation_config"] = gen_cfg replan_state["rejected_candidates"] = rejected replan_state["turnaround_triggered"] = False # keep multi-shot replan_state["latency_log"] = { "t_sensing": 0.0, "t_intent": 0.0, "t_retrieval": 0.0, "t_generation": 0.0, "t_total": 0.0, } planner_update = planner_node.run_primary(replan_state) replan_state.update(planner_update) # type: ignore[typeddict-item] # Feedback node rewrites history + assigns a new run_id. Each regenerate # is its own row in turns.jsonl for the eval record. feedback_update = feedback_node.run(replan_state) replan_state.update(feedback_update) # type: ignore[typeddict-item] session["session_history"] = replan_state["session_history"] session["bucket_priors"] = replan_state["bucket_priors"] session["type_priors"] = replan_state["type_priors"] session["last_state"] = replan_state affect_emotion = (replan_state.get("affect") or {}).get("emotion", "NEUTRAL") run_id = replan_state.get("run_id") eval_scores = _compute_and_persist_evals( run_id=run_id, user_id=req.user_id, turn_id=replan_state["turn_id"], response=replan_state["selected_response"] or "", chunks=list(replan_state.get("retrieved_chunks") or []), latency_log=dict(replan_state.get("latency_log") or {}), affect=affect_emotion, gesture_tag=replan_state.get("gesture_tag"), gaze_bucket=replan_state.get("gaze_bucket"), query=replan_state.get("raw_query") or "", candidates=_candidate_dicts(replan_state), ) return ChatResponse( user_id=req.user_id, query=replan_state["raw_query"], response=replan_state["selected_response"] or "", candidates=[CandidateOut(**c) for c in replan_state.get("candidates") or []], affect=affect_emotion, llm_tier=replan_state.get("llm_tier_used", "unknown"), llm_model=replan_state.get("llm_model_used", "unknown"), retrieval_mode=replan_state.get("retrieval_mode_used", "unknown"), latency=replan_state.get("latency_log") or {}, guardrail_passed=replan_state.get("guardrail_passed", True), run_id=run_id, turn_id=replan_state["turn_id"], eval_scores=eval_scores, ) @app.post("/feedback/rating") def submit_rating(req: RatingRequest): if not _RUN_ID_RE.match(req.run_id): raise HTTPException(status_code=400, detail="invalid run_id") logs_dir = Path(settings.logs_dir) logs_dir.mkdir(parents=True, exist_ok=True) entry = { "ts": time.time(), "run_id": req.run_id, "user_id": req.user_id, "authenticity": req.authenticity, "rater_id": req.rater_id, "notes": req.notes, } with open(logs_dir / "ratings.jsonl", "a", encoding="utf-8") as f: f.write(json.dumps(entry, ensure_ascii=False) + "\n") return {"status": "ok"} class InkRecognizeRequest(BaseModel): image_base64: str @lru_cache(maxsize=1) def _get_vision_client(): from openai import OpenAI as _OpenAI return _OpenAI( base_url=settings.ink_vision_base_url, api_key=settings.ink_vision_api_key or "unused", ) @app.post("/ink/recognize") def ink_recognize(req: InkRecognizeRequest): if not req.image_base64: return {"text": ""} if not settings.ink_vision_api_key: _log.warning("/ink/recognize called but INK_VISION_API_KEY is not set") raise HTTPException(status_code=503, detail="INK_VISION_API_KEY not configured") try: client = _get_vision_client() response = client.chat.completions.create( model=settings.ink_vision_model, messages=[ { "role": "user", "content": [ { "type": "image_url", "image_url": { "url": f"data:image/png;base64,{req.image_base64}" }, }, { "type": "text", "text": ( "This is a single handwritten character or short word " "drawn in the air. Reply with ONLY the character or " "word, nothing else." ), }, ], } ], max_tokens=64, temperature=0.0, ) raw = response.choices[0].message.content or "" _log.info("/ink/recognize raw → %r", raw[:200]) # Strip blocks emitted by reasoning models, harmless on others. text = re.sub(r".*?", "", raw, flags=re.DOTALL).strip() _log.info("/ink/recognize → %r", text) return {"text": text} except Exception as exc: _log.exception("/ink/recognize failed: %r", exc) raise HTTPException(status_code=500, detail=str(exc)) from exc # Serve React frontend — must be last so API routes take priority _frontend_dist = Path(__file__).parent.parent.parent / "frontend" / "dist" if _frontend_dist.exists(): app.mount("/", StaticFiles(directory=str(_frontend_dist), html=True), name="static")