Spaces:
Sleeping
Sleeping
File size: 6,920 Bytes
375924d e06dc15 af222c8 e06dc15 375924d 5187368 af222c8 375924d 690c106 5187368 7fd8c8a e06dc15 375924d 5187368 375924d e06dc15 375924d e06dc15 7fd8c8a e06dc15 375924d e06dc15 af222c8 e06dc15 375924d e06dc15 af222c8 e06dc15 09fe9bc e06dc15 375924d e06dc15 375924d e06dc15 375924d e06dc15 09fe9bc e06dc15 690c106 e06dc15 375924d e06dc15 375924d e06dc15 535a98d e06dc15 375924d e06dc15 690c106 e06dc15 df78c68 e06dc15 375924d 5187368 e06dc15 5187368 e06dc15 375924d 690c106 e06dc15 375924d e06dc15 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 | # CLI entry point for the AAC chatbot pipeline.
import argparse
import copy
import json
import os
import sys
import time
from backend.config.settings import settings
from backend.guardrails.checks import check_input
from backend.pipeline.graph import run_pipeline
from backend.pipeline.nodes.intent import _AFFECT_CONFIG
from backend.pipeline.state import GenerationConfig, PipelineState
from backend.retrieval.priors import BUCKETS, CHUNK_TYPES, uniform
from backend.retrieval.vector_store import _get_embedder
from backend.sensing.bucket_keywords import infer_bucket
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="AAC Chatbot CLI")
p.add_argument("--user", type=str, default=None, help="Persona user_id")
p.add_argument("--debug", action="store_true", help="Print latency table each turn")
p.add_argument(
"--fast",
action="store_true",
help="Skip LLM intent call β use keyword routing instead (faster local dev)",
)
p.add_argument(
"--tier",
type=str,
default=None,
choices=["primary", "fallback"],
help="Override LLM tier (default: settings.active_llm_tier)",
)
return p.parse_args()
# ββ Fast keyword-based intent routing (bypasses the slow LLM intent call) ββββββ
def _keyword_intent(query: str) -> tuple[dict, GenerationConfig]:
"""Replicate milestone-1 keyword routing as a fast local-dev shortcut."""
q = query.lower()
bucket = infer_bucket(query)
intent_type = (
"CONTEXTUAL"
if any(w in q for w in ["you just said", "earlier", "you mentioned"])
else "PERSONAL"
)
# `style_constraints` is vestigial β planner reads `generation_config` (below) as the source of truth.
route = {
"sub_intents": [
{
"type": intent_type,
"query": query,
"bucket_hint": bucket,
"priority": "normal",
}
],
"style_constraints": {
"tone_tag": "[TONE:DEFAULT]",
"max_tokens": 100,
"retrieval_mode": "full",
"persona_mod": "baseline",
},
"affect": "NEUTRAL",
}
# Deep-copy: callers may mutate gen_config downstream; never hand them the shared constant.
gen_config: GenerationConfig = copy.deepcopy(_AFFECT_CONFIG["NEUTRAL"])
return route, gen_config
def load_users() -> dict[str, dict]:
with open(settings.users_json) as f:
return {u["id"]: u for u in json.load(f)["users"]}
def load_persona_profile(user_id: str) -> dict:
with open(settings.memories_dir / f"{user_id}.json") as f:
return json.load(f)["profile"]
def select_user(users: dict[str, dict], user_arg: str | None) -> str:
if user_arg:
if user_arg not in users:
print(f"Unknown user '{user_arg}'. Available: {list(users)}")
sys.exit(1)
return user_arg
print("\nAvailable personas:")
for uid, u in users.items():
print(f" {uid:20s} β {u['name']} ({u['condition']})")
uid = input("\nSelect user id: ").strip()
if uid not in users:
print("Invalid id.")
sys.exit(1)
return uid
def print_latency(log: dict, turn: int) -> None:
fields = ["t_sensing", "t_intent", "t_retrieval", "t_generation", "t_total"]
labels = ["sensing", "intent", "retrieval", "generation", "TOTAL"]
vals = [f"{log.get(f, 0):.3f}s" for f in fields]
widths = [max(len(l), len(v)) for l, v in zip(labels, vals)]
sep = " | "
print(f"\n[turn {turn} latency]")
print(sep.join(l.ljust(w) for l, w in zip(labels, widths)))
print(sep.join(v.ljust(w) for v, w in zip(vals, widths)))
def main() -> None:
args = parse_args()
# Optionally override the LLM tier at runtime
if args.tier:
os.environ["ACTIVE_LLM_TIER"] = args.tier
users = load_users()
user_id = select_user(users, args.user)
profile = load_persona_profile(user_id)
# Warm up models
print(f"\nLoading models for {profile['name']} β¦", end=" ", flush=True)
_get_embedder()
print("ready.\n")
session_history: list[dict] = []
bucket_priors = uniform(BUCKETS)
type_priors = uniform(CHUNK_TYPES)
turn_id = 0
print(f"Chatting as {profile['name']}. Type 'quit' to exit.\n")
while True:
try:
query = input("Partner: ").strip()
except (EOFError, KeyboardInterrupt):
print("\nBye.")
break
if query.lower() in {"quit", "exit", "q"}:
break
if not query:
continue
guard = check_input(query)
if not guard["allowed"]:
print(f"AAC Bot: {guard['fallback']}\n")
continue
turn_id += 1
# --fast: resolve intent via keywords, skip the slow LLM intent node
t_intent_fast = 0.0
if args.fast:
t0 = time.perf_counter()
pre_route, pre_gen_config = _keyword_intent(query)
t_intent_fast = time.perf_counter() - t0
else:
pre_route, pre_gen_config = None, None
state = PipelineState(
user_id=user_id,
persona_profile=profile,
session_history=session_history,
turn_id=turn_id,
affect=None,
gesture_tag=None,
gaze_bucket=None,
air_written_text=None,
voice_text=None,
resolved_intent=None,
raw_query=query,
intent_route=pre_route, # pre-filled β intent node sees it and skips LLM call
generation_config=pre_gen_config,
retrieved_chunks=[],
bucket_priors=bucket_priors,
type_priors=type_priors,
retrieval_mode_used="",
augmented_prompt=None,
candidates=[],
rejected_candidates=[],
selected_response=None,
llm_tier_used="",
latency_log={
"t_sensing": 0.0,
"t_intent": round(t_intent_fast, 4),
"t_retrieval": 0.0,
"t_generation": 0.0,
"t_total": 0.0,
},
run_id=None,
guardrail_passed=True,
)
result: PipelineState = run_pipeline(state)
print(f"AAC Bot: {result['selected_response']}\n")
session_history = result["session_history"]
bucket_priors = result["bucket_priors"]
type_priors = result["type_priors"]
if args.debug:
print_latency(result.get("latency_log") or {}, turn_id)
print(
f" tier={result.get('llm_tier_used')} | "
f"retrieval={result.get('retrieval_mode_used')} | "
f"affect={(result.get('affect') or {}).get('emotion', '?')}\n"
)
if __name__ == "__main__":
main()
|