# CLI entry point for the AAC chatbot pipeline. import argparse import copy import json import os import sys import time from backend.config.settings import settings from backend.guardrails.checks import check_input from backend.pipeline.graph import run_pipeline from backend.pipeline.nodes.intent import _AFFECT_CONFIG from backend.pipeline.state import GenerationConfig, PipelineState from backend.retrieval.priors import BUCKETS, CHUNK_TYPES, uniform from backend.retrieval.vector_store import _get_embedder from backend.sensing.bucket_keywords import infer_bucket def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description="AAC Chatbot CLI") p.add_argument("--user", type=str, default=None, help="Persona user_id") p.add_argument("--debug", action="store_true", help="Print latency table each turn") p.add_argument( "--fast", action="store_true", help="Skip LLM intent call — use keyword routing instead (faster local dev)", ) p.add_argument( "--tier", type=str, default=None, choices=["primary", "fallback"], help="Override LLM tier (default: settings.active_llm_tier)", ) return p.parse_args() # ── Fast keyword-based intent routing (bypasses the slow LLM intent call) ────── def _keyword_intent(query: str) -> tuple[dict, GenerationConfig]: """Replicate milestone-1 keyword routing as a fast local-dev shortcut.""" q = query.lower() bucket = infer_bucket(query) intent_type = ( "CONTEXTUAL" if any(w in q for w in ["you just said", "earlier", "you mentioned"]) else "PERSONAL" ) # `style_constraints` is vestigial — planner reads `generation_config` (below) as the source of truth. route = { "sub_intents": [ { "type": intent_type, "query": query, "bucket_hint": bucket, "priority": "normal", } ], "style_constraints": { "tone_tag": "[TONE:DEFAULT]", "max_tokens": 100, "retrieval_mode": "full", "persona_mod": "baseline", }, "affect": "NEUTRAL", } # Deep-copy: callers may mutate gen_config downstream; never hand them the shared constant. gen_config: GenerationConfig = copy.deepcopy(_AFFECT_CONFIG["NEUTRAL"]) return route, gen_config def load_users() -> dict[str, dict]: with open(settings.users_json) as f: return {u["id"]: u for u in json.load(f)["users"]} def load_persona_profile(user_id: str) -> dict: with open(settings.memories_dir / f"{user_id}.json") as f: return json.load(f)["profile"] def select_user(users: dict[str, dict], user_arg: str | None) -> str: if user_arg: if user_arg not in users: print(f"Unknown user '{user_arg}'. Available: {list(users)}") sys.exit(1) return user_arg print("\nAvailable personas:") for uid, u in users.items(): print(f" {uid:20s} — {u['name']} ({u['condition']})") uid = input("\nSelect user id: ").strip() if uid not in users: print("Invalid id.") sys.exit(1) return uid def print_latency(log: dict, turn: int) -> None: fields = ["t_sensing", "t_intent", "t_retrieval", "t_generation", "t_total"] labels = ["sensing", "intent", "retrieval", "generation", "TOTAL"] vals = [f"{log.get(f, 0):.3f}s" for f in fields] widths = [max(len(l), len(v)) for l, v in zip(labels, vals)] sep = " | " print(f"\n[turn {turn} latency]") print(sep.join(l.ljust(w) for l, w in zip(labels, widths))) print(sep.join(v.ljust(w) for v, w in zip(vals, widths))) def main() -> None: args = parse_args() # Optionally override the LLM tier at runtime if args.tier: os.environ["ACTIVE_LLM_TIER"] = args.tier users = load_users() user_id = select_user(users, args.user) profile = load_persona_profile(user_id) # Warm up models print(f"\nLoading models for {profile['name']} …", end=" ", flush=True) _get_embedder() print("ready.\n") session_history: list[dict] = [] bucket_priors = uniform(BUCKETS) type_priors = uniform(CHUNK_TYPES) turn_id = 0 print(f"Chatting as {profile['name']}. Type 'quit' to exit.\n") while True: try: query = input("Partner: ").strip() except (EOFError, KeyboardInterrupt): print("\nBye.") break if query.lower() in {"quit", "exit", "q"}: break if not query: continue guard = check_input(query) if not guard["allowed"]: print(f"AAC Bot: {guard['fallback']}\n") continue turn_id += 1 # --fast: resolve intent via keywords, skip the slow LLM intent node t_intent_fast = 0.0 if args.fast: t0 = time.perf_counter() pre_route, pre_gen_config = _keyword_intent(query) t_intent_fast = time.perf_counter() - t0 else: pre_route, pre_gen_config = None, None state = PipelineState( user_id=user_id, persona_profile=profile, session_history=session_history, turn_id=turn_id, affect=None, gesture_tag=None, gaze_bucket=None, air_written_text=None, voice_text=None, resolved_intent=None, raw_query=query, intent_route=pre_route, # pre-filled → intent node sees it and skips LLM call generation_config=pre_gen_config, retrieved_chunks=[], bucket_priors=bucket_priors, type_priors=type_priors, retrieval_mode_used="", augmented_prompt=None, candidates=[], rejected_candidates=[], selected_response=None, llm_tier_used="", latency_log={ "t_sensing": 0.0, "t_intent": round(t_intent_fast, 4), "t_retrieval": 0.0, "t_generation": 0.0, "t_total": 0.0, }, run_id=None, guardrail_passed=True, ) result: PipelineState = run_pipeline(state) print(f"AAC Bot: {result['selected_response']}\n") session_history = result["session_history"] bucket_priors = result["bucket_priors"] type_priors = result["type_priors"] if args.debug: print_latency(result.get("latency_log") or {}, turn_id) print( f" tier={result.get('llm_tier_used')} | " f"retrieval={result.get('retrieval_mode_used')} | " f"affect={(result.get('affect') or {}).get('emotion', '?')}\n" ) if __name__ == "__main__": main()