aac-chatbot / backend /main.py
akashkolte's picture
-
79a823e
# CLI entry point for the AAC chatbot pipeline.
import argparse
import copy
import json
import os
import sys
import time
from backend.config.settings import settings
from backend.guardrails.checks import check_input
from backend.pipeline.graph import run_pipeline
from backend.pipeline.nodes.intent import _AFFECT_CONFIG
from backend.pipeline.state import GenerationConfig, PipelineState
from backend.retrieval.priors import BUCKETS, CHUNK_TYPES, uniform
from backend.retrieval.vector_store import _get_embedder
from backend.sensing.bucket_keywords import infer_bucket
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="AAC Chatbot CLI")
p.add_argument("--user", type=str, default=None, help="Persona user_id")
p.add_argument("--debug", action="store_true", help="Print latency table each turn")
p.add_argument(
"--fast",
action="store_true",
help="Skip LLM intent call — use keyword routing instead (faster local dev)",
)
p.add_argument(
"--tier",
type=str,
default=None,
choices=["primary", "fallback"],
help="Override LLM tier (default: settings.active_llm_tier)",
)
return p.parse_args()
# ── Fast keyword-based intent routing (bypasses the slow LLM intent call) ──────
def _keyword_intent(query: str) -> tuple[dict, GenerationConfig]:
"""Replicate milestone-1 keyword routing as a fast local-dev shortcut."""
q = query.lower()
bucket = infer_bucket(query)
intent_type = (
"CONTEXTUAL"
if any(w in q for w in ["you just said", "earlier", "you mentioned"])
else "PERSONAL"
)
# `style_constraints` is vestigial — planner reads `generation_config` (below) as the source of truth.
route = {
"sub_intents": [
{
"type": intent_type,
"query": query,
"bucket_hint": bucket,
"priority": "normal",
}
],
"style_constraints": {
"tone_tag": "[TONE:DEFAULT]",
"max_tokens": 100,
"retrieval_mode": "full",
"persona_mod": "baseline",
},
"affect": "NEUTRAL",
}
# Deep-copy: callers may mutate gen_config downstream; never hand them the shared constant.
gen_config: GenerationConfig = copy.deepcopy(_AFFECT_CONFIG["NEUTRAL"])
return route, gen_config
def load_users() -> dict[str, dict]:
with open(settings.users_json) as f:
return {u["id"]: u for u in json.load(f)["users"]}
def load_persona_profile(user_id: str) -> dict:
with open(settings.memories_dir / f"{user_id}.json") as f:
return json.load(f)["profile"]
def select_user(users: dict[str, dict], user_arg: str | None) -> str:
if user_arg:
if user_arg not in users:
print(f"Unknown user '{user_arg}'. Available: {list(users)}")
sys.exit(1)
return user_arg
print("\nAvailable personas:")
for uid, u in users.items():
print(f" {uid:20s}{u['name']} ({u['condition']})")
uid = input("\nSelect user id: ").strip()
if uid not in users:
print("Invalid id.")
sys.exit(1)
return uid
def print_latency(log: dict, turn: int) -> None:
fields = ["t_sensing", "t_intent", "t_retrieval", "t_generation", "t_total"]
labels = ["sensing", "intent", "retrieval", "generation", "TOTAL"]
vals = [f"{log.get(f, 0):.3f}s" for f in fields]
widths = [max(len(l), len(v)) for l, v in zip(labels, vals)]
sep = " | "
print(f"\n[turn {turn} latency]")
print(sep.join(l.ljust(w) for l, w in zip(labels, widths)))
print(sep.join(v.ljust(w) for v, w in zip(vals, widths)))
def main() -> None:
args = parse_args()
# Optionally override the LLM tier at runtime
if args.tier:
os.environ["ACTIVE_LLM_TIER"] = args.tier
users = load_users()
user_id = select_user(users, args.user)
profile = load_persona_profile(user_id)
# Warm up models
print(f"\nLoading models for {profile['name']} …", end=" ", flush=True)
_get_embedder()
print("ready.\n")
session_history: list[dict] = []
bucket_priors = uniform(BUCKETS)
type_priors = uniform(CHUNK_TYPES)
turn_id = 0
print(f"Chatting as {profile['name']}. Type 'quit' to exit.\n")
while True:
try:
query = input("Partner: ").strip()
except (EOFError, KeyboardInterrupt):
print("\nBye.")
break
if query.lower() in {"quit", "exit", "q"}:
break
if not query:
continue
guard = check_input(query)
if not guard["allowed"]:
print(f"AAC Bot: {guard['fallback']}\n")
continue
turn_id += 1
# --fast: resolve intent via keywords, skip the slow LLM intent node
t_intent_fast = 0.0
if args.fast:
t0 = time.perf_counter()
pre_route, pre_gen_config = _keyword_intent(query)
t_intent_fast = time.perf_counter() - t0
else:
pre_route, pre_gen_config = None, None
state = PipelineState(
user_id=user_id,
persona_profile=profile,
session_history=session_history,
turn_id=turn_id,
affect=None,
gesture_tag=None,
gaze_bucket=None,
air_written_text=None,
voice_text=None,
resolved_intent=None,
raw_query=query,
intent_route=pre_route, # pre-filled → intent node sees it and skips LLM call
generation_config=pre_gen_config,
retrieved_chunks=[],
bucket_priors=bucket_priors,
type_priors=type_priors,
retrieval_mode_used="",
augmented_prompt=None,
candidates=[],
rejected_candidates=[],
selected_response=None,
llm_tier_used="",
latency_log={
"t_sensing": 0.0,
"t_intent": round(t_intent_fast, 4),
"t_retrieval": 0.0,
"t_generation": 0.0,
"t_total": 0.0,
},
run_id=None,
guardrail_passed=True,
)
result: PipelineState = run_pipeline(state)
print(f"AAC Bot: {result['selected_response']}\n")
session_history = result["session_history"]
bucket_priors = result["bucket_priors"]
type_priors = result["type_priors"]
if args.debug:
print_latency(result.get("latency_log") or {}, turn_id)
print(
f" tier={result.get('llm_tier_used')} | "
f"retrieval={result.get('retrieval_mode_used')} | "
f"affect={(result.get('affect') or {}).get('emotion', '?')}\n"
)
if __name__ == "__main__":
main()