Shwetangi commited on
Commit
fd77577
Β·
2 Parent(s): 626c0b8e7cf650

Merge pull request #3 from akashkolte/akash/v1

Browse files
.env.example ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copy this file to .env and fill in your values.
2
+ # Settings here override the defaults in config/settings.py.
3
+
4
+ # ── Active LLM tier ────────────────────────────────────────────────────────────
5
+ # "local" β†’ Ollama on MacBook M2 (dev, no GPU needed)
6
+ # "primary" β†’ Qwen3-30B-A3B on GCP A100/T4 via vLLM
7
+ # "fallback" β†’ Qwen3-8B on same vLLM server
8
+ ACTIVE_LLM_TIER=local
9
+
10
+ # ── Primary vLLM server (GCP) ─────────────────────────────────────────────────
11
+ PRIMARY_BASE_URL=http://<GCP_IP>:8000/v1
12
+ PRIMARY_API_KEY=token-abc
13
+ PRIMARY_MODEL=Qwen/Qwen3-30B-A3B
14
+
15
+ # ── Fallback model (same vLLM server) ─────────────────────────────────────────
16
+ FALLBACK_MODEL=Qwen/Qwen3-8B
17
+ FALLBACK_BASE_URL=http://<GCP_IP>:8000/v1
18
+
19
+ # ── Local Ollama (dev) ────────────────────────────────────────────────────────
20
+ LOCAL_BASE_URL=http://localhost:11434/v1
21
+ LOCAL_MODEL=gemma4:31b-cloud # qwen3:8b qwen3.5:397b-cloud
22
+
23
+ # ── MLflow ────────────────────────────────────────────────────────────────────
24
+ MLFLOW_TRACKING_URI=mlruns
25
+ MLFLOW_EXPERIMENT=aac-chatbot
26
+
27
+ # ── Thinking mode ─────────────────────────────────────────────────────────────
28
+ # "off" β€” suppress thinking (fastest, best for latency-sensitive AAC)
29
+ # "strip" β€” let model think, but strip <think> tags from output
30
+ # "full" β€” return raw response including <think> blocks
31
+ THINKING_MODE=off
32
+ # Extra tokens added when thinking is enabled (strip/full). Ignored when off.
33
+ THINKING_TOKEN_BUDGET=4096
34
+
35
+ # ── Latency fallback threshold (seconds) ──────────────────────────────────────
36
+ FALLBACK_LATENCY_THRESHOLD=3.5
.gitignore ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *.pyo
5
+ *.pyd
6
+ .Python
7
+ *.egg-info/
8
+ dist/
9
+ build/
10
+
11
+ # Virtual environment
12
+ .venv/
13
+ venv/
14
+ env/
15
+
16
+ # Environment secrets
17
+ .env
18
+
19
+ # Data β€” indexes are rebuilt from source; do NOT commit binaries
20
+ data/faiss_store/
21
+
22
+ # Air-writing templates (large numpy files, track separately if needed)
23
+ data/air_write_templates/
24
+
25
+ # MLflow run artifacts
26
+ mlruns/
27
+
28
+ # Latency logs
29
+ timings.csv
30
+ *.csv
31
+
32
+ # IDE
33
+ .vscode/
34
+ .idea/
35
+ *.swp
36
+
37
+ # Claude Code β€” local settings and generated knowledge graph
38
+ .claude/
39
+ .code-review-graph/
40
+
41
+ # macOS
42
+ .DS_Store
43
+
44
+ # Jupyter
45
+ .ipynb_checkpoints/
46
+ *.ipynb
CLAUDE.md ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Multimodal AAC Chatbot β€” Project Guide
2
+
3
+ ## What This Project Does
4
+
5
+ An AI chatbot that **speaks as an AAC user**, not to them. Given a user persona
6
+ (Mia, Gerald, or Arjun), it fuses real-time multimodal non-verbal signals with
7
+ personal memory retrieval to generate responses in that person's authentic voice.
8
+ Orchestrated as a **LangGraph stateful directed graph** across five layers.
9
+
10
+ ---
11
+
12
+ ## Architecture
13
+
14
+ ```
15
+ main.py / api/main.py / ui/app.py
16
+ └── pipeline/graph.py ← LangGraph StateGraph (5 nodes + cond. edges)
17
+ β”œβ”€β”€ pipeline/nodes/intent.py L2 β€” LLM + Pydantic intent routing
18
+ β”œβ”€β”€ pipeline/nodes/retrieval.py L3 β€” FAISS + BGE retrieval (fast / full)
19
+ β”œβ”€β”€ pipeline/nodes/planner.py L4 β€” expression-conditioned generation
20
+ └── pipeline/nodes/feedback.py L5 β€” MLflow logging + Bayesian priors
21
+
22
+ sensing/ L1 β€” MediaPipe face mesh, gesture, gaze, air writing
23
+ retrieval/ FAISS ops, HDBSCAN clustering, Bayesian bucket priors
24
+ generation/ Multi-tier LLM client (vLLM primary / fallback / Ollama local)
25
+ guardrails/ Input + output safety checks
26
+ config/ Pydantic BaseSettings β€” all config in one place
27
+ ```
28
+
29
+ ## Key Design Decisions
30
+
31
+ - **LangGraph** orchestrates the pipeline as a stateful directed graph with
32
+ conditional edges (affect β†’ fast/full retrieval; latency β†’ primary/fallback LLM)
33
+ - **BGE-small-en-v1.5** for embeddings (beats MiniLM on MTEB at same speed)
34
+ - **BGE-reranker-v2-m3** cross-encoder β€” multilingual, handles Arjun's Hindi
35
+ - **FAISS IndexFlatIP** with L2-normalised vectors (inner product = cosine sim)
36
+ - **Qwen3-30B-A3B** MoE via vLLM β€” 3B active params/token, sub-3s on T4
37
+ - **Three-tier LLM fallback**: primary (vLLM GCP) β†’ fallback (Qwen3-8B) β†’ local (Ollama)
38
+ - **Pydantic-validated** LLM routing output β€” LangGraph retries on schema failures
39
+ - **Expression-conditioned response shaping** β€” affect steers tone, retrieval depth,
40
+ and candidate ranking (not just metadata annotation)
41
+ - **Bayesian bucket priors** β€” session-level P(bucket) updated after each accepted turn
42
+
43
+ ---
44
+
45
+ ## Personas
46
+
47
+ | ID | Name | Condition | Access |
48
+ |----|------|-----------|--------|
49
+ | `mia_chen` | Mia Chen, 28 | Cerebral palsy | Webcam head-tracking |
50
+ | `gerald_okafor` | Gerald Okafor, 61 | ALS (early-mid) | Eye-gaze device |
51
+ | `arjun_mehta` | Arjun Mehta, 17 | Autism (non-verbal) | Tablet touch grid |
52
+
53
+ 25 memory chunks each (5 buckets Γ— 5 memories). Arjun code-switches Hindi/English.
54
+
55
+ ---
56
+
57
+ ## How to Run
58
+
59
+ ```bash
60
+ # One-time setup: rebuild FAISS indexes with BGE embedder
61
+ python -m retrieval.vector_store
62
+
63
+ # CLI (local Ollama tier, set ACTIVE_LLM_TIER=local in .env)
64
+ python main.py --debug
65
+
66
+ # Full stack
67
+ uvicorn api.main:app --reload # FastAPI on :8000
68
+ streamlit run ui/app.py # Streamlit on :8501
69
+ ```
70
+
71
+ ---
72
+
73
+ ## Configuration
74
+
75
+ All config lives in [config/settings.py](config/settings.py) as Pydantic `BaseSettings`.
76
+ Copy `.env.example` β†’ `.env` and set:
77
+
78
+ - `ACTIVE_LLM_TIER` β€” `local` (dev) | `primary` (GCP A100) | `fallback` (Qwen3-8B)
79
+ - `PRIMARY_BASE_URL` β€” vLLM server address on GCP
80
+ - `MLFLOW_TRACKING_URI` β€” where MLflow stores runs (default: `mlruns/`)
81
+
82
+ ---
83
+
84
+ ## Data Files
85
+
86
+ | Path | Purpose |
87
+ |------|---------|
88
+ | `data/users.json` | Flat user index (id, name, condition, style) |
89
+ | `data/memories/<uid>.json` | Full persona JSON with bucketed memories |
90
+ | `data/faiss_store/<uid>/` | FAISS index + metadata β€” **rebuild after any persona edit** |
91
+ | `data/generate_users.py` | Regenerates memories + users.json |
92
+
93
+ ---
94
+
95
+ ## Development Notes
96
+
97
+ - **Adding a persona**: add to `PERSONAS` in `data/generate_users.py`, re-run it,
98
+ then `python -m retrieval.vector_store` to rebuild indexes
99
+ - **Changing LLM**: set `ACTIVE_LLM_TIER` in `.env` β€” no code changes needed
100
+ - **Extending sensing**: add module under `sensing/`, wire output into
101
+ `PipelineState` fields in `pipeline/state.py`
102
+ - **Guardrail tuning**: edit signal lists in `guardrails/checks.py`
103
+ - **Affect β†’ generation mapping**: `_AFFECT_CONFIG` in `pipeline/nodes/intent.py`
104
+ and `_PERSONA_TONE_OVERRIDES` in `pipeline/nodes/planner.py`
105
+ - The `.venv/` directory is local β€” do not read or modify files inside it
106
+ - FAISS indexes in `data/faiss_store/` are gitignored β€” rebuilt from source JSONs
README.md CHANGED
@@ -1,74 +1,206 @@
1
  # Multimodal AAC Chatbot
2
 
3
- A multimodal chatbot designed to empower **Augmentative and Alternative Communication (AAC)** users β€” enabling more natural, accessible, and expressive conversations through the power of AI.
 
 
 
 
4
 
5
  ---
6
 
7
- ## What is AAC?
8
 
9
- **Augmentative and Alternative Communication (AAC)** refers to tools, strategies, and technologies that help people who have difficulty with spoken or written communication. AAC users may include individuals with:
 
 
 
 
 
 
 
 
 
 
10
 
11
- - Autism Spectrum Disorder (ASD)
12
- - Cerebral Palsy
13
- - ALS / Motor Neurone Disease
14
- - Aphasia
15
- - Down Syndrome
16
- - Or any other condition that impacts verbal communication
17
 
18
- AAC tools range from low-tech picture boards to high-tech speech-generating devices. This project brings the power of modern AI chatbots to the AAC community.
 
 
 
19
 
20
  ---
21
 
22
- ## About This Project
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- The **Multimodal AAC Chatbot** is an AI-powered conversational assistant built with AAC users in mind. It accepts multiple input modalities β€” such as text, images, and symbols β€” and generates clear, accessible responses to support communication.
25
 
26
- ### Key Features
27
 
28
- - πŸ—£οΈ **Multimodal Input** β€” Communicate using text, images, symbols, or a combination of all three
29
- - πŸ€– **AI-Powered Responses** β€” Leverages large language models (LLMs) to generate natural and context-aware replies
30
- - β™Ώ **Accessibility First** β€” Designed from the ground up for users with communication challenges
31
- - 🧩 **AAC-Friendly Interface** β€” Supports common AAC workflows and symbol-based communication
32
- - πŸ’¬ **Conversational Context** β€” Maintains conversation history for more coherent, multi-turn dialogues
33
 
34
  ---
35
 
36
- ## Getting Started
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- ### Prerequisites
39
 
40
- - Python 3.8 or higher
41
- - pip
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- ### Installation
44
 
45
- 1. **Clone the repository**
46
- ```bash
47
- git clone https://github.com/akashkolte/multimodal_aac_chatbot.git
48
- cd multimodal_aac_chatbot
49
- ```
50
 
51
- 2. **Install dependencies**
52
- ```bash
53
- pip install -r requirements.txt
54
- ```
55
 
56
- 3. **Run the chatbot**
57
- ```bash
58
- python app.py
59
- ```
 
 
 
 
 
 
60
 
61
  ---
62
 
63
- ## Usage
 
 
64
 
65
- Once running, users can interact with the chatbot by:
 
 
 
 
 
 
 
66
 
67
- - Typing a text message
68
- - Uploading an image or symbol to describe their intent
69
- - Combining symbols and short text phrases as AAC users typically do
 
70
 
71
- The chatbot will interpret the input and respond in a clear, friendly manner.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  ---
74
 
@@ -76,28 +208,79 @@ The chatbot will interpret the input and respond in a clear, friendly manner.
76
 
77
  ```
78
  multimodal_aac_chatbot/
79
- β”œβ”€β”€ app.py # Main application entry point
80
- β”œβ”€β”€ requirements.txt # Python dependencies
81
- β”œβ”€β”€ README.md # Project documentation
82
- └── LICENSE # License information
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  ```
84
 
85
  ---
86
 
87
- ## Contributing
 
 
 
 
 
 
88
 
89
- This project is currently under active development. Feedback and suggestions from the AAC community and researchers are very welcome β€” please open an issue to share your thoughts.
90
 
91
- > **Note:** This software is proprietary. All rights are reserved. Any use, copying, modification, or distribution requires explicit written permission from the author.
92
 
93
  ---
94
 
95
- ## License
 
 
 
96
 
97
- All rights reserved. No permission is granted to use, copy, modify, or distribute this software. See the [LICENSE](LICENSE) file for details.
98
 
99
  ---
100
 
101
- ## Acknowledgements
102
 
103
- This project is dedicated to the AAC community and the researchers, caregivers, and developers working to make communication more accessible for everyone.
 
1
  # Multimodal AAC Chatbot
2
 
3
+ An AI chatbot that **speaks as an AAC user**, not to them. Given a persona (Mia, Gerald, or Arjun),
4
+ it fuses real-time multimodal non-verbal signals β€” facial expressions, hand gestures, gaze, and
5
+ air writing β€” with personal memory retrieval to generate responses in that person's authentic voice.
6
+
7
+ Built as a training-free, agentic RAG pipeline orchestrated via **LangGraph**.
8
 
9
  ---
10
 
11
+ ## Table of Contents
12
 
13
+ - [What is AAC?](#what-is-aac)
14
+ - [System Architecture](#system-architecture)
15
+ - [Prerequisites](#prerequisites)
16
+ - [Setup](#setup)
17
+ - [Configuration](#configuration)
18
+ - [Running the Project](#running-the-project)
19
+ - [Project Structure](#project-structure)
20
+ - [Personas](#personas)
21
+ - [Team](#team)
22
+
23
+ ---
24
 
25
+ ## What is AAC?
 
 
 
 
 
26
 
27
+ **Augmentative and Alternative Communication (AAC)** refers to tools and technologies that help
28
+ people who have difficulty with spoken or written communication β€” including individuals with
29
+ Cerebral Palsy, ALS, Autism Spectrum Disorder, and other conditions. This project gives AAC users
30
+ a personalized digital twin that communicates on their behalf.
31
 
32
  ---
33
 
34
+ ## System Architecture
35
+
36
+ ```
37
+ Webcam (L1: sensing) β†’ Intent Decomposition (L2) β†’ Retrieval (L3) β†’ Generation (L4) β†’ Feedback (L5)
38
+ ```
39
+
40
+ | Layer | Module | What it does |
41
+ |-------|--------|-------------|
42
+ | L1 | `sensing/` | MediaPipe face mesh, hand gestures, gaze tracking, air writing |
43
+ | L2 | `pipeline/nodes/intent.py` | LLM + Pydantic-validated intent routing |
44
+ | L3 | `pipeline/nodes/retrieval.py` | FAISS + BGE embeddings + cross-encoder reranking |
45
+ | L4 | `pipeline/nodes/planner.py` | Expression-conditioned response generation (Qwen3) |
46
+ | L5 | `pipeline/nodes/feedback.py` | MLflow tracking + Bayesian bucket prior update |
47
+
48
+ The pipeline runs as a **LangGraph stateful directed graph** with conditional edges:
49
+ - FRUSTRATED affect β†’ fast retrieval path (k=2, no reranker)
50
+ - Latency > 3.5s β†’ fallback to smaller Qwen3-8B model
51
 
52
+ ---
53
 
54
+ ## Prerequisites
55
 
56
+ - Python **3.10 – 3.12** (Python 3.14 has a known Pydantic v1 incompatibility warning β€” functional but noisy)
57
+ - [Ollama](https://ollama.com) installed locally for the `local` LLM tier
58
+ - A webcam (required for the live sensing layer; optional for CLI mode)
59
+ - Git
 
60
 
61
  ---
62
 
63
+ ## Setup
64
+
65
+ ### 1. Clone the repository
66
+
67
+ ```bash
68
+ git clone https://github.com/akashkolte/multimodal_aac_chatbot.git
69
+ cd multimodal_aac_chatbot
70
+ ```
71
+
72
+ ### 2. Check out the active branch
73
+
74
+ ```bash
75
+ git checkout akash/v1
76
+ ```
77
+
78
+ ### 3. Create and activate a virtual environment
79
+
80
+ ```bash
81
+ python3 -m venv .venv
82
+ source .venv/bin/activate # macOS / Linux
83
+ # .venv\Scripts\activate # Windows
84
+ ```
85
+
86
+ ### 4. Install dependencies
87
+
88
+ ```bash
89
+ pip install -r requirements.txt
90
+ ```
91
+
92
+ > This installs LangGraph, FAISS, sentence-transformers (BGE), FastAPI, Streamlit, MLflow,
93
+ > MediaPipe, and all other dependencies.
94
+
95
+ ### 5. Configure environment variables
96
+
97
+ ```bash
98
+ cp .env.example .env
99
+ ```
100
+
101
+ Open `.env` and set at minimum:
102
+
103
+ ```env
104
+ ACTIVE_LLM_TIER=local # use Ollama on your machine for dev
105
+ ```
106
+
107
+ See [Configuration](#configuration) for all options.
108
 
109
+ ### 6. Pull the local LLM model (Ollama)
110
 
111
+ ```bash
112
+ ollama pull qwen3:8b
113
+ ```
114
+
115
+ > Make sure Ollama is running (`ollama serve`) before starting the chatbot.
116
+
117
+ ### 7. Build FAISS indexes
118
+
119
+ The persona memory indexes must be built once with the BGE embedder before first run:
120
+
121
+ ```bash
122
+ python -m retrieval.vector_store
123
+ ```
124
+
125
+ Expected output:
126
+ ```
127
+ Building index for arjun_mehta … Saved 25 chunks
128
+ Building index for gerald_okafor … Saved 25 chunks
129
+ Building index for mia_chen … Saved 25 chunks
130
+ All indexes built.
131
+ ```
132
+
133
+ > You must re-run this step whenever you add or edit persona memory files.
134
 
135
+ ---
136
 
137
+ ## Configuration
 
 
 
 
138
 
139
+ All settings live in [config/settings.py](config/settings.py) and can be overridden via `.env`.
 
 
 
140
 
141
+ | Variable | Default | Description |
142
+ |----------|---------|-------------|
143
+ | `ACTIVE_LLM_TIER` | `local` | `local` (Ollama) \| `primary` (vLLM GCP) \| `fallback` (Qwen3-8B) |
144
+ | `LOCAL_MODEL` | `qwen3:8b` | Ollama model name for local dev |
145
+ | `LOCAL_BASE_URL` | `http://localhost:11434/v1` | Ollama OpenAI-compatible endpoint |
146
+ | `PRIMARY_BASE_URL` | *(GCP IP)* | vLLM server URL on GCP (set when using cloud tier) |
147
+ | `PRIMARY_MODEL` | `Qwen/Qwen3-30B-A3B` | Primary MoE model served via vLLM |
148
+ | `FALLBACK_LATENCY_THRESHOLD` | `3.5` | Seconds before falling back to smaller model |
149
+ | `MLFLOW_TRACKING_URI` | `mlruns` | Local MLflow storage path |
150
+ | `MLFLOW_EXPERIMENT` | `aac-chatbot` | MLflow experiment name |
151
 
152
  ---
153
 
154
+ ## Running the Project
155
+
156
+ ### Option A β€” CLI (simplest, no webcam needed)
157
 
158
+ ```bash
159
+ python main.py
160
+ ```
161
+
162
+ With debug latency output:
163
+ ```bash
164
+ python main.py --debug
165
+ ```
166
 
167
+ Select a specific persona and LLM tier:
168
+ ```bash
169
+ python main.py --user mia_chen --tier local
170
+ ```
171
 
172
+ ### Option B β€” Full stack (FastAPI + Streamlit UI)
173
+
174
+ Start the API server in one terminal:
175
+ ```bash
176
+ uvicorn api.main:app --reload --port 8000
177
+ ```
178
+
179
+ Start the Streamlit frontend in another terminal:
180
+ ```bash
181
+ streamlit run ui/app.py
182
+ ```
183
+
184
+ Then open [http://localhost:8501](http://localhost:8501) in your browser.
185
+
186
+ The UI includes:
187
+ - Persona selector
188
+ - Affect override controls (simulate webcam for testing)
189
+ - Live chat interface
190
+ - Per-turn latency breakdown panel
191
+
192
+ ### Option C β€” API only (for integration / testing)
193
+
194
+ ```bash
195
+ uvicorn api.main:app --reload
196
+ ```
197
+
198
+ Example request:
199
+ ```bash
200
+ curl -X POST http://localhost:8000/chat \
201
+ -H "Content-Type: application/json" \
202
+ -d '{"user_id": "mia_chen", "query": "What do you like to do on weekends?"}'
203
+ ```
204
 
205
  ---
206
 
 
208
 
209
  ```
210
  multimodal_aac_chatbot/
211
+ β”‚
212
+ β”œβ”€β”€ config/
213
+ β”‚ └── settings.py # All config via Pydantic BaseSettings
214
+ β”‚
215
+ β”œβ”€β”€ data/
216
+ β”‚ β”œβ”€β”€ generate_users.py # Regenerates persona memories + users.json
217
+ β”‚ β”œβ”€β”€ users.json # Flat user index
218
+ β”‚ β”œβ”€β”€ memories/ # Per-persona memory JSON files
219
+ β”‚ └── faiss_store/ # Built FAISS indexes (gitignored, rebuild locally)
220
+ β”‚
221
+ β”œβ”€β”€ sensing/ # L1 β€” multimodal input
222
+ β”‚ β”œβ”€β”€ face_mesh.py # MediaPipe affect detection (MAR/EAR/BRI/LCP)
223
+ β”‚ β”œβ”€β”€ gesture.py # Hand gesture classifier
224
+ β”‚ β”œβ”€β”€ gaze.py # Gaze-based bucket activation (bonus)
225
+ β”‚ └── air_writing.py # DTW air-writing stroke classifier (bonus)
226
+ β”‚
227
+ β”œβ”€β”€ pipeline/ # LangGraph orchestration
228
+ β”‚ β”œβ”€β”€ state.py # Typed PipelineState (TypedDict)
229
+ β”‚ β”œβ”€β”€ graph.py # Graph definition + conditional edges
230
+ β”‚ └── nodes/
231
+ β”‚ β”œβ”€β”€ intent.py # L2 β€” LLM + Pydantic routing
232
+ β”‚ β”œβ”€β”€ retrieval.py # L3 β€” fast + full retrieval paths
233
+ β”‚ β”œβ”€β”€ planner.py # L4 β€” expression-conditioned generation
234
+ β”‚ └── feedback.py # L5 β€” MLflow + Bayesian prior update
235
+ β”‚
236
+ β”œβ”€β”€ retrieval/
237
+ β”‚ β”œβ”€β”€ vector_store.py # FAISS ops with BGE-small-en-v1.5
238
+ β”‚ β”œβ”€β”€ clustering.py # HDBSCAN semantic bucketing
239
+ β”‚ └── bucket_priors.py # Bayesian session priors
240
+ β”‚
241
+ β”œβ”€β”€ generation/
242
+ β”‚ └── llm_client.py # 3-tier LLM client (vLLM / Ollama)
243
+ β”‚
244
+ β”œβ”€β”€ guardrails/
245
+ β”‚ └── checks.py # Input + output safety checks
246
+ β”‚
247
+ β”œβ”€β”€ api/
248
+ β”‚ └── main.py # FastAPI backend
249
+ β”‚
250
+ β”œβ”€β”€ ui/
251
+ β”‚ └── app.py # Streamlit frontend
252
+ β”‚
253
+ β”œβ”€β”€ main.py # CLI entry point
254
+ β”œβ”€β”€ requirements.txt # Python dependencies
255
+ β”œβ”€β”€ .env.example # Environment variable template
256
+ └── CLAUDE.md # Developer notes (AI assistant context)
257
  ```
258
 
259
  ---
260
 
261
+ ## Personas
262
+
263
+ | ID | Name | Condition | Style | Access |
264
+ |----|------|-----------|-------|--------|
265
+ | `mia_chen` | Mia Chen, 28 | Cerebral palsy | Witty, dry humour, short punchy sentences | Webcam head-tracking |
266
+ | `gerald_okafor` | Gerald Okafor, 61 | ALS (early-mid stage) | Formal, measured, eloquent | Eye-gaze device |
267
+ | `arjun_mehta` | Arjun Mehta, 17 | Autism (non-verbal) | Direct, routine-focused, Hindi-English code-switching | Tablet touch grid |
268
 
269
+ Each persona has 25 memory chunks across 5 buckets: `family`, `medical`, `hobbies`, `daily_routine`, `social`.
270
 
271
+ To add a new persona, edit `data/generate_users.py` and re-run `python -m retrieval.vector_store`.
272
 
273
  ---
274
 
275
+ ## Team
276
+
277
+ - **Akash Kolte** β€” akashjag@buffalo.edu
278
+ - **Shwetangi** β€” shwetang@buffalo.edu
279
 
280
+ University at Buffalo, SUNY
281
 
282
  ---
283
 
284
+ ## License
285
 
286
+ All rights reserved. See the [LICENSE](LICENSE) file for details.
api/__init__.py ADDED
File without changes
api/main.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI backend β€” exposes the LangGraph pipeline as a REST API.
3
+
4
+ Endpoints:
5
+ POST /chat β€” single-turn inference (non-streaming)
6
+ POST /chat/stream β€” streaming token delivery via SSE
7
+ GET /users β€” list available personas
8
+ POST /session/reset β€” reset session state for a user
9
+ GET /health β€” liveness check
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import json
14
+ import time
15
+ from typing import AsyncGenerator
16
+
17
+ from fastapi import FastAPI, HTTPException
18
+ from fastapi.middleware.cors import CORSMiddleware
19
+ from fastapi.responses import StreamingResponse
20
+ from pydantic import BaseModel
21
+
22
+ from config.settings import settings
23
+ from guardrails.checks import check_input
24
+ from pipeline.graph import aac_graph
25
+ from pipeline.state import PipelineState
26
+ from retrieval.bucket_priors import uniform_priors
27
+
28
+ app = FastAPI(
29
+ title="Multimodal AAC Chatbot API",
30
+ description="Agentic RAG pipeline for AAC persona communication",
31
+ version="2.0.0",
32
+ )
33
+
34
+ app.add_middleware(
35
+ CORSMiddleware,
36
+ allow_origins=["*"],
37
+ allow_methods=["*"],
38
+ allow_headers=["*"],
39
+ )
40
+
41
+ # ── In-memory session store (replace with Redis for multi-worker deployments) ──
42
+ _sessions: dict[str, dict] = {}
43
+
44
+
45
+ # ── Request / response schemas ─────────────────────────────────────────────────
46
+
47
+ class ChatRequest(BaseModel):
48
+ user_id: str
49
+ query: str
50
+ affect_override: str | None = None # "HAPPY"|"FRUSTRATED"|"NEUTRAL"|"SURPRISED"
51
+ gesture_tag: str | None = None
52
+ gaze_bucket: str | None = None
53
+
54
+
55
+ class ChatResponse(BaseModel):
56
+ user_id: str
57
+ query: str
58
+ response: str
59
+ affect: str
60
+ llm_tier: str
61
+ retrieval_mode: str
62
+ latency: dict
63
+ guardrail_passed: bool
64
+
65
+
66
+ # ── Helpers ────────────────────────────────────────────────────────────────────
67
+
68
+ def _get_or_init_session(user_id: str) -> dict:
69
+ if user_id not in _sessions:
70
+ with open(settings.users_json) as f:
71
+ users = {u["id"]: u for u in json.load(f)["users"]}
72
+ if user_id not in users:
73
+ raise HTTPException(status_code=404, detail=f"User '{user_id}' not found")
74
+ _sessions[user_id] = {
75
+ "persona_profile": users[user_id],
76
+ "session_history": [],
77
+ "bucket_priors": uniform_priors(),
78
+ "turn_id": 0,
79
+ }
80
+ return _sessions[user_id]
81
+
82
+
83
+ def _build_initial_state(req: ChatRequest, session: dict) -> PipelineState:
84
+ affect_state = None
85
+ if req.affect_override:
86
+ affect_state = {"emotion": req.affect_override, "vector": {}, "smoothed": {}}
87
+
88
+ session["turn_id"] += 1
89
+
90
+ return PipelineState(
91
+ user_id=req.user_id,
92
+ persona_profile=session["persona_profile"],
93
+ session_history=session["session_history"],
94
+ turn_id=session["turn_id"],
95
+ affect=affect_state,
96
+ gesture_tag=req.gesture_tag,
97
+ gaze_bucket=req.gaze_bucket,
98
+ air_written_text=None,
99
+ raw_query=req.query,
100
+ intent_route=None,
101
+ generation_config=None,
102
+ retrieved_chunks=[],
103
+ bucket_priors=session["bucket_priors"],
104
+ retrieval_mode_used="",
105
+ augmented_prompt=None,
106
+ candidates=[],
107
+ selected_response=None,
108
+ llm_tier_used="",
109
+ latency_log={"t_sensing": 0.0, "t_intent": 0.0, "t_retrieval": 0.0, "t_generation": 0.0, "t_total": 0.0},
110
+ mlflow_run_id=None,
111
+ guardrail_passed=True,
112
+ )
113
+
114
+
115
+ # ── Routes ─────────────────────────────────────────────────────────────────────
116
+
117
+ @app.get("/health")
118
+ def health():
119
+ return {"status": "ok"}
120
+
121
+
122
+ @app.get("/users")
123
+ def list_users():
124
+ with open(settings.users_json) as f:
125
+ return json.load(f)
126
+
127
+
128
+ @app.post("/session/reset")
129
+ def reset_session(user_id: str):
130
+ _sessions.pop(user_id, None)
131
+ return {"status": "reset", "user_id": user_id}
132
+
133
+
134
+ @app.post("/chat", response_model=ChatResponse)
135
+ def chat(req: ChatRequest):
136
+ guard = check_input(req.query)
137
+ if not guard["allowed"]:
138
+ return ChatResponse(
139
+ user_id=req.user_id,
140
+ query=req.query,
141
+ response=guard["fallback"],
142
+ affect="NEUTRAL",
143
+ llm_tier="none",
144
+ retrieval_mode="none",
145
+ latency={},
146
+ guardrail_passed=False,
147
+ )
148
+
149
+ session = _get_or_init_session(req.user_id)
150
+ initial_state = _build_initial_state(req, session)
151
+
152
+ result: PipelineState = aac_graph.invoke(initial_state)
153
+
154
+ # Persist updated session state
155
+ session["session_history"] = result["session_history"]
156
+ session["bucket_priors"] = result["bucket_priors"]
157
+
158
+ return ChatResponse(
159
+ user_id=req.user_id,
160
+ query=req.query,
161
+ response=result["selected_response"] or "",
162
+ affect=(result.get("affect") or {}).get("emotion", "NEUTRAL"),
163
+ llm_tier=result.get("llm_tier_used", "unknown"),
164
+ retrieval_mode=result.get("retrieval_mode_used", "unknown"),
165
+ latency=result.get("latency_log") or {},
166
+ guardrail_passed=result.get("guardrail_passed", True),
167
+ )
config/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from config.settings import settings
2
+
3
+ __all__ = ["settings"]
config/settings.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from pydantic_settings import BaseSettings, SettingsConfigDict
3
+
4
+
5
+ class Settings(BaseSettings):
6
+ model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
7
+
8
+ # ── Paths ──────────────────────────────────────────────────────────────────
9
+ data_dir: Path = Path("data")
10
+ faiss_store_dir: Path = Path("data/faiss_store")
11
+ memories_dir: Path = Path("data/memories")
12
+ users_json: Path = Path("data/users.json")
13
+
14
+ # ── Retrieval models ───────────────────────────────────────────────────────
15
+ embed_model: str = "BAAI/bge-small-en-v1.5"
16
+ rerank_model: str = "BAAI/bge-reranker-v2-m3"
17
+ retrieval_top_k: int = 5
18
+ retrieval_rerank_k: int = 3
19
+ retrieval_fast_k: int = 2 # used when affect == FRUSTRATED
20
+
21
+ # ── LLM tiers ─────────────────────────────────────────────────────────────
22
+ # Tier 1 β€” primary (Qwen3-30B-A3B via vLLM on GCP)
23
+ primary_model: str = "Qwen/Qwen3-30B-A3B"
24
+ primary_base_url: str = "http://localhost:8000/v1"
25
+ primary_api_key: str = "token-abc" # vLLM default
26
+
27
+ # Tier 2 β€” fallback dense model (Qwen3-8B via vLLM, same server)
28
+ fallback_model: str = "Qwen/Qwen3-8B"
29
+ fallback_base_url: str = "http://localhost:8000/v1"
30
+
31
+ # Tier 3 β€” local dev (Ollama on MacBook M2)
32
+ local_model: str = "qwen3:8b"
33
+ local_base_url: str = "http://localhost:11434/v1"
34
+ local_api_key: str = "ollama"
35
+
36
+ # Active tier: "primary" | "fallback" | "local"
37
+ active_llm_tier: str = "local"
38
+
39
+ # Thinking mode: "off" = plain completion, no thinking whatsoever
40
+ # "strip" = let model think, but strip <think> tags from output
41
+ # "full" = return raw response including <think> blocks
42
+ # "suppress" = actively suppress thinking via /no_think (Ollama) or
43
+ # chat_template_kwargs (vLLM). Use for models like Qwen3
44
+ # that think by default and need explicit suppression.
45
+ thinking_mode: str = "off"
46
+
47
+ # Extra token budget added on top of max_tokens when thinking is enabled
48
+ # (thinking_mode = "strip" or "full"). Set to 0 if using a non-thinking model.
49
+ thinking_token_budget: int = 4096
50
+
51
+ # Wall-clock threshold (seconds) that triggers fallback within a turn
52
+ fallback_latency_threshold: float = 3.5
53
+
54
+ # ── Generation ────────────────────────────────────────────────────────────
55
+ max_tokens_happy: int = 150
56
+ max_tokens_neutral: int = 100
57
+ max_tokens_frustrated: int = 60
58
+ max_tokens_surprised: int = 80
59
+ num_candidates: int = 2 # responses generated per turn for ranking
60
+
61
+ # ── Sensing ───────────────────────────────────────────────────────────────
62
+ affect_ema_alpha: float = 0.3 # exponential moving average smoothing
63
+ gaze_dwell_threshold_s: float = 1.5
64
+ air_write_velocity_start: int = 15 # px/frame β€” stroke begin threshold
65
+ air_write_velocity_end: int = 5 # px/frame β€” stroke end threshold
66
+ air_write_end_gap_ms: int = 200 # ms of stillness to end a stroke
67
+ conflict_overlap_ms: int = 500 # audio + gesture co-occurrence window
68
+
69
+ # ── MLflow ────────────────────────────────────────────────────────────────
70
+ mlflow_tracking_uri: str = "mlruns"
71
+ mlflow_experiment: str = "aac-chatbot"
72
+
73
+ # ── Candidate ranking weights (Eq. 2 in proposal) ─────────────────────────
74
+ rank_alpha: float = 0.4 # faithfulness weight
75
+ rank_beta: float = 0.3 # style similarity weight
76
+ rank_gamma: float = 0.3 # affect-match weight
77
+
78
+
79
+ settings = Settings()
data/generate_users.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ # ── 3 hand-crafted AAC personas ───────────────────────────────────────────────
5
+ # Each has a distinct condition, voice, and bucketed memories.
6
+ # Depth > quantity: 3 rich personas beat 50 generic ones for retrieval quality.
7
+
8
+ PERSONAS = [
9
+
10
+ {
11
+ "profile": {
12
+ "name": "Mia Chen",
13
+ "age": 28,
14
+ "condition": "cerebral palsy",
15
+ "communication_style":"witty, dry humour, short punchy sentences, uses sarcasm",
16
+ "access_method": "webcam head-tracking",
17
+ "languages": ["English"]
18
+ },
19
+ "memory_buckets": {
20
+ "family": [
21
+ "My mom calls every Sunday and always asks if I've eaten. I love it but won't admit it.",
22
+ "My brother Ravi helped me set up this AAC system. He's at Cornell doing CS.",
23
+ "We do a family movie night every Diwali β€” always an 80s Bollywood film nobody likes except Dad.",
24
+ "My parents moved from Chengdu before I was born. We still make dumplings on Chinese New Year.",
25
+ "My sister Lena is three years younger and somehow already more responsible than me."
26
+ ],
27
+ "medical": [
28
+ "I have a PT session every Tuesday at 2pm with Dr. Sandra Hollis.",
29
+ "I use a power wheelchair. The joystick is on my left side.",
30
+ "I'm allergic to penicillin. I have to mention this at every hospital visit.",
31
+ "My spasticity is worse in cold weather. Winter in Chicago is not my friend.",
32
+ "I use baclofen for muscle tone. It makes me sleepy if I take it too early."
33
+ ],
34
+ "hobbies": [
35
+ "I follow competitive Smash Bros. I could beat most people if my hands worked differently.",
36
+ "I've been watching every Studio Ghibli film in order. Currently on Porco Rosso.",
37
+ "I collect vintage sci-fi paperbacks. Asimov and Le Guin mostly.",
38
+ "I got really into chess puzzles during lockdown. Still do them before bed.",
39
+ "I enjoy critiquing bad movie sequels. It's practically a hobby at this point."
40
+ ],
41
+ "daily_routine": [
42
+ "Mornings are slow. I need about 45 minutes before I feel like a person.",
43
+ "I order from the same Thai place every Friday. Green curry, always.",
44
+ "I keep a voice memo journal since typing long things is tiring.",
45
+ "I usually watch one episode of something after dinner to decompress.",
46
+ "My caregiver Marcus arrives at 8am on weekdays. He makes decent coffee."
47
+ ],
48
+ "social": [
49
+ "My best friend Priya visits on weekends. She narrates everything like a nature documentary.",
50
+ "I'm part of an online disability advocacy group. We meet on Zoom every other Wednesday.",
51
+ "I don't love big parties. Small dinners with three or four people are my ideal.",
52
+ "My neighbour Tom always stops to chat when I'm outside. He's retired and lonely, I think.",
53
+ "I met most of my close friends through a gaming Discord server."
54
+ ]
55
+ }
56
+ },
57
+
58
+ {
59
+ "profile": {
60
+ "name": "Gerald Okafor",
61
+ "age": 61,
62
+ "condition": "ALS (early-to-mid stage)",
63
+ "communication_style":"formal, measured, eloquent, longer structured sentences",
64
+ "access_method": "eye-gaze device",
65
+ "languages": ["English"]
66
+ },
67
+ "memory_buckets": {
68
+ "family": [
69
+ "My wife Constance and I have been married for 34 years. She is the reason I stay organised.",
70
+ "My son Emeka is a civil engineer based in Houston. He calls every Thursday evening.",
71
+ "My daughter Adaeze is doing her residency in paediatrics in Baltimore. I am very proud.",
72
+ "We used to take a family trip to Lagos every two years to visit my mother's side.",
73
+ "My youngest grandchild, Tobenna, was born last April. I have not met him in person yet."
74
+ ],
75
+ "medical": [
76
+ "I was diagnosed with ALS in November 2024. I am still adjusting to what that means day to day.",
77
+ "My speech was the first thing to decline noticeably. That is why I began using AAC.",
78
+ "I see my neurologist Dr. Patricia Eze at Northwestern every six weeks.",
79
+ "I take riluzole daily. I have not noticed significant side effects so far.",
80
+ "My occupational therapist is helping me adapt my home office for continued work."
81
+ ],
82
+ "hobbies": [
83
+ "I taught economics at DePaul University for twenty-two years.",
84
+ "I have read most of Chinua Achebe's work. Things Fall Apart shaped how I see storytelling.",
85
+ "I enjoy chess β€” classical time controls, not blitz. Patience is the point.",
86
+ "I used to cook elaborate Sunday stews. Constance has taken that over now, which is bittersweet.",
87
+ "I listen to Fela Kuti when I need to feel grounded. Always has."
88
+ ],
89
+ "daily_routine": [
90
+ "I begin each morning by reading two newspapers β€” the Tribune and the Guardian.",
91
+ "I try to write for at least thirty minutes each day, even if it is just reflections.",
92
+ "Afternoons are for rest. My energy is most reliable in the mornings.",
93
+ "Constance and I watch the evening news together. We have done this for decades.",
94
+ "I use the eye-gaze device for most communication now. It takes patience but it works."
95
+ ],
96
+ "social": [
97
+ "My closest friend is Charles Nwosu. We have known each other since secondary school in Enugu.",
98
+ "I stay in touch with former colleagues at DePaul, though visits have become less frequent.",
99
+ "My church community at St. Clement has been a source of genuine support since my diagnosis.",
100
+ "I prefer one-on-one conversations. I find group settings harder to follow now.",
101
+ "I joined an ALS support group that meets virtually. It helps more than I expected."
102
+ ]
103
+ }
104
+ },
105
+
106
+ {
107
+ "profile": {
108
+ "name": "Arjun Mehta",
109
+ "age": 17,
110
+ "condition": "autism spectrum disorder (non-verbal)",
111
+ "communication_style":"direct, topic-specific, narrow vocabulary, code-switches Hindi/English, routine-focused",
112
+ "access_method": "tablet touch grid + AAC app",
113
+ "languages": ["English", "Hindi"]
114
+ },
115
+ "memory_buckets": {
116
+ "family": [
117
+ "Mummy makes aloo paratha on Sunday mornings. That is my favourite thing.",
118
+ "Papa works at a software company. He brings home a samosa sometimes on Fridays.",
119
+ "My dadi lives with us. She watches serials very loudly but I like that she is home.",
120
+ "My cousin Rohan visits in the summer. We play Minecraft together for many hours.",
121
+ "Mummy knows what I want even when I cannot say it. She is very good at that."
122
+ ],
123
+ "medical": [
124
+ "I see my therapist Riya didi every Wednesday at 4pm.",
125
+ "I do not like the occupational therapy exercises but I do them.",
126
+ "I cannot eat food that has a slimy texture. It makes me feel very bad.",
127
+ "I take melatonin at night. Without it, sleeping is very hard.",
128
+ "My school has a support aide named Mr. Fernandez. He is calm and that helps."
129
+ ],
130
+ "hobbies": [
131
+ "I know the complete timetable of all Mumbai Metro lines.",
132
+ "I like sorting my LEGO bricks by colour and size before building.",
133
+ "My favourite YouTube channel is about deep sea creatures. Anglerfish are very strange.",
134
+ "I have watched the same three episodes of Doraemon more than fifty times each.",
135
+ "I am learning the capitals of every country. I know 142 so far."
136
+ ],
137
+ "daily_routine": [
138
+ "I wake up at 6:47am. Changing this time makes my whole day feel wrong.",
139
+ "I eat the same breakfast β€” two rotis with ghee and one glass of milk.",
140
+ "School starts at 8:30am. I like to arrive before the other students.",
141
+ "After school I need quiet time for at least one hour. No talking.",
142
+ "Dinner must be at 7:30pm. If it is late I feel very unsettled."
143
+ ],
144
+ "social": [
145
+ "I have one friend at school named Vivaan. We do not talk much but we sit together.",
146
+ "I do not like it when people stand too close. One arm's distance is comfortable.",
147
+ "I prefer typing to speaking when I need to say something important.",
148
+ "Loud places with many people feel like too much information at once.",
149
+ "I like it when people tell me exactly what is going to happen next."
150
+ ]
151
+ }
152
+ }
153
+ ]
154
+
155
+
156
+ def main():
157
+ os.makedirs("memories", exist_ok=True)
158
+
159
+ user_index = []
160
+
161
+ for persona in PERSONAS:
162
+ uid = persona["profile"]["name"].lower().replace(" ", "_")
163
+ path = f"memories/{uid}.json"
164
+
165
+ with open(path, "w") as f:
166
+ json.dump(persona, f, indent=2, ensure_ascii=False)
167
+
168
+ user_index.append({
169
+ "id": uid,
170
+ "name": persona["profile"]["name"],
171
+ "condition": persona["profile"]["condition"],
172
+ "style": persona["profile"]["communication_style"],
173
+ "file": path
174
+ })
175
+
176
+ print(f" Wrote {path}")
177
+
178
+ with open("users.json", "w") as f:
179
+ json.dump({"users": user_index}, f, indent=2, ensure_ascii=False)
180
+
181
+ print(f"\n Done β€” {len(PERSONAS)} personas written to memories/")
182
+ print(" Files:", [u["file"] for u in user_index])
183
+
184
+
185
+ if __name__ == "__main__":
186
+ main()
data/memories/arjun_mehta.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "profile": {
3
+ "name": "Arjun Mehta",
4
+ "age": 17,
5
+ "condition": "autism spectrum disorder (non-verbal)",
6
+ "communication_style": "direct, topic-specific, narrow vocabulary, code-switches Hindi/English, routine-focused",
7
+ "access_method": "tablet touch grid + AAC app",
8
+ "languages": [
9
+ "English",
10
+ "Hindi"
11
+ ]
12
+ },
13
+ "memory_buckets": {
14
+ "family": [
15
+ "Mummy makes aloo paratha on Sunday mornings. That is my favourite thing.",
16
+ "Papa works at a software company. He brings home a samosa sometimes on Fridays.",
17
+ "My dadi lives with us. She watches serials very loudly but I like that she is home.",
18
+ "My cousin Rohan visits in the summer. We play Minecraft together for many hours.",
19
+ "Mummy knows what I want even when I cannot say it. She is very good at that."
20
+ ],
21
+ "medical": [
22
+ "I see my therapist Riya didi every Wednesday at 4pm.",
23
+ "I do not like the occupational therapy exercises but I do them.",
24
+ "I cannot eat food that has a slimy texture. It makes me feel very bad.",
25
+ "I take melatonin at night. Without it, sleeping is very hard.",
26
+ "My school has a support aide named Mr. Fernandez. He is calm and that helps."
27
+ ],
28
+ "hobbies": [
29
+ "I know the complete timetable of all Mumbai Metro lines.",
30
+ "I like sorting my LEGO bricks by colour and size before building.",
31
+ "My favourite YouTube channel is about deep sea creatures. Anglerfish are very strange.",
32
+ "I have watched the same three episodes of Doraemon more than fifty times each.",
33
+ "I am learning the capitals of every country. I know 142 so far."
34
+ ],
35
+ "daily_routine": [
36
+ "I wake up at 6:47am. Changing this time makes my whole day feel wrong.",
37
+ "I eat the same breakfast β€” two rotis with ghee and one glass of milk.",
38
+ "School starts at 8:30am. I like to arrive before the other students.",
39
+ "After school I need quiet time for at least one hour. No talking.",
40
+ "Dinner must be at 7:30pm. If it is late I feel very unsettled."
41
+ ],
42
+ "social": [
43
+ "I have one friend at school named Vivaan. We do not talk much but we sit together.",
44
+ "I do not like it when people stand too close. One arm's distance is comfortable.",
45
+ "I prefer typing to speaking when I need to say something important.",
46
+ "Loud places with many people feel like too much information at once.",
47
+ "I like it when people tell me exactly what is going to happen next."
48
+ ]
49
+ }
50
+ }
data/memories/gerald_okafor.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "profile": {
3
+ "name": "Gerald Okafor",
4
+ "age": 61,
5
+ "condition": "ALS (early-to-mid stage)",
6
+ "communication_style": "formal, measured, eloquent, longer structured sentences",
7
+ "access_method": "eye-gaze device",
8
+ "languages": [
9
+ "English"
10
+ ]
11
+ },
12
+ "memory_buckets": {
13
+ "family": [
14
+ "My wife Constance and I have been married for 34 years. She is the reason I stay organised.",
15
+ "My son Emeka is a civil engineer based in Houston. He calls every Thursday evening.",
16
+ "My daughter Adaeze is doing her residency in paediatrics in Baltimore. I am very proud.",
17
+ "We used to take a family trip to Lagos every two years to visit my mother's side.",
18
+ "My youngest grandchild, Tobenna, was born last April. I have not met him in person yet."
19
+ ],
20
+ "medical": [
21
+ "I was diagnosed with ALS in November 2024. I am still adjusting to what that means day to day.",
22
+ "My speech was the first thing to decline noticeably. That is why I began using AAC.",
23
+ "I see my neurologist Dr. Patricia Eze at Northwestern every six weeks.",
24
+ "I take riluzole daily. I have not noticed significant side effects so far.",
25
+ "My occupational therapist is helping me adapt my home office for continued work."
26
+ ],
27
+ "hobbies": [
28
+ "I taught economics at DePaul University for twenty-two years.",
29
+ "I have read most of Chinua Achebe's work. Things Fall Apart shaped how I see storytelling.",
30
+ "I enjoy chess β€” classical time controls, not blitz. Patience is the point.",
31
+ "I used to cook elaborate Sunday stews. Constance has taken that over now, which is bittersweet.",
32
+ "I listen to Fela Kuti when I need to feel grounded. Always has."
33
+ ],
34
+ "daily_routine": [
35
+ "I begin each morning by reading two newspapers β€” the Tribune and the Guardian.",
36
+ "I try to write for at least thirty minutes each day, even if it is just reflections.",
37
+ "Afternoons are for rest. My energy is most reliable in the mornings.",
38
+ "Constance and I watch the evening news together. We have done this for decades.",
39
+ "I use the eye-gaze device for most communication now. It takes patience but it works."
40
+ ],
41
+ "social": [
42
+ "My closest friend is Charles Nwosu. We have known each other since secondary school in Enugu.",
43
+ "I stay in touch with former colleagues at DePaul, though visits have become less frequent.",
44
+ "My church community at St. Clement has been a source of genuine support since my diagnosis.",
45
+ "I prefer one-on-one conversations. I find group settings harder to follow now.",
46
+ "I joined an ALS support group that meets virtually. It helps more than I expected."
47
+ ]
48
+ }
49
+ }
data/memories/mia_chen.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "profile": {
3
+ "name": "Mia Chen",
4
+ "age": 28,
5
+ "condition": "cerebral palsy",
6
+ "communication_style": "witty, dry humour, short punchy sentences, uses sarcasm",
7
+ "access_method": "webcam head-tracking",
8
+ "languages": [
9
+ "English"
10
+ ]
11
+ },
12
+ "memory_buckets": {
13
+ "family": [
14
+ "My mom calls every Sunday and always asks if I've eaten. I love it but won't admit it.",
15
+ "My brother Ravi helped me set up this AAC system. He's at Cornell doing CS.",
16
+ "We do a family movie night every Diwali β€” always an 80s Bollywood film nobody likes except Dad.",
17
+ "My parents moved from Chengdu before I was born. We still make dumplings on Chinese New Year.",
18
+ "My sister Lena is three years younger and somehow already more responsible than me."
19
+ ],
20
+ "medical": [
21
+ "I have a PT session every Tuesday at 2pm with Dr. Sandra Hollis.",
22
+ "I use a power wheelchair. The joystick is on my left side.",
23
+ "I'm allergic to penicillin. I have to mention this at every hospital visit.",
24
+ "My spasticity is worse in cold weather. Winter in Chicago is not my friend.",
25
+ "I use baclofen for muscle tone. It makes me sleepy if I take it too early."
26
+ ],
27
+ "hobbies": [
28
+ "I follow competitive Smash Bros. I could beat most people if my hands worked differently.",
29
+ "I've been watching every Studio Ghibli film in order. Currently on Porco Rosso.",
30
+ "I collect vintage sci-fi paperbacks. Asimov and Le Guin mostly.",
31
+ "I got really into chess puzzles during lockdown. Still do them before bed.",
32
+ "I enjoy critiquing bad movie sequels. It's practically a hobby at this point."
33
+ ],
34
+ "daily_routine": [
35
+ "Mornings are slow. I need about 45 minutes before I feel like a person.",
36
+ "I order from the same Thai place every Friday. Green curry, always.",
37
+ "I keep a voice memo journal since typing long things is tiring.",
38
+ "I usually watch one episode of something after dinner to decompress.",
39
+ "My caregiver Marcus arrives at 8am on weekdays. He makes decent coffee."
40
+ ],
41
+ "social": [
42
+ "My best friend Priya visits on weekends. She narrates everything like a nature documentary.",
43
+ "I'm part of an online disability advocacy group. We meet on Zoom every other Wednesday.",
44
+ "I don't love big parties. Small dinners with three or four people are my ideal.",
45
+ "My neighbour Tom always stops to chat when I'm outside. He's retired and lonely, I think.",
46
+ "I met most of my close friends through a gaming Discord server."
47
+ ]
48
+ }
49
+ }
data/users.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "users": [
3
+ {
4
+ "id": "mia_chen",
5
+ "name": "Mia Chen",
6
+ "condition": "cerebral palsy",
7
+ "style": "witty, dry humour, short punchy sentences, uses sarcasm",
8
+ "file": "memories/mia_chen.json"
9
+ },
10
+ {
11
+ "id": "gerald_okafor",
12
+ "name": "Gerald Okafor",
13
+ "condition": "ALS (early-to-mid stage)",
14
+ "style": "formal, measured, eloquent, longer structured sentences",
15
+ "file": "memories/gerald_okafor.json"
16
+ },
17
+ {
18
+ "id": "arjun_mehta",
19
+ "name": "Arjun Mehta",
20
+ "condition": "autism spectrum disorder (non-verbal)",
21
+ "style": "direct, topic-specific, narrow vocabulary, code-switches Hindi/English, routine-focused",
22
+ "file": "memories/arjun_mehta.json"
23
+ }
24
+ ]
25
+ }
generation/__init__.py ADDED
File without changes
generation/llm_client.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-tier LLM client (proposal Β§5.6).
3
+
4
+ All three tiers expose the same OpenAI-compatible API, so only the
5
+ base_url + model name change β€” no code-path differences downstream.
6
+
7
+ Tier 1 β€” primary: Qwen3-30B-A3B via vLLM on GCP (A100 / T4)
8
+ Tier 2 β€” fallback: Qwen3-8B via vLLM on same server (latency > 3.5 s)
9
+ Tier 3 β€” local: Qwen3-8B via Ollama on MacBook M2 (dev / offline)
10
+
11
+ Active tier is controlled by settings.active_llm_tier or the `tier`
12
+ argument passed explicitly by the planner node.
13
+
14
+ Thinking mode is controlled by settings.thinking_mode:
15
+ "off" β€” prepend /no_think (Ollama) or chat_template_kwargs (vLLM)
16
+ "strip" β€” let the model think, but strip <think>…</think> from output
17
+ "full" β€” return everything including <think> blocks
18
+ """
19
+ from __future__ import annotations
20
+
21
+ import re
22
+ from functools import lru_cache
23
+ from typing import Any
24
+
25
+ from openai import OpenAI
26
+
27
+ from config.settings import settings
28
+
29
+
30
+ @lru_cache(maxsize=3)
31
+ def _build_client(base_url: str, api_key: str) -> OpenAI:
32
+ """One cached OpenAI client per (base_url, api_key) pair."""
33
+ return OpenAI(base_url=base_url, api_key=api_key)
34
+
35
+
36
+ def get_client(tier: str | None = None) -> OpenAI:
37
+ """
38
+ Return the OpenAI-compatible client for the requested tier.
39
+
40
+ Args:
41
+ tier: "primary" | "fallback" | "local" | None (uses settings.active_llm_tier)
42
+ """
43
+ resolved = tier or settings.active_llm_tier
44
+
45
+ if resolved == "primary":
46
+ return _build_client(settings.primary_base_url, settings.primary_api_key)
47
+ if resolved == "fallback":
48
+ return _build_client(settings.fallback_base_url, settings.primary_api_key)
49
+ # local / default
50
+ return _build_client(settings.local_base_url, settings.local_api_key)
51
+
52
+
53
+ def active_model(tier: str | None = None) -> str:
54
+ """Return the model name string for the given tier."""
55
+ resolved = tier or settings.active_llm_tier
56
+ return {
57
+ "primary": settings.primary_model,
58
+ "fallback": settings.fallback_model,
59
+ "local": settings.local_model,
60
+ }[resolved]
61
+
62
+
63
+ def _apply_no_think(messages: list[dict]) -> list[dict]:
64
+ """
65
+ Prepend /no_think to the first user message.
66
+ This is the Ollama-compatible way to suppress thinking mode.
67
+ """
68
+ result = list(messages)
69
+ for i, msg in enumerate(result):
70
+ if msg.get("role") == "user":
71
+ result[i] = {**msg, "content": f"/no_think\n\n{msg['content']}"}
72
+ break
73
+ return result
74
+
75
+
76
+ def _strip_think_tags(text: str) -> str:
77
+ """Remove <think>…</think> blocks from model output."""
78
+ return re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
79
+
80
+
81
+ def chat_complete(
82
+ messages: list[dict],
83
+ max_tokens: int,
84
+ tier: str | None = None,
85
+ temperature: float = 0.7,
86
+ **kwargs: Any,
87
+ ) -> str:
88
+ """
89
+ Model-agnostic chat completion. Returns the response text directly.
90
+
91
+ Thinking mode behaviour is controlled entirely by settings.thinking_mode:
92
+ "off" β€” suppress thinking via /no_think (Ollama) or extra_body (vLLM)
93
+ "strip" β€” allow thinking but remove <think> tags from the response
94
+ "full" β€” return the raw response including any <think> blocks
95
+
96
+ In local dev mode (active_llm_tier="local"), all tier requests are
97
+ redirected to Ollama β€” there is no separate fallback server locally.
98
+ """
99
+ resolved_tier = tier or settings.active_llm_tier
100
+
101
+ # Local dev: no GCP server available β€” collapse all tiers to Ollama
102
+ if settings.active_llm_tier == "local":
103
+ resolved_tier = "local"
104
+ model = active_model(resolved_tier)
105
+ client = get_client(resolved_tier)
106
+
107
+ patched_messages = messages
108
+ extra_body: dict[str, Any] = kwargs.pop("extra_body", {})
109
+
110
+ # "suppress" = actively inject /no_think or vLLM flag for models
111
+ # like Qwen3 that think by default and need explicit suppression.
112
+ if settings.thinking_mode == "suppress":
113
+ if resolved_tier == "local":
114
+ patched_messages = _apply_no_think(messages)
115
+ else:
116
+ extra_body = {**extra_body, "chat_template_kwargs": {"enable_thinking": False}}
117
+
118
+ # When thinking is enabled (strip/full), add budget so the model
119
+ # has room to reason without truncating the actual answer.
120
+ effective_max_tokens = max_tokens
121
+ if settings.thinking_mode in ("strip", "full"):
122
+ effective_max_tokens = max_tokens + settings.thinking_token_budget
123
+
124
+ resp = client.chat.completions.create(
125
+ model=model,
126
+ messages=patched_messages,
127
+ max_tokens=effective_max_tokens,
128
+ temperature=temperature,
129
+ extra_body=extra_body or None,
130
+ **kwargs,
131
+ )
132
+ raw = resp.choices[0].message.content or ""
133
+
134
+ if settings.thinking_mode in ("off", "strip"):
135
+ raw = _strip_think_tags(raw)
136
+
137
+ return raw.strip()
138
+
139
+
140
+ def warmup(tier: str | None = None) -> None:
141
+ """Send a minimal prompt to pre-load the model and warm KV cache."""
142
+ chat_complete(
143
+ messages=[{"role": "user", "content": "hi"}],
144
+ max_tokens=5,
145
+ tier=tier,
146
+ temperature=0.0,
147
+ )
guardrails/__init__.py ADDED
File without changes
guardrails/checks.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Input and output safety guardrails.
3
+
4
+ check_input β€” runs BEFORE retrieval (blocks out-of-scope requests)
5
+ check_output β€” runs AFTER generation (catches persona breaks / hallucinations)
6
+
7
+ Both return a result dict so the caller decides how to handle failures
8
+ rather than raising exceptions inside pipeline nodes.
9
+ """
10
+ from __future__ import annotations
11
+
12
+ # ── Signal lists ───────────────────────────────────────────────────────────────
13
+
14
+ PERSONA_BREAK_SIGNALS = [
15
+ "as an ai",
16
+ "i'm an ai",
17
+ "i am an ai",
18
+ "as a language model",
19
+ "i don't have personal",
20
+ "i cannot have",
21
+ "i'm not able to",
22
+ "as your assistant",
23
+ "i was trained",
24
+ "my training data",
25
+ ]
26
+
27
+ OUT_OF_SCOPE_SIGNALS = [
28
+ "write a poem",
29
+ "write me a story",
30
+ "solve this math",
31
+ "translate this",
32
+ "summarize this article",
33
+ "what's the weather",
34
+ "who won the game",
35
+ "stock price",
36
+ "breaking news",
37
+ ]
38
+
39
+ SAFE_FALLBACK = "I don't know."
40
+ OOS_FALLBACK = "I'm here to help communicate as this person β€” that's a bit outside what I do."
41
+
42
+
43
+ # ── Public API ─────────────────────────────────────────────────────────────────
44
+
45
+ def check_input(query: str) -> dict:
46
+ """
47
+ Validate the partner's query before retrieval.
48
+
49
+ Returns:
50
+ {"allowed": bool, "reason": str | None, "fallback": str | None}
51
+ """
52
+ q = query.lower().strip()
53
+
54
+ if any(s in q for s in OUT_OF_SCOPE_SIGNALS):
55
+ return {"allowed": False, "reason": "out_of_scope", "fallback": OOS_FALLBACK}
56
+
57
+ if len(q) < 2:
58
+ return {"allowed": False, "reason": "empty_query", "fallback": "Could you repeat that?"}
59
+
60
+ return {"allowed": True, "reason": None, "fallback": None}
61
+
62
+
63
+ def check_output(response: str, memories: list[dict]) -> dict:
64
+ """
65
+ Validate the generated response after generation.
66
+
67
+ Checks:
68
+ 1. Persona break β€” did the model say "as an AI …"?
69
+ 2. Basic hallucination signal β€” response claims facts not in memories.
70
+
71
+ Returns:
72
+ {"passed": bool, "issue": str | None, "fallback": str | None}
73
+ """
74
+ r = response.lower()
75
+
76
+ if any(signal in r for signal in PERSONA_BREAK_SIGNALS):
77
+ return {"passed": False, "issue": "persona_break", "fallback": SAFE_FALLBACK}
78
+
79
+ # Light hallucination check: if the model asserts specific numbers or
80
+ # proper nouns that don't appear anywhere in the retrieved memories, flag it.
81
+ # (Full NLI-based check is handled in the evaluation pipeline, not here.)
82
+ if not memories and _makes_factual_claim(response):
83
+ return {"passed": False, "issue": "unsupported_claim", "fallback": SAFE_FALLBACK}
84
+
85
+ return {"passed": True, "issue": None, "fallback": None}
86
+
87
+
88
+ # ── Helpers ───────────────────────────────────────────────────────────────────
89
+
90
+ _FACTUAL_MARKERS = [
91
+ " is ", " was ", " has ", " have ", " lives in ",
92
+ " born in ", " works at ", " studied at ",
93
+ ]
94
+
95
+ def _makes_factual_claim(text: str) -> bool:
96
+ """Heuristic: does the text assert a specific fact?"""
97
+ t = text.lower()
98
+ return any(marker in t for marker in _FACTUAL_MARKERS)
main.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CLI entry point β€” thin wrapper around the LangGraph pipeline.
3
+
4
+ Usage:
5
+ python main.py # interactive chat, local LLM tier
6
+ python main.py --user mia_chen # skip persona selection prompt
7
+ python main.py --debug # print per-turn latency table
8
+ python main.py --fast # skip LLM intent call (keyword routing),
9
+ # cuts turn time from ~2min β†’ ~45s on M2 Mac
10
+ python main.py --tier primary # override LLM tier
11
+
12
+ For the full UI, run the FastAPI + Streamlit stack instead:
13
+ uvicorn api.main:app --reload
14
+ streamlit run ui/app.py
15
+ """
16
+ from __future__ import annotations
17
+
18
+ import argparse
19
+ import json
20
+ import os
21
+ import sys
22
+ import time
23
+
24
+ from config.settings import settings
25
+ from guardrails.checks import check_input
26
+ from pipeline.graph import aac_graph
27
+ from pipeline.state import PipelineState, GenerationConfig
28
+ from retrieval.bucket_priors import uniform_priors
29
+ from retrieval.vector_store import _get_embedder, _get_reranker
30
+
31
+
32
+ def parse_args() -> argparse.Namespace:
33
+ p = argparse.ArgumentParser(description="AAC Chatbot CLI")
34
+ p.add_argument("--user", type=str, default=None, help="Persona user_id")
35
+ p.add_argument("--debug", action="store_true", help="Print latency table each turn")
36
+ p.add_argument("--fast", action="store_true",
37
+ help="Skip LLM intent call β€” use keyword routing instead (faster local dev)")
38
+ p.add_argument("--tier", type=str, default=None,
39
+ choices=["primary", "fallback", "local"],
40
+ help="Override LLM tier (default: settings.active_llm_tier)")
41
+ return p.parse_args()
42
+
43
+
44
+ # ── Fast keyword-based intent routing (bypasses the slow LLM intent call) ──────
45
+
46
+ def _keyword_intent(query: str) -> tuple[dict, GenerationConfig]:
47
+ """Replicate milestone-1 keyword routing as a fast local-dev shortcut."""
48
+ q = query.lower()
49
+ bucket: str | None = None
50
+
51
+ if any(w in q for w in ["medication", "medicine", "doctor", "health", "allergic", "therapy"]):
52
+ bucket = "medical"
53
+ elif any(w in q for w in ["family", "mom", "dad", "brother", "sister", "parents"]):
54
+ bucket = "family"
55
+ elif any(w in q for w in ["hobby", "like to do", "enjoy", "weekend", "fun"]):
56
+ bucket = "hobbies"
57
+ elif any(w in q for w in ["routine", "morning", "wake", "sleep", "daily"]):
58
+ bucket = "daily_routine"
59
+ elif any(w in q for w in ["friend", "social", "people", "party", "community"]):
60
+ bucket = "social"
61
+
62
+ intent_type = "CONTEXTUAL" if any(w in q for w in ["you just said", "earlier", "you mentioned"]) else "PERSONAL"
63
+
64
+ route = {
65
+ "sub_intents": [{"type": intent_type, "query": query, "bucket_hint": bucket, "priority": "normal"}],
66
+ "style_constraints": {"tone_tag": "[TONE:DEFAULT]", "max_tokens": 100,
67
+ "retrieval_mode": "full", "persona_mod": "baseline"},
68
+ "affect": "NEUTRAL",
69
+ }
70
+ gen_config: GenerationConfig = {
71
+ "max_tokens": settings.max_tokens_neutral,
72
+ "tone_tag": "[TONE:DEFAULT]",
73
+ "retrieval_mode": "full",
74
+ "persona_mod": "baseline",
75
+ }
76
+ return route, gen_config
77
+
78
+
79
+ def load_users() -> dict[str, dict]:
80
+ with open(settings.users_json) as f:
81
+ return {u["id"]: u for u in json.load(f)["users"]}
82
+
83
+
84
+ def select_user(users: dict[str, dict], user_arg: str | None) -> str:
85
+ if user_arg:
86
+ if user_arg not in users:
87
+ print(f"Unknown user '{user_arg}'. Available: {list(users)}")
88
+ sys.exit(1)
89
+ return user_arg
90
+ print("\nAvailable personas:")
91
+ for uid, u in users.items():
92
+ print(f" {uid:20s} β€” {u['name']} ({u['condition']})")
93
+ uid = input("\nSelect user id: ").strip()
94
+ if uid not in users:
95
+ print(f"Invalid id.")
96
+ sys.exit(1)
97
+ return uid
98
+
99
+
100
+ def print_latency(log: dict, turn: int) -> None:
101
+ fields = ["t_sensing", "t_intent", "t_retrieval", "t_generation", "t_total"]
102
+ labels = ["sensing", "intent", "retrieval", "generation", "TOTAL"]
103
+ vals = [f"{log.get(f, 0):.3f}s" for f in fields]
104
+ widths = [max(len(l), len(v)) for l, v in zip(labels, vals)]
105
+ sep = " | "
106
+ print(f"\n[turn {turn} latency]")
107
+ print(sep.join(l.ljust(w) for l, w in zip(labels, widths)))
108
+ print(sep.join(v.ljust(w) for v, w in zip(vals, widths)))
109
+
110
+
111
+ def main() -> None:
112
+ args = parse_args()
113
+
114
+ # Optionally override the LLM tier at runtime
115
+ if args.tier:
116
+ os.environ["ACTIVE_LLM_TIER"] = args.tier
117
+ settings.active_llm_tier = args.tier
118
+
119
+ users = load_users()
120
+ user_id = select_user(users, args.user)
121
+ profile = users[user_id]
122
+
123
+ # Warm up models
124
+ print(f"\nLoading models for {profile['name']} …", end=" ", flush=True)
125
+ _get_embedder()
126
+ _get_reranker()
127
+ print("ready.\n")
128
+
129
+ session_history: list[dict] = []
130
+ bucket_priors = uniform_priors()
131
+ turn_id = 0
132
+
133
+ print(f"Chatting as {profile['name']}. Type 'quit' to exit.\n")
134
+
135
+ while True:
136
+ try:
137
+ query = input("Partner: ").strip()
138
+ except (EOFError, KeyboardInterrupt):
139
+ print("\nBye.")
140
+ break
141
+
142
+ if query.lower() in {"quit", "exit", "q"}:
143
+ break
144
+ if not query:
145
+ continue
146
+
147
+ guard = check_input(query)
148
+ if not guard["allowed"]:
149
+ print(f"AAC Bot: {guard['fallback']}\n")
150
+ continue
151
+
152
+ turn_id += 1
153
+
154
+ # --fast: resolve intent via keywords, skip the slow LLM intent node
155
+ pre_route, pre_gen_config = (
156
+ _keyword_intent(query) if args.fast else (None, None)
157
+ )
158
+ t_intent_fast = 0.0
159
+ if args.fast:
160
+ t0 = time.perf_counter()
161
+ _keyword_intent(query) # just for timing reference
162
+ t_intent_fast = time.perf_counter() - t0
163
+
164
+ state = PipelineState(
165
+ user_id=user_id,
166
+ persona_profile=profile,
167
+ session_history=session_history,
168
+ turn_id=turn_id,
169
+ affect=None,
170
+ gesture_tag=None,
171
+ gaze_bucket=None,
172
+ air_written_text=None,
173
+ raw_query=query,
174
+ intent_route=pre_route, # pre-filled β†’ intent node sees it and skips LLM call
175
+ generation_config=pre_gen_config,
176
+ retrieved_chunks=[],
177
+ bucket_priors=bucket_priors,
178
+ retrieval_mode_used="",
179
+ augmented_prompt=None,
180
+ candidates=[],
181
+ selected_response=None,
182
+ llm_tier_used="",
183
+ latency_log={"t_sensing": 0.0, "t_intent": round(t_intent_fast, 4),
184
+ "t_retrieval": 0.0, "t_generation": 0.0, "t_total": 0.0},
185
+ mlflow_run_id=None,
186
+ guardrail_passed=True,
187
+ )
188
+
189
+ result: PipelineState = aac_graph.invoke(state)
190
+
191
+ print(f"AAC Bot: {result['selected_response']}\n")
192
+
193
+ session_history = result["session_history"]
194
+ bucket_priors = result["bucket_priors"]
195
+
196
+ if args.debug:
197
+ print_latency(result.get("latency_log") or {}, turn_id)
198
+ print(f" tier={result.get('llm_tier_used')} | "
199
+ f"retrieval={result.get('retrieval_mode_used')} | "
200
+ f"affect={(result.get('affect') or {}).get('emotion','?')}\n")
201
+
202
+
203
+ if __name__ == "__main__":
204
+ main()
pipeline/__init__.py ADDED
File without changes
pipeline/graph.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LangGraph stateful directed graph β€” the five-layer AAC pipeline.
3
+
4
+ Topology (see proposal Figure 2):
5
+
6
+ intent ──► [affect check] ──► fast_retrieval ──► [latency check] ──► fallback_gen ──► feedback
7
+ └──► full_retrieval ──► [latency check] ──► primary_gen ──► feedback
8
+ """
9
+ from langgraph.graph import StateGraph, END
10
+
11
+ from pipeline.state import PipelineState
12
+ from pipeline.nodes import intent, retrieval, planner, feedback
13
+
14
+
15
+ def _route_by_affect(state: PipelineState) -> str:
16
+ """Conditional edge: FRUSTRATED β†’ fast path, otherwise full retrieval."""
17
+ emotion = (state.get("affect") or {}).get("emotion", "NEUTRAL")
18
+ return "fast" if emotion == "FRUSTRATED" else "full"
19
+
20
+
21
+ def _route_by_latency(state: PipelineState) -> str:
22
+ """Conditional edge: if cumulative latency > threshold, use fallback LLM."""
23
+ from config.settings import settings
24
+ log = state.get("latency_log") or {}
25
+ elapsed = log.get("t_intent", 0.0) + log.get("t_retrieval", 0.0)
26
+ return "fallback" if elapsed > settings.fallback_latency_threshold else "primary"
27
+
28
+
29
+ def build_graph() -> StateGraph:
30
+ graph = StateGraph(PipelineState)
31
+
32
+ # ── Nodes ──────────────────────────────────────────────────────────────────
33
+ graph.add_node("intent", intent.run)
34
+ graph.add_node("fast_retrieval", retrieval.run_fast)
35
+ graph.add_node("full_retrieval", retrieval.run_full)
36
+ graph.add_node("primary_gen", planner.run_primary)
37
+ graph.add_node("fallback_gen", planner.run_fallback)
38
+ graph.add_node("feedback", feedback.run)
39
+
40
+ # ── Entry ──────────────────────────────────────────────────────────────────
41
+ graph.set_entry_point("intent")
42
+
43
+ # ── Affect-aware routing after intent ─────────────────────────────────────
44
+ graph.add_conditional_edges(
45
+ "intent",
46
+ _route_by_affect,
47
+ {"fast": "fast_retrieval", "full": "full_retrieval"},
48
+ )
49
+
50
+ # ── Latency-aware routing after retrieval ─────────────────────────────────
51
+ graph.add_conditional_edges(
52
+ "fast_retrieval",
53
+ _route_by_latency,
54
+ {"primary": "primary_gen", "fallback": "fallback_gen"},
55
+ )
56
+ graph.add_conditional_edges(
57
+ "full_retrieval",
58
+ _route_by_latency,
59
+ {"primary": "primary_gen", "fallback": "fallback_gen"},
60
+ )
61
+
62
+ # ── Feedback loop ─────────────────────────────────────────────────────────
63
+ graph.add_edge("primary_gen", "feedback")
64
+ graph.add_edge("fallback_gen", "feedback")
65
+ graph.add_edge("feedback", END)
66
+
67
+ return graph.compile()
68
+
69
+
70
+ # Module-level compiled graph β€” import this everywhere
71
+ aac_graph = build_graph()
pipeline/nodes/__init__.py ADDED
File without changes
pipeline/nodes/feedback.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ L5 β€” Feedback Loop node.
3
+
4
+ After a response is accepted:
5
+ 1. Log the full turn to MLflow (latency, metrics, prompt version, tier used)
6
+ 2. Update session-level Bayesian bucket priors
7
+ 3. Append the accepted turn to session history
8
+
9
+ Rejected candidates are also logged for offline analysis.
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import json
14
+ import time
15
+
16
+ import mlflow
17
+
18
+ from config.settings import settings
19
+ from pipeline.state import PipelineState
20
+ from retrieval.bucket_priors import update_priors
21
+
22
+
23
+ def run(state: PipelineState) -> dict:
24
+ t0 = time.perf_counter()
25
+
26
+ mlflow_run_id = _log_to_mlflow(state)
27
+ updated_priors = _update_bucket_priors(state)
28
+ updated_history = _append_turn_to_history(state)
29
+
30
+ return {
31
+ "bucket_priors": updated_priors,
32
+ "session_history": updated_history,
33
+ "mlflow_run_id": mlflow_run_id,
34
+ }
35
+
36
+
37
+ # ── MLflow logging ─────────────────────────────────────────────────────────────
38
+
39
+ def _log_to_mlflow(state: PipelineState) -> str:
40
+ mlflow.set_tracking_uri(settings.mlflow_tracking_uri)
41
+ mlflow.set_experiment(settings.mlflow_experiment)
42
+
43
+ latency = state.get("latency_log") or {}
44
+ affect = (state.get("affect") or {}).get("emotion", "UNKNOWN")
45
+
46
+ with mlflow.start_run(run_name=f"turn-{state['turn_id']}") as run:
47
+ mlflow.log_params({
48
+ "user_id": state["user_id"],
49
+ "turn_id": state["turn_id"],
50
+ "llm_tier": state.get("llm_tier_used", "unknown"),
51
+ "retrieval_mode": state.get("retrieval_mode_used", "unknown"),
52
+ "affect": affect,
53
+ "guardrail_passed": state.get("guardrail_passed", True),
54
+ })
55
+ mlflow.log_metrics({
56
+ "t_sensing": latency.get("t_sensing", 0.0),
57
+ "t_intent": latency.get("t_intent", 0.0),
58
+ "t_retrieval": latency.get("t_retrieval", 0.0),
59
+ "t_generation": latency.get("t_generation", 0.0),
60
+ "t_total": latency.get("t_total", 0.0),
61
+ "num_chunks": float(len(state.get("retrieved_chunks") or [])),
62
+ })
63
+
64
+ # Log the selected response as artifact text for qualitative review
65
+ mlflow.log_text(
66
+ state.get("selected_response") or "",
67
+ f"responses/turn_{state['turn_id']}.txt",
68
+ )
69
+
70
+ return run.info.run_id
71
+
72
+
73
+ # ── Bayesian bucket prior update ───────────────────────────────────────────────
74
+
75
+ def _update_bucket_priors(state: PipelineState) -> dict[str, float]:
76
+ chunks = state.get("retrieved_chunks") or []
77
+ if not chunks:
78
+ return state.get("bucket_priors") or {}
79
+
80
+ # Which bucket sourced the accepted response?
81
+ top_bucket = chunks[0].get("bucket")
82
+ if not top_bucket:
83
+ return state.get("bucket_priors") or {}
84
+
85
+ return update_priors(
86
+ priors=state.get("bucket_priors") or {},
87
+ accepted_bucket=top_bucket,
88
+ )
89
+
90
+
91
+ # ── Session history append ─────────────────────────────────────────────────────
92
+
93
+ def _append_turn_to_history(state: PipelineState) -> list[dict]:
94
+ """Returns a single-element list; LangGraph's Annotated[list, add] merges it."""
95
+ return [
96
+ {"role": "partner", "content": state["raw_query"]},
97
+ {"role": "aac_user", "content": state.get("selected_response") or ""},
98
+ ]
pipeline/nodes/intent.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ L2 β€” Agentic Intent Decomposition node.
3
+
4
+ Receives the partner query + affect state, calls the controller LLM once
5
+ (non-thinking mode, ReAct style), and returns a Pydantic-validated
6
+ IntentRoute that drives all downstream routing decisions.
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import re
11
+ import time
12
+ from typing import Literal, Optional
13
+
14
+ from pydantic import BaseModel
15
+ from config.settings import settings
16
+ from generation.llm_client import chat_complete
17
+ from pipeline.state import PipelineState, GenerationConfig, IntentRoute
18
+
19
+ # ── Pydantic output schemas ────────────────────────────────────────────────────
20
+
21
+ BucketType = Literal["family", "medical", "hobbies", "daily_routine", "social"]
22
+ AffectEmotion = Literal["HAPPY", "FRUSTRATED", "NEUTRAL", "SURPRISED"]
23
+
24
+
25
+ class SubIntentSchema(BaseModel):
26
+ type: Literal["PERSONAL", "CONTEXTUAL", "OPEN_DOMAIN"]
27
+ query: str
28
+ bucket_hint: Optional[BucketType] = None
29
+ priority: Literal["fast", "normal"] = "normal"
30
+
31
+
32
+ class StyleConfig(BaseModel):
33
+ tone_tag: str # e.g. "[TONE:WITTY_SARCASTIC]"
34
+ max_tokens: int
35
+ retrieval_mode: str # "fast" | "full"
36
+ persona_mod: str # "amplify_quirks" | "suppress_humor" | "baseline" | "add_confirmation"
37
+
38
+
39
+ class IntentRouteSchema(BaseModel):
40
+ sub_intents: list[SubIntentSchema]
41
+ style_constraints: StyleConfig
42
+ affect: AffectEmotion
43
+
44
+
45
+ # ── Affect β†’ generation config mapping (proposal Table 1) ─────────────────────
46
+
47
+ _AFFECT_CONFIG: dict[str, GenerationConfig] = {
48
+ "HAPPY": {
49
+ "max_tokens": settings.max_tokens_happy,
50
+ "tone_tag": "[TONE:WARM]",
51
+ "retrieval_mode": "full",
52
+ "persona_mod": "amplify_quirks",
53
+ },
54
+ "FRUSTRATED": {
55
+ "max_tokens": settings.max_tokens_frustrated,
56
+ "tone_tag": "[TONE:DIRECT_EMPATHETIC]",
57
+ "retrieval_mode": "fast",
58
+ "persona_mod": "suppress_humor",
59
+ },
60
+ "NEUTRAL": {
61
+ "max_tokens": settings.max_tokens_neutral,
62
+ "tone_tag": "[TONE:DEFAULT]",
63
+ "retrieval_mode": "full",
64
+ "persona_mod": "baseline",
65
+ },
66
+ "SURPRISED": {
67
+ "max_tokens": settings.max_tokens_surprised,
68
+ "tone_tag": "[TONE:CLARIFYING]",
69
+ "retrieval_mode": "full",
70
+ "persona_mod": "add_confirmation",
71
+ },
72
+ }
73
+
74
+ # ── System prompt ──────────────────────────────────────────────────────────────
75
+
76
+ _SYSTEM_PROMPT = """\
77
+ You are the intent decomposition controller for an AAC (Augmentative and \
78
+ Alternative Communication) chatbot. Given a partner's query and the AAC \
79
+ user's current affect state, classify each intent and produce routing \
80
+ instructions in the required JSON format.
81
+
82
+ Intent types:
83
+ - PERSONAL: requires autobiographical memory retrieval
84
+ - CONTEXTUAL: answerable from session history
85
+ - OPEN_DOMAIN: answerable from general knowledge (no retrieval needed)
86
+
87
+ Bucket hints (only for PERSONAL): family | medical | hobbies | daily_routine | social
88
+ Priority: set "fast" when affect is FRUSTRATED to reduce latency.
89
+
90
+ Respond ONLY with valid JSON matching the IntentRoute schema. No extra text.
91
+ """
92
+
93
+
94
+ def _build_user_prompt(query: str, affect: str, persona_name: str) -> str:
95
+ return (
96
+ f"Persona: {persona_name}\n"
97
+ f"Affect: {affect}\n"
98
+ f"Partner query: {query}\n\n"
99
+ "Produce the IntentRoute JSON:"
100
+ )
101
+
102
+
103
+ # ── Node entry point ───────────────────────────────────────────────────────────
104
+
105
+ def run(state: PipelineState) -> dict:
106
+ """LangGraph node: intent decomposition."""
107
+ t0 = time.perf_counter()
108
+
109
+ # --fast mode: intent_route already resolved by keyword routing in main.py
110
+ if state.get("intent_route") and state.get("generation_config"):
111
+ return {} # nothing to update β€” downstream nodes use the pre-filled values
112
+
113
+ affect_state = state.get("affect") or {}
114
+ emotion: str = affect_state.get("emotion", "NEUTRAL")
115
+ query: str = state["raw_query"]
116
+ persona_name: str = state["persona_profile"].get("name", "unknown")
117
+
118
+ gen_config = _AFFECT_CONFIG.get(emotion, _AFFECT_CONFIG["NEUTRAL"])
119
+
120
+ route: IntentRoute | None = None
121
+ last_error: str = ""
122
+
123
+ for attempt in range(3): # LangGraph retry logic (up to 2 retries)
124
+ messages = [
125
+ {"role": "system", "content": _SYSTEM_PROMPT},
126
+ {"role": "user", "content": _build_user_prompt(query, emotion, persona_name)},
127
+ ]
128
+ if attempt > 0:
129
+ messages.append({"role": "user", "content": f"Validation error: {last_error}. Fix and retry."})
130
+
131
+ raw = chat_complete(
132
+ messages=messages,
133
+ max_tokens=512,
134
+ temperature=0.0,
135
+ )
136
+
137
+ try:
138
+ # Strip markdown fences (```json ... ```) that many models add
139
+ cleaned = re.sub(r"^```(?:json)?\s*", "", raw.strip())
140
+ cleaned = re.sub(r"\s*```$", "", cleaned.strip())
141
+ parsed = IntentRouteSchema.model_validate_json(cleaned)
142
+ route = {
143
+ "sub_intents": [si.model_dump() for si in parsed.sub_intents],
144
+ "style_constraints": parsed.style_constraints.model_dump(),
145
+ "affect": parsed.affect,
146
+ }
147
+ break
148
+ except Exception as exc:
149
+ last_error = str(exc)
150
+
151
+ if route is None:
152
+ # Hard fallback: treat as a single PERSONAL intent, full retrieval
153
+ route = {
154
+ "sub_intents": [{"type": "PERSONAL", "query": query, "bucket_hint": None, "priority": "normal"}],
155
+ "style_constraints": gen_config,
156
+ "affect": emotion,
157
+ }
158
+
159
+ t_intent = time.perf_counter() - t0
160
+
161
+ latency_log = dict(state.get("latency_log") or {})
162
+ latency_log["t_intent"] = round(t_intent, 4)
163
+
164
+ return {
165
+ "intent_route": route,
166
+ "generation_config": gen_config,
167
+ "latency_log": latency_log,
168
+ }
169
+
170
+
pipeline/nodes/planner.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ L4 β€” Dialogue Planning & Generation node.
3
+
4
+ Expression-conditioned response shaping (proposal Β§5.5):
5
+ 1. Build augmented prompt (persona profile + retrieved evidence + affect config + style exemplar)
6
+ 2. Generate N candidate responses
7
+ 3. Rank candidates by composite score: Ξ±Β·faithful + Ξ²Β·style + Ξ³Β·affect_match
8
+ 4. Return the top-ranked response
9
+
10
+ Two entry points:
11
+ run_primary β€” Qwen3-30B-A3B (or configured primary tier)
12
+ run_fallback β€” Qwen3-8B (faster, triggered by latency threshold)
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import time
17
+
18
+ from config.settings import settings
19
+ from generation.llm_client import chat_complete
20
+ from guardrails.checks import check_output
21
+ from pipeline.state import PipelineState
22
+
23
+ # ── Persona-specific tone tags (applied on top of affect base tag) ─────────────
24
+
25
+ _PERSONA_TONE_OVERRIDES: dict[str, dict[str, str]] = {
26
+ "mia_chen": {
27
+ "HAPPY": "[TONE:WITTY_SARCASTIC]",
28
+ "FRUSTRATED": "[TONE:DIRECT_EMPATHETIC]",
29
+ },
30
+ "gerald_okafor": {
31
+ "HAPPY": "[TONE:WARM_FORMAL]",
32
+ "FRUSTRATED": "[TONE:MEASURED_EMPATHETIC]",
33
+ },
34
+ "arjun_mehta": {
35
+ "HAPPY": "[TONE:DIRECT_WARM]",
36
+ "FRUSTRATED": "[TONE:MINIMAL_DIRECT]",
37
+ },
38
+ }
39
+
40
+
41
+ def run_primary(state: PipelineState) -> dict:
42
+ return _run(state, tier="primary")
43
+
44
+
45
+ def run_fallback(state: PipelineState) -> dict:
46
+ return _run(state, tier="fallback")
47
+
48
+
49
+ def route_by_latency(state: PipelineState) -> str:
50
+ """Conditional edge after retrieval nodes."""
51
+ log = state.get("latency_log") or {}
52
+ elapsed = log.get("t_intent", 0.0) + log.get("t_retrieval", 0.0)
53
+ return "fallback" if elapsed > settings.fallback_latency_threshold else "primary"
54
+
55
+
56
+ # ── Core implementation ────────────────────────────────────────────────────────
57
+
58
+ def _run(state: PipelineState, tier: str) -> dict:
59
+ t0 = time.perf_counter()
60
+
61
+ profile = state["persona_profile"]
62
+ user_id = state["user_id"]
63
+ affect = (state.get("affect") or {}).get("emotion", "NEUTRAL")
64
+ gen_cfg = state.get("generation_config") or {}
65
+ chunks = state.get("retrieved_chunks") or []
66
+ history = (state.get("session_history") or [])[-3:] # last 3 turns only
67
+
68
+ tone_tag = _resolve_tone_tag(user_id, affect, gen_cfg.get("tone_tag", "[TONE:DEFAULT]"))
69
+ prompt = _build_prompt(profile, chunks, history, state["raw_query"], tone_tag, gen_cfg)
70
+
71
+ candidates: list[str] = []
72
+ for _ in range(settings.num_candidates):
73
+ text = chat_complete(
74
+ messages=[{"role": "user", "content": prompt}],
75
+ max_tokens=gen_cfg.get("max_tokens", settings.max_tokens_neutral) + 256,
76
+ temperature=0.7,
77
+ tier=tier,
78
+ )
79
+ candidates.append(text)
80
+
81
+ selected = _rank_candidates(candidates, chunks, affect, profile)
82
+
83
+ # Guardrail β€” replace with safe fallback if output breaks persona
84
+ guard = check_output(selected, chunks)
85
+ if not guard["passed"]:
86
+ selected = guard["fallback"]
87
+
88
+ t_gen = time.perf_counter() - t0
89
+ latency_log = dict(state.get("latency_log") or {})
90
+ latency_log["t_generation"] = round(t_gen, 4)
91
+ latency_log["t_total"] = round(
92
+ latency_log.get("t_sensing", 0)
93
+ + latency_log.get("t_intent", 0)
94
+ + latency_log.get("t_retrieval", 0)
95
+ + t_gen,
96
+ 4,
97
+ )
98
+
99
+ return {
100
+ "augmented_prompt": prompt,
101
+ "candidates": candidates,
102
+ "selected_response": selected,
103
+ "llm_tier_used": tier,
104
+ "latency_log": latency_log,
105
+ "guardrail_passed": guard["passed"],
106
+ }
107
+
108
+
109
+ def _resolve_tone_tag(user_id: str, affect: str, default_tag: str) -> str:
110
+ return _PERSONA_TONE_OVERRIDES.get(user_id, {}).get(affect, default_tag)
111
+
112
+
113
+ def _build_prompt(
114
+ profile: dict,
115
+ chunks: list[dict],
116
+ history: list[dict],
117
+ query: str,
118
+ tone_tag: str,
119
+ gen_cfg: dict,
120
+ ) -> str:
121
+ memory_block = "\n".join(f" [{c['bucket']}] {c['text']}" for c in chunks) or " (no memories retrieved)"
122
+ history_block = "\n".join(f" {h.get('role','?')}: {h.get('content','')}" for h in history) or " (start of session)"
123
+ style_exemplar = profile.get("style_exemplar", "")
124
+
125
+ persona_mod = gen_cfg.get("persona_mod", "baseline")
126
+ persona_instruction = {
127
+ "amplify_quirks": "Amplify your characteristic style and personality.",
128
+ "suppress_humor": "Be direct and supportive. Suppress humor.",
129
+ "baseline": "Use your natural communication style.",
130
+ "add_confirmation": "Add a clarifying question or confirmation at the end.",
131
+ }.get(persona_mod, "Use your natural communication style.")
132
+
133
+ return f"""\
134
+ You are {profile['name']}, an AAC device user with {profile['condition']}.
135
+ Communication style: {profile['style']}
136
+ {tone_tag}
137
+
138
+ Style exemplar β€” match this register:
139
+ {style_exemplar}
140
+
141
+ Personal memories (use ONLY these for personal facts):
142
+ {memory_block}
143
+
144
+ Recent conversation:
145
+ {history_block}
146
+
147
+ Partner says: {query}
148
+
149
+ Instructions:
150
+ - Speak in first person as {profile['name']}.
151
+ - {persona_instruction}
152
+ - Keep response to 1-3 sentences.
153
+ - If the answer isn't in your memories, say "I don't know."
154
+ - Do NOT say "As an AI" or break persona.
155
+
156
+ Response:"""
157
+
158
+
159
+ def _rank_candidates(
160
+ candidates: list[str],
161
+ chunks: list[dict],
162
+ affect: str,
163
+ profile: dict,
164
+ ) -> str:
165
+ """
166
+ Composite ranking: score = Ξ±Β·faithful + Ξ²Β·style + Ξ³Β·affect_match
167
+ Simple heuristic version β€” replace with NLI + cosine similarity for final eval.
168
+ """
169
+ if not candidates:
170
+ return "I don't know."
171
+ if len(candidates) == 1:
172
+ return candidates[0]
173
+
174
+ evidence_words = set(" ".join(c["text"] for c in chunks).lower().split())
175
+ style_words = set(profile.get("style", "").lower().split())
176
+
177
+ affect_positive_map = {
178
+ "HAPPY": ["great", "love", "enjoy", "happy", "fun"],
179
+ "FRUSTRATED": ["okay", "fine", "sure", "yes", "no"],
180
+ "NEUTRAL": [],
181
+ "SURPRISED": ["really", "oh", "interesting", "wow"],
182
+ }
183
+ affect_words = set(affect_positive_map.get(affect, []))
184
+
185
+ def score(c: str) -> float:
186
+ words = set(c.lower().split())
187
+ faithful = len(words & evidence_words) / max(len(words), 1)
188
+ style_sim = len(words & style_words) / max(len(words), 1)
189
+ affect_m = len(words & affect_words) / max(len(words), 1)
190
+ return (
191
+ settings.rank_alpha * faithful
192
+ + settings.rank_beta * style_sim
193
+ + settings.rank_gamma * affect_m
194
+ )
195
+
196
+ return max(candidates, key=score)
pipeline/nodes/retrieval.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ L3 β€” Semantic Bucketing & Retrieval node.
3
+
4
+ Two entry points:
5
+ run_fast β€” FRUSTRATED affect: k=2, single bucket, no reranking
6
+ run_full β€” standard: k=5, optional bucket hint, BGE cross-encoder reranking
7
+
8
+ Also exports the conditional edge function used by graph.py.
9
+ """
10
+ from __future__ import annotations
11
+
12
+ import time
13
+
14
+ from config.settings import settings
15
+ from pipeline.state import PipelineState, RetrievedChunk
16
+ from retrieval.vector_store import retrieve
17
+ from retrieval.bucket_priors import update_priors
18
+
19
+
20
+ def run_fast(state: PipelineState) -> dict:
21
+ """Fast retrieval path for FRUSTRATED affect (k=2, no reranker)."""
22
+ t0 = time.perf_counter()
23
+
24
+ bucket_hint = _top_prior_bucket(state["bucket_priors"])
25
+ chunks = retrieve(
26
+ query=state["raw_query"],
27
+ user_id=state["user_id"],
28
+ top_k=settings.retrieval_fast_k,
29
+ rerank_k=settings.retrieval_fast_k,
30
+ bucket_filter=bucket_hint,
31
+ use_reranker=False,
32
+ )
33
+
34
+ return _build_return(state, chunks, "fast", t0)
35
+
36
+
37
+ def run_full(state: PipelineState) -> dict:
38
+ """Full retrieval path with BGE cross-encoder reranking."""
39
+ t0 = time.perf_counter()
40
+
41
+ # Prefer gaze hint > intent bucket hint > None
42
+ route = state.get("intent_route") or {}
43
+ sub_intents = route.get("sub_intents", [])
44
+ bucket_hint = (
45
+ state.get("gaze_bucket")
46
+ or next((si.get("bucket_hint") for si in sub_intents if si.get("bucket_hint")), None)
47
+ )
48
+
49
+ chunks = retrieve(
50
+ query=state["raw_query"],
51
+ user_id=state["user_id"],
52
+ top_k=settings.retrieval_top_k,
53
+ rerank_k=settings.retrieval_rerank_k,
54
+ bucket_filter=bucket_hint,
55
+ use_reranker=True,
56
+ )
57
+
58
+ return _build_return(state, chunks, "full", t0)
59
+
60
+
61
+ def route_by_affect(state: PipelineState) -> str:
62
+ """Conditional edge function β€” called by graph.py after the intent node."""
63
+ emotion = (state.get("affect") or {}).get("emotion", "NEUTRAL")
64
+ return "fast" if emotion == "FRUSTRATED" else "full"
65
+
66
+
67
+ # ── Helpers ───────────────────────────────────────────────────────────────────
68
+
69
+ def _top_prior_bucket(priors: dict[str, float]) -> str | None:
70
+ if not priors:
71
+ return None
72
+ return max(priors, key=priors.get)
73
+
74
+
75
+ def _build_return(
76
+ state: PipelineState,
77
+ chunks: list[RetrievedChunk],
78
+ mode: str,
79
+ t0: float,
80
+ ) -> dict:
81
+ t_retrieval = time.perf_counter() - t0
82
+
83
+ latency_log = dict(state.get("latency_log") or {})
84
+ latency_log["t_retrieval"] = round(t_retrieval, 4)
85
+
86
+ return {
87
+ "retrieved_chunks": chunks,
88
+ "retrieval_mode_used": mode,
89
+ "latency_log": latency_log,
90
+ }
pipeline/state.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Typed state object that flows through every LangGraph node.
3
+
4
+ Each node receives the full PipelineState and returns a dict
5
+ containing only the keys it updates β€” LangGraph merges them.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ from typing import Annotated, Any, Optional
10
+ from typing_extensions import TypedDict
11
+ import operator
12
+
13
+
14
+ # ── Sub-types ──────────────────────────────────────────────────────────────────
15
+
16
+ class AffectVector(TypedDict):
17
+ MAR: float # Mouth Aspect Ratio
18
+ EAR: float # Eye Aspect Ratio
19
+ BRI: float # Brow Raise Index
20
+ LCP: float # Lip Corner Pull
21
+
22
+
23
+ class AffectState(TypedDict):
24
+ emotion: str # "HAPPY" | "FRUSTRATED" | "NEUTRAL" | "SURPRISED"
25
+ vector: AffectVector
26
+ smoothed: AffectVector # EMA-smoothed vector
27
+
28
+
29
+ class RetrievedChunk(TypedDict):
30
+ text: str
31
+ bucket: str # family | medical | hobbies | daily_routine | social
32
+ user: str
33
+ score: float # cross-encoder rerank score
34
+
35
+
36
+ class SubIntent(TypedDict):
37
+ type: str # "PERSONAL" | "CONTEXTUAL" | "OPEN_DOMAIN"
38
+ query: str
39
+ bucket_hint: Optional[str]
40
+ priority: str # "fast" | "normal"
41
+
42
+
43
+ class IntentRoute(TypedDict):
44
+ sub_intents: list[SubIntent]
45
+ style_constraints: dict[str, Any] # tone, max_tokens, etc.
46
+ affect: str
47
+
48
+
49
+ class GenerationConfig(TypedDict):
50
+ max_tokens: int
51
+ tone_tag: str # e.g. "[TONE:WITTY_SARCASTIC]"
52
+ retrieval_mode: str # "fast" | "full"
53
+ persona_mod: str # "amplify_quirks" | "suppress_humor" | "baseline" | "add_confirmation"
54
+
55
+
56
+ class LatencyLog(TypedDict):
57
+ t_sensing: float
58
+ t_intent: float
59
+ t_retrieval: float
60
+ t_generation: float
61
+ t_total: float
62
+
63
+
64
+ # ── Main pipeline state ────────────────────────────────────────────────────────
65
+
66
+ class PipelineState(TypedDict):
67
+ # ── Session context (set at turn start, stable across nodes) ──────────────
68
+ user_id: str
69
+ persona_profile: dict[str, Any] # full profile from users.json
70
+ session_history: Annotated[list[dict], operator.add] # auto-appended
71
+ turn_id: int
72
+
73
+ # ── L1: Sensing outputs ───────────────────────────────────────────────────
74
+ affect: Optional[AffectState]
75
+ gesture_tag: Optional[str] # e.g. "THUMBS_UP"
76
+ gaze_bucket: Optional[str] # bucket hinted by gaze fixation
77
+ air_written_text: Optional[str] # concatenated air-written chars
78
+
79
+ # ── L2: Intent decomposition outputs ─────────────────────────────────────
80
+ raw_query: str # partner's typed/spoken query
81
+ intent_route: Optional[IntentRoute] # Pydantic-validated routing
82
+ generation_config: Optional[GenerationConfig]
83
+
84
+ # ── L3: Retrieval outputs ─────────────────────────────────────────────────
85
+ retrieved_chunks: list[RetrievedChunk]
86
+ bucket_priors: dict[str, float] # session-level Bayesian priors
87
+ retrieval_mode_used: str # "fast" | "full"
88
+
89
+ # ── L4: Generation outputs ────────────────────────────────────────────────
90
+ augmented_prompt: Optional[str]
91
+ candidates: list[str] # 2-3 candidate responses
92
+ selected_response: Optional[str]
93
+ llm_tier_used: str # "primary" | "fallback" | "local"
94
+
95
+ # ── L5: Feedback / tracking ───────────────────────────────────────────────
96
+ latency_log: Optional[LatencyLog]
97
+ mlflow_run_id: Optional[str]
98
+ guardrail_passed: bool
requirements.txt ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ── Orchestration ──────────────────────────────────────────────────────────────
2
+ langgraph>=1.1
3
+ langchain-core>=0.2
4
+ pydantic>=2.0
5
+ pydantic-settings>=2.0
6
+
7
+ # ── LLM clients ────────────────────────────────────────────────────────────────
8
+ openai>=1.0 # OpenAI-compatible client for vLLM + Ollama
9
+ ollama>=0.2 # local dev fallback (direct Ollama SDK)
10
+
11
+ # ── Retrieval ──────────────────────────────────────────────────────────────────
12
+ faiss-cpu>=1.7
13
+ sentence-transformers>=3.0
14
+ torch>=2.0
15
+ transformers>=4.40
16
+ numpy>=1.24
17
+
18
+ # ── Clustering ─────────────────────────────────────────────────────────────────
19
+ hdbscan>=0.8.29
20
+ scikit-learn>=1.3
21
+
22
+ # ── Sensing ────────────────────────────────────────────────────────────────────
23
+ mediapipe>=0.10
24
+ opencv-python>=4.8
25
+
26
+ # ── API backend ────────────────────────────────────────────────────────────────
27
+ fastapi>=0.111
28
+ uvicorn[standard]>=0.29
29
+
30
+ # ── UI ─────────────────────────────────────────────────────────────────────────
31
+ streamlit>=1.35
32
+ requests>=2.31 # Streamlit β†’ FastAPI calls
33
+
34
+ # ── Experiment tracking ────────────────────────────────────────────────────────
35
+ mlflow>=2.13
36
+
37
+ # ── Utilities ──────────────────────────────────────────────────────────────────
38
+ python-dotenv>=1.0
39
+ rich>=13.0
retrieval/__init__.py ADDED
File without changes
retrieval/bucket_priors.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Session-level Bayesian bucket priors (proposal Β§5.4 Bonus).
3
+
4
+ Prior P(bucket_i) is initialized uniformly across the 5 buckets.
5
+ After each accepted response, the prior is updated proportionally
6
+ to the historical acceptance rate for that bucket in the session.
7
+
8
+ P(bucket_i | accept) ∝ P(accept | bucket_i) · P(bucket_i)
9
+
10
+ The updated priors are stored in PipelineState and passed to the
11
+ retrieval node to bias FAISS search toward the most contextually
12
+ likely topic for the session.
13
+ """
14
+ from __future__ import annotations
15
+
16
+ BUCKETS = ["family", "medical", "hobbies", "daily_routine", "social"]
17
+
18
+
19
+ def uniform_priors() -> dict[str, float]:
20
+ """Return equal probability mass over all buckets."""
21
+ p = 1.0 / len(BUCKETS)
22
+ return {b: p for b in BUCKETS}
23
+
24
+
25
+ def update_priors(
26
+ priors: dict[str, float],
27
+ accepted_bucket: str,
28
+ smoothing: float = 0.1,
29
+ ) -> dict[str, float]:
30
+ """
31
+ Bayesian update: boost the accepted bucket, normalise.
32
+
33
+ Args:
34
+ priors: Current session priors (must sum to ~1.0).
35
+ accepted_bucket: Bucket that sourced the accepted response.
36
+ smoothing: Additive smoothing constant to prevent zero probabilities.
37
+ """
38
+ if not priors:
39
+ priors = uniform_priors()
40
+
41
+ updated = {b: v + smoothing for b, v in priors.items()}
42
+ updated[accepted_bucket] = updated.get(accepted_bucket, smoothing) + 1.0
43
+
44
+ total = sum(updated.values())
45
+ return {b: round(v / total, 6) for b, v in updated.items()}
46
+
47
+
48
+ def top_bucket(priors: dict[str, float]) -> str:
49
+ """Return the bucket with the highest prior."""
50
+ if not priors:
51
+ return BUCKETS[0]
52
+ return max(priors, key=priors.get)
retrieval/clustering.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HDBSCAN-based semantic bucketing over BGE embeddings.
3
+
4
+ Used to validate / discover thematic clusters in persona memories,
5
+ and to auto-assign bucket labels when adding new memory chunks.
6
+ The hand-authored bucket labels in the JSON files remain the ground
7
+ truth β€” this module provides a data-driven cross-check and supports
8
+ future expansion to unlabelled memory stores.
9
+ """
10
+ from __future__ import annotations
11
+
12
+ import json
13
+ from pathlib import Path
14
+
15
+ import numpy as np
16
+
17
+ from config.settings import settings
18
+ from retrieval.vector_store import _get_embedder
19
+
20
+ try:
21
+ import hdbscan
22
+ _HDBSCAN_AVAILABLE = True
23
+ except ImportError:
24
+ _HDBSCAN_AVAILABLE = False
25
+ print("[clustering] hdbscan not installed β€” clustering unavailable.")
26
+
27
+
28
+ BUCKET_LABELS = ["family", "medical", "hobbies", "daily_routine", "social"]
29
+
30
+
31
+ def cluster_persona_memories(user_id: str) -> dict[str, list[str]]:
32
+ """
33
+ Embed all memory chunks for a persona and cluster with HDBSCAN.
34
+
35
+ Returns a dict mapping cluster_id β†’ list of memory texts.
36
+ Cluster -1 = noise (unclustered points).
37
+ """
38
+ if not _HDBSCAN_AVAILABLE:
39
+ raise RuntimeError("hdbscan package is required. Run: pip install hdbscan")
40
+
41
+ memory_path = settings.memories_dir / f"{user_id}.json"
42
+ with open(memory_path) as f:
43
+ persona = json.load(f)
44
+
45
+ texts, true_buckets = [], []
46
+ for bucket, memories in persona["memory_buckets"].items():
47
+ for mem in memories:
48
+ texts.append(mem)
49
+ true_buckets.append(bucket)
50
+
51
+ embedder = _get_embedder()
52
+ vecs = embedder.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
53
+
54
+ clusterer = hdbscan.HDBSCAN(
55
+ min_cluster_size=3,
56
+ min_samples=2,
57
+ metric="euclidean",
58
+ )
59
+ labels = clusterer.fit_predict(vecs)
60
+
61
+ clusters: dict[str, list[str]] = {}
62
+ for text, label, true_bucket in zip(texts, labels, true_buckets):
63
+ key = f"cluster_{label}" if label >= 0 else "noise"
64
+ clusters.setdefault(key, []).append(text)
65
+
66
+ return clusters
67
+
68
+
69
+ def evaluate_bucket_alignment(user_id: str) -> dict:
70
+ """
71
+ Compare HDBSCAN cluster assignments against hand-authored bucket labels.
72
+ Returns per-bucket purity scores (fraction of dominant label in each cluster).
73
+ """
74
+ if not _HDBSCAN_AVAILABLE:
75
+ raise RuntimeError("hdbscan package is required.")
76
+
77
+ memory_path = settings.memories_dir / f"{user_id}.json"
78
+ with open(memory_path) as f:
79
+ persona = json.load(f)
80
+
81
+ texts, true_buckets = [], []
82
+ for bucket, memories in persona["memory_buckets"].items():
83
+ for mem in memories:
84
+ texts.append(mem)
85
+ true_buckets.append(bucket)
86
+
87
+ embedder = _get_embedder()
88
+ vecs = embedder.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
89
+
90
+ clusterer = hdbscan.HDBSCAN(min_cluster_size=3, min_samples=2, metric="euclidean")
91
+ pred_labels = clusterer.fit_predict(vecs)
92
+
93
+ cluster_bucket_counts: dict[int, dict[str, int]] = {}
94
+ for pred, true in zip(pred_labels, true_buckets):
95
+ cluster_bucket_counts.setdefault(pred, {})
96
+ cluster_bucket_counts[pred][true] = cluster_bucket_counts[pred].get(true, 0) + 1
97
+
98
+ purity_scores = {}
99
+ for cluster_id, bucket_counts in cluster_bucket_counts.items():
100
+ total = sum(bucket_counts.values())
101
+ dominant = max(bucket_counts.values())
102
+ purity_scores[cluster_id] = round(dominant / total, 3)
103
+
104
+ return {
105
+ "n_clusters": len([k for k in purity_scores if k >= 0]),
106
+ "n_noise": cluster_bucket_counts.get(-1, {}),
107
+ "cluster_purity": purity_scores,
108
+ "mean_purity": round(
109
+ np.mean([v for k, v in purity_scores.items() if k >= 0] or [0.0]), 3
110
+ ),
111
+ }
retrieval/vector_store.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FAISS-backed dense retrieval with BGE embeddings and cross-encoder reranking.
3
+
4
+ Models are lazy-loaded on first use (safe for FastAPI / LangGraph workers).
5
+
6
+ NOTE: The FAISS indexes in data/faiss_store/ must be built with BGE embeddings.
7
+ Run `python -m retrieval.vector_store` to rebuild all persona indexes.
8
+ """
9
+ from __future__ import annotations
10
+
11
+ import json
12
+ import time
13
+ from functools import lru_cache
14
+ from pathlib import Path
15
+
16
+ import faiss
17
+ import numpy as np
18
+ from sentence_transformers import CrossEncoder, SentenceTransformer
19
+
20
+ from config.settings import settings
21
+ from pipeline.state import RetrievedChunk
22
+
23
+ # ── Lazy model singletons ──────────────────────────────────────────────────────
24
+
25
+ @lru_cache(maxsize=1)
26
+ def _get_embedder() -> SentenceTransformer:
27
+ return SentenceTransformer(settings.embed_model)
28
+
29
+
30
+ @lru_cache(maxsize=1)
31
+ def _get_reranker() -> CrossEncoder:
32
+ return CrossEncoder(settings.rerank_model)
33
+
34
+
35
+ # ── Index cache (one FAISS index per user_id) ─────────────────────────────────
36
+
37
+ _index_cache: dict[str, tuple[faiss.Index, list[dict]]] = {}
38
+
39
+
40
+ def load_index(user_id: str) -> tuple[faiss.Index, list[dict]]:
41
+ if user_id not in _index_cache:
42
+ store_path = settings.faiss_store_dir / user_id
43
+ index = faiss.read_index(str(store_path / "index.faiss"))
44
+ with open(store_path / "meta.json") as f:
45
+ meta = json.load(f)
46
+ _index_cache[user_id] = (index, meta)
47
+ return _index_cache[user_id]
48
+
49
+
50
+ # ── Core retrieve function ─────────────────────────────────────────────────────
51
+
52
+ def retrieve(
53
+ query: str,
54
+ user_id: str,
55
+ top_k: int = 5,
56
+ rerank_k: int = 3,
57
+ bucket_filter: str | None = None,
58
+ use_reranker: bool = True,
59
+ debug: bool = False,
60
+ ) -> list[RetrievedChunk]:
61
+ """
62
+ Two-stage retrieval:
63
+ 1. BGE-small-en-v1.5 bi-encoder β†’ FAISS IndexFlatIP (cosine similarity)
64
+ 2. BGE-reranker-v2-m3 cross-encoder reranking (multilingual, skippable)
65
+
66
+ Args:
67
+ query: Partner's text query.
68
+ user_id: Persona identifier (e.g. "mia_chen").
69
+ top_k: Number of candidates from FAISS before reranking.
70
+ rerank_k: Final number of chunks returned after reranking.
71
+ bucket_filter: If set, restrict candidates to this memory bucket.
72
+ use_reranker: False for the FRUSTRATED fast path.
73
+ debug: Return timing breakdown alongside results.
74
+ """
75
+ embedder = _get_embedder()
76
+ index, meta = load_index(user_id)
77
+
78
+ t0 = time.perf_counter()
79
+ q_vec = embedder.encode(
80
+ [query], convert_to_numpy=True, normalize_embeddings=True
81
+ )
82
+ t_embed = time.perf_counter() - t0
83
+
84
+ t0 = time.perf_counter()
85
+ _, idxs = index.search(q_vec, top_k)
86
+ t_faiss = time.perf_counter() - t0
87
+
88
+ candidates = [meta[i] for i in idxs[0] if i < len(meta)]
89
+
90
+ if bucket_filter:
91
+ filtered = [c for c in candidates if c["bucket"] == bucket_filter]
92
+ candidates = filtered if filtered else candidates # fallback: all buckets
93
+
94
+ t0 = time.perf_counter()
95
+ if use_reranker and len(candidates) > 1:
96
+ reranker = _get_reranker()
97
+ pairs = [(query, c["text"]) for c in candidates]
98
+ ce_scores = reranker.predict(pairs)
99
+ ranked = sorted(zip(ce_scores, candidates), key=lambda x: x[0], reverse=True)
100
+ top = [
101
+ RetrievedChunk(text=c["text"], bucket=c["bucket"], user=c["user"], score=float(s))
102
+ for s, c in ranked[:rerank_k]
103
+ ]
104
+ else:
105
+ top = [
106
+ RetrievedChunk(text=c["text"], bucket=c["bucket"], user=c["user"], score=1.0)
107
+ for c in candidates[:rerank_k]
108
+ ]
109
+ t_rerank = time.perf_counter() - t0
110
+
111
+ if debug:
112
+ return top, {"t_embed": t_embed, "t_faiss": t_faiss, "t_rerank": t_rerank}
113
+ return top
114
+
115
+
116
+ # ── Index builder ──────────────────────────────────────────────────────────────
117
+
118
+ def build_index(persona_path: str | Path) -> tuple[faiss.Index, list[dict]]:
119
+ """Embed all memory chunks for a persona and build a FAISS IndexFlatIP."""
120
+ with open(persona_path) as f:
121
+ persona = json.load(f)
122
+
123
+ user_name = persona["profile"]["name"]
124
+ chunks, meta = [], []
125
+
126
+ for bucket, memories in persona["memory_buckets"].items():
127
+ for mem in memories:
128
+ chunks.append(mem)
129
+ meta.append({"text": mem, "bucket": bucket, "user": user_name})
130
+
131
+ embedder = _get_embedder()
132
+ vecs = embedder.encode(chunks, convert_to_numpy=True, normalize_embeddings=True)
133
+
134
+ dim = vecs.shape[1]
135
+ index = faiss.IndexFlatIP(dim)
136
+ index.add(vecs.astype(np.float32))
137
+ return index, meta
138
+
139
+
140
+ def save_index(index: faiss.Index, meta: list[dict], save_dir: str | Path) -> None:
141
+ p = Path(save_dir)
142
+ p.mkdir(parents=True, exist_ok=True)
143
+ faiss.write_index(index, str(p / "index.faiss"))
144
+ with open(p / "meta.json", "w") as f:
145
+ json.dump(meta, f, indent=2)
146
+
147
+
148
+ def build_all(
149
+ memories_dir: str | Path | None = None,
150
+ store_dir: str | Path | None = None,
151
+ ) -> None:
152
+ """Rebuild FAISS indexes for all personas using the configured BGE embedder."""
153
+ memories_dir = Path(memories_dir or settings.memories_dir)
154
+ store_dir = Path(store_dir or settings.faiss_store_dir)
155
+
156
+ for persona_file in sorted(memories_dir.glob("*.json")):
157
+ uid = persona_file.stem
158
+ print(f" Building index for {uid} …")
159
+ index, meta = build_index(persona_file)
160
+ save_index(index, meta, store_dir / uid)
161
+ print(f" Saved {len(meta)} chunks β†’ {store_dir / uid}/")
162
+ print("\nAll indexes built.")
163
+
164
+
165
+ # ── Entrypoint ────────────────────────────────────────────────────────────────
166
+
167
+ if __name__ == "__main__":
168
+ build_all()
sensing/__init__.py ADDED
File without changes
sensing/air_writing.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ L1 β€” Air writing recognition via index-finger tip trajectory (proposal Β§5.2).
3
+
4
+ Tracks MediaPipe Hands landmark 8 (index fingertip) across frames.
5
+ Stroke segmentation uses velocity thresholding:
6
+ - stroke starts when velocity > START_VEL px/frame
7
+ - stroke ends when velocity < END_VEL px/frame for > GAP_MS ms
8
+
9
+ Segmented strokes are classified against a template library using
10
+ Dynamic Time Warping (DTW). Supports:
11
+ - 26 uppercase English letters (A-Z)
12
+ - 10 digits (0-9)
13
+ - 10 most frequent Devanagari characters (for Arjun's Hindi inputs)
14
+
15
+ Recognised characters are concatenated and returned as a text string
16
+ to the intent decomposition layer.
17
+ """
18
+ from __future__ import annotations
19
+
20
+ import time
21
+ from collections import deque
22
+ from dataclasses import dataclass, field
23
+
24
+ import numpy as np
25
+
26
+ from config.settings import settings
27
+
28
+ try:
29
+ import mediapipe as mp
30
+ _MP_AVAILABLE = True
31
+ except ImportError:
32
+ _MP_AVAILABLE = False
33
+
34
+ # ── Landmark index ─────────────────────────────────────────────────────────────
35
+ _INDEX_TIP = 8
36
+
37
+
38
+ @dataclass
39
+ class AirWriter:
40
+ """
41
+ Stateful air-writing recogniser. Feed frames from a webcam loop.
42
+ Call `get_text()` to retrieve and clear the current buffer.
43
+ """
44
+ _trajectory: list[tuple[float, float]] = field(default_factory=list)
45
+ _in_stroke: bool = False
46
+ _stroke_end_time: float = field(default=0.0)
47
+ _text_buffer: list[str] = field(default_factory=list)
48
+ _templates: dict[str, np.ndarray] = field(default_factory=dict)
49
+
50
+ def __post_init__(self):
51
+ if not _MP_AVAILABLE:
52
+ raise ImportError("mediapipe is required: pip install mediapipe")
53
+ self._hands = mp.solutions.hands.Hands(
54
+ static_image_mode=False,
55
+ max_num_hands=1,
56
+ min_detection_confidence=0.6,
57
+ min_tracking_confidence=0.5,
58
+ )
59
+ self._prev_pt: tuple[float, float] | None = None
60
+ self._templates = _load_templates()
61
+
62
+ def process_frame(self, bgr_frame) -> str | None:
63
+ """
64
+ Process one frame. Returns a recognised character when a stroke
65
+ completes, or None otherwise.
66
+ """
67
+ import cv2
68
+ rgb = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
69
+ result = self._hands.process(rgb)
70
+
71
+ if not result.multi_hand_landmarks:
72
+ self._prev_pt = None
73
+ return self._check_stroke_end()
74
+
75
+ h, w = bgr_frame.shape[:2]
76
+ lm = result.multi_hand_landmarks[0].landmark
77
+ tip = (lm[_INDEX_TIP].x * w, lm[_INDEX_TIP].y * h)
78
+
79
+ velocity = 0.0
80
+ if self._prev_pt is not None:
81
+ velocity = np.linalg.norm(np.array(tip) - np.array(self._prev_pt))
82
+ self._prev_pt = tip
83
+
84
+ start_v = settings.air_write_velocity_start
85
+ end_v = settings.air_write_velocity_end
86
+
87
+ if velocity > start_v:
88
+ self._in_stroke = True
89
+ self._trajectory.append(tip)
90
+ self._stroke_end_time = 0.0
91
+ elif self._in_stroke and velocity < end_v:
92
+ if self._stroke_end_time == 0.0:
93
+ self._stroke_end_time = time.time()
94
+ return self._check_stroke_end()
95
+
96
+ return None
97
+
98
+ def _check_stroke_end(self) -> str | None:
99
+ if not self._in_stroke or self._stroke_end_time == 0.0:
100
+ return None
101
+ gap_s = settings.air_write_end_gap_ms / 1000.0
102
+ if time.time() - self._stroke_end_time >= gap_s:
103
+ char = self._recognise(self._trajectory)
104
+ self._trajectory = []
105
+ self._in_stroke = False
106
+ self._stroke_end_time = 0.0
107
+ if char:
108
+ self._text_buffer.append(char)
109
+ return char
110
+ return None
111
+
112
+ def _recognise(self, trajectory: list[tuple[float, float]]) -> str | None:
113
+ if len(trajectory) < 5 or not self._templates:
114
+ return None
115
+ query = _normalise_trajectory(np.array(trajectory))
116
+ best_char, best_dist = None, float("inf")
117
+ for char, template in self._templates.items():
118
+ dist = _dtw_distance(query, template)
119
+ if dist < best_dist:
120
+ best_dist = dist
121
+ best_char = char
122
+ return best_char
123
+
124
+ def get_text(self) -> str:
125
+ """Return and clear the accumulated air-written text."""
126
+ text = "".join(self._text_buffer)
127
+ self._text_buffer.clear()
128
+ return text
129
+
130
+ def release(self):
131
+ self._hands.close()
132
+
133
+
134
+ # ── DTW helpers ───────────────────────────────────────────────────────────────
135
+
136
+ def _normalise_trajectory(pts: np.ndarray) -> np.ndarray:
137
+ """Scale trajectory to unit bounding box, resample to 32 points."""
138
+ pts = pts - pts.min(axis=0)
139
+ scale = pts.max(axis=0) + 1e-6
140
+ pts = pts / scale
141
+ # Resample to fixed length via linear interpolation
142
+ t_old = np.linspace(0, 1, len(pts))
143
+ t_new = np.linspace(0, 1, 32)
144
+ return np.column_stack([
145
+ np.interp(t_new, t_old, pts[:, 0]),
146
+ np.interp(t_new, t_old, pts[:, 1]),
147
+ ])
148
+
149
+
150
+ def _dtw_distance(a: np.ndarray, b: np.ndarray) -> float:
151
+ """Simple O(nΒ²) DTW β€” trajectories are short (32 pts), so this is fine."""
152
+ n, m = len(a), len(b)
153
+ dtw = np.full((n + 1, m + 1), np.inf)
154
+ dtw[0, 0] = 0.0
155
+ for i in range(1, n + 1):
156
+ for j in range(1, m + 1):
157
+ cost = np.linalg.norm(a[i - 1] - b[j - 1])
158
+ dtw[i, j] = cost + min(dtw[i - 1, j], dtw[i, j - 1], dtw[i - 1, j - 1])
159
+ return float(dtw[n, m])
160
+
161
+
162
+ def _load_templates() -> dict[str, np.ndarray]:
163
+ """
164
+ Load pre-recorded stroke templates from disk.
165
+ Template files should be numpy arrays of shape (32, 2) stored as .npy.
166
+ Returns an empty dict if no template directory exists yet.
167
+ """
168
+ from pathlib import Path
169
+ template_dir = Path("data/air_write_templates")
170
+ if not template_dir.exists():
171
+ return {}
172
+ templates = {}
173
+ for f in template_dir.glob("*.npy"):
174
+ char = f.stem # filename = character label
175
+ templates[char] = np.load(f)
176
+ return templates
sensing/face_mesh.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ L1 β€” Facial affect detection via MediaPipe 2D Face Mesh.
3
+
4
+ Extracts 4 geometric features from 478 landmarks at ~10 fps:
5
+ MAR β€” Mouth Aspect Ratio (surprise / speech attempt)
6
+ EAR β€” Eye Aspect Ratio (frustration / blink)
7
+ BRI β€” Brow Raise Index (surprise / questioning)
8
+ LCP β€” Lip Corner Pull (smile vs frown)
9
+
10
+ These form the affect vector fed into MobileNetV3-Small affect classifier,
11
+ which maps to one of 4 actionable states: HAPPY | FRUSTRATED | NEUTRAL | SURPRISED.
12
+
13
+ EMA smoothing (Ξ±=0.3) prevents transient expressions (sneezes, blinks)
14
+ from destabilising the detected state across turns.
15
+ """
16
+ from __future__ import annotations
17
+
18
+ import time
19
+ from dataclasses import dataclass, field
20
+
21
+ import numpy as np
22
+
23
+ from config.settings import settings
24
+ from pipeline.state import AffectState, AffectVector
25
+
26
+ try:
27
+ import mediapipe as mp
28
+ _MP_AVAILABLE = True
29
+ except ImportError:
30
+ _MP_AVAILABLE = False
31
+
32
+ try:
33
+ import cv2
34
+ _CV2_AVAILABLE = True
35
+ except ImportError:
36
+ _CV2_AVAILABLE = False
37
+
38
+
39
+ # ── MediaPipe landmark indices (from proposal Β§5.2) ───────────────────────────
40
+
41
+ # MAR β€” mouth vertical / horizontal ratio
42
+ _MOUTH_TOP = 13
43
+ _MOUTH_BOTTOM = 14
44
+ _MOUTH_LEFT = 61
45
+ _MOUTH_RIGHT = 291
46
+
47
+ # EAR β€” eye vertical / horizontal ratio (right eye)
48
+ _EYE_TOP = 159
49
+ _EYE_BOTTOM = 145
50
+ _EYE_LEFT = 33
51
+ _EYE_RIGHT = 133
52
+
53
+ # BRI β€” brow vertical displacement relative to eye centre
54
+ _BROW_LEFT = 70
55
+ _BROW_RIGHT = 300
56
+
57
+ # LCP β€” mouth corner horizontal displacement from neutral baseline
58
+ _CORNER_LEFT = 61
59
+ _CORNER_RIGHT = 291
60
+
61
+
62
+ # ── Affect classes ────────────────────────────────────────────────────────────
63
+
64
+ AFFECT_CLASSES = ["HAPPY", "FRUSTRATED", "NEUTRAL", "SURPRISED"]
65
+
66
+
67
+ @dataclass
68
+ class AffectDetector:
69
+ """
70
+ Stateful detector that maintains EMA-smoothed affect across frames.
71
+ Create one instance per session and call `process_frame` each frame.
72
+ """
73
+ _smoothed: AffectVector = field(default_factory=lambda: AffectVector(MAR=0.0, EAR=0.3, BRI=0.0, LCP=0.0))
74
+ _neutral_lcp: float = 0.0 # calibrated at session start
75
+ _calibrated: bool = False
76
+
77
+ def __post_init__(self):
78
+ if not _MP_AVAILABLE:
79
+ raise ImportError("mediapipe is required: pip install mediapipe")
80
+ if not _CV2_AVAILABLE:
81
+ raise ImportError("opencv-python is required: pip install opencv-python")
82
+
83
+ self._face_mesh = mp.solutions.face_mesh.FaceMesh(
84
+ static_image_mode=False,
85
+ max_num_faces=1,
86
+ refine_landmarks=True, # enables iris landmarks (468-477)
87
+ min_detection_confidence=0.5,
88
+ min_tracking_confidence=0.5,
89
+ )
90
+
91
+ def process_frame(self, bgr_frame: np.ndarray) -> AffectState | None:
92
+ """
93
+ Process one BGR frame from OpenCV and return the current AffectState,
94
+ or None if no face is detected.
95
+ """
96
+ rgb = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
97
+ result = self._face_mesh.process(rgb)
98
+
99
+ if not result.multi_face_landmarks:
100
+ return None
101
+
102
+ lm = result.multi_face_landmarks[0].landmark
103
+ h, w = bgr_frame.shape[:2]
104
+
105
+ def pt(idx):
106
+ l = lm[idx]
107
+ return np.array([l.x * w, l.y * h])
108
+
109
+ raw = self._compute_features(pt)
110
+
111
+ if not self._calibrated:
112
+ self._neutral_lcp = raw["LCP"]
113
+ self._calibrated = True
114
+
115
+ raw["LCP"] = raw["LCP"] - self._neutral_lcp # relative to neutral baseline
116
+
117
+ alpha = settings.affect_ema_alpha
118
+ smoothed = AffectVector(
119
+ MAR=alpha * raw["MAR"] + (1 - alpha) * self._smoothed["MAR"],
120
+ EAR=alpha * raw["EAR"] + (1 - alpha) * self._smoothed["EAR"],
121
+ BRI=alpha * raw["BRI"] + (1 - alpha) * self._smoothed["BRI"],
122
+ LCP=alpha * raw["LCP"] + (1 - alpha) * self._smoothed["LCP"],
123
+ )
124
+ self._smoothed = smoothed
125
+
126
+ emotion = self._classify(smoothed)
127
+ return AffectState(emotion=emotion, vector=raw, smoothed=smoothed)
128
+
129
+ def _compute_features(self, pt) -> dict:
130
+ # MAR
131
+ mouth_v = np.linalg.norm(pt(_MOUTH_TOP) - pt(_MOUTH_BOTTOM))
132
+ mouth_h = np.linalg.norm(pt(_MOUTH_LEFT) - pt(_MOUTH_RIGHT))
133
+ MAR = mouth_v / (mouth_h + 1e-6)
134
+
135
+ # EAR
136
+ eye_v = np.linalg.norm(pt(_EYE_TOP) - pt(_EYE_BOTTOM))
137
+ eye_h = np.linalg.norm(pt(_EYE_LEFT) - pt(_EYE_RIGHT))
138
+ EAR = eye_v / (eye_h + 1e-6)
139
+
140
+ # BRI β€” average brow displacement relative to eye centre
141
+ eye_center = (pt(_EYE_LEFT) + pt(_EYE_RIGHT)) / 2
142
+ inter_ocular = eye_h
143
+ brow_mid = (pt(_BROW_LEFT) + pt(_BROW_RIGHT)) / 2
144
+ BRI = (eye_center[1] - brow_mid[1]) / (inter_ocular + 1e-6)
145
+
146
+ # LCP β€” average horizontal mouth corner displacement
147
+ LCP = float((pt(_CORNER_LEFT)[0] + pt(_CORNER_RIGHT)[0]) / 2)
148
+
149
+ return {"MAR": float(MAR), "EAR": float(EAR), "BRI": float(BRI), "LCP": float(LCP)}
150
+
151
+ @staticmethod
152
+ def _classify(v: AffectVector) -> str:
153
+ """
154
+ Rule-based classifier over the 4 geometric features.
155
+ Replace with MobileNetV3-Small for final evaluation.
156
+ """
157
+ if v["BRI"] > 0.25 and v["MAR"] > 0.3:
158
+ return "SURPRISED"
159
+ if v["EAR"] < 0.15 and v["LCP"] < -5:
160
+ return "FRUSTRATED"
161
+ if v["LCP"] > 5:
162
+ return "HAPPY"
163
+ return "NEUTRAL"
164
+
165
+ def release(self):
166
+ self._face_mesh.close()
sensing/gaze.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ L1 β€” Gaze-based retrieval activation (Bonus feature, proposal Β§5.2).
3
+
4
+ Uses MediaPipe iris landmarks (468-472) to estimate gaze direction as
5
+ a 2D screen-coordinate vector. Sustained fixation (> 1.5 s dwell time)
6
+ on a defined UI region pre-biases the retrieval layer toward the
7
+ corresponding memory bucket.
8
+
9
+ UI region β†’ bucket mapping:
10
+ top-left quadrant β†’ family
11
+ top-right quadrant β†’ medical
12
+ bottom-left quadrant β†’ hobbies
13
+ bottom-right quadrant β†’ daily_routine
14
+ centre strip β†’ social
15
+ """
16
+ from __future__ import annotations
17
+
18
+ import time
19
+ from dataclasses import dataclass, field
20
+
21
+ import numpy as np
22
+
23
+ from config.settings import settings
24
+
25
+ try:
26
+ import mediapipe as mp
27
+ _MP_AVAILABLE = True
28
+ except ImportError:
29
+ _MP_AVAILABLE = False
30
+
31
+
32
+ # ── Iris landmark indices ──────────────────────────────────────────────────────
33
+ # MediaPipe refine_landmarks=True adds iris landmarks 468-477
34
+ _LEFT_IRIS_CENTER = 468
35
+ _RIGHT_IRIS_CENTER = 473
36
+
37
+ # ── Screen region β†’ bucket map ─────────────────────────────────────────────────
38
+ # Defined as (x_min, y_min, x_max, y_max) in normalised [0,1] coords
39
+ _REGION_BUCKET: list[tuple[tuple[float, float, float, float], str]] = [
40
+ ((0.0, 0.0, 0.5, 0.5), "family"),
41
+ ((0.5, 0.0, 1.0, 0.5), "medical"),
42
+ ((0.0, 0.5, 0.5, 1.0), "hobbies"),
43
+ ((0.5, 0.5, 1.0, 1.0), "daily_routine"),
44
+ ((0.3, 0.3, 0.7, 0.7), "social"), # centre strip (checked last β†’ lowest priority)
45
+ ]
46
+
47
+
48
+ @dataclass
49
+ class GazeTracker:
50
+ """
51
+ Stateful gaze tracker. Call `process_frame` each frame.
52
+ Returns the bucket name when dwell threshold is exceeded, else None.
53
+ """
54
+ _dwell_start: float = field(default=0.0)
55
+ _current_region: str | None = field(default=None)
56
+
57
+ def __post_init__(self):
58
+ if not _MP_AVAILABLE:
59
+ raise ImportError("mediapipe is required: pip install mediapipe")
60
+ self._face_mesh = mp.solutions.face_mesh.FaceMesh(
61
+ static_image_mode=False,
62
+ max_num_faces=1,
63
+ refine_landmarks=True,
64
+ min_detection_confidence=0.5,
65
+ min_tracking_confidence=0.5,
66
+ )
67
+
68
+ def process_frame(self, bgr_frame) -> str | None:
69
+ """
70
+ Returns the hinted bucket name once dwell threshold is exceeded,
71
+ then resets the dwell timer. Returns None otherwise.
72
+ """
73
+ import cv2
74
+ rgb = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
75
+ result = self._face_mesh.process(rgb)
76
+
77
+ if not result.multi_face_landmarks:
78
+ self._reset()
79
+ return None
80
+
81
+ lm = result.multi_face_landmarks[0].landmark
82
+
83
+ # Average left + right iris centres for gaze estimate
84
+ gaze_x = (lm[_LEFT_IRIS_CENTER].x + lm[_RIGHT_IRIS_CENTER].x) / 2
85
+ gaze_y = (lm[_LEFT_IRIS_CENTER].y + lm[_RIGHT_IRIS_CENTER].y) / 2
86
+
87
+ bucket = self._region_for(gaze_x, gaze_y)
88
+
89
+ if bucket != self._current_region:
90
+ self._current_region = bucket
91
+ self._dwell_start = time.time()
92
+ return None
93
+
94
+ dwell = time.time() - self._dwell_start
95
+ if dwell >= settings.gaze_dwell_threshold_s and bucket is not None:
96
+ self._reset()
97
+ return bucket
98
+
99
+ return None
100
+
101
+ @staticmethod
102
+ def _region_for(x: float, y: float) -> str | None:
103
+ for (x0, y0, x1, y1), bucket in _REGION_BUCKET:
104
+ if x0 <= x <= x1 and y0 <= y <= y1:
105
+ return bucket
106
+ return None
107
+
108
+ def _reset(self):
109
+ self._dwell_start = 0.0
110
+ self._current_region = None
111
+
112
+ def release(self):
113
+ self._face_mesh.close()
sensing/gesture.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ L1 β€” Hand gesture recognition via MediaPipe Hands.
3
+
4
+ Recognises 4 gestures from 21 3D hand landmarks at ~15 fps using
5
+ normalised joint-angle rules (no ML model needed at this stage):
6
+
7
+ THUMBS_UP β†’ [TONE:AFFIRMATIVE]
8
+ THUMBS_DOWN β†’ [TONE:NEGATIVE]
9
+ POINTING β†’ [INTENT:REFERENTIAL]
10
+ WAVING β†’ [INTENT:GREETING]
11
+
12
+ Each detected gesture is mapped to a stylistic constraint tag that is
13
+ injected into the generation prompt by the planner node.
14
+ """
15
+ from __future__ import annotations
16
+
17
+ import numpy as np
18
+
19
+ try:
20
+ import mediapipe as mp
21
+ _MP_AVAILABLE = True
22
+ except ImportError:
23
+ _MP_AVAILABLE = False
24
+
25
+
26
+ # Gesture β†’ prompt constraint tag mapping
27
+ GESTURE_TO_TAG: dict[str, str] = {
28
+ "THUMBS_UP": "[GESTURE:THUMBS_UP][TONE:AFFIRMATIVE]",
29
+ "THUMBS_DOWN": "[GESTURE:THUMBS_DOWN][TONE:NEGATIVE]",
30
+ "POINTING": "[GESTURE:POINTING][INTENT:REFERENTIAL]",
31
+ "WAVING": "[GESTURE:WAVING][INTENT:GREETING]",
32
+ }
33
+
34
+
35
+ class GestureClassifier:
36
+ """
37
+ Stateful classifier β€” create one instance per session.
38
+ Feed MediaPipe hand landmark results each frame.
39
+ """
40
+
41
+ def __init__(self):
42
+ if not _MP_AVAILABLE:
43
+ raise ImportError("mediapipe is required: pip install mediapipe")
44
+ self._hands = mp.solutions.hands.Hands(
45
+ static_image_mode=False,
46
+ max_num_hands=1,
47
+ min_detection_confidence=0.6,
48
+ min_tracking_confidence=0.5,
49
+ )
50
+
51
+ def process_frame(self, bgr_frame) -> str | None:
52
+ """
53
+ Returns a gesture label string or None if no clear gesture is detected.
54
+ """
55
+ import cv2
56
+ rgb = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
57
+ result = self._hands.process(rgb)
58
+
59
+ if not result.multi_hand_landmarks:
60
+ return None
61
+
62
+ lm = result.multi_hand_landmarks[0].landmark
63
+ pts = np.array([[l.x, l.y, l.z] for l in lm])
64
+
65
+ return self._classify(pts)
66
+
67
+ def gesture_tag(self, bgr_frame) -> str | None:
68
+ """Convenience: returns the prompt tag directly, or None."""
69
+ gesture = self.process_frame(bgr_frame)
70
+ return GESTURE_TO_TAG.get(gesture) if gesture else None
71
+
72
+ @staticmethod
73
+ def _classify(pts: np.ndarray) -> str | None:
74
+ """
75
+ Rule-based gesture classification over normalised joint positions.
76
+
77
+ MediaPipe hand landmark indices:
78
+ 0=WRIST, 1-4=THUMB, 5-8=INDEX, 9-12=MIDDLE, 13-16=RING, 17-20=PINKY
79
+ """
80
+ # Normalise: wrist at origin, scale by palm width
81
+ wrist = pts[0]
82
+ palm_width = np.linalg.norm(pts[5] - pts[17]) + 1e-6
83
+ p = (pts - wrist) / palm_width
84
+
85
+ thumb_tip = p[4]
86
+ index_tip = p[8]
87
+ middle_tip = p[12]
88
+ ring_tip = p[16]
89
+ pinky_tip = p[20]
90
+ index_mcp = p[5] # knuckle
91
+
92
+ # THUMBS_UP: thumb tip above wrist, other fingers curled
93
+ fingers_curled = all(
94
+ np.linalg.norm(tip) < np.linalg.norm(p[mcp])
95
+ for tip, mcp in [(index_tip, p[5]), (middle_tip, p[9]), (ring_tip, p[13])]
96
+ )
97
+ if thumb_tip[1] < -0.3 and fingers_curled:
98
+ return "THUMBS_UP"
99
+
100
+ # THUMBS_DOWN: thumb tip below wrist, other fingers curled
101
+ if thumb_tip[1] > 0.3 and fingers_curled:
102
+ return "THUMBS_DOWN"
103
+
104
+ # POINTING: index extended, others curled
105
+ index_extended = np.linalg.norm(index_tip) > np.linalg.norm(index_mcp) * 1.3
106
+ others_curled = all(
107
+ np.linalg.norm(tip) < 0.5
108
+ for tip in [middle_tip, ring_tip, pinky_tip]
109
+ )
110
+ if index_extended and others_curled:
111
+ return "POINTING"
112
+
113
+ # WAVING: all fingers extended, hand roughly vertical
114
+ all_extended = all(
115
+ np.linalg.norm(tip) > 0.5
116
+ for tip in [index_tip, middle_tip, ring_tip, pinky_tip]
117
+ )
118
+ if all_extended:
119
+ return "WAVING"
120
+
121
+ return None
122
+
123
+ def release(self):
124
+ self._hands.close()
ui/app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Streamlit frontend β€” webcam + chat + live metrics dashboard.
3
+
4
+ Panels:
5
+ Left sidebar β€” persona selector, session controls, live affect display
6
+ Centre β€” chat interface with streaming response
7
+ Right sidebar β€” latency breakdown, bucket priors bar chart
8
+
9
+ Run: streamlit run ui/app.py
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import json
14
+ import time
15
+
16
+ import requests
17
+ import streamlit as st
18
+
19
+ # ── Config ─────────────────────────────────────────────────────────────────────
20
+ API_BASE = "http://localhost:8000"
21
+
22
+ st.set_page_config(
23
+ page_title="AAC Chatbot",
24
+ layout="wide",
25
+ initial_sidebar_state="expanded",
26
+ )
27
+
28
+
29
+ # ── Session state init ─────────────────────────────────────────────────────────
30
+ if "user_id" not in st.session_state:
31
+ st.session_state.user_id = None
32
+ if "messages" not in st.session_state:
33
+ st.session_state.messages = []
34
+ if "last_latency" not in st.session_state:
35
+ st.session_state.last_latency = {}
36
+ if "last_affect" not in st.session_state:
37
+ st.session_state.last_affect = "NEUTRAL"
38
+ if "affect_override" not in st.session_state:
39
+ st.session_state.affect_override = None
40
+
41
+
42
+ # ── Sidebar ────────────────────────────────────────────────────────────────────
43
+ with st.sidebar:
44
+ st.title("AAC Chatbot")
45
+
46
+ # Persona selection
47
+ try:
48
+ users_resp = requests.get(f"{API_BASE}/users", timeout=3)
49
+ users = users_resp.json().get("users", [])
50
+ except Exception:
51
+ users = []
52
+ st.error("API not reachable β€” start the FastAPI server first.")
53
+
54
+ user_options = {u["id"]: f"{u['name']} ({u['condition']})" for u in users}
55
+ selected = st.selectbox("Select persona", options=list(user_options.keys()),
56
+ format_func=lambda k: user_options.get(k, k))
57
+
58
+ if selected != st.session_state.user_id:
59
+ st.session_state.user_id = selected
60
+ st.session_state.messages = []
61
+ try:
62
+ requests.post(f"{API_BASE}/session/reset", params={"user_id": selected})
63
+ except Exception:
64
+ pass
65
+
66
+ st.divider()
67
+
68
+ # Affect override (for demo / testing without webcam)
69
+ st.subheader("Affect Override")
70
+ st.caption("Simulates webcam affect detection")
71
+ affect_choice = st.radio(
72
+ "Current affect",
73
+ ["Auto (webcam)", "HAPPY", "FRUSTRATED", "NEUTRAL", "SURPRISED"],
74
+ index=0,
75
+ )
76
+ st.session_state.affect_override = None if affect_choice == "Auto (webcam)" else affect_choice
77
+
78
+ st.divider()
79
+
80
+ # Live affect indicator
81
+ st.subheader("Detected Affect")
82
+ affect_emoji = {
83
+ "HAPPY": "😊", "FRUSTRATED": "😀",
84
+ "NEUTRAL": "😐", "SURPRISED": "😲",
85
+ }
86
+ af = st.session_state.last_affect
87
+ st.markdown(f"### {affect_emoji.get(af, '❓')} {af}")
88
+
89
+ # Webcam placeholder
90
+ st.divider()
91
+ st.subheader("Webcam Feed")
92
+ st.info("Live webcam sensing runs in the sensing client.\nAffect is sent to the API automatically.")
93
+
94
+
95
+ # ── Main chat area ─────────────────────────────────────────────────────────────
96
+ st.header(f"Talking as: {user_options.get(st.session_state.user_id, 'β€”')}")
97
+
98
+ chat_col, metrics_col = st.columns([3, 1])
99
+
100
+ with chat_col:
101
+ for msg in st.session_state.messages:
102
+ role_label = "Partner" if msg["role"] == "partner" else "AAC User"
103
+ with st.chat_message("user" if msg["role"] == "partner" else "assistant"):
104
+ st.markdown(f"**{role_label}:** {msg['content']}")
105
+
106
+ query = st.chat_input("Type as the communication partner…")
107
+
108
+ if query and st.session_state.user_id:
109
+ st.session_state.messages.append({"role": "partner", "content": query})
110
+ with st.chat_message("user"):
111
+ st.markdown(f"**Partner:** {query}")
112
+
113
+ with st.chat_message("assistant"):
114
+ with st.spinner("Generating response…"):
115
+ try:
116
+ payload = {
117
+ "user_id": st.session_state.user_id,
118
+ "query": query,
119
+ "affect_override": st.session_state.affect_override,
120
+ }
121
+ resp = requests.post(f"{API_BASE}/chat", json=payload, timeout=15)
122
+ data = resp.json()
123
+
124
+ response_text = data.get("response", "I don't know.")
125
+ st.markdown(f"**AAC User:** {response_text}")
126
+
127
+ st.session_state.messages.append({"role": "aac_user", "content": response_text})
128
+ st.session_state.last_affect = data.get("affect", "NEUTRAL")
129
+ st.session_state.last_latency = data.get("latency", {})
130
+
131
+ if not data.get("guardrail_passed", True):
132
+ st.warning("⚠ Guardrail triggered β€” response was sanitised.")
133
+
134
+ except requests.exceptions.Timeout:
135
+ st.error("Request timed out. Is the server running?")
136
+ except Exception as e:
137
+ st.error(f"Error: {e}")
138
+
139
+ with metrics_col:
140
+ st.subheader("Turn Latency (s)")
141
+ lat = st.session_state.last_latency
142
+ if lat:
143
+ for key, label in [
144
+ ("t_sensing", "Sensing"),
145
+ ("t_intent", "Intent"),
146
+ ("t_retrieval", "Retrieval"),
147
+ ("t_generation", "Generation"),
148
+ ("t_total", "**Total**"),
149
+ ]:
150
+ val = lat.get(key, 0.0)
151
+ st.metric(label=label, value=f"{val:.3f}s")
152
+ else:
153
+ st.caption("No turn yet.")