Spaces:
Sleeping
Sleeping
Commit ·
7fd8c8a
1
Parent(s): 09fe9bc
Route sub-intents to their own pools, rip out the LLM intent router
Browse filesPartner queries now get split on conjunctions and each fragment goes
to the right place: personal memories, session history, or general
knowledge. The old LLM router was burning ~100s per turn on retry
loops; replaced it with a BGE cosine match against seed sentences
(~30ms).
Also:
- CONTEXTUAL pulls persona memory too, so "what did I just ask" still
sounds like them
- Planner prompt split into system + user so the character sheet gets
cached between turns
- Tightened anti-meta rules after Gemma4 leaked its own character
sheet into the response
- THINKING_MODE=suppress so it stops thinking out loud
- run.sh --debug now forwards to the CLI
- README: three mermaid diagrams, rewritten intro, updated for the 14
personas
- README.md +194 -66
- backend/guardrails/checks.py +8 -0
- backend/main.py +2 -15
- backend/pipeline/nodes/feedback.py +10 -4
- backend/pipeline/nodes/intent.py +148 -120
- backend/pipeline/nodes/planner.py +105 -38
- backend/pipeline/nodes/retrieval.py +99 -30
- backend/pipeline/state.py +3 -2
- backend/retrieval/contextual.py +57 -0
- backend/retrieval/vector_store.py +5 -0
- backend/sensing/bucket_keywords.py +15 -0
- run.sh +6 -0
README.md
CHANGED
|
@@ -1,12 +1,8 @@
|
|
| 1 |
# Multimodal AAC Chatbot
|
| 2 |
|
| 3 |
-
|
| 4 |
-
it fuses real-time multimodal non-verbal signals — facial expressions, hand gestures, gaze, and
|
| 5 |
-
air writing — with personal memory retrieval to generate responses in that person's authentic voice.
|
| 6 |
|
| 7 |
-
|
| 8 |
-
with two conditional branches (no LangGraph / LangChain), torch-tensor
|
| 9 |
-
retrieval (no FAISS), and JSONL turn logging (no MLflow).
|
| 10 |
|
| 11 |
---
|
| 12 |
|
|
@@ -26,44 +22,181 @@ retrieval (no FAISS), and JSONL turn logging (no MLflow).
|
|
| 26 |
|
| 27 |
## What is AAC?
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
a personalized digital twin that communicates on their behalf.
|
| 33 |
|
| 34 |
---
|
| 35 |
|
| 36 |
## System Architecture
|
| 37 |
|
|
|
|
|
|
|
| 38 |
```
|
| 39 |
-
React
|
| 40 |
-
MediaPipe JS
|
| 41 |
-
Chat UI ────────
|
| 42 |
-
Webcam
|
| 43 |
-
|
| 44 |
```
|
| 45 |
|
|
|
|
|
|
|
| 46 |
| Layer | Module | What it does |
|
| 47 |
|-------|--------|-------------|
|
| 48 |
-
| L1 | `frontend/src/hooks/useSensing.ts` |
|
| 49 |
-
| L2 | `backend/pipeline/nodes/intent.py` |
|
| 50 |
-
| L3 | `backend/pipeline/nodes/retrieval.py` |
|
| 51 |
-
| L4 | `backend/pipeline/nodes/planner.py` |
|
| 52 |
-
| L5 | `backend/pipeline/nodes/feedback.py` | JSONL turn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
|
| 55 |
-
-
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
---
|
| 59 |
|
| 60 |
## Prerequisites
|
| 61 |
|
| 62 |
-
- Python
|
| 63 |
-
- Node.js
|
| 64 |
-
- An [Ollama Cloud](https://ollama.com) account —
|
| 65 |
-
|
| 66 |
-
- A webcam (for live sensing; optional for CLI mode)
|
| 67 |
|
| 68 |
---
|
| 69 |
|
|
@@ -75,57 +208,60 @@ cd multimodal_aac_chatbot
|
|
| 75 |
bash setup.sh
|
| 76 |
```
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
- `.env` file creation from template
|
| 82 |
-
- Vector index building (downloads BGE-small embedder on first run, saves
|
| 83 |
-
per-user `vectors.pt` under `data/vector_store/`)
|
| 84 |
-
- Frontend dependency installation (pnpm)
|
| 85 |
|
| 86 |
---
|
| 87 |
|
| 88 |
## Configuration
|
| 89 |
|
| 90 |
-
|
| 91 |
|
| 92 |
-
| Variable | Default |
|
| 93 |
|----------|---------|-------------|
|
| 94 |
-
| `ACTIVE_LLM_TIER` | `primary` | `primary`
|
| 95 |
-
| `PRIMARY_MODEL` | `gemma4:31b-cloud` | Ollama Cloud model for primary tier |
|
| 96 |
-
| `FALLBACK_MODEL` | `gemma4:31b-cloud` |
|
| 97 |
-
| `PRIMARY_BASE_URL` | `http://localhost:11434/v1` |
|
| 98 |
-
| `FALLBACK_LATENCY_THRESHOLD` | `3.5` |
|
| 99 |
-
| `LOGS_DIR` | `logs` | Where per-turn JSONL
|
| 100 |
|
| 101 |
---
|
| 102 |
|
| 103 |
## Running the Project
|
| 104 |
|
| 105 |
-
### Full stack
|
| 106 |
|
| 107 |
```bash
|
| 108 |
bash run.sh
|
| 109 |
```
|
| 110 |
|
| 111 |
-
|
| 112 |
-
Open [http://localhost:7550](http://localhost:7550) in your browser.
|
| 113 |
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
```bash
|
| 117 |
conda activate aac-chatbot
|
| 118 |
python -m backend.main --debug
|
| 119 |
```
|
| 120 |
|
| 121 |
-
|
|
|
|
|
|
|
| 122 |
|
| 123 |
```bash
|
| 124 |
conda activate aac-chatbot
|
| 125 |
uvicorn backend.api.main:app --reload
|
| 126 |
```
|
| 127 |
|
| 128 |
-
Example request:
|
| 129 |
```bash
|
| 130 |
curl -X POST http://localhost:8000/chat \
|
| 131 |
-H "Content-Type: application/json" \
|
|
@@ -173,38 +309,30 @@ multimodal_aac_chatbot/
|
|
| 173 |
|
| 174 |
## Personas
|
| 175 |
|
| 176 |
-
|
| 177 |
-
Parkinson's, locked-in syndrome, aphasia, Alzheimer's, cerebral palsy, non-verbal autism,
|
| 178 |
-
savant autism, intellectual disability, and spinal cord injury.
|
| 179 |
|
| 180 |
| ID | Source | Condition |
|
| 181 |
|----|--------|-----------|
|
| 182 |
| `stephen_hawking` | Real — *My Brief History* + interviews | ALS (mid-stage) |
|
| 183 |
-
| `michael_j_fox` | Real —
|
| 184 |
| `wendy_mitchell` | Real — *Somebody I Used to Know* + blog | Early-onset Alzheimer's |
|
| 185 |
| `christopher_reeve` | Real — *Still Me* | C4 spinal cord injury |
|
| 186 |
| `christy_brown` | Real — *My Left Foot* | Cerebral palsy (adult) |
|
| 187 |
| `gabby_giffords` | Real — *Gabby* memoir | Aphasia + TBI |
|
| 188 |
| `jason_becker` | Real — *Not Dead Yet* doc | Late-stage ALS |
|
| 189 |
| `jean_dominique_bauby` | Real — *The Diving Bell and the Butterfly* | Locked-in syndrome |
|
| 190 |
-
| `tito_mukhopadhyay` | Real —
|
| 191 |
| `abed_nadir` | Fictional — *Community* | Autism (verbal) |
|
| 192 |
| `allie_calhoun` | Fictional — *The Notebook* | Late-stage Alzheimer's |
|
| 193 |
| `forrest_gump` | Fictional — *Forrest Gump* | Intellectual disability |
|
| 194 |
| `walter_jr_white` | Fictional — *Breaking Bad* | Cerebral palsy (teen) |
|
| 195 |
| `raymond_babbitt` | Fictional — *Rain Man* | Savant autism |
|
| 196 |
|
| 197 |
-
Each persona has ~120
|
| 198 |
-
(`family`, `medical`, `hobbies`, `daily_routine`, `social`) and 3 chunk types
|
| 199 |
-
(`narrative`, `social_post`, `chat_log`). Total: ~2,300 chunks.
|
| 200 |
|
| 201 |
-
|
| 202 |
-
bibliography of memoirs, films, interviews, and other canonical sources behind every
|
| 203 |
-
persona, plus ethics notes on living-persons treatment.
|
| 204 |
|
| 205 |
-
|
| 206 |
-
existing persona, then run `python data/generate_users.py` and
|
| 207 |
-
`python -m backend.retrieval.vector_store`.
|
| 208 |
|
| 209 |
---
|
| 210 |
|
|
@@ -235,10 +363,10 @@ Heads up: all camera/sensing stuff is in the frontend (MediaPipe JS). Backend ju
|
|
| 235 |
|
| 236 |
### Intent decomposition
|
| 237 |
|
| 238 |
-
> Current state:
|
| 239 |
|
| 240 |
-
- [
|
| 241 |
-
- [
|
| 242 |
|
| 243 |
### Retrieval
|
| 244 |
|
|
|
|
| 1 |
# Multimodal AAC Chatbot
|
| 2 |
|
| 3 |
+
A chatbot that **speaks as an AAC user, not to them.** You pick a persona (Mia, Gerald, or Arjun) and the partner talks to them — the bot replies in that person's voice, using their memories, and adjusts what it says based on what the webcam sees: facial expression, hand gestures, where they're looking, and letters they trace in the air.
|
|
|
|
|
|
|
| 4 |
|
| 5 |
+
It's a training-free agentic RAG pipeline — a plain Python function chain with two branching points, torch matmul for retrieval, JSONL for logging. The goal was to keep every piece simple enough to read top-to-bottom in an afternoon.
|
|
|
|
|
|
|
| 6 |
|
| 7 |
---
|
| 8 |
|
|
|
|
| 22 |
|
| 23 |
## What is AAC?
|
| 24 |
|
| 25 |
+
AAC (Augmentative and Alternative Communication) covers the tools people use when spoken or written communication is hard for them — cerebral palsy, ALS, autism, stroke recovery, and so on. Usually that's a tablet with a symbol grid, or an eye-tracker, or a switch. The slow part isn't the typing — it's that most devices don't know *you*. Every conversation starts from scratch.
|
| 26 |
+
|
| 27 |
+
This project is a small attempt at the other direction: give each user a persona their device already knows, and let the device reply in their voice.
|
|
|
|
| 28 |
|
| 29 |
---
|
| 30 |
|
| 31 |
## System Architecture
|
| 32 |
|
| 33 |
+
The browser does all the camera work. MediaPipe JS runs inside React, classifies what it sees into small labels (`affect`, `gesture_tag`, `gaze_bucket`, `air_written_text`), and sends those alongside the partner's text to `/chat`. The backend never touches pixels.
|
| 34 |
+
|
| 35 |
```
|
| 36 |
+
React (browser) Backend (Python)
|
| 37 |
+
MediaPipe JS ──┐
|
| 38 |
+
Chat UI ────────┼── POST /chat ──► FastAPI ──► run_pipeline()
|
| 39 |
+
Webcam ─────────┘ │
|
| 40 |
+
Intent ──► Retrieval ──► Planner ──► Feedback
|
| 41 |
```
|
| 42 |
|
| 43 |
+
Five layers, each a tiny file:
|
| 44 |
+
|
| 45 |
| Layer | Module | What it does |
|
| 46 |
|-------|--------|-------------|
|
| 47 |
+
| L1 | `frontend/src/hooks/useSensing.ts` | Watches the webcam. Turns faces/hands/gaze/air-writing into string labels. Purely frontend. |
|
| 48 |
+
| L2 | `backend/pipeline/nodes/intent.py` | Splits the partner's question on conjunctions and punctuation, then classifies each fragment as PERSONAL, CONTEXTUAL, or OPEN_DOMAIN using cosine similarity against a handful of seed sentences. No LLM call. ~30ms per turn. |
|
| 49 |
+
| L3 | `backend/pipeline/nodes/retrieval.py` | Each sub-intent goes to its own pool. Personal → the user's memory vector store. Contextual → persona memory + relevant in-session turns layered on top (so "what did I just ask" still sounds like *them*). Open-domain → a stub chunk telling the LLM to answer from its own knowledge (web search is deliberately out of scope). |
|
| 50 |
+
| L4 | `backend/pipeline/nodes/planner.py` | Builds the prompt, calls the LLM, picks a response. Tone and max_tokens are shaped by the detected affect. |
|
| 51 |
+
| L5 | `backend/pipeline/nodes/feedback.py` | Writes one JSONL row per turn and bumps the Bayesian priors over which memory bucket was useful. |
|
| 52 |
+
|
| 53 |
+
Two places the pipeline branches:
|
| 54 |
+
- **Frustrated affect** → use the fast retrieval path (k=2, skip the reranker). The user wants an answer, not a thesis.
|
| 55 |
+
- **Cumulative latency past 3.5s** → switch to the smaller fallback model for generation.
|
| 56 |
+
|
| 57 |
+
### End-to-end: from partner speaking to response rendered
|
| 58 |
+
|
| 59 |
+
One diagram, left to right, every step a turn goes through. Follow the arrows.
|
| 60 |
+
|
| 61 |
+
```mermaid
|
| 62 |
+
flowchart LR
|
| 63 |
+
subgraph S1["① Partner side (browser)"]
|
| 64 |
+
direction TB
|
| 65 |
+
IN1[Partner types or speaks a question]
|
| 66 |
+
IN2[Webcam frame]
|
| 67 |
+
IN1 --> UI[Chat UI]
|
| 68 |
+
IN2 --> MP[MediaPipe JS<br/>face + hands + gaze]
|
| 69 |
+
MP --> LAB[Classify into labels<br/>affect, gesture_tag,<br/>gaze_bucket, air_written_text]
|
| 70 |
+
UI --> REQ
|
| 71 |
+
LAB --> REQ[POST /chat<br/>query + labels]
|
| 72 |
+
end
|
| 73 |
+
|
| 74 |
+
REQ ==> S2
|
| 75 |
+
|
| 76 |
+
subgraph S2["② Backend pipeline (Python)"]
|
| 77 |
+
direction TB
|
| 78 |
+
HYD[Hydrate PipelineState<br/>session_history, priors, profile] --> INT
|
| 79 |
+
INT[Intent node<br/>split query + classify fragments] --> BR1{FRUSTRATED?}
|
| 80 |
+
BR1 -- yes --> RFAST[Fast retrieval<br/>k=2]
|
| 81 |
+
BR1 -- no --> RFULL[Full retrieval<br/>k=5 → rerank 3]
|
| 82 |
+
RFAST --> POOL
|
| 83 |
+
RFULL --> POOL[Dispatch per sub-intent]
|
| 84 |
+
POOL --> PP[PERSONAL<br/>BGE vector store]
|
| 85 |
+
POOL --> PC[CONTEXTUAL<br/>personal + BGE over history]
|
| 86 |
+
POOL --> PO[OPEN_DOMAIN<br/>stub chunk]
|
| 87 |
+
PP --> MERGE[Merge + dedupe chunks]
|
| 88 |
+
PC --> MERGE
|
| 89 |
+
PO --> MERGE
|
| 90 |
+
MERGE --> PLAN[Planner<br/>build prompt with<br/>3 retrieval blocks + tone tag]
|
| 91 |
+
PLAN --> BR2{Total latency<br/>> 3.5s?}
|
| 92 |
+
BR2 -- yes --> LLMF[Fallback LLM<br/>Ollama Cloud, smaller]
|
| 93 |
+
BR2 -- no --> LLMP[Primary LLM<br/>Ollama Cloud]
|
| 94 |
+
LLMF --> GRD[Guardrail check<br/>persona breaks,<br/>unsupported claims]
|
| 95 |
+
LLMP --> GRD
|
| 96 |
+
GRD --> FB[Feedback node<br/>log turn to JSONL,<br/>bump bucket priors,<br/>append to session history]
|
| 97 |
+
end
|
| 98 |
+
|
| 99 |
+
FB ==> S3
|
| 100 |
+
|
| 101 |
+
subgraph S3["③ Back to partner"]
|
| 102 |
+
direction TB
|
| 103 |
+
RESP[Response in persona's voice<br/>+ latency breakdown<br/>+ eval scores]
|
| 104 |
+
RESP --> RENDER[Chat UI renders it]
|
| 105 |
+
end
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
**A concrete example.** Partner says *"how are you, and what's the capital of France?"* while the webcam reads a relaxed face:
|
| 109 |
+
|
| 110 |
+
1. Browser sends `{query, affect: NEUTRAL, gesture_tag: null, …}`.
|
| 111 |
+
2. Intent node splits on `,` and ` and ` → two fragments. Classifier tags them `PERSONAL` and `OPEN_DOMAIN`.
|
| 112 |
+
3. Affect isn't FRUSTRATED, so full retrieval runs.
|
| 113 |
+
4. Dispatcher hits the persona store for fragment one, emits the open-domain stub for fragment two, merges both.
|
| 114 |
+
5. Planner drops the two chunks into separate prompt blocks and calls the primary LLM.
|
| 115 |
+
6. Guardrail passes, feedback writes the row, the response — in Mia's voice — comes back through the same `/chat` response.
|
| 116 |
+
|
| 117 |
+
Total wall time is normally under 6 seconds end-to-end; the slow part is the LLM call, not anything you wrote.
|
| 118 |
+
|
| 119 |
+
### What a single turn actually looks like
|
| 120 |
+
|
| 121 |
+
```mermaid
|
| 122 |
+
flowchart TD
|
| 123 |
+
A[Partner types or speaks] --> B[React captures query<br/>+ webcam labels]
|
| 124 |
+
B --> C[POST /chat]
|
| 125 |
+
C --> D[Intent node<br/>split + classify]
|
| 126 |
+
D --> E{Any FRUSTRATED<br/>affect signal?}
|
| 127 |
+
E -- yes --> F[Fast retrieval<br/>k=2, no reranker]
|
| 128 |
+
E -- no --> G[Full retrieval<br/>k=5 → rerank to 3]
|
| 129 |
+
F --> H{Cumulative<br/>latency > 3.5s?}
|
| 130 |
+
G --> H
|
| 131 |
+
H -- yes --> I[Fallback LLM<br/>smaller, faster]
|
| 132 |
+
H -- no --> J[Primary LLM]
|
| 133 |
+
I --> K[Guardrail check]
|
| 134 |
+
J --> K
|
| 135 |
+
K --> L[Feedback node<br/>JSONL log + priors]
|
| 136 |
+
L --> M[Response in persona's voice]
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
### How sub-intents fan out
|
| 140 |
+
|
| 141 |
+
This is the part that took a few iterations to get right. Each partner query can be *multiple* questions stitched together with "and" / "but" / punctuation. Each fragment gets classified separately and sent to its own retrieval pool.
|
| 142 |
+
|
| 143 |
+
```mermaid
|
| 144 |
+
flowchart LR
|
| 145 |
+
Q[""how are you,<br/>and what's the<br/>capital of France?""] --> S[Split on conjunctions<br/>and punctuation]
|
| 146 |
+
S --> F1[fragment:<br/>how are you]
|
| 147 |
+
S --> F2[fragment:<br/>capital of France]
|
| 148 |
+
|
| 149 |
+
F1 --> CL[BGE zero-shot<br/>cosine vs exemplars]
|
| 150 |
+
F2 --> CL
|
| 151 |
|
| 152 |
+
CL --> P[PERSONAL<br/>→ persona memory vectors]
|
| 153 |
+
CL --> CX[CONTEXTUAL<br/>→ persona memory +<br/>relevant session history]
|
| 154 |
+
CL --> OD[OPEN_DOMAIN<br/>→ stub, LLM answers<br/>from own knowledge]
|
| 155 |
+
|
| 156 |
+
P --> MERGE[Merge, dedupe,<br/>hand to planner]
|
| 157 |
+
CX --> MERGE
|
| 158 |
+
OD --> MERGE
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
The classifier is just cosine similarity against 5 seed sentences per class — no LLM, ~30ms per turn. The old version called an LLM and retried up to 3× on JSON errors; on a bad day that was 100+ seconds of dead time.
|
| 162 |
+
|
| 163 |
+
### State that flows between nodes
|
| 164 |
+
|
| 165 |
+
Every node takes a `PipelineState` dict and returns a partial update. Nothing is global.
|
| 166 |
+
|
| 167 |
+
```mermaid
|
| 168 |
+
flowchart LR
|
| 169 |
+
subgraph "set at turn start"
|
| 170 |
+
A[user_id, persona_profile,<br/>session_history, turn_id]
|
| 171 |
+
B[affect, gesture_tag,<br/>gaze_bucket, air_written_text]
|
| 172 |
+
C[raw_query]
|
| 173 |
+
end
|
| 174 |
+
|
| 175 |
+
subgraph "filled in by the pipeline"
|
| 176 |
+
D[intent_route,<br/>generation_config]
|
| 177 |
+
E[retrieved_chunks,<br/>retrieval_mode_used]
|
| 178 |
+
F[candidates,<br/>selected_response,<br/>llm_tier_used]
|
| 179 |
+
G[latency_log,<br/>run_id,<br/>guardrail_passed]
|
| 180 |
+
end
|
| 181 |
+
|
| 182 |
+
A --> D
|
| 183 |
+
B --> D
|
| 184 |
+
C --> D
|
| 185 |
+
D --> E
|
| 186 |
+
B --> E
|
| 187 |
+
D --> F
|
| 188 |
+
E --> F
|
| 189 |
+
F --> G
|
| 190 |
+
```
|
| 191 |
|
| 192 |
---
|
| 193 |
|
| 194 |
## Prerequisites
|
| 195 |
|
| 196 |
+
- Python 3.10+ (we use conda; 3.12 is what the env ships with)
|
| 197 |
+
- Node.js 22+ and pnpm
|
| 198 |
+
- An [Ollama Cloud](https://ollama.com) account. Generation hits cloud models — you don't need a local Ollama daemon running.
|
| 199 |
+
- A webcam if you want to play with the full stack. The CLI works without one.
|
|
|
|
| 200 |
|
| 201 |
---
|
| 202 |
|
|
|
|
| 208 |
bash setup.sh
|
| 209 |
```
|
| 210 |
|
| 211 |
+
`setup.sh` takes care of everything on the first run: creates the `aac-chatbot` conda env, installs Python and frontend deps, copies `.env.example` → `.env` for you to fill in, and builds the per-persona vector indexes under `data/vector_store/`. The first build downloads the BGE-small embedder (~130MB), so expect a short wait.
|
| 212 |
+
|
| 213 |
+
If you edit a persona later, rebuild the indexes: `python -m backend.retrieval.vector_store`.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
---
|
| 216 |
|
| 217 |
## Configuration
|
| 218 |
|
| 219 |
+
Everything is a Pydantic setting in [backend/config/settings.py](backend/config/settings.py) with a `.env` override. The knobs you'll actually touch:
|
| 220 |
|
| 221 |
+
| Variable | Default | What it does |
|
| 222 |
|----------|---------|-------------|
|
| 223 |
+
| `ACTIVE_LLM_TIER` | `primary` | Which tier to start on — `primary` or `fallback`. The pipeline switches automatically if a turn is slow. |
|
| 224 |
+
| `PRIMARY_MODEL` | `gemma4:31b-cloud` | Ollama Cloud model for the primary tier. |
|
| 225 |
+
| `FALLBACK_MODEL` | `gemma4:31b-cloud` | Smaller/faster model for the fallback tier. Point this at whatever smaller cloud model you have access to. |
|
| 226 |
+
| `PRIMARY_BASE_URL` | `http://localhost:11434/v1` | OpenAI-compatible endpoint. Defaults to the local Ollama proxy. |
|
| 227 |
+
| `FALLBACK_LATENCY_THRESHOLD` | `3.5` | If intent+retrieval already took this many seconds, skip the primary tier. |
|
| 228 |
+
| `LOGS_DIR` | `logs` | Where the per-turn JSONL goes. |
|
| 229 |
|
| 230 |
---
|
| 231 |
|
| 232 |
## Running the Project
|
| 233 |
|
| 234 |
+
### Full stack
|
| 235 |
|
| 236 |
```bash
|
| 237 |
bash run.sh
|
| 238 |
```
|
| 239 |
|
| 240 |
+
Starts FastAPI on `:8000` and the React dev server on `:7550`. Open [http://localhost:7550](http://localhost:7550). This is the mode you want for the webcam + sensing demo.
|
|
|
|
| 241 |
|
| 242 |
+
Pass any `backend.main` flag to `run.sh` and it drops the full stack and runs the CLI with those flags instead — handy for fast iteration:
|
| 243 |
+
|
| 244 |
+
```bash
|
| 245 |
+
bash run.sh --debug # CLI with per-turn state dumps
|
| 246 |
+
bash run.sh --user mia_chen --debug # jump straight to Mia
|
| 247 |
+
```
|
| 248 |
+
|
| 249 |
+
### CLI directly
|
| 250 |
|
| 251 |
```bash
|
| 252 |
conda activate aac-chatbot
|
| 253 |
python -m backend.main --debug
|
| 254 |
```
|
| 255 |
|
| 256 |
+
The CLI prints the full `PipelineState` after each turn — useful when you want to see what the classifier did or which chunks came back from which pool.
|
| 257 |
+
|
| 258 |
+
### API directly
|
| 259 |
|
| 260 |
```bash
|
| 261 |
conda activate aac-chatbot
|
| 262 |
uvicorn backend.api.main:app --reload
|
| 263 |
```
|
| 264 |
|
|
|
|
| 265 |
```bash
|
| 266 |
curl -X POST http://localhost:8000/chat \
|
| 267 |
-H "Content-Type: application/json" \
|
|
|
|
| 309 |
|
| 310 |
## Personas
|
| 311 |
|
| 312 |
+
Fourteen personas — nine anchored in real memoirs, five in canonical fiction. Together they span ALS, Parkinson's, locked-in syndrome, aphasia, Alzheimer's, cerebral palsy, non-verbal and savant autism, intellectual disability, and spinal cord injury. The point isn't to represent any one person — it's to give the model a wide enough range of voices that "sound like Mia" is a harder target than "sound helpful."
|
|
|
|
|
|
|
| 313 |
|
| 314 |
| ID | Source | Condition |
|
| 315 |
|----|--------|-----------|
|
| 316 |
| `stephen_hawking` | Real — *My Brief History* + interviews | ALS (mid-stage) |
|
| 317 |
+
| `michael_j_fox` | Real — four memoirs | Young-onset Parkinson's |
|
| 318 |
| `wendy_mitchell` | Real — *Somebody I Used to Know* + blog | Early-onset Alzheimer's |
|
| 319 |
| `christopher_reeve` | Real — *Still Me* | C4 spinal cord injury |
|
| 320 |
| `christy_brown` | Real — *My Left Foot* | Cerebral palsy (adult) |
|
| 321 |
| `gabby_giffords` | Real — *Gabby* memoir | Aphasia + TBI |
|
| 322 |
| `jason_becker` | Real — *Not Dead Yet* doc | Late-stage ALS |
|
| 323 |
| `jean_dominique_bauby` | Real — *The Diving Bell and the Butterfly* | Locked-in syndrome |
|
| 324 |
+
| `tito_mukhopadhyay` | Real — three+ books | Non-verbal autism |
|
| 325 |
| `abed_nadir` | Fictional — *Community* | Autism (verbal) |
|
| 326 |
| `allie_calhoun` | Fictional — *The Notebook* | Late-stage Alzheimer's |
|
| 327 |
| `forrest_gump` | Fictional — *Forrest Gump* | Intellectual disability |
|
| 328 |
| `walter_jr_white` | Fictional — *Breaking Bad* | Cerebral palsy (teen) |
|
| 329 |
| `raymond_babbitt` | Fictional — *Rain Man* | Savant autism |
|
| 330 |
|
| 331 |
+
Each persona has ~120–210 memory chunks (canon-driven, no filler) across five buckets — `family`, `medical`, `hobbies`, `daily_routine`, `social` — and three chunk types: `narrative`, `social_post`, `chat_log`. Somewhere around 2,300 chunks total across the set.
|
|
|
|
|
|
|
| 332 |
|
| 333 |
+
Data provenance is documented. See [references.md](references.md) for the bibliography — memoirs, films, interviews — and the ethics notes on living-persons treatment.
|
|
|
|
|
|
|
| 334 |
|
| 335 |
+
Adding a new persona: drop a JSON file into `data/memories/` following the schema of any existing one, then run `python data/generate_users.py` and `python -m backend.retrieval.vector_store`.
|
|
|
|
|
|
|
| 336 |
|
| 337 |
---
|
| 338 |
|
|
|
|
| 363 |
|
| 364 |
### Intent decomposition
|
| 365 |
|
| 366 |
+
> Current state: regex-splits the partner query on conjunctions/punctuation into fragments, then runs each fragment through a BGE zero-shot classifier (cosine vs. 5 seed exemplars per class). No LLM call, no retries. Runs in ~10–30ms per turn. Bucket hints for `PERSONAL` fragments come from a shared keyword helper in [backend/sensing/bucket_keywords.py](backend/sensing/bucket_keywords.py). Earlier versions used an LLM with Pydantic validation + 3 retries, which cost ~100s per turn on Ollama Cloud when the model emitted bad JSON.
|
| 367 |
|
| 368 |
+
- [x] **[Core]** Personal / Contextual / Open-domain dispatch to distinct pools (personal → BGE vector store; contextual → persona memory + relevant in-session turns layered on top; open-domain → stub chunk, LLM answers from its own general knowledge — web search is intentionally out of scope).
|
| 369 |
+
- [x] intent node latency — split + BGE zero-shot classifier replaces the LLM router. Parallelising sub-query retrieval is still open.
|
| 370 |
|
| 371 |
### Retrieval
|
| 372 |
|
backend/guardrails/checks.py
CHANGED
|
@@ -14,6 +14,14 @@ PERSONA_BREAK_SIGNALS = [
|
|
| 14 |
"as your assistant",
|
| 15 |
"i was trained",
|
| 16 |
"my training data",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
]
|
| 18 |
|
| 19 |
OUT_OF_SCOPE_SIGNALS = [
|
|
|
|
| 14 |
"as your assistant",
|
| 15 |
"i was trained",
|
| 16 |
"my training data",
|
| 17 |
+
# meta-narration / brief leakage
|
| 18 |
+
"the user wants me",
|
| 19 |
+
"the user is asking",
|
| 20 |
+
"roleplay as",
|
| 21 |
+
"role-play as",
|
| 22 |
+
"key characteristics",
|
| 23 |
+
"character sheet",
|
| 24 |
+
"reference only",
|
| 25 |
]
|
| 26 |
|
| 27 |
OUT_OF_SCOPE_SIGNALS = [
|
backend/main.py
CHANGED
|
@@ -13,6 +13,7 @@ from backend.pipeline.graph import run_pipeline
|
|
| 13 |
from backend.pipeline.state import GenerationConfig, PipelineState
|
| 14 |
from backend.retrieval.bucket_priors import uniform_priors
|
| 15 |
from backend.retrieval.vector_store import _get_embedder
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
def parse_args() -> argparse.Namespace:
|
|
@@ -40,21 +41,7 @@ def parse_args() -> argparse.Namespace:
|
|
| 40 |
def _keyword_intent(query: str) -> tuple[dict, GenerationConfig]:
|
| 41 |
"""Replicate milestone-1 keyword routing as a fast local-dev shortcut."""
|
| 42 |
q = query.lower()
|
| 43 |
-
bucket
|
| 44 |
-
|
| 45 |
-
if any(
|
| 46 |
-
w in q
|
| 47 |
-
for w in ["medication", "medicine", "doctor", "health", "allergic", "therapy"]
|
| 48 |
-
):
|
| 49 |
-
bucket = "medical"
|
| 50 |
-
elif any(w in q for w in ["family", "mom", "dad", "brother", "sister", "parents"]):
|
| 51 |
-
bucket = "family"
|
| 52 |
-
elif any(w in q for w in ["hobby", "like to do", "enjoy", "weekend", "fun"]):
|
| 53 |
-
bucket = "hobbies"
|
| 54 |
-
elif any(w in q for w in ["routine", "morning", "wake", "sleep", "daily"]):
|
| 55 |
-
bucket = "daily_routine"
|
| 56 |
-
elif any(w in q for w in ["friend", "social", "people", "party", "community"]):
|
| 57 |
-
bucket = "social"
|
| 58 |
|
| 59 |
intent_type = (
|
| 60 |
"CONTEXTUAL"
|
|
|
|
| 13 |
from backend.pipeline.state import GenerationConfig, PipelineState
|
| 14 |
from backend.retrieval.bucket_priors import uniform_priors
|
| 15 |
from backend.retrieval.vector_store import _get_embedder
|
| 16 |
+
from backend.sensing.bucket_keywords import infer_bucket
|
| 17 |
|
| 18 |
|
| 19 |
def parse_args() -> argparse.Namespace:
|
|
|
|
| 41 |
def _keyword_intent(query: str) -> tuple[dict, GenerationConfig]:
|
| 42 |
"""Replicate milestone-1 keyword routing as a fast local-dev shortcut."""
|
| 43 |
q = query.lower()
|
| 44 |
+
bucket = infer_bucket(query)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
intent_type = (
|
| 47 |
"CONTEXTUAL"
|
backend/pipeline/nodes/feedback.py
CHANGED
|
@@ -33,6 +33,7 @@ def _log_to_jsonl(state: PipelineState, run_id: str) -> None:
|
|
| 33 |
|
| 34 |
latency = state.get("latency_log") or {}
|
| 35 |
affect = (state.get("affect") or {}).get("emotion", "UNKNOWN")
|
|
|
|
| 36 |
|
| 37 |
entry = {
|
| 38 |
"run_id": run_id,
|
|
@@ -43,7 +44,12 @@ def _log_to_jsonl(state: PipelineState, run_id: str) -> None:
|
|
| 43 |
"retrieval_mode": state.get("retrieval_mode_used", "unknown"),
|
| 44 |
"affect": affect,
|
| 45 |
"guardrail_passed": state.get("guardrail_passed", True),
|
| 46 |
-
"num_chunks": len(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
"latency": {
|
| 48 |
"t_sensing": latency.get("t_sensing", 0.0),
|
| 49 |
"t_intent": latency.get("t_intent", 0.0),
|
|
@@ -60,11 +66,11 @@ def _log_to_jsonl(state: PipelineState, run_id: str) -> None:
|
|
| 60 |
|
| 61 |
def _update_bucket_priors(state: PipelineState) -> dict[str, float]:
|
| 62 |
chunks = state.get("retrieved_chunks") or []
|
| 63 |
-
|
|
|
|
| 64 |
return state.get("bucket_priors") or {}
|
| 65 |
|
| 66 |
-
|
| 67 |
-
top_bucket = chunks[0].get("bucket")
|
| 68 |
if not top_bucket:
|
| 69 |
return state.get("bucket_priors") or {}
|
| 70 |
|
|
|
|
| 33 |
|
| 34 |
latency = state.get("latency_log") or {}
|
| 35 |
affect = (state.get("affect") or {}).get("emotion", "UNKNOWN")
|
| 36 |
+
chunks = state.get("retrieved_chunks") or []
|
| 37 |
|
| 38 |
entry = {
|
| 39 |
"run_id": run_id,
|
|
|
|
| 44 |
"retrieval_mode": state.get("retrieval_mode_used", "unknown"),
|
| 45 |
"affect": affect,
|
| 46 |
"guardrail_passed": state.get("guardrail_passed", True),
|
| 47 |
+
"num_chunks": len(chunks),
|
| 48 |
+
"num_personal": sum(
|
| 49 |
+
1 for c in chunks if c.get("source", "personal") == "personal"
|
| 50 |
+
),
|
| 51 |
+
"num_contextual": sum(1 for c in chunks if c.get("source") == "contextual"),
|
| 52 |
+
"num_open_domain": sum(1 for c in chunks if c.get("source") == "open_domain"),
|
| 53 |
"latency": {
|
| 54 |
"t_sensing": latency.get("t_sensing", 0.0),
|
| 55 |
"t_intent": latency.get("t_intent", 0.0),
|
|
|
|
| 66 |
|
| 67 |
def _update_bucket_priors(state: PipelineState) -> dict[str, float]:
|
| 68 |
chunks = state.get("retrieved_chunks") or []
|
| 69 |
+
personal = [c for c in chunks if c.get("source", "personal") == "personal"]
|
| 70 |
+
if not personal:
|
| 71 |
return state.get("bucket_priors") or {}
|
| 72 |
|
| 73 |
+
top_bucket = personal[0].get("bucket")
|
|
|
|
| 74 |
if not top_bucket:
|
| 75 |
return state.get("bucket_priors") or {}
|
| 76 |
|
backend/pipeline/nodes/intent.py
CHANGED
|
@@ -1,45 +1,73 @@
|
|
| 1 |
-
# Intent decomposition node —
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
import re
|
| 5 |
import time
|
| 6 |
-
from
|
| 7 |
|
| 8 |
-
|
| 9 |
|
| 10 |
from backend.config.settings import settings
|
| 11 |
-
from backend.
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
_AFFECT_CONFIG: dict[str, GenerationConfig] = {
|
| 45 |
"HAPPY": {
|
|
@@ -68,41 +96,61 @@ _AFFECT_CONFIG: dict[str, GenerationConfig] = {
|
|
| 68 |
},
|
| 69 |
}
|
| 70 |
|
| 71 |
-
# ── System prompt ──────────────────────────────────────────────────────────────
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
-
Intent types:
|
| 80 |
-
- PERSONAL: requires autobiographical memory retrieval
|
| 81 |
-
- CONTEXTUAL: answerable from session history
|
| 82 |
-
- OPEN_DOMAIN: answerable from general knowledge (no retrieval needed)
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
-
|
| 88 |
-
|
|
|
|
| 89 |
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
) -> str:
|
| 94 |
-
air_note = (
|
| 95 |
-
f'\nAir-written supplement: "{air_written_text}"' if air_written_text else ""
|
| 96 |
-
)
|
| 97 |
-
return (
|
| 98 |
-
f"Persona: {persona_name}\n"
|
| 99 |
-
f"Affect: {affect}\n"
|
| 100 |
-
f"Partner query: {query}{air_note}\n\n"
|
| 101 |
-
"Produce the IntentRoute JSON:"
|
| 102 |
-
)
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
-
|
| 106 |
|
| 107 |
|
| 108 |
def run(state: PipelineState) -> dict:
|
|
@@ -110,78 +158,58 @@ def run(state: PipelineState) -> dict:
|
|
| 110 |
|
| 111 |
# --fast mode: intent_route already resolved by keyword routing in main.py
|
| 112 |
if state.get("intent_route") and state.get("generation_config"):
|
| 113 |
-
return {}
|
| 114 |
|
| 115 |
affect_state = state.get("affect") or {}
|
| 116 |
emotion: str = affect_state.get("emotion", "NEUTRAL")
|
| 117 |
query: str = state["raw_query"]
|
| 118 |
-
persona_name: str = state["persona_profile"].get("name", "unknown")
|
| 119 |
-
|
| 120 |
gen_config = _AFFECT_CONFIG.get(emotion, _AFFECT_CONFIG["NEUTRAL"])
|
| 121 |
|
| 122 |
-
|
| 123 |
-
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
|
|
|
|
|
|
| 128 |
{
|
| 129 |
-
"
|
| 130 |
-
"
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
air_written_text=state.get("air_written_text"),
|
| 135 |
-
),
|
| 136 |
-
},
|
| 137 |
-
]
|
| 138 |
-
if attempt > 0:
|
| 139 |
-
messages.append(
|
| 140 |
-
{
|
| 141 |
-
"role": "user",
|
| 142 |
-
"content": f"Validation error: {last_error}. Fix and retry.",
|
| 143 |
-
}
|
| 144 |
-
)
|
| 145 |
-
|
| 146 |
-
raw = chat_complete(
|
| 147 |
-
messages=messages,
|
| 148 |
-
max_tokens=512,
|
| 149 |
-
temperature=0.0,
|
| 150 |
)
|
| 151 |
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
"
|
| 159 |
-
"style_constraints": parsed.style_constraints.model_dump(),
|
| 160 |
-
"affect": parsed.affect,
|
| 161 |
}
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
t_intent = time.perf_counter() - t0
|
| 182 |
|
| 183 |
latency_log = dict(state.get("latency_log") or {})
|
| 184 |
-
latency_log["t_intent"] = round(
|
| 185 |
|
| 186 |
return {
|
| 187 |
"intent_route": route,
|
|
|
|
| 1 |
+
# Intent decomposition node — regex-split fragments + BGE zero-shot classifier.
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
import re
|
| 5 |
import time
|
| 6 |
+
from functools import lru_cache
|
| 7 |
|
| 8 |
+
import torch
|
| 9 |
|
| 10 |
from backend.config.settings import settings
|
| 11 |
+
from backend.pipeline.state import (
|
| 12 |
+
GenerationConfig,
|
| 13 |
+
IntentRoute,
|
| 14 |
+
PipelineState,
|
| 15 |
+
SubIntent,
|
| 16 |
+
)
|
| 17 |
+
from backend.retrieval.vector_store import get_device, get_embedder
|
| 18 |
+
from backend.sensing.bucket_keywords import infer_bucket
|
| 19 |
+
|
| 20 |
+
_CLASS_EXEMPLARS: dict[str, list[str]] = {
|
| 21 |
+
"PERSONAL": [
|
| 22 |
+
"how are you today",
|
| 23 |
+
"what is your favourite food",
|
| 24 |
+
"tell me about your family",
|
| 25 |
+
"what do you do in the mornings",
|
| 26 |
+
"did you enjoy the weekend",
|
| 27 |
+
],
|
| 28 |
+
"CONTEXTUAL": [
|
| 29 |
+
"what did you just say",
|
| 30 |
+
"what did I ask earlier",
|
| 31 |
+
"you mentioned something before",
|
| 32 |
+
"can you repeat that",
|
| 33 |
+
"what were we talking about",
|
| 34 |
+
],
|
| 35 |
+
"OPEN_DOMAIN": [
|
| 36 |
+
"what is the capital of france",
|
| 37 |
+
"how many planets are there",
|
| 38 |
+
"who wrote hamlet",
|
| 39 |
+
"when was world war two",
|
| 40 |
+
"what does photosynthesis mean",
|
| 41 |
+
],
|
| 42 |
+
}
|
| 43 |
|
| 44 |
+
_CLASSIFIER_THRESHOLD = (
|
| 45 |
+
0.45 # below this → PERSONAL fallback (safe default for OOV / typos / short input)
|
| 46 |
+
)
|
| 47 |
+
_CONTEXTUAL_MARGIN_MIN = (
|
| 48 |
+
0.08 # CONTEXTUAL must beat runner-up by at least this — it over-matches without it
|
| 49 |
+
)
|
| 50 |
+
_MIN_FRAGMENT_WORDS = 3
|
| 51 |
+
_MAX_FRAGMENTS = 4
|
| 52 |
+
|
| 53 |
+
_CONTEXTUAL_MARKERS = (
|
| 54 |
+
"earlier",
|
| 55 |
+
"before",
|
| 56 |
+
"mentioned",
|
| 57 |
+
"said",
|
| 58 |
+
"asked",
|
| 59 |
+
"just",
|
| 60 |
+
"repeat",
|
| 61 |
+
)
|
| 62 |
+
_CONTEXTUAL_MARKER_PATTERN = re.compile(
|
| 63 |
+
r"\b(" + "|".join(_CONTEXTUAL_MARKERS) + r")\b",
|
| 64 |
+
flags=re.IGNORECASE,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
_SPLIT_PATTERN = re.compile(
|
| 68 |
+
r"\s+(?:and|but|also|plus)\s+|[;.?!]+\s+|,\s+(?=\w)",
|
| 69 |
+
flags=re.IGNORECASE,
|
| 70 |
+
)
|
| 71 |
|
| 72 |
_AFFECT_CONFIG: dict[str, GenerationConfig] = {
|
| 73 |
"HAPPY": {
|
|
|
|
| 96 |
},
|
| 97 |
}
|
| 98 |
|
|
|
|
| 99 |
|
| 100 |
+
@lru_cache(maxsize=1)
|
| 101 |
+
def _exemplar_matrices() -> dict[str, torch.Tensor]:
|
| 102 |
+
embedder = get_embedder()
|
| 103 |
+
device = get_device()
|
| 104 |
+
return {
|
| 105 |
+
cls: embedder.encode(
|
| 106 |
+
exemplars,
|
| 107 |
+
convert_to_tensor=True,
|
| 108 |
+
normalize_embeddings=True,
|
| 109 |
+
device=device,
|
| 110 |
+
)
|
| 111 |
+
for cls, exemplars in _CLASS_EXEMPLARS.items()
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _split_query(query: str) -> list[str]:
|
| 116 |
+
raw = [p.strip() for p in _SPLIT_PATTERN.split(query) if p and p.strip()]
|
| 117 |
+
keep = [p for p in raw if len(p.split()) >= _MIN_FRAGMENT_WORDS]
|
| 118 |
+
if not keep:
|
| 119 |
+
keep = [query.strip()] if query.strip() else []
|
| 120 |
+
return keep[:_MAX_FRAGMENTS]
|
| 121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
+
def _classify(fragment: str) -> str:
|
| 124 |
+
embedder = get_embedder()
|
| 125 |
+
device = get_device()
|
| 126 |
+
vec = embedder.encode(
|
| 127 |
+
[fragment],
|
| 128 |
+
convert_to_tensor=True,
|
| 129 |
+
normalize_embeddings=True,
|
| 130 |
+
device=device,
|
| 131 |
+
)[0]
|
| 132 |
|
| 133 |
+
scores: dict[str, float] = {}
|
| 134 |
+
for cls, mat in _exemplar_matrices().items():
|
| 135 |
+
scores[cls] = float((mat @ vec).max())
|
| 136 |
|
| 137 |
+
ranked = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)
|
| 138 |
+
best_cls, best_score = ranked[0]
|
| 139 |
+
runner_up_score = ranked[1][1]
|
| 140 |
|
| 141 |
+
if best_score < _CLASSIFIER_THRESHOLD:
|
| 142 |
+
return "PERSONAL" # conservative default: treat as a question about the persona
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
+
# CONTEXTUAL is the riskiest class — if wrong, we lose all persona grounding.
|
| 145 |
+
# Require it to clearly beat the runner-up and for the fragment to mention
|
| 146 |
+
# prior discourse (matched at word boundaries, so "just" doesn't match "unjust").
|
| 147 |
+
if best_cls == "CONTEXTUAL":
|
| 148 |
+
margin = best_score - runner_up_score
|
| 149 |
+
has_discourse_marker = bool(_CONTEXTUAL_MARKER_PATTERN.search(fragment))
|
| 150 |
+
if margin < _CONTEXTUAL_MARGIN_MIN or not has_discourse_marker:
|
| 151 |
+
return "PERSONAL"
|
| 152 |
|
| 153 |
+
return best_cls
|
| 154 |
|
| 155 |
|
| 156 |
def run(state: PipelineState) -> dict:
|
|
|
|
| 158 |
|
| 159 |
# --fast mode: intent_route already resolved by keyword routing in main.py
|
| 160 |
if state.get("intent_route") and state.get("generation_config"):
|
| 161 |
+
return {}
|
| 162 |
|
| 163 |
affect_state = state.get("affect") or {}
|
| 164 |
emotion: str = affect_state.get("emotion", "NEUTRAL")
|
| 165 |
query: str = state["raw_query"]
|
|
|
|
|
|
|
| 166 |
gen_config = _AFFECT_CONFIG.get(emotion, _AFFECT_CONFIG["NEUTRAL"])
|
| 167 |
|
| 168 |
+
fragments = _split_query(query)
|
| 169 |
+
priority = "fast" if emotion == "FRUSTRATED" else "normal"
|
| 170 |
|
| 171 |
+
sub_intents: list[SubIntent] = []
|
| 172 |
+
for frag in fragments:
|
| 173 |
+
cls = _classify(frag)
|
| 174 |
+
bucket_hint = infer_bucket(frag) if cls == "PERSONAL" else None
|
| 175 |
+
sub_intents.append(
|
| 176 |
{
|
| 177 |
+
"type": cls,
|
| 178 |
+
"query": frag,
|
| 179 |
+
"bucket_hint": bucket_hint,
|
| 180 |
+
"priority": priority,
|
| 181 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
)
|
| 183 |
|
| 184 |
+
if not sub_intents:
|
| 185 |
+
sub_intents = [
|
| 186 |
+
{
|
| 187 |
+
"type": "PERSONAL",
|
| 188 |
+
"query": query,
|
| 189 |
+
"bucket_hint": None,
|
| 190 |
+
"priority": priority,
|
|
|
|
|
|
|
| 191 |
}
|
| 192 |
+
]
|
| 193 |
+
|
| 194 |
+
air_written = state.get("air_written_text")
|
| 195 |
+
if air_written:
|
| 196 |
+
sub_intents.append(
|
| 197 |
+
{
|
| 198 |
+
"type": "PERSONAL",
|
| 199 |
+
"query": air_written,
|
| 200 |
+
"bucket_hint": infer_bucket(air_written),
|
| 201 |
+
"priority": priority,
|
| 202 |
+
}
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
route: IntentRoute = {
|
| 206 |
+
"sub_intents": sub_intents,
|
| 207 |
+
"style_constraints": dict(gen_config),
|
| 208 |
+
"affect": emotion,
|
| 209 |
+
}
|
|
|
|
|
|
|
| 210 |
|
| 211 |
latency_log = dict(state.get("latency_log") or {})
|
| 212 |
+
latency_log["t_intent"] = round(time.perf_counter() - t0, 4)
|
| 213 |
|
| 214 |
return {
|
| 215 |
"intent_route": route,
|
backend/pipeline/nodes/planner.py
CHANGED
|
@@ -53,7 +53,7 @@ def _run(state: PipelineState, tier: str) -> dict:
|
|
| 53 |
)
|
| 54 |
gesture_tag = state.get("gesture_tag")
|
| 55 |
air_written_text = state.get("air_written_text")
|
| 56 |
-
|
| 57 |
profile,
|
| 58 |
chunks,
|
| 59 |
history,
|
|
@@ -65,7 +65,7 @@ def _run(state: PipelineState, tier: str) -> dict:
|
|
| 65 |
)
|
| 66 |
|
| 67 |
selected = chat_complete(
|
| 68 |
-
messages=
|
| 69 |
max_tokens=gen_cfg.get("max_tokens", settings.max_tokens_neutral),
|
| 70 |
temperature=0.4,
|
| 71 |
tier=tier,
|
|
@@ -86,8 +86,9 @@ def _run(state: PipelineState, tier: str) -> dict:
|
|
| 86 |
4,
|
| 87 |
)
|
| 88 |
|
|
|
|
| 89 |
return {
|
| 90 |
-
"augmented_prompt":
|
| 91 |
"candidates": [selected],
|
| 92 |
"selected_response": selected,
|
| 93 |
"llm_tier_used": tier,
|
|
@@ -101,7 +102,7 @@ def _resolve_tone_tag(user_id: str, affect: str, default_tag: str) -> str:
|
|
| 101 |
return _PERSONA_TONE_OVERRIDES.get(user_id, {}).get(affect, default_tag)
|
| 102 |
|
| 103 |
|
| 104 |
-
def
|
| 105 |
profile: dict,
|
| 106 |
chunks: list[dict],
|
| 107 |
history: list[dict],
|
|
@@ -110,19 +111,29 @@ def _build_prompt(
|
|
| 110 |
gen_cfg: dict,
|
| 111 |
gesture_tag: str | None = None,
|
| 112 |
air_written_text: str | None = None,
|
| 113 |
-
) ->
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
|
|
|
| 126 |
prefs = profile.get("stylistic_preferences") or {}
|
| 127 |
style_bits = []
|
| 128 |
if prefs.get("tone"):
|
|
@@ -139,9 +150,73 @@ def _build_prompt(
|
|
| 139 |
|
| 140 |
exemplars = prefs.get("example_phrases") or []
|
| 141 |
style_exemplar = "\n ".join(exemplars) if exemplars else "(no exemplar)"
|
| 142 |
-
|
| 143 |
access = (profile.get("access_needs") or {}).get("input_method") or "an AAC device"
|
| 144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
gesture_line = ""
|
| 146 |
if gesture_tag:
|
| 147 |
g_tag = GESTURE_TO_TAG.get(gesture_tag, f"[GESTURE:{gesture_tag}]")
|
|
@@ -151,35 +226,27 @@ def _build_prompt(
|
|
| 151 |
if air_written_text:
|
| 152 |
air_writing_line = f'\nThe user air-wrote: "{air_written_text}" — treat as supplementary intent.'
|
| 153 |
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
"
|
| 157 |
-
|
| 158 |
-
"baseline": "Use your natural communication style.",
|
| 159 |
-
"add_confirmation": "Add a clarifying question or confirmation at the end.",
|
| 160 |
-
}.get(persona_mod, "Use your natural communication style.")
|
| 161 |
|
| 162 |
return f"""\
|
| 163 |
-
You are {profile["name"]}. You have {profile["condition"]} and communicate through {access}, but your voice and thoughts are fully your own.
|
| 164 |
-
Communication style: {style_summary}
|
| 165 |
{tone_tag}{gesture_line}{air_writing_line}
|
|
|
|
| 166 |
|
| 167 |
-
|
| 168 |
-
{style_exemplar}
|
| 169 |
-
|
| 170 |
-
Personal memories (use ONLY these for personal facts; each tagged [bucket/type] where type is narrative, social_post, or chat_log):
|
| 171 |
{memory_block}
|
| 172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
Recent conversation:
|
| 174 |
{history_block}
|
| 175 |
|
| 176 |
-
Partner
|
| 177 |
-
|
| 178 |
-
Instructions:
|
| 179 |
-
- Speak in first person as {profile["name"]}.
|
| 180 |
-
- {persona_instruction}
|
| 181 |
-
- Keep response to 1-3 sentences.
|
| 182 |
-
- If the answer isn't in your memories, say "I don't know."
|
| 183 |
-
- Do NOT say "As an AI" or break persona.
|
| 184 |
|
| 185 |
-
|
|
|
|
| 53 |
)
|
| 54 |
gesture_tag = state.get("gesture_tag")
|
| 55 |
air_written_text = state.get("air_written_text")
|
| 56 |
+
messages = _build_messages(
|
| 57 |
profile,
|
| 58 |
chunks,
|
| 59 |
history,
|
|
|
|
| 65 |
)
|
| 66 |
|
| 67 |
selected = chat_complete(
|
| 68 |
+
messages=messages,
|
| 69 |
max_tokens=gen_cfg.get("max_tokens", settings.max_tokens_neutral),
|
| 70 |
temperature=0.4,
|
| 71 |
tier=tier,
|
|
|
|
| 86 |
4,
|
| 87 |
)
|
| 88 |
|
| 89 |
+
augmented_prompt = "\n\n".join(m["content"] for m in messages)
|
| 90 |
return {
|
| 91 |
+
"augmented_prompt": augmented_prompt,
|
| 92 |
"candidates": [selected],
|
| 93 |
"selected_response": selected,
|
| 94 |
"llm_tier_used": tier,
|
|
|
|
| 102 |
return _PERSONA_TONE_OVERRIDES.get(user_id, {}).get(affect, default_tag)
|
| 103 |
|
| 104 |
|
| 105 |
+
def _build_messages(
|
| 106 |
profile: dict,
|
| 107 |
chunks: list[dict],
|
| 108 |
history: list[dict],
|
|
|
|
| 111 |
gen_cfg: dict,
|
| 112 |
gesture_tag: str | None = None,
|
| 113 |
air_written_text: str | None = None,
|
| 114 |
+
) -> list[dict]:
|
| 115 |
+
# Split into a stable system message (same per persona — gets cached by the
|
| 116 |
+
# provider) and a turn-specific user message. Anything that changes per
|
| 117 |
+
# turn (retrieval, affect, gesture, partner query) must live in the user
|
| 118 |
+
# message or the prefix cache invalidates.
|
| 119 |
+
system_content = _build_system(profile)
|
| 120 |
+
user_content = _build_user(
|
| 121 |
+
chunks,
|
| 122 |
+
history,
|
| 123 |
+
query,
|
| 124 |
+
tone_tag,
|
| 125 |
+
gen_cfg,
|
| 126 |
+
gesture_tag,
|
| 127 |
+
air_written_text,
|
| 128 |
+
profile["name"],
|
| 129 |
)
|
| 130 |
+
return [
|
| 131 |
+
{"role": "system", "content": system_content},
|
| 132 |
+
{"role": "user", "content": user_content},
|
| 133 |
+
]
|
| 134 |
+
|
| 135 |
|
| 136 |
+
def _build_system(profile: dict) -> str:
|
| 137 |
prefs = profile.get("stylistic_preferences") or {}
|
| 138 |
style_bits = []
|
| 139 |
if prefs.get("tone"):
|
|
|
|
| 150 |
|
| 151 |
exemplars = prefs.get("example_phrases") or []
|
| 152 |
style_exemplar = "\n ".join(exemplars) if exemplars else "(no exemplar)"
|
|
|
|
| 153 |
access = (profile.get("access_needs") or {}).get("input_method") or "an AAC device"
|
| 154 |
|
| 155 |
+
return f"""\
|
| 156 |
+
You are {profile["name"]}. Reply in first person as them, in 1–3 sentences. \
|
| 157 |
+
Never narrate, analyze, describe, or list traits about your character. \
|
| 158 |
+
Never say "As an AI", "The user wants me to", "Key characteristics", or anything meta. \
|
| 159 |
+
Just speak.
|
| 160 |
+
|
| 161 |
+
--- Character sheet (reference only — do NOT quote or paraphrase this block) ---
|
| 162 |
+
Condition: {profile["condition"]}
|
| 163 |
+
Access: {access}
|
| 164 |
+
Voice: {style_summary}
|
| 165 |
+
|
| 166 |
+
Style examples (match this register when you speak):
|
| 167 |
+
{style_exemplar}
|
| 168 |
+
|
| 169 |
+
Answering rules:
|
| 170 |
+
- For personal questions: use ONLY the memories in the user message; if they don't cover it, say "I don't know."
|
| 171 |
+
- For general-knowledge questions: answer from what you know, in your voice.
|
| 172 |
+
- Keep it to 1–3 sentences, first person, no meta-commentary.
|
| 173 |
+
--- end character sheet ---"""
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
_PERSONA_MOD_INSTRUCTIONS = {
|
| 177 |
+
"amplify_quirks": "Amplify your characteristic style and personality.",
|
| 178 |
+
"suppress_humor": "Be direct and supportive. Suppress humor.",
|
| 179 |
+
"baseline": "Use your natural communication style.",
|
| 180 |
+
"add_confirmation": "Add a clarifying question or confirmation at the end.",
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def _build_user(
|
| 185 |
+
chunks: list[dict],
|
| 186 |
+
history: list[dict],
|
| 187 |
+
query: str,
|
| 188 |
+
tone_tag: str,
|
| 189 |
+
gen_cfg: dict,
|
| 190 |
+
gesture_tag: str | None,
|
| 191 |
+
air_written_text: str | None,
|
| 192 |
+
persona_name: str,
|
| 193 |
+
) -> str:
|
| 194 |
+
personal_chunks = [c for c in chunks if c.get("source", "personal") == "personal"]
|
| 195 |
+
contextual_chunks = [c for c in chunks if c.get("source") == "contextual"]
|
| 196 |
+
open_domain_chunks = [c for c in chunks if c.get("source") == "open_domain"]
|
| 197 |
+
|
| 198 |
+
memory_block = (
|
| 199 |
+
"\n".join(
|
| 200 |
+
f" [{c['bucket']}/{c.get('type', 'narrative')}] {c['text']}"
|
| 201 |
+
for c in personal_chunks
|
| 202 |
+
)
|
| 203 |
+
or " (no memories retrieved)"
|
| 204 |
+
)
|
| 205 |
+
contextual_block = (
|
| 206 |
+
"\n".join(f" {c['text']}" for c in contextual_chunks)
|
| 207 |
+
or " (nothing relevant from this session)"
|
| 208 |
+
)
|
| 209 |
+
open_domain_note = (
|
| 210 |
+
" Treat this sub-query as general knowledge; answer from what you know.\n"
|
| 211 |
+
+ "\n".join(f" {c['text']}" for c in open_domain_chunks)
|
| 212 |
+
if open_domain_chunks
|
| 213 |
+
else " (none)"
|
| 214 |
+
)
|
| 215 |
+
history_block = (
|
| 216 |
+
"\n".join(f" {h.get('role', '?')}: {h.get('content', '')}" for h in history)
|
| 217 |
+
or " (start of session)"
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
gesture_line = ""
|
| 221 |
if gesture_tag:
|
| 222 |
g_tag = GESTURE_TO_TAG.get(gesture_tag, f"[GESTURE:{gesture_tag}]")
|
|
|
|
| 226 |
if air_written_text:
|
| 227 |
air_writing_line = f'\nThe user air-wrote: "{air_written_text}" — treat as supplementary intent.'
|
| 228 |
|
| 229 |
+
persona_instruction = _PERSONA_MOD_INSTRUCTIONS.get(
|
| 230 |
+
gen_cfg.get("persona_mod", "baseline"),
|
| 231 |
+
_PERSONA_MOD_INSTRUCTIONS["baseline"],
|
| 232 |
+
)
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
return f"""\
|
|
|
|
|
|
|
| 235 |
{tone_tag}{gesture_line}{air_writing_line}
|
| 236 |
+
{persona_instruction}
|
| 237 |
|
| 238 |
+
Personal memories:
|
|
|
|
|
|
|
|
|
|
| 239 |
{memory_block}
|
| 240 |
|
| 241 |
+
From earlier in this conversation:
|
| 242 |
+
{contextual_block}
|
| 243 |
+
|
| 244 |
+
General knowledge note:
|
| 245 |
+
{open_domain_note}
|
| 246 |
+
|
| 247 |
Recent conversation:
|
| 248 |
{history_block}
|
| 249 |
|
| 250 |
+
Partner just said: {query}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
+
Your reply as {persona_name} (1–3 sentences, first person):"""
|
backend/pipeline/nodes/retrieval.py
CHANGED
|
@@ -1,59 +1,128 @@
|
|
| 1 |
-
# Retrieval node —
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
import time
|
| 5 |
|
| 6 |
from backend.config.settings import settings
|
| 7 |
-
from backend.pipeline.state import PipelineState, RetrievedChunk
|
|
|
|
| 8 |
from backend.retrieval.vector_store import retrieve
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
def run_fast(state: PipelineState) -> dict:
|
| 12 |
"""Fast retrieval path for FRUSTRATED affect (k=2, no reranker)."""
|
| 13 |
t0 = time.perf_counter()
|
| 14 |
-
|
| 15 |
-
priors = state["bucket_priors"]
|
| 16 |
-
prior_vals = list(priors.values()) if priors else []
|
| 17 |
-
priors_uniform = prior_vals and (max(prior_vals) - min(prior_vals)) < 0.05
|
| 18 |
-
bucket_hint = (
|
| 19 |
-
state.get("gaze_bucket")
|
| 20 |
-
if priors_uniform and state.get("gaze_bucket")
|
| 21 |
-
else _top_prior_bucket(priors)
|
| 22 |
-
)
|
| 23 |
-
chunks = retrieve(
|
| 24 |
-
query=state["raw_query"],
|
| 25 |
-
user_id=state["user_id"],
|
| 26 |
-
top_k=settings.retrieval_fast_k,
|
| 27 |
-
rerank_k=settings.retrieval_fast_k,
|
| 28 |
-
bucket_filter=bucket_hint,
|
| 29 |
-
)
|
| 30 |
-
|
| 31 |
return _build_return(state, chunks, "fast", t0)
|
| 32 |
|
| 33 |
|
| 34 |
def run_full(state: PipelineState) -> dict:
|
| 35 |
"""Full retrieval path: top_k cosine matches narrowed to rerank_k."""
|
| 36 |
t0 = time.perf_counter()
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
|
| 39 |
route = state.get("intent_route") or {}
|
| 40 |
-
sub_intents = route.get("sub_intents"
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
)
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
|
|
|
| 47 |
user_id=state["user_id"],
|
| 48 |
-
top_k=
|
| 49 |
-
rerank_k=
|
| 50 |
bucket_filter=bucket_hint,
|
| 51 |
)
|
| 52 |
|
| 53 |
-
return _build_return(state, chunks, "full", t0)
|
| 54 |
-
|
| 55 |
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
|
| 59 |
def _top_prior_bucket(priors: dict[str, float]) -> str | None:
|
|
|
|
| 1 |
+
# Retrieval node — dispatches each sub-intent to its pool, merges results.
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
import time
|
| 5 |
|
| 6 |
from backend.config.settings import settings
|
| 7 |
+
from backend.pipeline.state import PipelineState, RetrievedChunk, SubIntent
|
| 8 |
+
from backend.retrieval.contextual import retrieve_from_history
|
| 9 |
from backend.retrieval.vector_store import retrieve
|
| 10 |
|
| 11 |
+
_OPEN_DOMAIN_STUB_TEXT = (
|
| 12 |
+
"(no external knowledge source wired — answer from general knowledge)"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
|
| 16 |
def run_fast(state: PipelineState) -> dict:
|
| 17 |
"""Fast retrieval path for FRUSTRATED affect (k=2, no reranker)."""
|
| 18 |
t0 = time.perf_counter()
|
| 19 |
+
chunks = _dispatch_all(state, per_intent_k=settings.retrieval_fast_k)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
return _build_return(state, chunks, "fast", t0)
|
| 21 |
|
| 22 |
|
| 23 |
def run_full(state: PipelineState) -> dict:
|
| 24 |
"""Full retrieval path: top_k cosine matches narrowed to rerank_k."""
|
| 25 |
t0 = time.perf_counter()
|
| 26 |
+
chunks = _dispatch_all(state, per_intent_k=settings.retrieval_rerank_k)
|
| 27 |
+
return _build_return(state, chunks, "full", t0)
|
| 28 |
+
|
| 29 |
|
| 30 |
+
def _dispatch_all(state: PipelineState, per_intent_k: int) -> list[RetrievedChunk]:
|
| 31 |
route = state.get("intent_route") or {}
|
| 32 |
+
sub_intents: list[SubIntent] = route.get("sub_intents") or []
|
| 33 |
+
|
| 34 |
+
if not sub_intents:
|
| 35 |
+
sub_intents = [
|
| 36 |
+
{
|
| 37 |
+
"type": "PERSONAL",
|
| 38 |
+
"query": state["raw_query"],
|
| 39 |
+
"bucket_hint": None,
|
| 40 |
+
"priority": "normal",
|
| 41 |
+
}
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
merged: list[RetrievedChunk] = []
|
| 45 |
+
for sub in sub_intents:
|
| 46 |
+
kind = (sub.get("type") or "PERSONAL").upper()
|
| 47 |
+
if kind == "PERSONAL":
|
| 48 |
+
merged.extend(_retrieve_personal(sub, state, per_intent_k))
|
| 49 |
+
elif kind == "CONTEXTUAL":
|
| 50 |
+
merged.extend(_retrieve_contextual(sub, state, per_intent_k))
|
| 51 |
+
elif kind == "OPEN_DOMAIN":
|
| 52 |
+
merged.extend(_retrieve_open_domain(sub))
|
| 53 |
+
else:
|
| 54 |
+
merged.extend(_retrieve_personal(sub, state, per_intent_k))
|
| 55 |
+
|
| 56 |
+
return _dedupe(merged)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _retrieve_personal(
|
| 60 |
+
sub: SubIntent, state: PipelineState, k: int
|
| 61 |
+
) -> list[RetrievedChunk]:
|
| 62 |
+
priors = state["bucket_priors"]
|
| 63 |
+
prior_vals = list(priors.values()) if priors else []
|
| 64 |
+
priors_uniform = prior_vals and (max(prior_vals) - min(prior_vals)) < 0.05
|
| 65 |
+
|
| 66 |
+
bucket_hint = (
|
| 67 |
+
state.get("gaze_bucket")
|
| 68 |
+
or sub.get("bucket_hint")
|
| 69 |
+
or (_top_prior_bucket(priors) if not priors_uniform else None)
|
| 70 |
)
|
| 71 |
|
| 72 |
+
top_k = max(k, settings.retrieval_top_k) if k >= settings.retrieval_rerank_k else k
|
| 73 |
+
return retrieve(
|
| 74 |
+
query=sub["query"],
|
| 75 |
user_id=state["user_id"],
|
| 76 |
+
top_k=top_k,
|
| 77 |
+
rerank_k=k,
|
| 78 |
bucket_filter=bucket_hint,
|
| 79 |
)
|
| 80 |
|
|
|
|
|
|
|
| 81 |
|
| 82 |
+
_CONTEXTUAL_MIN_SCORE = (
|
| 83 |
+
0.5 # empirical: below this, history matches are usually spurious
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _retrieve_contextual(
|
| 88 |
+
sub: SubIntent, state: PipelineState, k: int
|
| 89 |
+
) -> list[RetrievedChunk]:
|
| 90 |
+
# CONTEXTUAL means "this turn leans on the recent conversation" — but the
|
| 91 |
+
# persona's memories are still the primary grounding. Always pull personal
|
| 92 |
+
# chunks; add contextual ones on top when the session history is relevant.
|
| 93 |
+
personal_chunks = _retrieve_personal(sub, state, k)
|
| 94 |
+
history = state.get("session_history") or []
|
| 95 |
+
history_chunks = retrieve_from_history(query=sub["query"], history=history, top_k=k)
|
| 96 |
+
relevant_history = [
|
| 97 |
+
c for c in history_chunks if c["score"] >= _CONTEXTUAL_MIN_SCORE
|
| 98 |
+
]
|
| 99 |
+
return personal_chunks + relevant_history
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _retrieve_open_domain(sub: SubIntent) -> list[RetrievedChunk]:
|
| 103 |
+
# Intentionally a stub — web search is out of scope. See README "Intent decomposition".
|
| 104 |
+
return [
|
| 105 |
+
RetrievedChunk(
|
| 106 |
+
text=f'{_OPEN_DOMAIN_STUB_TEXT} (sub-query: "{sub["query"]}")',
|
| 107 |
+
bucket="open_domain",
|
| 108 |
+
type="narrative",
|
| 109 |
+
user="",
|
| 110 |
+
score=0.0,
|
| 111 |
+
source="open_domain",
|
| 112 |
+
)
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _dedupe(chunks: list[RetrievedChunk]) -> list[RetrievedChunk]:
|
| 117 |
+
seen: set[tuple[str, str]] = set()
|
| 118 |
+
out: list[RetrievedChunk] = []
|
| 119 |
+
for c in chunks:
|
| 120 |
+
key = (c["source"], c["text"])
|
| 121 |
+
if key in seen:
|
| 122 |
+
continue
|
| 123 |
+
seen.add(key)
|
| 124 |
+
out.append(c)
|
| 125 |
+
return out
|
| 126 |
|
| 127 |
|
| 128 |
def _top_prior_bucket(priors: dict[str, float]) -> str | None:
|
backend/pipeline/state.py
CHANGED
|
@@ -23,10 +23,11 @@ class AffectState(TypedDict):
|
|
| 23 |
|
| 24 |
class RetrievedChunk(TypedDict):
|
| 25 |
text: str
|
| 26 |
-
bucket: str # family | medical | hobbies | daily_routine | social
|
| 27 |
-
type: str # narrative | social_post | chat_log
|
| 28 |
user: str
|
| 29 |
score: float # cosine similarity from the embedder
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
class SubIntent(TypedDict):
|
|
|
|
| 23 |
|
| 24 |
class RetrievedChunk(TypedDict):
|
| 25 |
text: str
|
| 26 |
+
bucket: str # family | medical | hobbies | daily_routine | social | contextual | open_domain
|
| 27 |
+
type: str # narrative | social_post | chat_log (personal chunks only)
|
| 28 |
user: str
|
| 29 |
score: float # cosine similarity from the embedder
|
| 30 |
+
source: str # "personal" | "contextual" | "open_domain"
|
| 31 |
|
| 32 |
|
| 33 |
class SubIntent(TypedDict):
|
backend/retrieval/contextual.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from backend.pipeline.state import RetrievedChunk
|
| 4 |
+
from backend.retrieval.vector_store import get_device, get_embedder
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def retrieve_from_history(
|
| 8 |
+
query: str,
|
| 9 |
+
history: list[dict],
|
| 10 |
+
top_k: int = 3,
|
| 11 |
+
recent_window: int = 20,
|
| 12 |
+
) -> list[RetrievedChunk]:
|
| 13 |
+
if not history or top_k <= 0:
|
| 14 |
+
return []
|
| 15 |
+
|
| 16 |
+
window = history[-recent_window:]
|
| 17 |
+
texts = [_format_turn(h) for h in window]
|
| 18 |
+
if not any(texts):
|
| 19 |
+
return []
|
| 20 |
+
|
| 21 |
+
embedder = get_embedder()
|
| 22 |
+
device = get_device()
|
| 23 |
+
|
| 24 |
+
q_vec = embedder.encode(
|
| 25 |
+
[query],
|
| 26 |
+
convert_to_tensor=True,
|
| 27 |
+
normalize_embeddings=True,
|
| 28 |
+
device=device,
|
| 29 |
+
)[0]
|
| 30 |
+
h_vecs = embedder.encode(
|
| 31 |
+
texts,
|
| 32 |
+
convert_to_tensor=True,
|
| 33 |
+
normalize_embeddings=True,
|
| 34 |
+
device=device,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
scores = h_vecs @ q_vec
|
| 38 |
+
k = min(top_k, scores.shape[0])
|
| 39 |
+
top_scores, top_idxs = torch.topk(scores, k)
|
| 40 |
+
|
| 41 |
+
return [
|
| 42 |
+
RetrievedChunk(
|
| 43 |
+
text=texts[int(idx)],
|
| 44 |
+
bucket="contextual",
|
| 45 |
+
type="chat_log",
|
| 46 |
+
user="",
|
| 47 |
+
score=float(score),
|
| 48 |
+
source="contextual",
|
| 49 |
+
)
|
| 50 |
+
for score, idx in zip(top_scores.tolist(), top_idxs.tolist())
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _format_turn(turn: dict) -> str:
|
| 55 |
+
role = turn.get("role", "?")
|
| 56 |
+
content = (turn.get("content") or "").strip()
|
| 57 |
+
return f"{role}: {content}" if content else ""
|
backend/retrieval/vector_store.py
CHANGED
|
@@ -31,6 +31,10 @@ def _get_embedder():
|
|
| 31 |
return SentenceTransformer(settings.embed_model, device=_DEVICE)
|
| 32 |
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
# Index cache: one (vectors_tensor, meta) per user_id.
|
| 35 |
_index_cache: dict[str, tuple[torch.Tensor, list[dict]]] = {}
|
| 36 |
|
|
@@ -88,6 +92,7 @@ def retrieve(
|
|
| 88 |
type=c.get("type", "narrative"),
|
| 89 |
user=c["user"],
|
| 90 |
score=float(s),
|
|
|
|
| 91 |
)
|
| 92 |
for s, c in candidates[:rerank_k]
|
| 93 |
]
|
|
|
|
| 31 |
return SentenceTransformer(settings.embed_model, device=_DEVICE)
|
| 32 |
|
| 33 |
|
| 34 |
+
def get_embedder():
|
| 35 |
+
return _get_embedder()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
# Index cache: one (vectors_tensor, meta) per user_id.
|
| 39 |
_index_cache: dict[str, tuple[torch.Tensor, list[dict]]] = {}
|
| 40 |
|
|
|
|
| 92 |
type=c.get("type", "narrative"),
|
| 93 |
user=c["user"],
|
| 94 |
score=float(s),
|
| 95 |
+
source="personal",
|
| 96 |
)
|
| 97 |
for s, c in candidates[:rerank_k]
|
| 98 |
]
|
backend/sensing/bucket_keywords.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_BUCKET_KEYWORDS: list[tuple[str, tuple[str, ...]]] = [
|
| 2 |
+
("medical", ("medication", "medicine", "doctor", "health", "allergic", "therapy")),
|
| 3 |
+
("family", ("family", "mom", "dad", "brother", "sister", "parents")),
|
| 4 |
+
("hobbies", ("hobby", "like to do", "enjoy", "weekend", "fun")),
|
| 5 |
+
("daily_routine", ("routine", "morning", "wake", "sleep", "daily")),
|
| 6 |
+
("social", ("friend", "social", "people", "party", "community")),
|
| 7 |
+
]
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def infer_bucket(query: str) -> str | None:
|
| 11 |
+
q = query.lower()
|
| 12 |
+
for bucket, words in _BUCKET_KEYWORDS:
|
| 13 |
+
if any(w in q for w in words):
|
| 14 |
+
return bucket
|
| 15 |
+
return None
|
run.sh
CHANGED
|
@@ -13,6 +13,12 @@ fi
|
|
| 13 |
eval "$(conda shell.bash hook)"
|
| 14 |
conda activate "$CONDA_ENV"
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
PIDS=()
|
| 17 |
|
| 18 |
cleanup() {
|
|
|
|
| 13 |
eval "$(conda shell.bash hook)"
|
| 14 |
conda activate "$CONDA_ENV"
|
| 15 |
|
| 16 |
+
# If any args were passed (e.g. --debug, --user mia_chen), run the CLI
|
| 17 |
+
# instead of the full stack and forward them verbatim.
|
| 18 |
+
if [ "$#" -gt 0 ]; then
|
| 19 |
+
exec python -m backend.main "$@"
|
| 20 |
+
fi
|
| 21 |
+
|
| 22 |
PIDS=()
|
| 23 |
|
| 24 |
cleanup() {
|