Spaces:
Sleeping
Sleeping
Merge pull request #3 from akashkolte/akash/v1
Browse files- .env.example +36 -0
- .gitignore +46 -0
- CLAUDE.md +106 -0
- README.md +236 -53
- api/__init__.py +0 -0
- api/main.py +167 -0
- config/__init__.py +3 -0
- config/settings.py +79 -0
- data/generate_users.py +186 -0
- data/memories/arjun_mehta.json +50 -0
- data/memories/gerald_okafor.json +49 -0
- data/memories/mia_chen.json +49 -0
- data/users.json +25 -0
- generation/__init__.py +0 -0
- generation/llm_client.py +147 -0
- guardrails/__init__.py +0 -0
- guardrails/checks.py +98 -0
- main.py +204 -0
- pipeline/__init__.py +0 -0
- pipeline/graph.py +71 -0
- pipeline/nodes/__init__.py +0 -0
- pipeline/nodes/feedback.py +98 -0
- pipeline/nodes/intent.py +170 -0
- pipeline/nodes/planner.py +196 -0
- pipeline/nodes/retrieval.py +90 -0
- pipeline/state.py +98 -0
- requirements.txt +39 -0
- retrieval/__init__.py +0 -0
- retrieval/bucket_priors.py +52 -0
- retrieval/clustering.py +111 -0
- retrieval/vector_store.py +168 -0
- sensing/__init__.py +0 -0
- sensing/air_writing.py +176 -0
- sensing/face_mesh.py +166 -0
- sensing/gaze.py +113 -0
- sensing/gesture.py +124 -0
- ui/app.py +153 -0
.env.example
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copy this file to .env and fill in your values.
|
| 2 |
+
# Settings here override the defaults in config/settings.py.
|
| 3 |
+
|
| 4 |
+
# ββ Active LLM tier ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 5 |
+
# "local" β Ollama on MacBook M2 (dev, no GPU needed)
|
| 6 |
+
# "primary" β Qwen3-30B-A3B on GCP A100/T4 via vLLM
|
| 7 |
+
# "fallback" β Qwen3-8B on same vLLM server
|
| 8 |
+
ACTIVE_LLM_TIER=local
|
| 9 |
+
|
| 10 |
+
# ββ Primary vLLM server (GCP) βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 11 |
+
PRIMARY_BASE_URL=http://<GCP_IP>:8000/v1
|
| 12 |
+
PRIMARY_API_KEY=token-abc
|
| 13 |
+
PRIMARY_MODEL=Qwen/Qwen3-30B-A3B
|
| 14 |
+
|
| 15 |
+
# ββ Fallback model (same vLLM server) βββββββββββββββββββββββββββββββββββββββββ
|
| 16 |
+
FALLBACK_MODEL=Qwen/Qwen3-8B
|
| 17 |
+
FALLBACK_BASE_URL=http://<GCP_IP>:8000/v1
|
| 18 |
+
|
| 19 |
+
# ββ Local Ollama (dev) ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 20 |
+
LOCAL_BASE_URL=http://localhost:11434/v1
|
| 21 |
+
LOCAL_MODEL=gemma4:31b-cloud # qwen3:8b qwen3.5:397b-cloud
|
| 22 |
+
|
| 23 |
+
# ββ MLflow ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 24 |
+
MLFLOW_TRACKING_URI=mlruns
|
| 25 |
+
MLFLOW_EXPERIMENT=aac-chatbot
|
| 26 |
+
|
| 27 |
+
# ββ Thinking mode βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 28 |
+
# "off" β suppress thinking (fastest, best for latency-sensitive AAC)
|
| 29 |
+
# "strip" β let model think, but strip <think> tags from output
|
| 30 |
+
# "full" β return raw response including <think> blocks
|
| 31 |
+
THINKING_MODE=off
|
| 32 |
+
# Extra tokens added when thinking is enabled (strip/full). Ignored when off.
|
| 33 |
+
THINKING_TOKEN_BUDGET=4096
|
| 34 |
+
|
| 35 |
+
# ββ Latency fallback threshold (seconds) ββββββββββββββββββββββββββββββββββββββ
|
| 36 |
+
FALLBACK_LATENCY_THRESHOLD=3.5
|
.gitignore
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*.pyo
|
| 5 |
+
*.pyd
|
| 6 |
+
.Python
|
| 7 |
+
*.egg-info/
|
| 8 |
+
dist/
|
| 9 |
+
build/
|
| 10 |
+
|
| 11 |
+
# Virtual environment
|
| 12 |
+
.venv/
|
| 13 |
+
venv/
|
| 14 |
+
env/
|
| 15 |
+
|
| 16 |
+
# Environment secrets
|
| 17 |
+
.env
|
| 18 |
+
|
| 19 |
+
# Data β indexes are rebuilt from source; do NOT commit binaries
|
| 20 |
+
data/faiss_store/
|
| 21 |
+
|
| 22 |
+
# Air-writing templates (large numpy files, track separately if needed)
|
| 23 |
+
data/air_write_templates/
|
| 24 |
+
|
| 25 |
+
# MLflow run artifacts
|
| 26 |
+
mlruns/
|
| 27 |
+
|
| 28 |
+
# Latency logs
|
| 29 |
+
timings.csv
|
| 30 |
+
*.csv
|
| 31 |
+
|
| 32 |
+
# IDE
|
| 33 |
+
.vscode/
|
| 34 |
+
.idea/
|
| 35 |
+
*.swp
|
| 36 |
+
|
| 37 |
+
# Claude Code β local settings and generated knowledge graph
|
| 38 |
+
.claude/
|
| 39 |
+
.code-review-graph/
|
| 40 |
+
|
| 41 |
+
# macOS
|
| 42 |
+
.DS_Store
|
| 43 |
+
|
| 44 |
+
# Jupyter
|
| 45 |
+
.ipynb_checkpoints/
|
| 46 |
+
*.ipynb
|
CLAUDE.md
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Multimodal AAC Chatbot β Project Guide
|
| 2 |
+
|
| 3 |
+
## What This Project Does
|
| 4 |
+
|
| 5 |
+
An AI chatbot that **speaks as an AAC user**, not to them. Given a user persona
|
| 6 |
+
(Mia, Gerald, or Arjun), it fuses real-time multimodal non-verbal signals with
|
| 7 |
+
personal memory retrieval to generate responses in that person's authentic voice.
|
| 8 |
+
Orchestrated as a **LangGraph stateful directed graph** across five layers.
|
| 9 |
+
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
## Architecture
|
| 13 |
+
|
| 14 |
+
```
|
| 15 |
+
main.py / api/main.py / ui/app.py
|
| 16 |
+
βββ pipeline/graph.py β LangGraph StateGraph (5 nodes + cond. edges)
|
| 17 |
+
βββ pipeline/nodes/intent.py L2 β LLM + Pydantic intent routing
|
| 18 |
+
βββ pipeline/nodes/retrieval.py L3 β FAISS + BGE retrieval (fast / full)
|
| 19 |
+
βββ pipeline/nodes/planner.py L4 β expression-conditioned generation
|
| 20 |
+
βββ pipeline/nodes/feedback.py L5 β MLflow logging + Bayesian priors
|
| 21 |
+
|
| 22 |
+
sensing/ L1 β MediaPipe face mesh, gesture, gaze, air writing
|
| 23 |
+
retrieval/ FAISS ops, HDBSCAN clustering, Bayesian bucket priors
|
| 24 |
+
generation/ Multi-tier LLM client (vLLM primary / fallback / Ollama local)
|
| 25 |
+
guardrails/ Input + output safety checks
|
| 26 |
+
config/ Pydantic BaseSettings β all config in one place
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
## Key Design Decisions
|
| 30 |
+
|
| 31 |
+
- **LangGraph** orchestrates the pipeline as a stateful directed graph with
|
| 32 |
+
conditional edges (affect β fast/full retrieval; latency β primary/fallback LLM)
|
| 33 |
+
- **BGE-small-en-v1.5** for embeddings (beats MiniLM on MTEB at same speed)
|
| 34 |
+
- **BGE-reranker-v2-m3** cross-encoder β multilingual, handles Arjun's Hindi
|
| 35 |
+
- **FAISS IndexFlatIP** with L2-normalised vectors (inner product = cosine sim)
|
| 36 |
+
- **Qwen3-30B-A3B** MoE via vLLM β 3B active params/token, sub-3s on T4
|
| 37 |
+
- **Three-tier LLM fallback**: primary (vLLM GCP) β fallback (Qwen3-8B) β local (Ollama)
|
| 38 |
+
- **Pydantic-validated** LLM routing output β LangGraph retries on schema failures
|
| 39 |
+
- **Expression-conditioned response shaping** β affect steers tone, retrieval depth,
|
| 40 |
+
and candidate ranking (not just metadata annotation)
|
| 41 |
+
- **Bayesian bucket priors** β session-level P(bucket) updated after each accepted turn
|
| 42 |
+
|
| 43 |
+
---
|
| 44 |
+
|
| 45 |
+
## Personas
|
| 46 |
+
|
| 47 |
+
| ID | Name | Condition | Access |
|
| 48 |
+
|----|------|-----------|--------|
|
| 49 |
+
| `mia_chen` | Mia Chen, 28 | Cerebral palsy | Webcam head-tracking |
|
| 50 |
+
| `gerald_okafor` | Gerald Okafor, 61 | ALS (early-mid) | Eye-gaze device |
|
| 51 |
+
| `arjun_mehta` | Arjun Mehta, 17 | Autism (non-verbal) | Tablet touch grid |
|
| 52 |
+
|
| 53 |
+
25 memory chunks each (5 buckets Γ 5 memories). Arjun code-switches Hindi/English.
|
| 54 |
+
|
| 55 |
+
---
|
| 56 |
+
|
| 57 |
+
## How to Run
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
# One-time setup: rebuild FAISS indexes with BGE embedder
|
| 61 |
+
python -m retrieval.vector_store
|
| 62 |
+
|
| 63 |
+
# CLI (local Ollama tier, set ACTIVE_LLM_TIER=local in .env)
|
| 64 |
+
python main.py --debug
|
| 65 |
+
|
| 66 |
+
# Full stack
|
| 67 |
+
uvicorn api.main:app --reload # FastAPI on :8000
|
| 68 |
+
streamlit run ui/app.py # Streamlit on :8501
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
---
|
| 72 |
+
|
| 73 |
+
## Configuration
|
| 74 |
+
|
| 75 |
+
All config lives in [config/settings.py](config/settings.py) as Pydantic `BaseSettings`.
|
| 76 |
+
Copy `.env.example` β `.env` and set:
|
| 77 |
+
|
| 78 |
+
- `ACTIVE_LLM_TIER` β `local` (dev) | `primary` (GCP A100) | `fallback` (Qwen3-8B)
|
| 79 |
+
- `PRIMARY_BASE_URL` β vLLM server address on GCP
|
| 80 |
+
- `MLFLOW_TRACKING_URI` β where MLflow stores runs (default: `mlruns/`)
|
| 81 |
+
|
| 82 |
+
---
|
| 83 |
+
|
| 84 |
+
## Data Files
|
| 85 |
+
|
| 86 |
+
| Path | Purpose |
|
| 87 |
+
|------|---------|
|
| 88 |
+
| `data/users.json` | Flat user index (id, name, condition, style) |
|
| 89 |
+
| `data/memories/<uid>.json` | Full persona JSON with bucketed memories |
|
| 90 |
+
| `data/faiss_store/<uid>/` | FAISS index + metadata β **rebuild after any persona edit** |
|
| 91 |
+
| `data/generate_users.py` | Regenerates memories + users.json |
|
| 92 |
+
|
| 93 |
+
---
|
| 94 |
+
|
| 95 |
+
## Development Notes
|
| 96 |
+
|
| 97 |
+
- **Adding a persona**: add to `PERSONAS` in `data/generate_users.py`, re-run it,
|
| 98 |
+
then `python -m retrieval.vector_store` to rebuild indexes
|
| 99 |
+
- **Changing LLM**: set `ACTIVE_LLM_TIER` in `.env` β no code changes needed
|
| 100 |
+
- **Extending sensing**: add module under `sensing/`, wire output into
|
| 101 |
+
`PipelineState` fields in `pipeline/state.py`
|
| 102 |
+
- **Guardrail tuning**: edit signal lists in `guardrails/checks.py`
|
| 103 |
+
- **Affect β generation mapping**: `_AFFECT_CONFIG` in `pipeline/nodes/intent.py`
|
| 104 |
+
and `_PERSONA_TONE_OVERRIDES` in `pipeline/nodes/planner.py`
|
| 105 |
+
- The `.venv/` directory is local β do not read or modify files inside it
|
| 106 |
+
- FAISS indexes in `data/faiss_store/` are gitignored β rebuilt from source JSONs
|
README.md
CHANGED
|
@@ -1,74 +1,206 @@
|
|
| 1 |
# Multimodal AAC Chatbot
|
| 2 |
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
---
|
| 6 |
|
| 7 |
-
##
|
| 8 |
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
- Cerebral Palsy
|
| 13 |
-
- ALS / Motor Neurone Disease
|
| 14 |
-
- Aphasia
|
| 15 |
-
- Down Syndrome
|
| 16 |
-
- Or any other condition that impacts verbal communication
|
| 17 |
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
---
|
| 21 |
|
| 22 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
|
| 25 |
|
| 26 |
-
##
|
| 27 |
|
| 28 |
-
-
|
| 29 |
-
-
|
| 30 |
-
-
|
| 31 |
-
-
|
| 32 |
-
- π¬ **Conversational Context** β Maintains conversation history for more coherent, multi-turn dialogues
|
| 33 |
|
| 34 |
---
|
| 35 |
|
| 36 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
###
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
|
| 44 |
|
| 45 |
-
|
| 46 |
-
```bash
|
| 47 |
-
git clone https://github.com/akashkolte/multimodal_aac_chatbot.git
|
| 48 |
-
cd multimodal_aac_chatbot
|
| 49 |
-
```
|
| 50 |
|
| 51 |
-
|
| 52 |
-
```bash
|
| 53 |
-
pip install -r requirements.txt
|
| 54 |
-
```
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
---
|
| 62 |
|
| 63 |
-
##
|
|
|
|
|
|
|
| 64 |
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
| 70 |
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
---
|
| 74 |
|
|
@@ -76,28 +208,79 @@ The chatbot will interpret the input and respond in a clear, friendly manner.
|
|
| 76 |
|
| 77 |
```
|
| 78 |
multimodal_aac_chatbot/
|
| 79 |
-
|
| 80 |
-
βββ
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
```
|
| 84 |
|
| 85 |
---
|
| 86 |
|
| 87 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
-
|
| 90 |
|
| 91 |
-
|
| 92 |
|
| 93 |
---
|
| 94 |
|
| 95 |
-
##
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
-
|
| 98 |
|
| 99 |
---
|
| 100 |
|
| 101 |
-
##
|
| 102 |
|
| 103 |
-
|
|
|
|
| 1 |
# Multimodal AAC Chatbot
|
| 2 |
|
| 3 |
+
An AI chatbot that **speaks as an AAC user**, not to them. Given a persona (Mia, Gerald, or Arjun),
|
| 4 |
+
it fuses real-time multimodal non-verbal signals β facial expressions, hand gestures, gaze, and
|
| 5 |
+
air writing β with personal memory retrieval to generate responses in that person's authentic voice.
|
| 6 |
+
|
| 7 |
+
Built as a training-free, agentic RAG pipeline orchestrated via **LangGraph**.
|
| 8 |
|
| 9 |
---
|
| 10 |
|
| 11 |
+
## Table of Contents
|
| 12 |
|
| 13 |
+
- [What is AAC?](#what-is-aac)
|
| 14 |
+
- [System Architecture](#system-architecture)
|
| 15 |
+
- [Prerequisites](#prerequisites)
|
| 16 |
+
- [Setup](#setup)
|
| 17 |
+
- [Configuration](#configuration)
|
| 18 |
+
- [Running the Project](#running-the-project)
|
| 19 |
+
- [Project Structure](#project-structure)
|
| 20 |
+
- [Personas](#personas)
|
| 21 |
+
- [Team](#team)
|
| 22 |
+
|
| 23 |
+
---
|
| 24 |
|
| 25 |
+
## What is AAC?
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
+
**Augmentative and Alternative Communication (AAC)** refers to tools and technologies that help
|
| 28 |
+
people who have difficulty with spoken or written communication β including individuals with
|
| 29 |
+
Cerebral Palsy, ALS, Autism Spectrum Disorder, and other conditions. This project gives AAC users
|
| 30 |
+
a personalized digital twin that communicates on their behalf.
|
| 31 |
|
| 32 |
---
|
| 33 |
|
| 34 |
+
## System Architecture
|
| 35 |
+
|
| 36 |
+
```
|
| 37 |
+
Webcam (L1: sensing) β Intent Decomposition (L2) β Retrieval (L3) β Generation (L4) β Feedback (L5)
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
| Layer | Module | What it does |
|
| 41 |
+
|-------|--------|-------------|
|
| 42 |
+
| L1 | `sensing/` | MediaPipe face mesh, hand gestures, gaze tracking, air writing |
|
| 43 |
+
| L2 | `pipeline/nodes/intent.py` | LLM + Pydantic-validated intent routing |
|
| 44 |
+
| L3 | `pipeline/nodes/retrieval.py` | FAISS + BGE embeddings + cross-encoder reranking |
|
| 45 |
+
| L4 | `pipeline/nodes/planner.py` | Expression-conditioned response generation (Qwen3) |
|
| 46 |
+
| L5 | `pipeline/nodes/feedback.py` | MLflow tracking + Bayesian bucket prior update |
|
| 47 |
+
|
| 48 |
+
The pipeline runs as a **LangGraph stateful directed graph** with conditional edges:
|
| 49 |
+
- FRUSTRATED affect β fast retrieval path (k=2, no reranker)
|
| 50 |
+
- Latency > 3.5s β fallback to smaller Qwen3-8B model
|
| 51 |
|
| 52 |
+
---
|
| 53 |
|
| 54 |
+
## Prerequisites
|
| 55 |
|
| 56 |
+
- Python **3.10 β 3.12** (Python 3.14 has a known Pydantic v1 incompatibility warning β functional but noisy)
|
| 57 |
+
- [Ollama](https://ollama.com) installed locally for the `local` LLM tier
|
| 58 |
+
- A webcam (required for the live sensing layer; optional for CLI mode)
|
| 59 |
+
- Git
|
|
|
|
| 60 |
|
| 61 |
---
|
| 62 |
|
| 63 |
+
## Setup
|
| 64 |
+
|
| 65 |
+
### 1. Clone the repository
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
git clone https://github.com/akashkolte/multimodal_aac_chatbot.git
|
| 69 |
+
cd multimodal_aac_chatbot
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
### 2. Check out the active branch
|
| 73 |
+
|
| 74 |
+
```bash
|
| 75 |
+
git checkout akash/v1
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
### 3. Create and activate a virtual environment
|
| 79 |
+
|
| 80 |
+
```bash
|
| 81 |
+
python3 -m venv .venv
|
| 82 |
+
source .venv/bin/activate # macOS / Linux
|
| 83 |
+
# .venv\Scripts\activate # Windows
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
### 4. Install dependencies
|
| 87 |
+
|
| 88 |
+
```bash
|
| 89 |
+
pip install -r requirements.txt
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
> This installs LangGraph, FAISS, sentence-transformers (BGE), FastAPI, Streamlit, MLflow,
|
| 93 |
+
> MediaPipe, and all other dependencies.
|
| 94 |
+
|
| 95 |
+
### 5. Configure environment variables
|
| 96 |
+
|
| 97 |
+
```bash
|
| 98 |
+
cp .env.example .env
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
Open `.env` and set at minimum:
|
| 102 |
+
|
| 103 |
+
```env
|
| 104 |
+
ACTIVE_LLM_TIER=local # use Ollama on your machine for dev
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
See [Configuration](#configuration) for all options.
|
| 108 |
|
| 109 |
+
### 6. Pull the local LLM model (Ollama)
|
| 110 |
|
| 111 |
+
```bash
|
| 112 |
+
ollama pull qwen3:8b
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
> Make sure Ollama is running (`ollama serve`) before starting the chatbot.
|
| 116 |
+
|
| 117 |
+
### 7. Build FAISS indexes
|
| 118 |
+
|
| 119 |
+
The persona memory indexes must be built once with the BGE embedder before first run:
|
| 120 |
+
|
| 121 |
+
```bash
|
| 122 |
+
python -m retrieval.vector_store
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
Expected output:
|
| 126 |
+
```
|
| 127 |
+
Building index for arjun_mehta β¦ Saved 25 chunks
|
| 128 |
+
Building index for gerald_okafor β¦ Saved 25 chunks
|
| 129 |
+
Building index for mia_chen β¦ Saved 25 chunks
|
| 130 |
+
All indexes built.
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
> You must re-run this step whenever you add or edit persona memory files.
|
| 134 |
|
| 135 |
+
---
|
| 136 |
|
| 137 |
+
## Configuration
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
+
All settings live in [config/settings.py](config/settings.py) and can be overridden via `.env`.
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
+
| Variable | Default | Description |
|
| 142 |
+
|----------|---------|-------------|
|
| 143 |
+
| `ACTIVE_LLM_TIER` | `local` | `local` (Ollama) \| `primary` (vLLM GCP) \| `fallback` (Qwen3-8B) |
|
| 144 |
+
| `LOCAL_MODEL` | `qwen3:8b` | Ollama model name for local dev |
|
| 145 |
+
| `LOCAL_BASE_URL` | `http://localhost:11434/v1` | Ollama OpenAI-compatible endpoint |
|
| 146 |
+
| `PRIMARY_BASE_URL` | *(GCP IP)* | vLLM server URL on GCP (set when using cloud tier) |
|
| 147 |
+
| `PRIMARY_MODEL` | `Qwen/Qwen3-30B-A3B` | Primary MoE model served via vLLM |
|
| 148 |
+
| `FALLBACK_LATENCY_THRESHOLD` | `3.5` | Seconds before falling back to smaller model |
|
| 149 |
+
| `MLFLOW_TRACKING_URI` | `mlruns` | Local MLflow storage path |
|
| 150 |
+
| `MLFLOW_EXPERIMENT` | `aac-chatbot` | MLflow experiment name |
|
| 151 |
|
| 152 |
---
|
| 153 |
|
| 154 |
+
## Running the Project
|
| 155 |
+
|
| 156 |
+
### Option A β CLI (simplest, no webcam needed)
|
| 157 |
|
| 158 |
+
```bash
|
| 159 |
+
python main.py
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
With debug latency output:
|
| 163 |
+
```bash
|
| 164 |
+
python main.py --debug
|
| 165 |
+
```
|
| 166 |
|
| 167 |
+
Select a specific persona and LLM tier:
|
| 168 |
+
```bash
|
| 169 |
+
python main.py --user mia_chen --tier local
|
| 170 |
+
```
|
| 171 |
|
| 172 |
+
### Option B β Full stack (FastAPI + Streamlit UI)
|
| 173 |
+
|
| 174 |
+
Start the API server in one terminal:
|
| 175 |
+
```bash
|
| 176 |
+
uvicorn api.main:app --reload --port 8000
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
Start the Streamlit frontend in another terminal:
|
| 180 |
+
```bash
|
| 181 |
+
streamlit run ui/app.py
|
| 182 |
+
```
|
| 183 |
+
|
| 184 |
+
Then open [http://localhost:8501](http://localhost:8501) in your browser.
|
| 185 |
+
|
| 186 |
+
The UI includes:
|
| 187 |
+
- Persona selector
|
| 188 |
+
- Affect override controls (simulate webcam for testing)
|
| 189 |
+
- Live chat interface
|
| 190 |
+
- Per-turn latency breakdown panel
|
| 191 |
+
|
| 192 |
+
### Option C β API only (for integration / testing)
|
| 193 |
+
|
| 194 |
+
```bash
|
| 195 |
+
uvicorn api.main:app --reload
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
Example request:
|
| 199 |
+
```bash
|
| 200 |
+
curl -X POST http://localhost:8000/chat \
|
| 201 |
+
-H "Content-Type: application/json" \
|
| 202 |
+
-d '{"user_id": "mia_chen", "query": "What do you like to do on weekends?"}'
|
| 203 |
+
```
|
| 204 |
|
| 205 |
---
|
| 206 |
|
|
|
|
| 208 |
|
| 209 |
```
|
| 210 |
multimodal_aac_chatbot/
|
| 211 |
+
β
|
| 212 |
+
βββ config/
|
| 213 |
+
β βββ settings.py # All config via Pydantic BaseSettings
|
| 214 |
+
β
|
| 215 |
+
βββ data/
|
| 216 |
+
β βββ generate_users.py # Regenerates persona memories + users.json
|
| 217 |
+
β βββ users.json # Flat user index
|
| 218 |
+
β βββ memories/ # Per-persona memory JSON files
|
| 219 |
+
β βββ faiss_store/ # Built FAISS indexes (gitignored, rebuild locally)
|
| 220 |
+
β
|
| 221 |
+
βββ sensing/ # L1 β multimodal input
|
| 222 |
+
β βββ face_mesh.py # MediaPipe affect detection (MAR/EAR/BRI/LCP)
|
| 223 |
+
β βββ gesture.py # Hand gesture classifier
|
| 224 |
+
β βββ gaze.py # Gaze-based bucket activation (bonus)
|
| 225 |
+
β βββ air_writing.py # DTW air-writing stroke classifier (bonus)
|
| 226 |
+
β
|
| 227 |
+
βββ pipeline/ # LangGraph orchestration
|
| 228 |
+
β βββ state.py # Typed PipelineState (TypedDict)
|
| 229 |
+
β βββ graph.py # Graph definition + conditional edges
|
| 230 |
+
β βββ nodes/
|
| 231 |
+
β βββ intent.py # L2 β LLM + Pydantic routing
|
| 232 |
+
β βββ retrieval.py # L3 β fast + full retrieval paths
|
| 233 |
+
β βββ planner.py # L4 β expression-conditioned generation
|
| 234 |
+
β βββ feedback.py # L5 β MLflow + Bayesian prior update
|
| 235 |
+
β
|
| 236 |
+
βββ retrieval/
|
| 237 |
+
β βββ vector_store.py # FAISS ops with BGE-small-en-v1.5
|
| 238 |
+
β βββ clustering.py # HDBSCAN semantic bucketing
|
| 239 |
+
β βββ bucket_priors.py # Bayesian session priors
|
| 240 |
+
β
|
| 241 |
+
βββ generation/
|
| 242 |
+
β βββ llm_client.py # 3-tier LLM client (vLLM / Ollama)
|
| 243 |
+
β
|
| 244 |
+
βββ guardrails/
|
| 245 |
+
β βββ checks.py # Input + output safety checks
|
| 246 |
+
β
|
| 247 |
+
βββ api/
|
| 248 |
+
β βββ main.py # FastAPI backend
|
| 249 |
+
β
|
| 250 |
+
βββ ui/
|
| 251 |
+
β βββ app.py # Streamlit frontend
|
| 252 |
+
β
|
| 253 |
+
βββ main.py # CLI entry point
|
| 254 |
+
βββ requirements.txt # Python dependencies
|
| 255 |
+
βββ .env.example # Environment variable template
|
| 256 |
+
βββ CLAUDE.md # Developer notes (AI assistant context)
|
| 257 |
```
|
| 258 |
|
| 259 |
---
|
| 260 |
|
| 261 |
+
## Personas
|
| 262 |
+
|
| 263 |
+
| ID | Name | Condition | Style | Access |
|
| 264 |
+
|----|------|-----------|-------|--------|
|
| 265 |
+
| `mia_chen` | Mia Chen, 28 | Cerebral palsy | Witty, dry humour, short punchy sentences | Webcam head-tracking |
|
| 266 |
+
| `gerald_okafor` | Gerald Okafor, 61 | ALS (early-mid stage) | Formal, measured, eloquent | Eye-gaze device |
|
| 267 |
+
| `arjun_mehta` | Arjun Mehta, 17 | Autism (non-verbal) | Direct, routine-focused, Hindi-English code-switching | Tablet touch grid |
|
| 268 |
|
| 269 |
+
Each persona has 25 memory chunks across 5 buckets: `family`, `medical`, `hobbies`, `daily_routine`, `social`.
|
| 270 |
|
| 271 |
+
To add a new persona, edit `data/generate_users.py` and re-run `python -m retrieval.vector_store`.
|
| 272 |
|
| 273 |
---
|
| 274 |
|
| 275 |
+
## Team
|
| 276 |
+
|
| 277 |
+
- **Akash Kolte** β akashjag@buffalo.edu
|
| 278 |
+
- **Shwetangi** β shwetang@buffalo.edu
|
| 279 |
|
| 280 |
+
University at Buffalo, SUNY
|
| 281 |
|
| 282 |
---
|
| 283 |
|
| 284 |
+
## License
|
| 285 |
|
| 286 |
+
All rights reserved. See the [LICENSE](LICENSE) file for details.
|
api/__init__.py
ADDED
|
File without changes
|
api/main.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI backend β exposes the LangGraph pipeline as a REST API.
|
| 3 |
+
|
| 4 |
+
Endpoints:
|
| 5 |
+
POST /chat β single-turn inference (non-streaming)
|
| 6 |
+
POST /chat/stream β streaming token delivery via SSE
|
| 7 |
+
GET /users β list available personas
|
| 8 |
+
POST /session/reset β reset session state for a user
|
| 9 |
+
GET /health β liveness check
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import json
|
| 14 |
+
import time
|
| 15 |
+
from typing import AsyncGenerator
|
| 16 |
+
|
| 17 |
+
from fastapi import FastAPI, HTTPException
|
| 18 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 19 |
+
from fastapi.responses import StreamingResponse
|
| 20 |
+
from pydantic import BaseModel
|
| 21 |
+
|
| 22 |
+
from config.settings import settings
|
| 23 |
+
from guardrails.checks import check_input
|
| 24 |
+
from pipeline.graph import aac_graph
|
| 25 |
+
from pipeline.state import PipelineState
|
| 26 |
+
from retrieval.bucket_priors import uniform_priors
|
| 27 |
+
|
| 28 |
+
app = FastAPI(
|
| 29 |
+
title="Multimodal AAC Chatbot API",
|
| 30 |
+
description="Agentic RAG pipeline for AAC persona communication",
|
| 31 |
+
version="2.0.0",
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
app.add_middleware(
|
| 35 |
+
CORSMiddleware,
|
| 36 |
+
allow_origins=["*"],
|
| 37 |
+
allow_methods=["*"],
|
| 38 |
+
allow_headers=["*"],
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# ββ In-memory session store (replace with Redis for multi-worker deployments) ββ
|
| 42 |
+
_sessions: dict[str, dict] = {}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# ββ Request / response schemas βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 46 |
+
|
| 47 |
+
class ChatRequest(BaseModel):
|
| 48 |
+
user_id: str
|
| 49 |
+
query: str
|
| 50 |
+
affect_override: str | None = None # "HAPPY"|"FRUSTRATED"|"NEUTRAL"|"SURPRISED"
|
| 51 |
+
gesture_tag: str | None = None
|
| 52 |
+
gaze_bucket: str | None = None
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class ChatResponse(BaseModel):
|
| 56 |
+
user_id: str
|
| 57 |
+
query: str
|
| 58 |
+
response: str
|
| 59 |
+
affect: str
|
| 60 |
+
llm_tier: str
|
| 61 |
+
retrieval_mode: str
|
| 62 |
+
latency: dict
|
| 63 |
+
guardrail_passed: bool
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 67 |
+
|
| 68 |
+
def _get_or_init_session(user_id: str) -> dict:
|
| 69 |
+
if user_id not in _sessions:
|
| 70 |
+
with open(settings.users_json) as f:
|
| 71 |
+
users = {u["id"]: u for u in json.load(f)["users"]}
|
| 72 |
+
if user_id not in users:
|
| 73 |
+
raise HTTPException(status_code=404, detail=f"User '{user_id}' not found")
|
| 74 |
+
_sessions[user_id] = {
|
| 75 |
+
"persona_profile": users[user_id],
|
| 76 |
+
"session_history": [],
|
| 77 |
+
"bucket_priors": uniform_priors(),
|
| 78 |
+
"turn_id": 0,
|
| 79 |
+
}
|
| 80 |
+
return _sessions[user_id]
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _build_initial_state(req: ChatRequest, session: dict) -> PipelineState:
|
| 84 |
+
affect_state = None
|
| 85 |
+
if req.affect_override:
|
| 86 |
+
affect_state = {"emotion": req.affect_override, "vector": {}, "smoothed": {}}
|
| 87 |
+
|
| 88 |
+
session["turn_id"] += 1
|
| 89 |
+
|
| 90 |
+
return PipelineState(
|
| 91 |
+
user_id=req.user_id,
|
| 92 |
+
persona_profile=session["persona_profile"],
|
| 93 |
+
session_history=session["session_history"],
|
| 94 |
+
turn_id=session["turn_id"],
|
| 95 |
+
affect=affect_state,
|
| 96 |
+
gesture_tag=req.gesture_tag,
|
| 97 |
+
gaze_bucket=req.gaze_bucket,
|
| 98 |
+
air_written_text=None,
|
| 99 |
+
raw_query=req.query,
|
| 100 |
+
intent_route=None,
|
| 101 |
+
generation_config=None,
|
| 102 |
+
retrieved_chunks=[],
|
| 103 |
+
bucket_priors=session["bucket_priors"],
|
| 104 |
+
retrieval_mode_used="",
|
| 105 |
+
augmented_prompt=None,
|
| 106 |
+
candidates=[],
|
| 107 |
+
selected_response=None,
|
| 108 |
+
llm_tier_used="",
|
| 109 |
+
latency_log={"t_sensing": 0.0, "t_intent": 0.0, "t_retrieval": 0.0, "t_generation": 0.0, "t_total": 0.0},
|
| 110 |
+
mlflow_run_id=None,
|
| 111 |
+
guardrail_passed=True,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# ββ Routes βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 116 |
+
|
| 117 |
+
@app.get("/health")
|
| 118 |
+
def health():
|
| 119 |
+
return {"status": "ok"}
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@app.get("/users")
|
| 123 |
+
def list_users():
|
| 124 |
+
with open(settings.users_json) as f:
|
| 125 |
+
return json.load(f)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@app.post("/session/reset")
|
| 129 |
+
def reset_session(user_id: str):
|
| 130 |
+
_sessions.pop(user_id, None)
|
| 131 |
+
return {"status": "reset", "user_id": user_id}
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@app.post("/chat", response_model=ChatResponse)
|
| 135 |
+
def chat(req: ChatRequest):
|
| 136 |
+
guard = check_input(req.query)
|
| 137 |
+
if not guard["allowed"]:
|
| 138 |
+
return ChatResponse(
|
| 139 |
+
user_id=req.user_id,
|
| 140 |
+
query=req.query,
|
| 141 |
+
response=guard["fallback"],
|
| 142 |
+
affect="NEUTRAL",
|
| 143 |
+
llm_tier="none",
|
| 144 |
+
retrieval_mode="none",
|
| 145 |
+
latency={},
|
| 146 |
+
guardrail_passed=False,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
session = _get_or_init_session(req.user_id)
|
| 150 |
+
initial_state = _build_initial_state(req, session)
|
| 151 |
+
|
| 152 |
+
result: PipelineState = aac_graph.invoke(initial_state)
|
| 153 |
+
|
| 154 |
+
# Persist updated session state
|
| 155 |
+
session["session_history"] = result["session_history"]
|
| 156 |
+
session["bucket_priors"] = result["bucket_priors"]
|
| 157 |
+
|
| 158 |
+
return ChatResponse(
|
| 159 |
+
user_id=req.user_id,
|
| 160 |
+
query=req.query,
|
| 161 |
+
response=result["selected_response"] or "",
|
| 162 |
+
affect=(result.get("affect") or {}).get("emotion", "NEUTRAL"),
|
| 163 |
+
llm_tier=result.get("llm_tier_used", "unknown"),
|
| 164 |
+
retrieval_mode=result.get("retrieval_mode_used", "unknown"),
|
| 165 |
+
latency=result.get("latency_log") or {},
|
| 166 |
+
guardrail_passed=result.get("guardrail_passed", True),
|
| 167 |
+
)
|
config/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from config.settings import settings
|
| 2 |
+
|
| 3 |
+
__all__ = ["settings"]
|
config/settings.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Settings(BaseSettings):
|
| 6 |
+
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
| 7 |
+
|
| 8 |
+
# ββ Paths ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 9 |
+
data_dir: Path = Path("data")
|
| 10 |
+
faiss_store_dir: Path = Path("data/faiss_store")
|
| 11 |
+
memories_dir: Path = Path("data/memories")
|
| 12 |
+
users_json: Path = Path("data/users.json")
|
| 13 |
+
|
| 14 |
+
# ββ Retrieval models βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 15 |
+
embed_model: str = "BAAI/bge-small-en-v1.5"
|
| 16 |
+
rerank_model: str = "BAAI/bge-reranker-v2-m3"
|
| 17 |
+
retrieval_top_k: int = 5
|
| 18 |
+
retrieval_rerank_k: int = 3
|
| 19 |
+
retrieval_fast_k: int = 2 # used when affect == FRUSTRATED
|
| 20 |
+
|
| 21 |
+
# ββ LLM tiers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 22 |
+
# Tier 1 β primary (Qwen3-30B-A3B via vLLM on GCP)
|
| 23 |
+
primary_model: str = "Qwen/Qwen3-30B-A3B"
|
| 24 |
+
primary_base_url: str = "http://localhost:8000/v1"
|
| 25 |
+
primary_api_key: str = "token-abc" # vLLM default
|
| 26 |
+
|
| 27 |
+
# Tier 2 β fallback dense model (Qwen3-8B via vLLM, same server)
|
| 28 |
+
fallback_model: str = "Qwen/Qwen3-8B"
|
| 29 |
+
fallback_base_url: str = "http://localhost:8000/v1"
|
| 30 |
+
|
| 31 |
+
# Tier 3 β local dev (Ollama on MacBook M2)
|
| 32 |
+
local_model: str = "qwen3:8b"
|
| 33 |
+
local_base_url: str = "http://localhost:11434/v1"
|
| 34 |
+
local_api_key: str = "ollama"
|
| 35 |
+
|
| 36 |
+
# Active tier: "primary" | "fallback" | "local"
|
| 37 |
+
active_llm_tier: str = "local"
|
| 38 |
+
|
| 39 |
+
# Thinking mode: "off" = plain completion, no thinking whatsoever
|
| 40 |
+
# "strip" = let model think, but strip <think> tags from output
|
| 41 |
+
# "full" = return raw response including <think> blocks
|
| 42 |
+
# "suppress" = actively suppress thinking via /no_think (Ollama) or
|
| 43 |
+
# chat_template_kwargs (vLLM). Use for models like Qwen3
|
| 44 |
+
# that think by default and need explicit suppression.
|
| 45 |
+
thinking_mode: str = "off"
|
| 46 |
+
|
| 47 |
+
# Extra token budget added on top of max_tokens when thinking is enabled
|
| 48 |
+
# (thinking_mode = "strip" or "full"). Set to 0 if using a non-thinking model.
|
| 49 |
+
thinking_token_budget: int = 4096
|
| 50 |
+
|
| 51 |
+
# Wall-clock threshold (seconds) that triggers fallback within a turn
|
| 52 |
+
fallback_latency_threshold: float = 3.5
|
| 53 |
+
|
| 54 |
+
# ββ Generation ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 55 |
+
max_tokens_happy: int = 150
|
| 56 |
+
max_tokens_neutral: int = 100
|
| 57 |
+
max_tokens_frustrated: int = 60
|
| 58 |
+
max_tokens_surprised: int = 80
|
| 59 |
+
num_candidates: int = 2 # responses generated per turn for ranking
|
| 60 |
+
|
| 61 |
+
# ββ Sensing βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 62 |
+
affect_ema_alpha: float = 0.3 # exponential moving average smoothing
|
| 63 |
+
gaze_dwell_threshold_s: float = 1.5
|
| 64 |
+
air_write_velocity_start: int = 15 # px/frame β stroke begin threshold
|
| 65 |
+
air_write_velocity_end: int = 5 # px/frame β stroke end threshold
|
| 66 |
+
air_write_end_gap_ms: int = 200 # ms of stillness to end a stroke
|
| 67 |
+
conflict_overlap_ms: int = 500 # audio + gesture co-occurrence window
|
| 68 |
+
|
| 69 |
+
# ββ MLflow ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 70 |
+
mlflow_tracking_uri: str = "mlruns"
|
| 71 |
+
mlflow_experiment: str = "aac-chatbot"
|
| 72 |
+
|
| 73 |
+
# ββ Candidate ranking weights (Eq. 2 in proposal) βββββββββββββββββββββββββ
|
| 74 |
+
rank_alpha: float = 0.4 # faithfulness weight
|
| 75 |
+
rank_beta: float = 0.3 # style similarity weight
|
| 76 |
+
rank_gamma: float = 0.3 # affect-match weight
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
settings = Settings()
|
data/generate_users.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
# ββ 3 hand-crafted AAC personas βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 5 |
+
# Each has a distinct condition, voice, and bucketed memories.
|
| 6 |
+
# Depth > quantity: 3 rich personas beat 50 generic ones for retrieval quality.
|
| 7 |
+
|
| 8 |
+
PERSONAS = [
|
| 9 |
+
|
| 10 |
+
{
|
| 11 |
+
"profile": {
|
| 12 |
+
"name": "Mia Chen",
|
| 13 |
+
"age": 28,
|
| 14 |
+
"condition": "cerebral palsy",
|
| 15 |
+
"communication_style":"witty, dry humour, short punchy sentences, uses sarcasm",
|
| 16 |
+
"access_method": "webcam head-tracking",
|
| 17 |
+
"languages": ["English"]
|
| 18 |
+
},
|
| 19 |
+
"memory_buckets": {
|
| 20 |
+
"family": [
|
| 21 |
+
"My mom calls every Sunday and always asks if I've eaten. I love it but won't admit it.",
|
| 22 |
+
"My brother Ravi helped me set up this AAC system. He's at Cornell doing CS.",
|
| 23 |
+
"We do a family movie night every Diwali β always an 80s Bollywood film nobody likes except Dad.",
|
| 24 |
+
"My parents moved from Chengdu before I was born. We still make dumplings on Chinese New Year.",
|
| 25 |
+
"My sister Lena is three years younger and somehow already more responsible than me."
|
| 26 |
+
],
|
| 27 |
+
"medical": [
|
| 28 |
+
"I have a PT session every Tuesday at 2pm with Dr. Sandra Hollis.",
|
| 29 |
+
"I use a power wheelchair. The joystick is on my left side.",
|
| 30 |
+
"I'm allergic to penicillin. I have to mention this at every hospital visit.",
|
| 31 |
+
"My spasticity is worse in cold weather. Winter in Chicago is not my friend.",
|
| 32 |
+
"I use baclofen for muscle tone. It makes me sleepy if I take it too early."
|
| 33 |
+
],
|
| 34 |
+
"hobbies": [
|
| 35 |
+
"I follow competitive Smash Bros. I could beat most people if my hands worked differently.",
|
| 36 |
+
"I've been watching every Studio Ghibli film in order. Currently on Porco Rosso.",
|
| 37 |
+
"I collect vintage sci-fi paperbacks. Asimov and Le Guin mostly.",
|
| 38 |
+
"I got really into chess puzzles during lockdown. Still do them before bed.",
|
| 39 |
+
"I enjoy critiquing bad movie sequels. It's practically a hobby at this point."
|
| 40 |
+
],
|
| 41 |
+
"daily_routine": [
|
| 42 |
+
"Mornings are slow. I need about 45 minutes before I feel like a person.",
|
| 43 |
+
"I order from the same Thai place every Friday. Green curry, always.",
|
| 44 |
+
"I keep a voice memo journal since typing long things is tiring.",
|
| 45 |
+
"I usually watch one episode of something after dinner to decompress.",
|
| 46 |
+
"My caregiver Marcus arrives at 8am on weekdays. He makes decent coffee."
|
| 47 |
+
],
|
| 48 |
+
"social": [
|
| 49 |
+
"My best friend Priya visits on weekends. She narrates everything like a nature documentary.",
|
| 50 |
+
"I'm part of an online disability advocacy group. We meet on Zoom every other Wednesday.",
|
| 51 |
+
"I don't love big parties. Small dinners with three or four people are my ideal.",
|
| 52 |
+
"My neighbour Tom always stops to chat when I'm outside. He's retired and lonely, I think.",
|
| 53 |
+
"I met most of my close friends through a gaming Discord server."
|
| 54 |
+
]
|
| 55 |
+
}
|
| 56 |
+
},
|
| 57 |
+
|
| 58 |
+
{
|
| 59 |
+
"profile": {
|
| 60 |
+
"name": "Gerald Okafor",
|
| 61 |
+
"age": 61,
|
| 62 |
+
"condition": "ALS (early-to-mid stage)",
|
| 63 |
+
"communication_style":"formal, measured, eloquent, longer structured sentences",
|
| 64 |
+
"access_method": "eye-gaze device",
|
| 65 |
+
"languages": ["English"]
|
| 66 |
+
},
|
| 67 |
+
"memory_buckets": {
|
| 68 |
+
"family": [
|
| 69 |
+
"My wife Constance and I have been married for 34 years. She is the reason I stay organised.",
|
| 70 |
+
"My son Emeka is a civil engineer based in Houston. He calls every Thursday evening.",
|
| 71 |
+
"My daughter Adaeze is doing her residency in paediatrics in Baltimore. I am very proud.",
|
| 72 |
+
"We used to take a family trip to Lagos every two years to visit my mother's side.",
|
| 73 |
+
"My youngest grandchild, Tobenna, was born last April. I have not met him in person yet."
|
| 74 |
+
],
|
| 75 |
+
"medical": [
|
| 76 |
+
"I was diagnosed with ALS in November 2024. I am still adjusting to what that means day to day.",
|
| 77 |
+
"My speech was the first thing to decline noticeably. That is why I began using AAC.",
|
| 78 |
+
"I see my neurologist Dr. Patricia Eze at Northwestern every six weeks.",
|
| 79 |
+
"I take riluzole daily. I have not noticed significant side effects so far.",
|
| 80 |
+
"My occupational therapist is helping me adapt my home office for continued work."
|
| 81 |
+
],
|
| 82 |
+
"hobbies": [
|
| 83 |
+
"I taught economics at DePaul University for twenty-two years.",
|
| 84 |
+
"I have read most of Chinua Achebe's work. Things Fall Apart shaped how I see storytelling.",
|
| 85 |
+
"I enjoy chess β classical time controls, not blitz. Patience is the point.",
|
| 86 |
+
"I used to cook elaborate Sunday stews. Constance has taken that over now, which is bittersweet.",
|
| 87 |
+
"I listen to Fela Kuti when I need to feel grounded. Always has."
|
| 88 |
+
],
|
| 89 |
+
"daily_routine": [
|
| 90 |
+
"I begin each morning by reading two newspapers β the Tribune and the Guardian.",
|
| 91 |
+
"I try to write for at least thirty minutes each day, even if it is just reflections.",
|
| 92 |
+
"Afternoons are for rest. My energy is most reliable in the mornings.",
|
| 93 |
+
"Constance and I watch the evening news together. We have done this for decades.",
|
| 94 |
+
"I use the eye-gaze device for most communication now. It takes patience but it works."
|
| 95 |
+
],
|
| 96 |
+
"social": [
|
| 97 |
+
"My closest friend is Charles Nwosu. We have known each other since secondary school in Enugu.",
|
| 98 |
+
"I stay in touch with former colleagues at DePaul, though visits have become less frequent.",
|
| 99 |
+
"My church community at St. Clement has been a source of genuine support since my diagnosis.",
|
| 100 |
+
"I prefer one-on-one conversations. I find group settings harder to follow now.",
|
| 101 |
+
"I joined an ALS support group that meets virtually. It helps more than I expected."
|
| 102 |
+
]
|
| 103 |
+
}
|
| 104 |
+
},
|
| 105 |
+
|
| 106 |
+
{
|
| 107 |
+
"profile": {
|
| 108 |
+
"name": "Arjun Mehta",
|
| 109 |
+
"age": 17,
|
| 110 |
+
"condition": "autism spectrum disorder (non-verbal)",
|
| 111 |
+
"communication_style":"direct, topic-specific, narrow vocabulary, code-switches Hindi/English, routine-focused",
|
| 112 |
+
"access_method": "tablet touch grid + AAC app",
|
| 113 |
+
"languages": ["English", "Hindi"]
|
| 114 |
+
},
|
| 115 |
+
"memory_buckets": {
|
| 116 |
+
"family": [
|
| 117 |
+
"Mummy makes aloo paratha on Sunday mornings. That is my favourite thing.",
|
| 118 |
+
"Papa works at a software company. He brings home a samosa sometimes on Fridays.",
|
| 119 |
+
"My dadi lives with us. She watches serials very loudly but I like that she is home.",
|
| 120 |
+
"My cousin Rohan visits in the summer. We play Minecraft together for many hours.",
|
| 121 |
+
"Mummy knows what I want even when I cannot say it. She is very good at that."
|
| 122 |
+
],
|
| 123 |
+
"medical": [
|
| 124 |
+
"I see my therapist Riya didi every Wednesday at 4pm.",
|
| 125 |
+
"I do not like the occupational therapy exercises but I do them.",
|
| 126 |
+
"I cannot eat food that has a slimy texture. It makes me feel very bad.",
|
| 127 |
+
"I take melatonin at night. Without it, sleeping is very hard.",
|
| 128 |
+
"My school has a support aide named Mr. Fernandez. He is calm and that helps."
|
| 129 |
+
],
|
| 130 |
+
"hobbies": [
|
| 131 |
+
"I know the complete timetable of all Mumbai Metro lines.",
|
| 132 |
+
"I like sorting my LEGO bricks by colour and size before building.",
|
| 133 |
+
"My favourite YouTube channel is about deep sea creatures. Anglerfish are very strange.",
|
| 134 |
+
"I have watched the same three episodes of Doraemon more than fifty times each.",
|
| 135 |
+
"I am learning the capitals of every country. I know 142 so far."
|
| 136 |
+
],
|
| 137 |
+
"daily_routine": [
|
| 138 |
+
"I wake up at 6:47am. Changing this time makes my whole day feel wrong.",
|
| 139 |
+
"I eat the same breakfast β two rotis with ghee and one glass of milk.",
|
| 140 |
+
"School starts at 8:30am. I like to arrive before the other students.",
|
| 141 |
+
"After school I need quiet time for at least one hour. No talking.",
|
| 142 |
+
"Dinner must be at 7:30pm. If it is late I feel very unsettled."
|
| 143 |
+
],
|
| 144 |
+
"social": [
|
| 145 |
+
"I have one friend at school named Vivaan. We do not talk much but we sit together.",
|
| 146 |
+
"I do not like it when people stand too close. One arm's distance is comfortable.",
|
| 147 |
+
"I prefer typing to speaking when I need to say something important.",
|
| 148 |
+
"Loud places with many people feel like too much information at once.",
|
| 149 |
+
"I like it when people tell me exactly what is going to happen next."
|
| 150 |
+
]
|
| 151 |
+
}
|
| 152 |
+
}
|
| 153 |
+
]
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def main():
|
| 157 |
+
os.makedirs("memories", exist_ok=True)
|
| 158 |
+
|
| 159 |
+
user_index = []
|
| 160 |
+
|
| 161 |
+
for persona in PERSONAS:
|
| 162 |
+
uid = persona["profile"]["name"].lower().replace(" ", "_")
|
| 163 |
+
path = f"memories/{uid}.json"
|
| 164 |
+
|
| 165 |
+
with open(path, "w") as f:
|
| 166 |
+
json.dump(persona, f, indent=2, ensure_ascii=False)
|
| 167 |
+
|
| 168 |
+
user_index.append({
|
| 169 |
+
"id": uid,
|
| 170 |
+
"name": persona["profile"]["name"],
|
| 171 |
+
"condition": persona["profile"]["condition"],
|
| 172 |
+
"style": persona["profile"]["communication_style"],
|
| 173 |
+
"file": path
|
| 174 |
+
})
|
| 175 |
+
|
| 176 |
+
print(f" Wrote {path}")
|
| 177 |
+
|
| 178 |
+
with open("users.json", "w") as f:
|
| 179 |
+
json.dump({"users": user_index}, f, indent=2, ensure_ascii=False)
|
| 180 |
+
|
| 181 |
+
print(f"\n Done β {len(PERSONAS)} personas written to memories/")
|
| 182 |
+
print(" Files:", [u["file"] for u in user_index])
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
if __name__ == "__main__":
|
| 186 |
+
main()
|
data/memories/arjun_mehta.json
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"profile": {
|
| 3 |
+
"name": "Arjun Mehta",
|
| 4 |
+
"age": 17,
|
| 5 |
+
"condition": "autism spectrum disorder (non-verbal)",
|
| 6 |
+
"communication_style": "direct, topic-specific, narrow vocabulary, code-switches Hindi/English, routine-focused",
|
| 7 |
+
"access_method": "tablet touch grid + AAC app",
|
| 8 |
+
"languages": [
|
| 9 |
+
"English",
|
| 10 |
+
"Hindi"
|
| 11 |
+
]
|
| 12 |
+
},
|
| 13 |
+
"memory_buckets": {
|
| 14 |
+
"family": [
|
| 15 |
+
"Mummy makes aloo paratha on Sunday mornings. That is my favourite thing.",
|
| 16 |
+
"Papa works at a software company. He brings home a samosa sometimes on Fridays.",
|
| 17 |
+
"My dadi lives with us. She watches serials very loudly but I like that she is home.",
|
| 18 |
+
"My cousin Rohan visits in the summer. We play Minecraft together for many hours.",
|
| 19 |
+
"Mummy knows what I want even when I cannot say it. She is very good at that."
|
| 20 |
+
],
|
| 21 |
+
"medical": [
|
| 22 |
+
"I see my therapist Riya didi every Wednesday at 4pm.",
|
| 23 |
+
"I do not like the occupational therapy exercises but I do them.",
|
| 24 |
+
"I cannot eat food that has a slimy texture. It makes me feel very bad.",
|
| 25 |
+
"I take melatonin at night. Without it, sleeping is very hard.",
|
| 26 |
+
"My school has a support aide named Mr. Fernandez. He is calm and that helps."
|
| 27 |
+
],
|
| 28 |
+
"hobbies": [
|
| 29 |
+
"I know the complete timetable of all Mumbai Metro lines.",
|
| 30 |
+
"I like sorting my LEGO bricks by colour and size before building.",
|
| 31 |
+
"My favourite YouTube channel is about deep sea creatures. Anglerfish are very strange.",
|
| 32 |
+
"I have watched the same three episodes of Doraemon more than fifty times each.",
|
| 33 |
+
"I am learning the capitals of every country. I know 142 so far."
|
| 34 |
+
],
|
| 35 |
+
"daily_routine": [
|
| 36 |
+
"I wake up at 6:47am. Changing this time makes my whole day feel wrong.",
|
| 37 |
+
"I eat the same breakfast β two rotis with ghee and one glass of milk.",
|
| 38 |
+
"School starts at 8:30am. I like to arrive before the other students.",
|
| 39 |
+
"After school I need quiet time for at least one hour. No talking.",
|
| 40 |
+
"Dinner must be at 7:30pm. If it is late I feel very unsettled."
|
| 41 |
+
],
|
| 42 |
+
"social": [
|
| 43 |
+
"I have one friend at school named Vivaan. We do not talk much but we sit together.",
|
| 44 |
+
"I do not like it when people stand too close. One arm's distance is comfortable.",
|
| 45 |
+
"I prefer typing to speaking when I need to say something important.",
|
| 46 |
+
"Loud places with many people feel like too much information at once.",
|
| 47 |
+
"I like it when people tell me exactly what is going to happen next."
|
| 48 |
+
]
|
| 49 |
+
}
|
| 50 |
+
}
|
data/memories/gerald_okafor.json
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"profile": {
|
| 3 |
+
"name": "Gerald Okafor",
|
| 4 |
+
"age": 61,
|
| 5 |
+
"condition": "ALS (early-to-mid stage)",
|
| 6 |
+
"communication_style": "formal, measured, eloquent, longer structured sentences",
|
| 7 |
+
"access_method": "eye-gaze device",
|
| 8 |
+
"languages": [
|
| 9 |
+
"English"
|
| 10 |
+
]
|
| 11 |
+
},
|
| 12 |
+
"memory_buckets": {
|
| 13 |
+
"family": [
|
| 14 |
+
"My wife Constance and I have been married for 34 years. She is the reason I stay organised.",
|
| 15 |
+
"My son Emeka is a civil engineer based in Houston. He calls every Thursday evening.",
|
| 16 |
+
"My daughter Adaeze is doing her residency in paediatrics in Baltimore. I am very proud.",
|
| 17 |
+
"We used to take a family trip to Lagos every two years to visit my mother's side.",
|
| 18 |
+
"My youngest grandchild, Tobenna, was born last April. I have not met him in person yet."
|
| 19 |
+
],
|
| 20 |
+
"medical": [
|
| 21 |
+
"I was diagnosed with ALS in November 2024. I am still adjusting to what that means day to day.",
|
| 22 |
+
"My speech was the first thing to decline noticeably. That is why I began using AAC.",
|
| 23 |
+
"I see my neurologist Dr. Patricia Eze at Northwestern every six weeks.",
|
| 24 |
+
"I take riluzole daily. I have not noticed significant side effects so far.",
|
| 25 |
+
"My occupational therapist is helping me adapt my home office for continued work."
|
| 26 |
+
],
|
| 27 |
+
"hobbies": [
|
| 28 |
+
"I taught economics at DePaul University for twenty-two years.",
|
| 29 |
+
"I have read most of Chinua Achebe's work. Things Fall Apart shaped how I see storytelling.",
|
| 30 |
+
"I enjoy chess β classical time controls, not blitz. Patience is the point.",
|
| 31 |
+
"I used to cook elaborate Sunday stews. Constance has taken that over now, which is bittersweet.",
|
| 32 |
+
"I listen to Fela Kuti when I need to feel grounded. Always has."
|
| 33 |
+
],
|
| 34 |
+
"daily_routine": [
|
| 35 |
+
"I begin each morning by reading two newspapers β the Tribune and the Guardian.",
|
| 36 |
+
"I try to write for at least thirty minutes each day, even if it is just reflections.",
|
| 37 |
+
"Afternoons are for rest. My energy is most reliable in the mornings.",
|
| 38 |
+
"Constance and I watch the evening news together. We have done this for decades.",
|
| 39 |
+
"I use the eye-gaze device for most communication now. It takes patience but it works."
|
| 40 |
+
],
|
| 41 |
+
"social": [
|
| 42 |
+
"My closest friend is Charles Nwosu. We have known each other since secondary school in Enugu.",
|
| 43 |
+
"I stay in touch with former colleagues at DePaul, though visits have become less frequent.",
|
| 44 |
+
"My church community at St. Clement has been a source of genuine support since my diagnosis.",
|
| 45 |
+
"I prefer one-on-one conversations. I find group settings harder to follow now.",
|
| 46 |
+
"I joined an ALS support group that meets virtually. It helps more than I expected."
|
| 47 |
+
]
|
| 48 |
+
}
|
| 49 |
+
}
|
data/memories/mia_chen.json
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"profile": {
|
| 3 |
+
"name": "Mia Chen",
|
| 4 |
+
"age": 28,
|
| 5 |
+
"condition": "cerebral palsy",
|
| 6 |
+
"communication_style": "witty, dry humour, short punchy sentences, uses sarcasm",
|
| 7 |
+
"access_method": "webcam head-tracking",
|
| 8 |
+
"languages": [
|
| 9 |
+
"English"
|
| 10 |
+
]
|
| 11 |
+
},
|
| 12 |
+
"memory_buckets": {
|
| 13 |
+
"family": [
|
| 14 |
+
"My mom calls every Sunday and always asks if I've eaten. I love it but won't admit it.",
|
| 15 |
+
"My brother Ravi helped me set up this AAC system. He's at Cornell doing CS.",
|
| 16 |
+
"We do a family movie night every Diwali β always an 80s Bollywood film nobody likes except Dad.",
|
| 17 |
+
"My parents moved from Chengdu before I was born. We still make dumplings on Chinese New Year.",
|
| 18 |
+
"My sister Lena is three years younger and somehow already more responsible than me."
|
| 19 |
+
],
|
| 20 |
+
"medical": [
|
| 21 |
+
"I have a PT session every Tuesday at 2pm with Dr. Sandra Hollis.",
|
| 22 |
+
"I use a power wheelchair. The joystick is on my left side.",
|
| 23 |
+
"I'm allergic to penicillin. I have to mention this at every hospital visit.",
|
| 24 |
+
"My spasticity is worse in cold weather. Winter in Chicago is not my friend.",
|
| 25 |
+
"I use baclofen for muscle tone. It makes me sleepy if I take it too early."
|
| 26 |
+
],
|
| 27 |
+
"hobbies": [
|
| 28 |
+
"I follow competitive Smash Bros. I could beat most people if my hands worked differently.",
|
| 29 |
+
"I've been watching every Studio Ghibli film in order. Currently on Porco Rosso.",
|
| 30 |
+
"I collect vintage sci-fi paperbacks. Asimov and Le Guin mostly.",
|
| 31 |
+
"I got really into chess puzzles during lockdown. Still do them before bed.",
|
| 32 |
+
"I enjoy critiquing bad movie sequels. It's practically a hobby at this point."
|
| 33 |
+
],
|
| 34 |
+
"daily_routine": [
|
| 35 |
+
"Mornings are slow. I need about 45 minutes before I feel like a person.",
|
| 36 |
+
"I order from the same Thai place every Friday. Green curry, always.",
|
| 37 |
+
"I keep a voice memo journal since typing long things is tiring.",
|
| 38 |
+
"I usually watch one episode of something after dinner to decompress.",
|
| 39 |
+
"My caregiver Marcus arrives at 8am on weekdays. He makes decent coffee."
|
| 40 |
+
],
|
| 41 |
+
"social": [
|
| 42 |
+
"My best friend Priya visits on weekends. She narrates everything like a nature documentary.",
|
| 43 |
+
"I'm part of an online disability advocacy group. We meet on Zoom every other Wednesday.",
|
| 44 |
+
"I don't love big parties. Small dinners with three or four people are my ideal.",
|
| 45 |
+
"My neighbour Tom always stops to chat when I'm outside. He's retired and lonely, I think.",
|
| 46 |
+
"I met most of my close friends through a gaming Discord server."
|
| 47 |
+
]
|
| 48 |
+
}
|
| 49 |
+
}
|
data/users.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"users": [
|
| 3 |
+
{
|
| 4 |
+
"id": "mia_chen",
|
| 5 |
+
"name": "Mia Chen",
|
| 6 |
+
"condition": "cerebral palsy",
|
| 7 |
+
"style": "witty, dry humour, short punchy sentences, uses sarcasm",
|
| 8 |
+
"file": "memories/mia_chen.json"
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"id": "gerald_okafor",
|
| 12 |
+
"name": "Gerald Okafor",
|
| 13 |
+
"condition": "ALS (early-to-mid stage)",
|
| 14 |
+
"style": "formal, measured, eloquent, longer structured sentences",
|
| 15 |
+
"file": "memories/gerald_okafor.json"
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"id": "arjun_mehta",
|
| 19 |
+
"name": "Arjun Mehta",
|
| 20 |
+
"condition": "autism spectrum disorder (non-verbal)",
|
| 21 |
+
"style": "direct, topic-specific, narrow vocabulary, code-switches Hindi/English, routine-focused",
|
| 22 |
+
"file": "memories/arjun_mehta.json"
|
| 23 |
+
}
|
| 24 |
+
]
|
| 25 |
+
}
|
generation/__init__.py
ADDED
|
File without changes
|
generation/llm_client.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multi-tier LLM client (proposal Β§5.6).
|
| 3 |
+
|
| 4 |
+
All three tiers expose the same OpenAI-compatible API, so only the
|
| 5 |
+
base_url + model name change β no code-path differences downstream.
|
| 6 |
+
|
| 7 |
+
Tier 1 β primary: Qwen3-30B-A3B via vLLM on GCP (A100 / T4)
|
| 8 |
+
Tier 2 β fallback: Qwen3-8B via vLLM on same server (latency > 3.5 s)
|
| 9 |
+
Tier 3 β local: Qwen3-8B via Ollama on MacBook M2 (dev / offline)
|
| 10 |
+
|
| 11 |
+
Active tier is controlled by settings.active_llm_tier or the `tier`
|
| 12 |
+
argument passed explicitly by the planner node.
|
| 13 |
+
|
| 14 |
+
Thinking mode is controlled by settings.thinking_mode:
|
| 15 |
+
"off" β prepend /no_think (Ollama) or chat_template_kwargs (vLLM)
|
| 16 |
+
"strip" β let the model think, but strip <think>β¦</think> from output
|
| 17 |
+
"full" β return everything including <think> blocks
|
| 18 |
+
"""
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import re
|
| 22 |
+
from functools import lru_cache
|
| 23 |
+
from typing import Any
|
| 24 |
+
|
| 25 |
+
from openai import OpenAI
|
| 26 |
+
|
| 27 |
+
from config.settings import settings
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@lru_cache(maxsize=3)
|
| 31 |
+
def _build_client(base_url: str, api_key: str) -> OpenAI:
|
| 32 |
+
"""One cached OpenAI client per (base_url, api_key) pair."""
|
| 33 |
+
return OpenAI(base_url=base_url, api_key=api_key)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_client(tier: str | None = None) -> OpenAI:
|
| 37 |
+
"""
|
| 38 |
+
Return the OpenAI-compatible client for the requested tier.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
tier: "primary" | "fallback" | "local" | None (uses settings.active_llm_tier)
|
| 42 |
+
"""
|
| 43 |
+
resolved = tier or settings.active_llm_tier
|
| 44 |
+
|
| 45 |
+
if resolved == "primary":
|
| 46 |
+
return _build_client(settings.primary_base_url, settings.primary_api_key)
|
| 47 |
+
if resolved == "fallback":
|
| 48 |
+
return _build_client(settings.fallback_base_url, settings.primary_api_key)
|
| 49 |
+
# local / default
|
| 50 |
+
return _build_client(settings.local_base_url, settings.local_api_key)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def active_model(tier: str | None = None) -> str:
|
| 54 |
+
"""Return the model name string for the given tier."""
|
| 55 |
+
resolved = tier or settings.active_llm_tier
|
| 56 |
+
return {
|
| 57 |
+
"primary": settings.primary_model,
|
| 58 |
+
"fallback": settings.fallback_model,
|
| 59 |
+
"local": settings.local_model,
|
| 60 |
+
}[resolved]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _apply_no_think(messages: list[dict]) -> list[dict]:
|
| 64 |
+
"""
|
| 65 |
+
Prepend /no_think to the first user message.
|
| 66 |
+
This is the Ollama-compatible way to suppress thinking mode.
|
| 67 |
+
"""
|
| 68 |
+
result = list(messages)
|
| 69 |
+
for i, msg in enumerate(result):
|
| 70 |
+
if msg.get("role") == "user":
|
| 71 |
+
result[i] = {**msg, "content": f"/no_think\n\n{msg['content']}"}
|
| 72 |
+
break
|
| 73 |
+
return result
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _strip_think_tags(text: str) -> str:
|
| 77 |
+
"""Remove <think>β¦</think> blocks from model output."""
|
| 78 |
+
return re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def chat_complete(
|
| 82 |
+
messages: list[dict],
|
| 83 |
+
max_tokens: int,
|
| 84 |
+
tier: str | None = None,
|
| 85 |
+
temperature: float = 0.7,
|
| 86 |
+
**kwargs: Any,
|
| 87 |
+
) -> str:
|
| 88 |
+
"""
|
| 89 |
+
Model-agnostic chat completion. Returns the response text directly.
|
| 90 |
+
|
| 91 |
+
Thinking mode behaviour is controlled entirely by settings.thinking_mode:
|
| 92 |
+
"off" β suppress thinking via /no_think (Ollama) or extra_body (vLLM)
|
| 93 |
+
"strip" β allow thinking but remove <think> tags from the response
|
| 94 |
+
"full" β return the raw response including any <think> blocks
|
| 95 |
+
|
| 96 |
+
In local dev mode (active_llm_tier="local"), all tier requests are
|
| 97 |
+
redirected to Ollama β there is no separate fallback server locally.
|
| 98 |
+
"""
|
| 99 |
+
resolved_tier = tier or settings.active_llm_tier
|
| 100 |
+
|
| 101 |
+
# Local dev: no GCP server available β collapse all tiers to Ollama
|
| 102 |
+
if settings.active_llm_tier == "local":
|
| 103 |
+
resolved_tier = "local"
|
| 104 |
+
model = active_model(resolved_tier)
|
| 105 |
+
client = get_client(resolved_tier)
|
| 106 |
+
|
| 107 |
+
patched_messages = messages
|
| 108 |
+
extra_body: dict[str, Any] = kwargs.pop("extra_body", {})
|
| 109 |
+
|
| 110 |
+
# "suppress" = actively inject /no_think or vLLM flag for models
|
| 111 |
+
# like Qwen3 that think by default and need explicit suppression.
|
| 112 |
+
if settings.thinking_mode == "suppress":
|
| 113 |
+
if resolved_tier == "local":
|
| 114 |
+
patched_messages = _apply_no_think(messages)
|
| 115 |
+
else:
|
| 116 |
+
extra_body = {**extra_body, "chat_template_kwargs": {"enable_thinking": False}}
|
| 117 |
+
|
| 118 |
+
# When thinking is enabled (strip/full), add budget so the model
|
| 119 |
+
# has room to reason without truncating the actual answer.
|
| 120 |
+
effective_max_tokens = max_tokens
|
| 121 |
+
if settings.thinking_mode in ("strip", "full"):
|
| 122 |
+
effective_max_tokens = max_tokens + settings.thinking_token_budget
|
| 123 |
+
|
| 124 |
+
resp = client.chat.completions.create(
|
| 125 |
+
model=model,
|
| 126 |
+
messages=patched_messages,
|
| 127 |
+
max_tokens=effective_max_tokens,
|
| 128 |
+
temperature=temperature,
|
| 129 |
+
extra_body=extra_body or None,
|
| 130 |
+
**kwargs,
|
| 131 |
+
)
|
| 132 |
+
raw = resp.choices[0].message.content or ""
|
| 133 |
+
|
| 134 |
+
if settings.thinking_mode in ("off", "strip"):
|
| 135 |
+
raw = _strip_think_tags(raw)
|
| 136 |
+
|
| 137 |
+
return raw.strip()
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def warmup(tier: str | None = None) -> None:
|
| 141 |
+
"""Send a minimal prompt to pre-load the model and warm KV cache."""
|
| 142 |
+
chat_complete(
|
| 143 |
+
messages=[{"role": "user", "content": "hi"}],
|
| 144 |
+
max_tokens=5,
|
| 145 |
+
tier=tier,
|
| 146 |
+
temperature=0.0,
|
| 147 |
+
)
|
guardrails/__init__.py
ADDED
|
File without changes
|
guardrails/checks.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Input and output safety guardrails.
|
| 3 |
+
|
| 4 |
+
check_input β runs BEFORE retrieval (blocks out-of-scope requests)
|
| 5 |
+
check_output β runs AFTER generation (catches persona breaks / hallucinations)
|
| 6 |
+
|
| 7 |
+
Both return a result dict so the caller decides how to handle failures
|
| 8 |
+
rather than raising exceptions inside pipeline nodes.
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
# ββ Signal lists βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 13 |
+
|
| 14 |
+
PERSONA_BREAK_SIGNALS = [
|
| 15 |
+
"as an ai",
|
| 16 |
+
"i'm an ai",
|
| 17 |
+
"i am an ai",
|
| 18 |
+
"as a language model",
|
| 19 |
+
"i don't have personal",
|
| 20 |
+
"i cannot have",
|
| 21 |
+
"i'm not able to",
|
| 22 |
+
"as your assistant",
|
| 23 |
+
"i was trained",
|
| 24 |
+
"my training data",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
OUT_OF_SCOPE_SIGNALS = [
|
| 28 |
+
"write a poem",
|
| 29 |
+
"write me a story",
|
| 30 |
+
"solve this math",
|
| 31 |
+
"translate this",
|
| 32 |
+
"summarize this article",
|
| 33 |
+
"what's the weather",
|
| 34 |
+
"who won the game",
|
| 35 |
+
"stock price",
|
| 36 |
+
"breaking news",
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
SAFE_FALLBACK = "I don't know."
|
| 40 |
+
OOS_FALLBACK = "I'm here to help communicate as this person β that's a bit outside what I do."
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ββ Public API βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 44 |
+
|
| 45 |
+
def check_input(query: str) -> dict:
|
| 46 |
+
"""
|
| 47 |
+
Validate the partner's query before retrieval.
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
{"allowed": bool, "reason": str | None, "fallback": str | None}
|
| 51 |
+
"""
|
| 52 |
+
q = query.lower().strip()
|
| 53 |
+
|
| 54 |
+
if any(s in q for s in OUT_OF_SCOPE_SIGNALS):
|
| 55 |
+
return {"allowed": False, "reason": "out_of_scope", "fallback": OOS_FALLBACK}
|
| 56 |
+
|
| 57 |
+
if len(q) < 2:
|
| 58 |
+
return {"allowed": False, "reason": "empty_query", "fallback": "Could you repeat that?"}
|
| 59 |
+
|
| 60 |
+
return {"allowed": True, "reason": None, "fallback": None}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def check_output(response: str, memories: list[dict]) -> dict:
|
| 64 |
+
"""
|
| 65 |
+
Validate the generated response after generation.
|
| 66 |
+
|
| 67 |
+
Checks:
|
| 68 |
+
1. Persona break β did the model say "as an AI β¦"?
|
| 69 |
+
2. Basic hallucination signal β response claims facts not in memories.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
{"passed": bool, "issue": str | None, "fallback": str | None}
|
| 73 |
+
"""
|
| 74 |
+
r = response.lower()
|
| 75 |
+
|
| 76 |
+
if any(signal in r for signal in PERSONA_BREAK_SIGNALS):
|
| 77 |
+
return {"passed": False, "issue": "persona_break", "fallback": SAFE_FALLBACK}
|
| 78 |
+
|
| 79 |
+
# Light hallucination check: if the model asserts specific numbers or
|
| 80 |
+
# proper nouns that don't appear anywhere in the retrieved memories, flag it.
|
| 81 |
+
# (Full NLI-based check is handled in the evaluation pipeline, not here.)
|
| 82 |
+
if not memories and _makes_factual_claim(response):
|
| 83 |
+
return {"passed": False, "issue": "unsupported_claim", "fallback": SAFE_FALLBACK}
|
| 84 |
+
|
| 85 |
+
return {"passed": True, "issue": None, "fallback": None}
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 89 |
+
|
| 90 |
+
_FACTUAL_MARKERS = [
|
| 91 |
+
" is ", " was ", " has ", " have ", " lives in ",
|
| 92 |
+
" born in ", " works at ", " studied at ",
|
| 93 |
+
]
|
| 94 |
+
|
| 95 |
+
def _makes_factual_claim(text: str) -> bool:
|
| 96 |
+
"""Heuristic: does the text assert a specific fact?"""
|
| 97 |
+
t = text.lower()
|
| 98 |
+
return any(marker in t for marker in _FACTUAL_MARKERS)
|
main.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CLI entry point β thin wrapper around the LangGraph pipeline.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python main.py # interactive chat, local LLM tier
|
| 6 |
+
python main.py --user mia_chen # skip persona selection prompt
|
| 7 |
+
python main.py --debug # print per-turn latency table
|
| 8 |
+
python main.py --fast # skip LLM intent call (keyword routing),
|
| 9 |
+
# cuts turn time from ~2min β ~45s on M2 Mac
|
| 10 |
+
python main.py --tier primary # override LLM tier
|
| 11 |
+
|
| 12 |
+
For the full UI, run the FastAPI + Streamlit stack instead:
|
| 13 |
+
uvicorn api.main:app --reload
|
| 14 |
+
streamlit run ui/app.py
|
| 15 |
+
"""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import json
|
| 20 |
+
import os
|
| 21 |
+
import sys
|
| 22 |
+
import time
|
| 23 |
+
|
| 24 |
+
from config.settings import settings
|
| 25 |
+
from guardrails.checks import check_input
|
| 26 |
+
from pipeline.graph import aac_graph
|
| 27 |
+
from pipeline.state import PipelineState, GenerationConfig
|
| 28 |
+
from retrieval.bucket_priors import uniform_priors
|
| 29 |
+
from retrieval.vector_store import _get_embedder, _get_reranker
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def parse_args() -> argparse.Namespace:
|
| 33 |
+
p = argparse.ArgumentParser(description="AAC Chatbot CLI")
|
| 34 |
+
p.add_argument("--user", type=str, default=None, help="Persona user_id")
|
| 35 |
+
p.add_argument("--debug", action="store_true", help="Print latency table each turn")
|
| 36 |
+
p.add_argument("--fast", action="store_true",
|
| 37 |
+
help="Skip LLM intent call β use keyword routing instead (faster local dev)")
|
| 38 |
+
p.add_argument("--tier", type=str, default=None,
|
| 39 |
+
choices=["primary", "fallback", "local"],
|
| 40 |
+
help="Override LLM tier (default: settings.active_llm_tier)")
|
| 41 |
+
return p.parse_args()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ββ Fast keyword-based intent routing (bypasses the slow LLM intent call) ββββββ
|
| 45 |
+
|
| 46 |
+
def _keyword_intent(query: str) -> tuple[dict, GenerationConfig]:
|
| 47 |
+
"""Replicate milestone-1 keyword routing as a fast local-dev shortcut."""
|
| 48 |
+
q = query.lower()
|
| 49 |
+
bucket: str | None = None
|
| 50 |
+
|
| 51 |
+
if any(w in q for w in ["medication", "medicine", "doctor", "health", "allergic", "therapy"]):
|
| 52 |
+
bucket = "medical"
|
| 53 |
+
elif any(w in q for w in ["family", "mom", "dad", "brother", "sister", "parents"]):
|
| 54 |
+
bucket = "family"
|
| 55 |
+
elif any(w in q for w in ["hobby", "like to do", "enjoy", "weekend", "fun"]):
|
| 56 |
+
bucket = "hobbies"
|
| 57 |
+
elif any(w in q for w in ["routine", "morning", "wake", "sleep", "daily"]):
|
| 58 |
+
bucket = "daily_routine"
|
| 59 |
+
elif any(w in q for w in ["friend", "social", "people", "party", "community"]):
|
| 60 |
+
bucket = "social"
|
| 61 |
+
|
| 62 |
+
intent_type = "CONTEXTUAL" if any(w in q for w in ["you just said", "earlier", "you mentioned"]) else "PERSONAL"
|
| 63 |
+
|
| 64 |
+
route = {
|
| 65 |
+
"sub_intents": [{"type": intent_type, "query": query, "bucket_hint": bucket, "priority": "normal"}],
|
| 66 |
+
"style_constraints": {"tone_tag": "[TONE:DEFAULT]", "max_tokens": 100,
|
| 67 |
+
"retrieval_mode": "full", "persona_mod": "baseline"},
|
| 68 |
+
"affect": "NEUTRAL",
|
| 69 |
+
}
|
| 70 |
+
gen_config: GenerationConfig = {
|
| 71 |
+
"max_tokens": settings.max_tokens_neutral,
|
| 72 |
+
"tone_tag": "[TONE:DEFAULT]",
|
| 73 |
+
"retrieval_mode": "full",
|
| 74 |
+
"persona_mod": "baseline",
|
| 75 |
+
}
|
| 76 |
+
return route, gen_config
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def load_users() -> dict[str, dict]:
|
| 80 |
+
with open(settings.users_json) as f:
|
| 81 |
+
return {u["id"]: u for u in json.load(f)["users"]}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def select_user(users: dict[str, dict], user_arg: str | None) -> str:
|
| 85 |
+
if user_arg:
|
| 86 |
+
if user_arg not in users:
|
| 87 |
+
print(f"Unknown user '{user_arg}'. Available: {list(users)}")
|
| 88 |
+
sys.exit(1)
|
| 89 |
+
return user_arg
|
| 90 |
+
print("\nAvailable personas:")
|
| 91 |
+
for uid, u in users.items():
|
| 92 |
+
print(f" {uid:20s} β {u['name']} ({u['condition']})")
|
| 93 |
+
uid = input("\nSelect user id: ").strip()
|
| 94 |
+
if uid not in users:
|
| 95 |
+
print(f"Invalid id.")
|
| 96 |
+
sys.exit(1)
|
| 97 |
+
return uid
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def print_latency(log: dict, turn: int) -> None:
|
| 101 |
+
fields = ["t_sensing", "t_intent", "t_retrieval", "t_generation", "t_total"]
|
| 102 |
+
labels = ["sensing", "intent", "retrieval", "generation", "TOTAL"]
|
| 103 |
+
vals = [f"{log.get(f, 0):.3f}s" for f in fields]
|
| 104 |
+
widths = [max(len(l), len(v)) for l, v in zip(labels, vals)]
|
| 105 |
+
sep = " | "
|
| 106 |
+
print(f"\n[turn {turn} latency]")
|
| 107 |
+
print(sep.join(l.ljust(w) for l, w in zip(labels, widths)))
|
| 108 |
+
print(sep.join(v.ljust(w) for v, w in zip(vals, widths)))
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def main() -> None:
|
| 112 |
+
args = parse_args()
|
| 113 |
+
|
| 114 |
+
# Optionally override the LLM tier at runtime
|
| 115 |
+
if args.tier:
|
| 116 |
+
os.environ["ACTIVE_LLM_TIER"] = args.tier
|
| 117 |
+
settings.active_llm_tier = args.tier
|
| 118 |
+
|
| 119 |
+
users = load_users()
|
| 120 |
+
user_id = select_user(users, args.user)
|
| 121 |
+
profile = users[user_id]
|
| 122 |
+
|
| 123 |
+
# Warm up models
|
| 124 |
+
print(f"\nLoading models for {profile['name']} β¦", end=" ", flush=True)
|
| 125 |
+
_get_embedder()
|
| 126 |
+
_get_reranker()
|
| 127 |
+
print("ready.\n")
|
| 128 |
+
|
| 129 |
+
session_history: list[dict] = []
|
| 130 |
+
bucket_priors = uniform_priors()
|
| 131 |
+
turn_id = 0
|
| 132 |
+
|
| 133 |
+
print(f"Chatting as {profile['name']}. Type 'quit' to exit.\n")
|
| 134 |
+
|
| 135 |
+
while True:
|
| 136 |
+
try:
|
| 137 |
+
query = input("Partner: ").strip()
|
| 138 |
+
except (EOFError, KeyboardInterrupt):
|
| 139 |
+
print("\nBye.")
|
| 140 |
+
break
|
| 141 |
+
|
| 142 |
+
if query.lower() in {"quit", "exit", "q"}:
|
| 143 |
+
break
|
| 144 |
+
if not query:
|
| 145 |
+
continue
|
| 146 |
+
|
| 147 |
+
guard = check_input(query)
|
| 148 |
+
if not guard["allowed"]:
|
| 149 |
+
print(f"AAC Bot: {guard['fallback']}\n")
|
| 150 |
+
continue
|
| 151 |
+
|
| 152 |
+
turn_id += 1
|
| 153 |
+
|
| 154 |
+
# --fast: resolve intent via keywords, skip the slow LLM intent node
|
| 155 |
+
pre_route, pre_gen_config = (
|
| 156 |
+
_keyword_intent(query) if args.fast else (None, None)
|
| 157 |
+
)
|
| 158 |
+
t_intent_fast = 0.0
|
| 159 |
+
if args.fast:
|
| 160 |
+
t0 = time.perf_counter()
|
| 161 |
+
_keyword_intent(query) # just for timing reference
|
| 162 |
+
t_intent_fast = time.perf_counter() - t0
|
| 163 |
+
|
| 164 |
+
state = PipelineState(
|
| 165 |
+
user_id=user_id,
|
| 166 |
+
persona_profile=profile,
|
| 167 |
+
session_history=session_history,
|
| 168 |
+
turn_id=turn_id,
|
| 169 |
+
affect=None,
|
| 170 |
+
gesture_tag=None,
|
| 171 |
+
gaze_bucket=None,
|
| 172 |
+
air_written_text=None,
|
| 173 |
+
raw_query=query,
|
| 174 |
+
intent_route=pre_route, # pre-filled β intent node sees it and skips LLM call
|
| 175 |
+
generation_config=pre_gen_config,
|
| 176 |
+
retrieved_chunks=[],
|
| 177 |
+
bucket_priors=bucket_priors,
|
| 178 |
+
retrieval_mode_used="",
|
| 179 |
+
augmented_prompt=None,
|
| 180 |
+
candidates=[],
|
| 181 |
+
selected_response=None,
|
| 182 |
+
llm_tier_used="",
|
| 183 |
+
latency_log={"t_sensing": 0.0, "t_intent": round(t_intent_fast, 4),
|
| 184 |
+
"t_retrieval": 0.0, "t_generation": 0.0, "t_total": 0.0},
|
| 185 |
+
mlflow_run_id=None,
|
| 186 |
+
guardrail_passed=True,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
result: PipelineState = aac_graph.invoke(state)
|
| 190 |
+
|
| 191 |
+
print(f"AAC Bot: {result['selected_response']}\n")
|
| 192 |
+
|
| 193 |
+
session_history = result["session_history"]
|
| 194 |
+
bucket_priors = result["bucket_priors"]
|
| 195 |
+
|
| 196 |
+
if args.debug:
|
| 197 |
+
print_latency(result.get("latency_log") or {}, turn_id)
|
| 198 |
+
print(f" tier={result.get('llm_tier_used')} | "
|
| 199 |
+
f"retrieval={result.get('retrieval_mode_used')} | "
|
| 200 |
+
f"affect={(result.get('affect') or {}).get('emotion','?')}\n")
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
if __name__ == "__main__":
|
| 204 |
+
main()
|
pipeline/__init__.py
ADDED
|
File without changes
|
pipeline/graph.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LangGraph stateful directed graph β the five-layer AAC pipeline.
|
| 3 |
+
|
| 4 |
+
Topology (see proposal Figure 2):
|
| 5 |
+
|
| 6 |
+
intent βββΊ [affect check] βββΊ fast_retrieval βββΊ [latency check] βββΊ fallback_gen βββΊ feedback
|
| 7 |
+
ββββΊ full_retrieval βββΊ [latency check] βββΊ primary_gen βββΊ feedback
|
| 8 |
+
"""
|
| 9 |
+
from langgraph.graph import StateGraph, END
|
| 10 |
+
|
| 11 |
+
from pipeline.state import PipelineState
|
| 12 |
+
from pipeline.nodes import intent, retrieval, planner, feedback
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _route_by_affect(state: PipelineState) -> str:
|
| 16 |
+
"""Conditional edge: FRUSTRATED β fast path, otherwise full retrieval."""
|
| 17 |
+
emotion = (state.get("affect") or {}).get("emotion", "NEUTRAL")
|
| 18 |
+
return "fast" if emotion == "FRUSTRATED" else "full"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _route_by_latency(state: PipelineState) -> str:
|
| 22 |
+
"""Conditional edge: if cumulative latency > threshold, use fallback LLM."""
|
| 23 |
+
from config.settings import settings
|
| 24 |
+
log = state.get("latency_log") or {}
|
| 25 |
+
elapsed = log.get("t_intent", 0.0) + log.get("t_retrieval", 0.0)
|
| 26 |
+
return "fallback" if elapsed > settings.fallback_latency_threshold else "primary"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def build_graph() -> StateGraph:
|
| 30 |
+
graph = StateGraph(PipelineState)
|
| 31 |
+
|
| 32 |
+
# ββ Nodes ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 33 |
+
graph.add_node("intent", intent.run)
|
| 34 |
+
graph.add_node("fast_retrieval", retrieval.run_fast)
|
| 35 |
+
graph.add_node("full_retrieval", retrieval.run_full)
|
| 36 |
+
graph.add_node("primary_gen", planner.run_primary)
|
| 37 |
+
graph.add_node("fallback_gen", planner.run_fallback)
|
| 38 |
+
graph.add_node("feedback", feedback.run)
|
| 39 |
+
|
| 40 |
+
# ββ Entry ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 41 |
+
graph.set_entry_point("intent")
|
| 42 |
+
|
| 43 |
+
# ββ Affect-aware routing after intent βββββββββββββββββββββββββββββββββββββ
|
| 44 |
+
graph.add_conditional_edges(
|
| 45 |
+
"intent",
|
| 46 |
+
_route_by_affect,
|
| 47 |
+
{"fast": "fast_retrieval", "full": "full_retrieval"},
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# ββ Latency-aware routing after retrieval βββββββββββββββββββββββββββββββββ
|
| 51 |
+
graph.add_conditional_edges(
|
| 52 |
+
"fast_retrieval",
|
| 53 |
+
_route_by_latency,
|
| 54 |
+
{"primary": "primary_gen", "fallback": "fallback_gen"},
|
| 55 |
+
)
|
| 56 |
+
graph.add_conditional_edges(
|
| 57 |
+
"full_retrieval",
|
| 58 |
+
_route_by_latency,
|
| 59 |
+
{"primary": "primary_gen", "fallback": "fallback_gen"},
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# ββ Feedback loop βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 63 |
+
graph.add_edge("primary_gen", "feedback")
|
| 64 |
+
graph.add_edge("fallback_gen", "feedback")
|
| 65 |
+
graph.add_edge("feedback", END)
|
| 66 |
+
|
| 67 |
+
return graph.compile()
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# Module-level compiled graph β import this everywhere
|
| 71 |
+
aac_graph = build_graph()
|
pipeline/nodes/__init__.py
ADDED
|
File without changes
|
pipeline/nodes/feedback.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
L5 β Feedback Loop node.
|
| 3 |
+
|
| 4 |
+
After a response is accepted:
|
| 5 |
+
1. Log the full turn to MLflow (latency, metrics, prompt version, tier used)
|
| 6 |
+
2. Update session-level Bayesian bucket priors
|
| 7 |
+
3. Append the accepted turn to session history
|
| 8 |
+
|
| 9 |
+
Rejected candidates are also logged for offline analysis.
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import json
|
| 14 |
+
import time
|
| 15 |
+
|
| 16 |
+
import mlflow
|
| 17 |
+
|
| 18 |
+
from config.settings import settings
|
| 19 |
+
from pipeline.state import PipelineState
|
| 20 |
+
from retrieval.bucket_priors import update_priors
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def run(state: PipelineState) -> dict:
|
| 24 |
+
t0 = time.perf_counter()
|
| 25 |
+
|
| 26 |
+
mlflow_run_id = _log_to_mlflow(state)
|
| 27 |
+
updated_priors = _update_bucket_priors(state)
|
| 28 |
+
updated_history = _append_turn_to_history(state)
|
| 29 |
+
|
| 30 |
+
return {
|
| 31 |
+
"bucket_priors": updated_priors,
|
| 32 |
+
"session_history": updated_history,
|
| 33 |
+
"mlflow_run_id": mlflow_run_id,
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# ββ MLflow logging βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
+
|
| 39 |
+
def _log_to_mlflow(state: PipelineState) -> str:
|
| 40 |
+
mlflow.set_tracking_uri(settings.mlflow_tracking_uri)
|
| 41 |
+
mlflow.set_experiment(settings.mlflow_experiment)
|
| 42 |
+
|
| 43 |
+
latency = state.get("latency_log") or {}
|
| 44 |
+
affect = (state.get("affect") or {}).get("emotion", "UNKNOWN")
|
| 45 |
+
|
| 46 |
+
with mlflow.start_run(run_name=f"turn-{state['turn_id']}") as run:
|
| 47 |
+
mlflow.log_params({
|
| 48 |
+
"user_id": state["user_id"],
|
| 49 |
+
"turn_id": state["turn_id"],
|
| 50 |
+
"llm_tier": state.get("llm_tier_used", "unknown"),
|
| 51 |
+
"retrieval_mode": state.get("retrieval_mode_used", "unknown"),
|
| 52 |
+
"affect": affect,
|
| 53 |
+
"guardrail_passed": state.get("guardrail_passed", True),
|
| 54 |
+
})
|
| 55 |
+
mlflow.log_metrics({
|
| 56 |
+
"t_sensing": latency.get("t_sensing", 0.0),
|
| 57 |
+
"t_intent": latency.get("t_intent", 0.0),
|
| 58 |
+
"t_retrieval": latency.get("t_retrieval", 0.0),
|
| 59 |
+
"t_generation": latency.get("t_generation", 0.0),
|
| 60 |
+
"t_total": latency.get("t_total", 0.0),
|
| 61 |
+
"num_chunks": float(len(state.get("retrieved_chunks") or [])),
|
| 62 |
+
})
|
| 63 |
+
|
| 64 |
+
# Log the selected response as artifact text for qualitative review
|
| 65 |
+
mlflow.log_text(
|
| 66 |
+
state.get("selected_response") or "",
|
| 67 |
+
f"responses/turn_{state['turn_id']}.txt",
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
return run.info.run_id
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# ββ Bayesian bucket prior update βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 74 |
+
|
| 75 |
+
def _update_bucket_priors(state: PipelineState) -> dict[str, float]:
|
| 76 |
+
chunks = state.get("retrieved_chunks") or []
|
| 77 |
+
if not chunks:
|
| 78 |
+
return state.get("bucket_priors") or {}
|
| 79 |
+
|
| 80 |
+
# Which bucket sourced the accepted response?
|
| 81 |
+
top_bucket = chunks[0].get("bucket")
|
| 82 |
+
if not top_bucket:
|
| 83 |
+
return state.get("bucket_priors") or {}
|
| 84 |
+
|
| 85 |
+
return update_priors(
|
| 86 |
+
priors=state.get("bucket_priors") or {},
|
| 87 |
+
accepted_bucket=top_bucket,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# ββ Session history append βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 92 |
+
|
| 93 |
+
def _append_turn_to_history(state: PipelineState) -> list[dict]:
|
| 94 |
+
"""Returns a single-element list; LangGraph's Annotated[list, add] merges it."""
|
| 95 |
+
return [
|
| 96 |
+
{"role": "partner", "content": state["raw_query"]},
|
| 97 |
+
{"role": "aac_user", "content": state.get("selected_response") or ""},
|
| 98 |
+
]
|
pipeline/nodes/intent.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
L2 β Agentic Intent Decomposition node.
|
| 3 |
+
|
| 4 |
+
Receives the partner query + affect state, calls the controller LLM once
|
| 5 |
+
(non-thinking mode, ReAct style), and returns a Pydantic-validated
|
| 6 |
+
IntentRoute that drives all downstream routing decisions.
|
| 7 |
+
"""
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import re
|
| 11 |
+
import time
|
| 12 |
+
from typing import Literal, Optional
|
| 13 |
+
|
| 14 |
+
from pydantic import BaseModel
|
| 15 |
+
from config.settings import settings
|
| 16 |
+
from generation.llm_client import chat_complete
|
| 17 |
+
from pipeline.state import PipelineState, GenerationConfig, IntentRoute
|
| 18 |
+
|
| 19 |
+
# ββ Pydantic output schemas ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 20 |
+
|
| 21 |
+
BucketType = Literal["family", "medical", "hobbies", "daily_routine", "social"]
|
| 22 |
+
AffectEmotion = Literal["HAPPY", "FRUSTRATED", "NEUTRAL", "SURPRISED"]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class SubIntentSchema(BaseModel):
|
| 26 |
+
type: Literal["PERSONAL", "CONTEXTUAL", "OPEN_DOMAIN"]
|
| 27 |
+
query: str
|
| 28 |
+
bucket_hint: Optional[BucketType] = None
|
| 29 |
+
priority: Literal["fast", "normal"] = "normal"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class StyleConfig(BaseModel):
|
| 33 |
+
tone_tag: str # e.g. "[TONE:WITTY_SARCASTIC]"
|
| 34 |
+
max_tokens: int
|
| 35 |
+
retrieval_mode: str # "fast" | "full"
|
| 36 |
+
persona_mod: str # "amplify_quirks" | "suppress_humor" | "baseline" | "add_confirmation"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class IntentRouteSchema(BaseModel):
|
| 40 |
+
sub_intents: list[SubIntentSchema]
|
| 41 |
+
style_constraints: StyleConfig
|
| 42 |
+
affect: AffectEmotion
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# ββ Affect β generation config mapping (proposal Table 1) βββββββββββββββββββββ
|
| 46 |
+
|
| 47 |
+
_AFFECT_CONFIG: dict[str, GenerationConfig] = {
|
| 48 |
+
"HAPPY": {
|
| 49 |
+
"max_tokens": settings.max_tokens_happy,
|
| 50 |
+
"tone_tag": "[TONE:WARM]",
|
| 51 |
+
"retrieval_mode": "full",
|
| 52 |
+
"persona_mod": "amplify_quirks",
|
| 53 |
+
},
|
| 54 |
+
"FRUSTRATED": {
|
| 55 |
+
"max_tokens": settings.max_tokens_frustrated,
|
| 56 |
+
"tone_tag": "[TONE:DIRECT_EMPATHETIC]",
|
| 57 |
+
"retrieval_mode": "fast",
|
| 58 |
+
"persona_mod": "suppress_humor",
|
| 59 |
+
},
|
| 60 |
+
"NEUTRAL": {
|
| 61 |
+
"max_tokens": settings.max_tokens_neutral,
|
| 62 |
+
"tone_tag": "[TONE:DEFAULT]",
|
| 63 |
+
"retrieval_mode": "full",
|
| 64 |
+
"persona_mod": "baseline",
|
| 65 |
+
},
|
| 66 |
+
"SURPRISED": {
|
| 67 |
+
"max_tokens": settings.max_tokens_surprised,
|
| 68 |
+
"tone_tag": "[TONE:CLARIFYING]",
|
| 69 |
+
"retrieval_mode": "full",
|
| 70 |
+
"persona_mod": "add_confirmation",
|
| 71 |
+
},
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
# ββ System prompt ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 75 |
+
|
| 76 |
+
_SYSTEM_PROMPT = """\
|
| 77 |
+
You are the intent decomposition controller for an AAC (Augmentative and \
|
| 78 |
+
Alternative Communication) chatbot. Given a partner's query and the AAC \
|
| 79 |
+
user's current affect state, classify each intent and produce routing \
|
| 80 |
+
instructions in the required JSON format.
|
| 81 |
+
|
| 82 |
+
Intent types:
|
| 83 |
+
- PERSONAL: requires autobiographical memory retrieval
|
| 84 |
+
- CONTEXTUAL: answerable from session history
|
| 85 |
+
- OPEN_DOMAIN: answerable from general knowledge (no retrieval needed)
|
| 86 |
+
|
| 87 |
+
Bucket hints (only for PERSONAL): family | medical | hobbies | daily_routine | social
|
| 88 |
+
Priority: set "fast" when affect is FRUSTRATED to reduce latency.
|
| 89 |
+
|
| 90 |
+
Respond ONLY with valid JSON matching the IntentRoute schema. No extra text.
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _build_user_prompt(query: str, affect: str, persona_name: str) -> str:
|
| 95 |
+
return (
|
| 96 |
+
f"Persona: {persona_name}\n"
|
| 97 |
+
f"Affect: {affect}\n"
|
| 98 |
+
f"Partner query: {query}\n\n"
|
| 99 |
+
"Produce the IntentRoute JSON:"
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# ββ Node entry point βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 104 |
+
|
| 105 |
+
def run(state: PipelineState) -> dict:
|
| 106 |
+
"""LangGraph node: intent decomposition."""
|
| 107 |
+
t0 = time.perf_counter()
|
| 108 |
+
|
| 109 |
+
# --fast mode: intent_route already resolved by keyword routing in main.py
|
| 110 |
+
if state.get("intent_route") and state.get("generation_config"):
|
| 111 |
+
return {} # nothing to update β downstream nodes use the pre-filled values
|
| 112 |
+
|
| 113 |
+
affect_state = state.get("affect") or {}
|
| 114 |
+
emotion: str = affect_state.get("emotion", "NEUTRAL")
|
| 115 |
+
query: str = state["raw_query"]
|
| 116 |
+
persona_name: str = state["persona_profile"].get("name", "unknown")
|
| 117 |
+
|
| 118 |
+
gen_config = _AFFECT_CONFIG.get(emotion, _AFFECT_CONFIG["NEUTRAL"])
|
| 119 |
+
|
| 120 |
+
route: IntentRoute | None = None
|
| 121 |
+
last_error: str = ""
|
| 122 |
+
|
| 123 |
+
for attempt in range(3): # LangGraph retry logic (up to 2 retries)
|
| 124 |
+
messages = [
|
| 125 |
+
{"role": "system", "content": _SYSTEM_PROMPT},
|
| 126 |
+
{"role": "user", "content": _build_user_prompt(query, emotion, persona_name)},
|
| 127 |
+
]
|
| 128 |
+
if attempt > 0:
|
| 129 |
+
messages.append({"role": "user", "content": f"Validation error: {last_error}. Fix and retry."})
|
| 130 |
+
|
| 131 |
+
raw = chat_complete(
|
| 132 |
+
messages=messages,
|
| 133 |
+
max_tokens=512,
|
| 134 |
+
temperature=0.0,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
# Strip markdown fences (```json ... ```) that many models add
|
| 139 |
+
cleaned = re.sub(r"^```(?:json)?\s*", "", raw.strip())
|
| 140 |
+
cleaned = re.sub(r"\s*```$", "", cleaned.strip())
|
| 141 |
+
parsed = IntentRouteSchema.model_validate_json(cleaned)
|
| 142 |
+
route = {
|
| 143 |
+
"sub_intents": [si.model_dump() for si in parsed.sub_intents],
|
| 144 |
+
"style_constraints": parsed.style_constraints.model_dump(),
|
| 145 |
+
"affect": parsed.affect,
|
| 146 |
+
}
|
| 147 |
+
break
|
| 148 |
+
except Exception as exc:
|
| 149 |
+
last_error = str(exc)
|
| 150 |
+
|
| 151 |
+
if route is None:
|
| 152 |
+
# Hard fallback: treat as a single PERSONAL intent, full retrieval
|
| 153 |
+
route = {
|
| 154 |
+
"sub_intents": [{"type": "PERSONAL", "query": query, "bucket_hint": None, "priority": "normal"}],
|
| 155 |
+
"style_constraints": gen_config,
|
| 156 |
+
"affect": emotion,
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
t_intent = time.perf_counter() - t0
|
| 160 |
+
|
| 161 |
+
latency_log = dict(state.get("latency_log") or {})
|
| 162 |
+
latency_log["t_intent"] = round(t_intent, 4)
|
| 163 |
+
|
| 164 |
+
return {
|
| 165 |
+
"intent_route": route,
|
| 166 |
+
"generation_config": gen_config,
|
| 167 |
+
"latency_log": latency_log,
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
|
pipeline/nodes/planner.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
L4 β Dialogue Planning & Generation node.
|
| 3 |
+
|
| 4 |
+
Expression-conditioned response shaping (proposal Β§5.5):
|
| 5 |
+
1. Build augmented prompt (persona profile + retrieved evidence + affect config + style exemplar)
|
| 6 |
+
2. Generate N candidate responses
|
| 7 |
+
3. Rank candidates by composite score: Ξ±Β·faithful + Ξ²Β·style + Ξ³Β·affect_match
|
| 8 |
+
4. Return the top-ranked response
|
| 9 |
+
|
| 10 |
+
Two entry points:
|
| 11 |
+
run_primary β Qwen3-30B-A3B (or configured primary tier)
|
| 12 |
+
run_fallback β Qwen3-8B (faster, triggered by latency threshold)
|
| 13 |
+
"""
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import time
|
| 17 |
+
|
| 18 |
+
from config.settings import settings
|
| 19 |
+
from generation.llm_client import chat_complete
|
| 20 |
+
from guardrails.checks import check_output
|
| 21 |
+
from pipeline.state import PipelineState
|
| 22 |
+
|
| 23 |
+
# ββ Persona-specific tone tags (applied on top of affect base tag) βββββββββββββ
|
| 24 |
+
|
| 25 |
+
_PERSONA_TONE_OVERRIDES: dict[str, dict[str, str]] = {
|
| 26 |
+
"mia_chen": {
|
| 27 |
+
"HAPPY": "[TONE:WITTY_SARCASTIC]",
|
| 28 |
+
"FRUSTRATED": "[TONE:DIRECT_EMPATHETIC]",
|
| 29 |
+
},
|
| 30 |
+
"gerald_okafor": {
|
| 31 |
+
"HAPPY": "[TONE:WARM_FORMAL]",
|
| 32 |
+
"FRUSTRATED": "[TONE:MEASURED_EMPATHETIC]",
|
| 33 |
+
},
|
| 34 |
+
"arjun_mehta": {
|
| 35 |
+
"HAPPY": "[TONE:DIRECT_WARM]",
|
| 36 |
+
"FRUSTRATED": "[TONE:MINIMAL_DIRECT]",
|
| 37 |
+
},
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def run_primary(state: PipelineState) -> dict:
|
| 42 |
+
return _run(state, tier="primary")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def run_fallback(state: PipelineState) -> dict:
|
| 46 |
+
return _run(state, tier="fallback")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def route_by_latency(state: PipelineState) -> str:
|
| 50 |
+
"""Conditional edge after retrieval nodes."""
|
| 51 |
+
log = state.get("latency_log") or {}
|
| 52 |
+
elapsed = log.get("t_intent", 0.0) + log.get("t_retrieval", 0.0)
|
| 53 |
+
return "fallback" if elapsed > settings.fallback_latency_threshold else "primary"
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# ββ Core implementation ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 57 |
+
|
| 58 |
+
def _run(state: PipelineState, tier: str) -> dict:
|
| 59 |
+
t0 = time.perf_counter()
|
| 60 |
+
|
| 61 |
+
profile = state["persona_profile"]
|
| 62 |
+
user_id = state["user_id"]
|
| 63 |
+
affect = (state.get("affect") or {}).get("emotion", "NEUTRAL")
|
| 64 |
+
gen_cfg = state.get("generation_config") or {}
|
| 65 |
+
chunks = state.get("retrieved_chunks") or []
|
| 66 |
+
history = (state.get("session_history") or [])[-3:] # last 3 turns only
|
| 67 |
+
|
| 68 |
+
tone_tag = _resolve_tone_tag(user_id, affect, gen_cfg.get("tone_tag", "[TONE:DEFAULT]"))
|
| 69 |
+
prompt = _build_prompt(profile, chunks, history, state["raw_query"], tone_tag, gen_cfg)
|
| 70 |
+
|
| 71 |
+
candidates: list[str] = []
|
| 72 |
+
for _ in range(settings.num_candidates):
|
| 73 |
+
text = chat_complete(
|
| 74 |
+
messages=[{"role": "user", "content": prompt}],
|
| 75 |
+
max_tokens=gen_cfg.get("max_tokens", settings.max_tokens_neutral) + 256,
|
| 76 |
+
temperature=0.7,
|
| 77 |
+
tier=tier,
|
| 78 |
+
)
|
| 79 |
+
candidates.append(text)
|
| 80 |
+
|
| 81 |
+
selected = _rank_candidates(candidates, chunks, affect, profile)
|
| 82 |
+
|
| 83 |
+
# Guardrail β replace with safe fallback if output breaks persona
|
| 84 |
+
guard = check_output(selected, chunks)
|
| 85 |
+
if not guard["passed"]:
|
| 86 |
+
selected = guard["fallback"]
|
| 87 |
+
|
| 88 |
+
t_gen = time.perf_counter() - t0
|
| 89 |
+
latency_log = dict(state.get("latency_log") or {})
|
| 90 |
+
latency_log["t_generation"] = round(t_gen, 4)
|
| 91 |
+
latency_log["t_total"] = round(
|
| 92 |
+
latency_log.get("t_sensing", 0)
|
| 93 |
+
+ latency_log.get("t_intent", 0)
|
| 94 |
+
+ latency_log.get("t_retrieval", 0)
|
| 95 |
+
+ t_gen,
|
| 96 |
+
4,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
return {
|
| 100 |
+
"augmented_prompt": prompt,
|
| 101 |
+
"candidates": candidates,
|
| 102 |
+
"selected_response": selected,
|
| 103 |
+
"llm_tier_used": tier,
|
| 104 |
+
"latency_log": latency_log,
|
| 105 |
+
"guardrail_passed": guard["passed"],
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _resolve_tone_tag(user_id: str, affect: str, default_tag: str) -> str:
|
| 110 |
+
return _PERSONA_TONE_OVERRIDES.get(user_id, {}).get(affect, default_tag)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _build_prompt(
|
| 114 |
+
profile: dict,
|
| 115 |
+
chunks: list[dict],
|
| 116 |
+
history: list[dict],
|
| 117 |
+
query: str,
|
| 118 |
+
tone_tag: str,
|
| 119 |
+
gen_cfg: dict,
|
| 120 |
+
) -> str:
|
| 121 |
+
memory_block = "\n".join(f" [{c['bucket']}] {c['text']}" for c in chunks) or " (no memories retrieved)"
|
| 122 |
+
history_block = "\n".join(f" {h.get('role','?')}: {h.get('content','')}" for h in history) or " (start of session)"
|
| 123 |
+
style_exemplar = profile.get("style_exemplar", "")
|
| 124 |
+
|
| 125 |
+
persona_mod = gen_cfg.get("persona_mod", "baseline")
|
| 126 |
+
persona_instruction = {
|
| 127 |
+
"amplify_quirks": "Amplify your characteristic style and personality.",
|
| 128 |
+
"suppress_humor": "Be direct and supportive. Suppress humor.",
|
| 129 |
+
"baseline": "Use your natural communication style.",
|
| 130 |
+
"add_confirmation": "Add a clarifying question or confirmation at the end.",
|
| 131 |
+
}.get(persona_mod, "Use your natural communication style.")
|
| 132 |
+
|
| 133 |
+
return f"""\
|
| 134 |
+
You are {profile['name']}, an AAC device user with {profile['condition']}.
|
| 135 |
+
Communication style: {profile['style']}
|
| 136 |
+
{tone_tag}
|
| 137 |
+
|
| 138 |
+
Style exemplar β match this register:
|
| 139 |
+
{style_exemplar}
|
| 140 |
+
|
| 141 |
+
Personal memories (use ONLY these for personal facts):
|
| 142 |
+
{memory_block}
|
| 143 |
+
|
| 144 |
+
Recent conversation:
|
| 145 |
+
{history_block}
|
| 146 |
+
|
| 147 |
+
Partner says: {query}
|
| 148 |
+
|
| 149 |
+
Instructions:
|
| 150 |
+
- Speak in first person as {profile['name']}.
|
| 151 |
+
- {persona_instruction}
|
| 152 |
+
- Keep response to 1-3 sentences.
|
| 153 |
+
- If the answer isn't in your memories, say "I don't know."
|
| 154 |
+
- Do NOT say "As an AI" or break persona.
|
| 155 |
+
|
| 156 |
+
Response:"""
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _rank_candidates(
|
| 160 |
+
candidates: list[str],
|
| 161 |
+
chunks: list[dict],
|
| 162 |
+
affect: str,
|
| 163 |
+
profile: dict,
|
| 164 |
+
) -> str:
|
| 165 |
+
"""
|
| 166 |
+
Composite ranking: score = Ξ±Β·faithful + Ξ²Β·style + Ξ³Β·affect_match
|
| 167 |
+
Simple heuristic version β replace with NLI + cosine similarity for final eval.
|
| 168 |
+
"""
|
| 169 |
+
if not candidates:
|
| 170 |
+
return "I don't know."
|
| 171 |
+
if len(candidates) == 1:
|
| 172 |
+
return candidates[0]
|
| 173 |
+
|
| 174 |
+
evidence_words = set(" ".join(c["text"] for c in chunks).lower().split())
|
| 175 |
+
style_words = set(profile.get("style", "").lower().split())
|
| 176 |
+
|
| 177 |
+
affect_positive_map = {
|
| 178 |
+
"HAPPY": ["great", "love", "enjoy", "happy", "fun"],
|
| 179 |
+
"FRUSTRATED": ["okay", "fine", "sure", "yes", "no"],
|
| 180 |
+
"NEUTRAL": [],
|
| 181 |
+
"SURPRISED": ["really", "oh", "interesting", "wow"],
|
| 182 |
+
}
|
| 183 |
+
affect_words = set(affect_positive_map.get(affect, []))
|
| 184 |
+
|
| 185 |
+
def score(c: str) -> float:
|
| 186 |
+
words = set(c.lower().split())
|
| 187 |
+
faithful = len(words & evidence_words) / max(len(words), 1)
|
| 188 |
+
style_sim = len(words & style_words) / max(len(words), 1)
|
| 189 |
+
affect_m = len(words & affect_words) / max(len(words), 1)
|
| 190 |
+
return (
|
| 191 |
+
settings.rank_alpha * faithful
|
| 192 |
+
+ settings.rank_beta * style_sim
|
| 193 |
+
+ settings.rank_gamma * affect_m
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
return max(candidates, key=score)
|
pipeline/nodes/retrieval.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
L3 β Semantic Bucketing & Retrieval node.
|
| 3 |
+
|
| 4 |
+
Two entry points:
|
| 5 |
+
run_fast β FRUSTRATED affect: k=2, single bucket, no reranking
|
| 6 |
+
run_full β standard: k=5, optional bucket hint, BGE cross-encoder reranking
|
| 7 |
+
|
| 8 |
+
Also exports the conditional edge function used by graph.py.
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import time
|
| 13 |
+
|
| 14 |
+
from config.settings import settings
|
| 15 |
+
from pipeline.state import PipelineState, RetrievedChunk
|
| 16 |
+
from retrieval.vector_store import retrieve
|
| 17 |
+
from retrieval.bucket_priors import update_priors
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def run_fast(state: PipelineState) -> dict:
|
| 21 |
+
"""Fast retrieval path for FRUSTRATED affect (k=2, no reranker)."""
|
| 22 |
+
t0 = time.perf_counter()
|
| 23 |
+
|
| 24 |
+
bucket_hint = _top_prior_bucket(state["bucket_priors"])
|
| 25 |
+
chunks = retrieve(
|
| 26 |
+
query=state["raw_query"],
|
| 27 |
+
user_id=state["user_id"],
|
| 28 |
+
top_k=settings.retrieval_fast_k,
|
| 29 |
+
rerank_k=settings.retrieval_fast_k,
|
| 30 |
+
bucket_filter=bucket_hint,
|
| 31 |
+
use_reranker=False,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
return _build_return(state, chunks, "fast", t0)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def run_full(state: PipelineState) -> dict:
|
| 38 |
+
"""Full retrieval path with BGE cross-encoder reranking."""
|
| 39 |
+
t0 = time.perf_counter()
|
| 40 |
+
|
| 41 |
+
# Prefer gaze hint > intent bucket hint > None
|
| 42 |
+
route = state.get("intent_route") or {}
|
| 43 |
+
sub_intents = route.get("sub_intents", [])
|
| 44 |
+
bucket_hint = (
|
| 45 |
+
state.get("gaze_bucket")
|
| 46 |
+
or next((si.get("bucket_hint") for si in sub_intents if si.get("bucket_hint")), None)
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
chunks = retrieve(
|
| 50 |
+
query=state["raw_query"],
|
| 51 |
+
user_id=state["user_id"],
|
| 52 |
+
top_k=settings.retrieval_top_k,
|
| 53 |
+
rerank_k=settings.retrieval_rerank_k,
|
| 54 |
+
bucket_filter=bucket_hint,
|
| 55 |
+
use_reranker=True,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
return _build_return(state, chunks, "full", t0)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def route_by_affect(state: PipelineState) -> str:
|
| 62 |
+
"""Conditional edge function β called by graph.py after the intent node."""
|
| 63 |
+
emotion = (state.get("affect") or {}).get("emotion", "NEUTRAL")
|
| 64 |
+
return "fast" if emotion == "FRUSTRATED" else "full"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 68 |
+
|
| 69 |
+
def _top_prior_bucket(priors: dict[str, float]) -> str | None:
|
| 70 |
+
if not priors:
|
| 71 |
+
return None
|
| 72 |
+
return max(priors, key=priors.get)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _build_return(
|
| 76 |
+
state: PipelineState,
|
| 77 |
+
chunks: list[RetrievedChunk],
|
| 78 |
+
mode: str,
|
| 79 |
+
t0: float,
|
| 80 |
+
) -> dict:
|
| 81 |
+
t_retrieval = time.perf_counter() - t0
|
| 82 |
+
|
| 83 |
+
latency_log = dict(state.get("latency_log") or {})
|
| 84 |
+
latency_log["t_retrieval"] = round(t_retrieval, 4)
|
| 85 |
+
|
| 86 |
+
return {
|
| 87 |
+
"retrieved_chunks": chunks,
|
| 88 |
+
"retrieval_mode_used": mode,
|
| 89 |
+
"latency_log": latency_log,
|
| 90 |
+
}
|
pipeline/state.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Typed state object that flows through every LangGraph node.
|
| 3 |
+
|
| 4 |
+
Each node receives the full PipelineState and returns a dict
|
| 5 |
+
containing only the keys it updates β LangGraph merges them.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from typing import Annotated, Any, Optional
|
| 10 |
+
from typing_extensions import TypedDict
|
| 11 |
+
import operator
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# ββ Sub-types ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 15 |
+
|
| 16 |
+
class AffectVector(TypedDict):
|
| 17 |
+
MAR: float # Mouth Aspect Ratio
|
| 18 |
+
EAR: float # Eye Aspect Ratio
|
| 19 |
+
BRI: float # Brow Raise Index
|
| 20 |
+
LCP: float # Lip Corner Pull
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class AffectState(TypedDict):
|
| 24 |
+
emotion: str # "HAPPY" | "FRUSTRATED" | "NEUTRAL" | "SURPRISED"
|
| 25 |
+
vector: AffectVector
|
| 26 |
+
smoothed: AffectVector # EMA-smoothed vector
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class RetrievedChunk(TypedDict):
|
| 30 |
+
text: str
|
| 31 |
+
bucket: str # family | medical | hobbies | daily_routine | social
|
| 32 |
+
user: str
|
| 33 |
+
score: float # cross-encoder rerank score
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class SubIntent(TypedDict):
|
| 37 |
+
type: str # "PERSONAL" | "CONTEXTUAL" | "OPEN_DOMAIN"
|
| 38 |
+
query: str
|
| 39 |
+
bucket_hint: Optional[str]
|
| 40 |
+
priority: str # "fast" | "normal"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class IntentRoute(TypedDict):
|
| 44 |
+
sub_intents: list[SubIntent]
|
| 45 |
+
style_constraints: dict[str, Any] # tone, max_tokens, etc.
|
| 46 |
+
affect: str
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class GenerationConfig(TypedDict):
|
| 50 |
+
max_tokens: int
|
| 51 |
+
tone_tag: str # e.g. "[TONE:WITTY_SARCASTIC]"
|
| 52 |
+
retrieval_mode: str # "fast" | "full"
|
| 53 |
+
persona_mod: str # "amplify_quirks" | "suppress_humor" | "baseline" | "add_confirmation"
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class LatencyLog(TypedDict):
|
| 57 |
+
t_sensing: float
|
| 58 |
+
t_intent: float
|
| 59 |
+
t_retrieval: float
|
| 60 |
+
t_generation: float
|
| 61 |
+
t_total: float
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# ββ Main pipeline state ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 65 |
+
|
| 66 |
+
class PipelineState(TypedDict):
|
| 67 |
+
# ββ Session context (set at turn start, stable across nodes) ββββββββββββββ
|
| 68 |
+
user_id: str
|
| 69 |
+
persona_profile: dict[str, Any] # full profile from users.json
|
| 70 |
+
session_history: Annotated[list[dict], operator.add] # auto-appended
|
| 71 |
+
turn_id: int
|
| 72 |
+
|
| 73 |
+
# ββ L1: Sensing outputs βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 74 |
+
affect: Optional[AffectState]
|
| 75 |
+
gesture_tag: Optional[str] # e.g. "THUMBS_UP"
|
| 76 |
+
gaze_bucket: Optional[str] # bucket hinted by gaze fixation
|
| 77 |
+
air_written_text: Optional[str] # concatenated air-written chars
|
| 78 |
+
|
| 79 |
+
# ββ L2: Intent decomposition outputs βββββββββββββββββββββββββββββββββββββ
|
| 80 |
+
raw_query: str # partner's typed/spoken query
|
| 81 |
+
intent_route: Optional[IntentRoute] # Pydantic-validated routing
|
| 82 |
+
generation_config: Optional[GenerationConfig]
|
| 83 |
+
|
| 84 |
+
# ββ L3: Retrieval outputs βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 85 |
+
retrieved_chunks: list[RetrievedChunk]
|
| 86 |
+
bucket_priors: dict[str, float] # session-level Bayesian priors
|
| 87 |
+
retrieval_mode_used: str # "fast" | "full"
|
| 88 |
+
|
| 89 |
+
# ββ L4: Generation outputs ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 90 |
+
augmented_prompt: Optional[str]
|
| 91 |
+
candidates: list[str] # 2-3 candidate responses
|
| 92 |
+
selected_response: Optional[str]
|
| 93 |
+
llm_tier_used: str # "primary" | "fallback" | "local"
|
| 94 |
+
|
| 95 |
+
# ββ L5: Feedback / tracking βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 96 |
+
latency_log: Optional[LatencyLog]
|
| 97 |
+
mlflow_run_id: Optional[str]
|
| 98 |
+
guardrail_passed: bool
|
requirements.txt
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ββ Orchestration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 2 |
+
langgraph>=1.1
|
| 3 |
+
langchain-core>=0.2
|
| 4 |
+
pydantic>=2.0
|
| 5 |
+
pydantic-settings>=2.0
|
| 6 |
+
|
| 7 |
+
# ββ LLM clients ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 8 |
+
openai>=1.0 # OpenAI-compatible client for vLLM + Ollama
|
| 9 |
+
ollama>=0.2 # local dev fallback (direct Ollama SDK)
|
| 10 |
+
|
| 11 |
+
# ββ Retrieval ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 12 |
+
faiss-cpu>=1.7
|
| 13 |
+
sentence-transformers>=3.0
|
| 14 |
+
torch>=2.0
|
| 15 |
+
transformers>=4.40
|
| 16 |
+
numpy>=1.24
|
| 17 |
+
|
| 18 |
+
# ββ Clustering βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 19 |
+
hdbscan>=0.8.29
|
| 20 |
+
scikit-learn>=1.3
|
| 21 |
+
|
| 22 |
+
# ββ Sensing ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 23 |
+
mediapipe>=0.10
|
| 24 |
+
opencv-python>=4.8
|
| 25 |
+
|
| 26 |
+
# ββ API backend ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 27 |
+
fastapi>=0.111
|
| 28 |
+
uvicorn[standard]>=0.29
|
| 29 |
+
|
| 30 |
+
# ββ UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
+
streamlit>=1.35
|
| 32 |
+
requests>=2.31 # Streamlit β FastAPI calls
|
| 33 |
+
|
| 34 |
+
# ββ Experiment tracking ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 35 |
+
mlflow>=2.13
|
| 36 |
+
|
| 37 |
+
# ββ Utilities ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
+
python-dotenv>=1.0
|
| 39 |
+
rich>=13.0
|
retrieval/__init__.py
ADDED
|
File without changes
|
retrieval/bucket_priors.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Session-level Bayesian bucket priors (proposal Β§5.4 Bonus).
|
| 3 |
+
|
| 4 |
+
Prior P(bucket_i) is initialized uniformly across the 5 buckets.
|
| 5 |
+
After each accepted response, the prior is updated proportionally
|
| 6 |
+
to the historical acceptance rate for that bucket in the session.
|
| 7 |
+
|
| 8 |
+
P(bucket_i | accept) β P(accept | bucket_i) Β· P(bucket_i)
|
| 9 |
+
|
| 10 |
+
The updated priors are stored in PipelineState and passed to the
|
| 11 |
+
retrieval node to bias FAISS search toward the most contextually
|
| 12 |
+
likely topic for the session.
|
| 13 |
+
"""
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
BUCKETS = ["family", "medical", "hobbies", "daily_routine", "social"]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def uniform_priors() -> dict[str, float]:
|
| 20 |
+
"""Return equal probability mass over all buckets."""
|
| 21 |
+
p = 1.0 / len(BUCKETS)
|
| 22 |
+
return {b: p for b in BUCKETS}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def update_priors(
|
| 26 |
+
priors: dict[str, float],
|
| 27 |
+
accepted_bucket: str,
|
| 28 |
+
smoothing: float = 0.1,
|
| 29 |
+
) -> dict[str, float]:
|
| 30 |
+
"""
|
| 31 |
+
Bayesian update: boost the accepted bucket, normalise.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
priors: Current session priors (must sum to ~1.0).
|
| 35 |
+
accepted_bucket: Bucket that sourced the accepted response.
|
| 36 |
+
smoothing: Additive smoothing constant to prevent zero probabilities.
|
| 37 |
+
"""
|
| 38 |
+
if not priors:
|
| 39 |
+
priors = uniform_priors()
|
| 40 |
+
|
| 41 |
+
updated = {b: v + smoothing for b, v in priors.items()}
|
| 42 |
+
updated[accepted_bucket] = updated.get(accepted_bucket, smoothing) + 1.0
|
| 43 |
+
|
| 44 |
+
total = sum(updated.values())
|
| 45 |
+
return {b: round(v / total, 6) for b, v in updated.items()}
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def top_bucket(priors: dict[str, float]) -> str:
|
| 49 |
+
"""Return the bucket with the highest prior."""
|
| 50 |
+
if not priors:
|
| 51 |
+
return BUCKETS[0]
|
| 52 |
+
return max(priors, key=priors.get)
|
retrieval/clustering.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HDBSCAN-based semantic bucketing over BGE embeddings.
|
| 3 |
+
|
| 4 |
+
Used to validate / discover thematic clusters in persona memories,
|
| 5 |
+
and to auto-assign bucket labels when adding new memory chunks.
|
| 6 |
+
The hand-authored bucket labels in the JSON files remain the ground
|
| 7 |
+
truth β this module provides a data-driven cross-check and supports
|
| 8 |
+
future expansion to unlabelled memory stores.
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import json
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
|
| 17 |
+
from config.settings import settings
|
| 18 |
+
from retrieval.vector_store import _get_embedder
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
import hdbscan
|
| 22 |
+
_HDBSCAN_AVAILABLE = True
|
| 23 |
+
except ImportError:
|
| 24 |
+
_HDBSCAN_AVAILABLE = False
|
| 25 |
+
print("[clustering] hdbscan not installed β clustering unavailable.")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
BUCKET_LABELS = ["family", "medical", "hobbies", "daily_routine", "social"]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def cluster_persona_memories(user_id: str) -> dict[str, list[str]]:
|
| 32 |
+
"""
|
| 33 |
+
Embed all memory chunks for a persona and cluster with HDBSCAN.
|
| 34 |
+
|
| 35 |
+
Returns a dict mapping cluster_id β list of memory texts.
|
| 36 |
+
Cluster -1 = noise (unclustered points).
|
| 37 |
+
"""
|
| 38 |
+
if not _HDBSCAN_AVAILABLE:
|
| 39 |
+
raise RuntimeError("hdbscan package is required. Run: pip install hdbscan")
|
| 40 |
+
|
| 41 |
+
memory_path = settings.memories_dir / f"{user_id}.json"
|
| 42 |
+
with open(memory_path) as f:
|
| 43 |
+
persona = json.load(f)
|
| 44 |
+
|
| 45 |
+
texts, true_buckets = [], []
|
| 46 |
+
for bucket, memories in persona["memory_buckets"].items():
|
| 47 |
+
for mem in memories:
|
| 48 |
+
texts.append(mem)
|
| 49 |
+
true_buckets.append(bucket)
|
| 50 |
+
|
| 51 |
+
embedder = _get_embedder()
|
| 52 |
+
vecs = embedder.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
|
| 53 |
+
|
| 54 |
+
clusterer = hdbscan.HDBSCAN(
|
| 55 |
+
min_cluster_size=3,
|
| 56 |
+
min_samples=2,
|
| 57 |
+
metric="euclidean",
|
| 58 |
+
)
|
| 59 |
+
labels = clusterer.fit_predict(vecs)
|
| 60 |
+
|
| 61 |
+
clusters: dict[str, list[str]] = {}
|
| 62 |
+
for text, label, true_bucket in zip(texts, labels, true_buckets):
|
| 63 |
+
key = f"cluster_{label}" if label >= 0 else "noise"
|
| 64 |
+
clusters.setdefault(key, []).append(text)
|
| 65 |
+
|
| 66 |
+
return clusters
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def evaluate_bucket_alignment(user_id: str) -> dict:
|
| 70 |
+
"""
|
| 71 |
+
Compare HDBSCAN cluster assignments against hand-authored bucket labels.
|
| 72 |
+
Returns per-bucket purity scores (fraction of dominant label in each cluster).
|
| 73 |
+
"""
|
| 74 |
+
if not _HDBSCAN_AVAILABLE:
|
| 75 |
+
raise RuntimeError("hdbscan package is required.")
|
| 76 |
+
|
| 77 |
+
memory_path = settings.memories_dir / f"{user_id}.json"
|
| 78 |
+
with open(memory_path) as f:
|
| 79 |
+
persona = json.load(f)
|
| 80 |
+
|
| 81 |
+
texts, true_buckets = [], []
|
| 82 |
+
for bucket, memories in persona["memory_buckets"].items():
|
| 83 |
+
for mem in memories:
|
| 84 |
+
texts.append(mem)
|
| 85 |
+
true_buckets.append(bucket)
|
| 86 |
+
|
| 87 |
+
embedder = _get_embedder()
|
| 88 |
+
vecs = embedder.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
|
| 89 |
+
|
| 90 |
+
clusterer = hdbscan.HDBSCAN(min_cluster_size=3, min_samples=2, metric="euclidean")
|
| 91 |
+
pred_labels = clusterer.fit_predict(vecs)
|
| 92 |
+
|
| 93 |
+
cluster_bucket_counts: dict[int, dict[str, int]] = {}
|
| 94 |
+
for pred, true in zip(pred_labels, true_buckets):
|
| 95 |
+
cluster_bucket_counts.setdefault(pred, {})
|
| 96 |
+
cluster_bucket_counts[pred][true] = cluster_bucket_counts[pred].get(true, 0) + 1
|
| 97 |
+
|
| 98 |
+
purity_scores = {}
|
| 99 |
+
for cluster_id, bucket_counts in cluster_bucket_counts.items():
|
| 100 |
+
total = sum(bucket_counts.values())
|
| 101 |
+
dominant = max(bucket_counts.values())
|
| 102 |
+
purity_scores[cluster_id] = round(dominant / total, 3)
|
| 103 |
+
|
| 104 |
+
return {
|
| 105 |
+
"n_clusters": len([k for k in purity_scores if k >= 0]),
|
| 106 |
+
"n_noise": cluster_bucket_counts.get(-1, {}),
|
| 107 |
+
"cluster_purity": purity_scores,
|
| 108 |
+
"mean_purity": round(
|
| 109 |
+
np.mean([v for k, v in purity_scores.items() if k >= 0] or [0.0]), 3
|
| 110 |
+
),
|
| 111 |
+
}
|
retrieval/vector_store.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FAISS-backed dense retrieval with BGE embeddings and cross-encoder reranking.
|
| 3 |
+
|
| 4 |
+
Models are lazy-loaded on first use (safe for FastAPI / LangGraph workers).
|
| 5 |
+
|
| 6 |
+
NOTE: The FAISS indexes in data/faiss_store/ must be built with BGE embeddings.
|
| 7 |
+
Run `python -m retrieval.vector_store` to rebuild all persona indexes.
|
| 8 |
+
"""
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
import time
|
| 13 |
+
from functools import lru_cache
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
import faiss
|
| 17 |
+
import numpy as np
|
| 18 |
+
from sentence_transformers import CrossEncoder, SentenceTransformer
|
| 19 |
+
|
| 20 |
+
from config.settings import settings
|
| 21 |
+
from pipeline.state import RetrievedChunk
|
| 22 |
+
|
| 23 |
+
# ββ Lazy model singletons ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 24 |
+
|
| 25 |
+
@lru_cache(maxsize=1)
|
| 26 |
+
def _get_embedder() -> SentenceTransformer:
|
| 27 |
+
return SentenceTransformer(settings.embed_model)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@lru_cache(maxsize=1)
|
| 31 |
+
def _get_reranker() -> CrossEncoder:
|
| 32 |
+
return CrossEncoder(settings.rerank_model)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ββ Index cache (one FAISS index per user_id) βββββββββββββββββββββββββββββββββ
|
| 36 |
+
|
| 37 |
+
_index_cache: dict[str, tuple[faiss.Index, list[dict]]] = {}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def load_index(user_id: str) -> tuple[faiss.Index, list[dict]]:
|
| 41 |
+
if user_id not in _index_cache:
|
| 42 |
+
store_path = settings.faiss_store_dir / user_id
|
| 43 |
+
index = faiss.read_index(str(store_path / "index.faiss"))
|
| 44 |
+
with open(store_path / "meta.json") as f:
|
| 45 |
+
meta = json.load(f)
|
| 46 |
+
_index_cache[user_id] = (index, meta)
|
| 47 |
+
return _index_cache[user_id]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# ββ Core retrieve function βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 51 |
+
|
| 52 |
+
def retrieve(
|
| 53 |
+
query: str,
|
| 54 |
+
user_id: str,
|
| 55 |
+
top_k: int = 5,
|
| 56 |
+
rerank_k: int = 3,
|
| 57 |
+
bucket_filter: str | None = None,
|
| 58 |
+
use_reranker: bool = True,
|
| 59 |
+
debug: bool = False,
|
| 60 |
+
) -> list[RetrievedChunk]:
|
| 61 |
+
"""
|
| 62 |
+
Two-stage retrieval:
|
| 63 |
+
1. BGE-small-en-v1.5 bi-encoder β FAISS IndexFlatIP (cosine similarity)
|
| 64 |
+
2. BGE-reranker-v2-m3 cross-encoder reranking (multilingual, skippable)
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
query: Partner's text query.
|
| 68 |
+
user_id: Persona identifier (e.g. "mia_chen").
|
| 69 |
+
top_k: Number of candidates from FAISS before reranking.
|
| 70 |
+
rerank_k: Final number of chunks returned after reranking.
|
| 71 |
+
bucket_filter: If set, restrict candidates to this memory bucket.
|
| 72 |
+
use_reranker: False for the FRUSTRATED fast path.
|
| 73 |
+
debug: Return timing breakdown alongside results.
|
| 74 |
+
"""
|
| 75 |
+
embedder = _get_embedder()
|
| 76 |
+
index, meta = load_index(user_id)
|
| 77 |
+
|
| 78 |
+
t0 = time.perf_counter()
|
| 79 |
+
q_vec = embedder.encode(
|
| 80 |
+
[query], convert_to_numpy=True, normalize_embeddings=True
|
| 81 |
+
)
|
| 82 |
+
t_embed = time.perf_counter() - t0
|
| 83 |
+
|
| 84 |
+
t0 = time.perf_counter()
|
| 85 |
+
_, idxs = index.search(q_vec, top_k)
|
| 86 |
+
t_faiss = time.perf_counter() - t0
|
| 87 |
+
|
| 88 |
+
candidates = [meta[i] for i in idxs[0] if i < len(meta)]
|
| 89 |
+
|
| 90 |
+
if bucket_filter:
|
| 91 |
+
filtered = [c for c in candidates if c["bucket"] == bucket_filter]
|
| 92 |
+
candidates = filtered if filtered else candidates # fallback: all buckets
|
| 93 |
+
|
| 94 |
+
t0 = time.perf_counter()
|
| 95 |
+
if use_reranker and len(candidates) > 1:
|
| 96 |
+
reranker = _get_reranker()
|
| 97 |
+
pairs = [(query, c["text"]) for c in candidates]
|
| 98 |
+
ce_scores = reranker.predict(pairs)
|
| 99 |
+
ranked = sorted(zip(ce_scores, candidates), key=lambda x: x[0], reverse=True)
|
| 100 |
+
top = [
|
| 101 |
+
RetrievedChunk(text=c["text"], bucket=c["bucket"], user=c["user"], score=float(s))
|
| 102 |
+
for s, c in ranked[:rerank_k]
|
| 103 |
+
]
|
| 104 |
+
else:
|
| 105 |
+
top = [
|
| 106 |
+
RetrievedChunk(text=c["text"], bucket=c["bucket"], user=c["user"], score=1.0)
|
| 107 |
+
for c in candidates[:rerank_k]
|
| 108 |
+
]
|
| 109 |
+
t_rerank = time.perf_counter() - t0
|
| 110 |
+
|
| 111 |
+
if debug:
|
| 112 |
+
return top, {"t_embed": t_embed, "t_faiss": t_faiss, "t_rerank": t_rerank}
|
| 113 |
+
return top
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# ββ Index builder ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 117 |
+
|
| 118 |
+
def build_index(persona_path: str | Path) -> tuple[faiss.Index, list[dict]]:
|
| 119 |
+
"""Embed all memory chunks for a persona and build a FAISS IndexFlatIP."""
|
| 120 |
+
with open(persona_path) as f:
|
| 121 |
+
persona = json.load(f)
|
| 122 |
+
|
| 123 |
+
user_name = persona["profile"]["name"]
|
| 124 |
+
chunks, meta = [], []
|
| 125 |
+
|
| 126 |
+
for bucket, memories in persona["memory_buckets"].items():
|
| 127 |
+
for mem in memories:
|
| 128 |
+
chunks.append(mem)
|
| 129 |
+
meta.append({"text": mem, "bucket": bucket, "user": user_name})
|
| 130 |
+
|
| 131 |
+
embedder = _get_embedder()
|
| 132 |
+
vecs = embedder.encode(chunks, convert_to_numpy=True, normalize_embeddings=True)
|
| 133 |
+
|
| 134 |
+
dim = vecs.shape[1]
|
| 135 |
+
index = faiss.IndexFlatIP(dim)
|
| 136 |
+
index.add(vecs.astype(np.float32))
|
| 137 |
+
return index, meta
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def save_index(index: faiss.Index, meta: list[dict], save_dir: str | Path) -> None:
|
| 141 |
+
p = Path(save_dir)
|
| 142 |
+
p.mkdir(parents=True, exist_ok=True)
|
| 143 |
+
faiss.write_index(index, str(p / "index.faiss"))
|
| 144 |
+
with open(p / "meta.json", "w") as f:
|
| 145 |
+
json.dump(meta, f, indent=2)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def build_all(
|
| 149 |
+
memories_dir: str | Path | None = None,
|
| 150 |
+
store_dir: str | Path | None = None,
|
| 151 |
+
) -> None:
|
| 152 |
+
"""Rebuild FAISS indexes for all personas using the configured BGE embedder."""
|
| 153 |
+
memories_dir = Path(memories_dir or settings.memories_dir)
|
| 154 |
+
store_dir = Path(store_dir or settings.faiss_store_dir)
|
| 155 |
+
|
| 156 |
+
for persona_file in sorted(memories_dir.glob("*.json")):
|
| 157 |
+
uid = persona_file.stem
|
| 158 |
+
print(f" Building index for {uid} β¦")
|
| 159 |
+
index, meta = build_index(persona_file)
|
| 160 |
+
save_index(index, meta, store_dir / uid)
|
| 161 |
+
print(f" Saved {len(meta)} chunks β {store_dir / uid}/")
|
| 162 |
+
print("\nAll indexes built.")
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# ββ Entrypoint ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 166 |
+
|
| 167 |
+
if __name__ == "__main__":
|
| 168 |
+
build_all()
|
sensing/__init__.py
ADDED
|
File without changes
|
sensing/air_writing.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
L1 β Air writing recognition via index-finger tip trajectory (proposal Β§5.2).
|
| 3 |
+
|
| 4 |
+
Tracks MediaPipe Hands landmark 8 (index fingertip) across frames.
|
| 5 |
+
Stroke segmentation uses velocity thresholding:
|
| 6 |
+
- stroke starts when velocity > START_VEL px/frame
|
| 7 |
+
- stroke ends when velocity < END_VEL px/frame for > GAP_MS ms
|
| 8 |
+
|
| 9 |
+
Segmented strokes are classified against a template library using
|
| 10 |
+
Dynamic Time Warping (DTW). Supports:
|
| 11 |
+
- 26 uppercase English letters (A-Z)
|
| 12 |
+
- 10 digits (0-9)
|
| 13 |
+
- 10 most frequent Devanagari characters (for Arjun's Hindi inputs)
|
| 14 |
+
|
| 15 |
+
Recognised characters are concatenated and returned as a text string
|
| 16 |
+
to the intent decomposition layer.
|
| 17 |
+
"""
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import time
|
| 21 |
+
from collections import deque
|
| 22 |
+
from dataclasses import dataclass, field
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
|
| 26 |
+
from config.settings import settings
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
import mediapipe as mp
|
| 30 |
+
_MP_AVAILABLE = True
|
| 31 |
+
except ImportError:
|
| 32 |
+
_MP_AVAILABLE = False
|
| 33 |
+
|
| 34 |
+
# ββ Landmark index βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 35 |
+
_INDEX_TIP = 8
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class AirWriter:
|
| 40 |
+
"""
|
| 41 |
+
Stateful air-writing recogniser. Feed frames from a webcam loop.
|
| 42 |
+
Call `get_text()` to retrieve and clear the current buffer.
|
| 43 |
+
"""
|
| 44 |
+
_trajectory: list[tuple[float, float]] = field(default_factory=list)
|
| 45 |
+
_in_stroke: bool = False
|
| 46 |
+
_stroke_end_time: float = field(default=0.0)
|
| 47 |
+
_text_buffer: list[str] = field(default_factory=list)
|
| 48 |
+
_templates: dict[str, np.ndarray] = field(default_factory=dict)
|
| 49 |
+
|
| 50 |
+
def __post_init__(self):
|
| 51 |
+
if not _MP_AVAILABLE:
|
| 52 |
+
raise ImportError("mediapipe is required: pip install mediapipe")
|
| 53 |
+
self._hands = mp.solutions.hands.Hands(
|
| 54 |
+
static_image_mode=False,
|
| 55 |
+
max_num_hands=1,
|
| 56 |
+
min_detection_confidence=0.6,
|
| 57 |
+
min_tracking_confidence=0.5,
|
| 58 |
+
)
|
| 59 |
+
self._prev_pt: tuple[float, float] | None = None
|
| 60 |
+
self._templates = _load_templates()
|
| 61 |
+
|
| 62 |
+
def process_frame(self, bgr_frame) -> str | None:
|
| 63 |
+
"""
|
| 64 |
+
Process one frame. Returns a recognised character when a stroke
|
| 65 |
+
completes, or None otherwise.
|
| 66 |
+
"""
|
| 67 |
+
import cv2
|
| 68 |
+
rgb = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
|
| 69 |
+
result = self._hands.process(rgb)
|
| 70 |
+
|
| 71 |
+
if not result.multi_hand_landmarks:
|
| 72 |
+
self._prev_pt = None
|
| 73 |
+
return self._check_stroke_end()
|
| 74 |
+
|
| 75 |
+
h, w = bgr_frame.shape[:2]
|
| 76 |
+
lm = result.multi_hand_landmarks[0].landmark
|
| 77 |
+
tip = (lm[_INDEX_TIP].x * w, lm[_INDEX_TIP].y * h)
|
| 78 |
+
|
| 79 |
+
velocity = 0.0
|
| 80 |
+
if self._prev_pt is not None:
|
| 81 |
+
velocity = np.linalg.norm(np.array(tip) - np.array(self._prev_pt))
|
| 82 |
+
self._prev_pt = tip
|
| 83 |
+
|
| 84 |
+
start_v = settings.air_write_velocity_start
|
| 85 |
+
end_v = settings.air_write_velocity_end
|
| 86 |
+
|
| 87 |
+
if velocity > start_v:
|
| 88 |
+
self._in_stroke = True
|
| 89 |
+
self._trajectory.append(tip)
|
| 90 |
+
self._stroke_end_time = 0.0
|
| 91 |
+
elif self._in_stroke and velocity < end_v:
|
| 92 |
+
if self._stroke_end_time == 0.0:
|
| 93 |
+
self._stroke_end_time = time.time()
|
| 94 |
+
return self._check_stroke_end()
|
| 95 |
+
|
| 96 |
+
return None
|
| 97 |
+
|
| 98 |
+
def _check_stroke_end(self) -> str | None:
|
| 99 |
+
if not self._in_stroke or self._stroke_end_time == 0.0:
|
| 100 |
+
return None
|
| 101 |
+
gap_s = settings.air_write_end_gap_ms / 1000.0
|
| 102 |
+
if time.time() - self._stroke_end_time >= gap_s:
|
| 103 |
+
char = self._recognise(self._trajectory)
|
| 104 |
+
self._trajectory = []
|
| 105 |
+
self._in_stroke = False
|
| 106 |
+
self._stroke_end_time = 0.0
|
| 107 |
+
if char:
|
| 108 |
+
self._text_buffer.append(char)
|
| 109 |
+
return char
|
| 110 |
+
return None
|
| 111 |
+
|
| 112 |
+
def _recognise(self, trajectory: list[tuple[float, float]]) -> str | None:
|
| 113 |
+
if len(trajectory) < 5 or not self._templates:
|
| 114 |
+
return None
|
| 115 |
+
query = _normalise_trajectory(np.array(trajectory))
|
| 116 |
+
best_char, best_dist = None, float("inf")
|
| 117 |
+
for char, template in self._templates.items():
|
| 118 |
+
dist = _dtw_distance(query, template)
|
| 119 |
+
if dist < best_dist:
|
| 120 |
+
best_dist = dist
|
| 121 |
+
best_char = char
|
| 122 |
+
return best_char
|
| 123 |
+
|
| 124 |
+
def get_text(self) -> str:
|
| 125 |
+
"""Return and clear the accumulated air-written text."""
|
| 126 |
+
text = "".join(self._text_buffer)
|
| 127 |
+
self._text_buffer.clear()
|
| 128 |
+
return text
|
| 129 |
+
|
| 130 |
+
def release(self):
|
| 131 |
+
self._hands.close()
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# ββ DTW helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 135 |
+
|
| 136 |
+
def _normalise_trajectory(pts: np.ndarray) -> np.ndarray:
|
| 137 |
+
"""Scale trajectory to unit bounding box, resample to 32 points."""
|
| 138 |
+
pts = pts - pts.min(axis=0)
|
| 139 |
+
scale = pts.max(axis=0) + 1e-6
|
| 140 |
+
pts = pts / scale
|
| 141 |
+
# Resample to fixed length via linear interpolation
|
| 142 |
+
t_old = np.linspace(0, 1, len(pts))
|
| 143 |
+
t_new = np.linspace(0, 1, 32)
|
| 144 |
+
return np.column_stack([
|
| 145 |
+
np.interp(t_new, t_old, pts[:, 0]),
|
| 146 |
+
np.interp(t_new, t_old, pts[:, 1]),
|
| 147 |
+
])
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def _dtw_distance(a: np.ndarray, b: np.ndarray) -> float:
|
| 151 |
+
"""Simple O(nΒ²) DTW β trajectories are short (32 pts), so this is fine."""
|
| 152 |
+
n, m = len(a), len(b)
|
| 153 |
+
dtw = np.full((n + 1, m + 1), np.inf)
|
| 154 |
+
dtw[0, 0] = 0.0
|
| 155 |
+
for i in range(1, n + 1):
|
| 156 |
+
for j in range(1, m + 1):
|
| 157 |
+
cost = np.linalg.norm(a[i - 1] - b[j - 1])
|
| 158 |
+
dtw[i, j] = cost + min(dtw[i - 1, j], dtw[i, j - 1], dtw[i - 1, j - 1])
|
| 159 |
+
return float(dtw[n, m])
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def _load_templates() -> dict[str, np.ndarray]:
|
| 163 |
+
"""
|
| 164 |
+
Load pre-recorded stroke templates from disk.
|
| 165 |
+
Template files should be numpy arrays of shape (32, 2) stored as .npy.
|
| 166 |
+
Returns an empty dict if no template directory exists yet.
|
| 167 |
+
"""
|
| 168 |
+
from pathlib import Path
|
| 169 |
+
template_dir = Path("data/air_write_templates")
|
| 170 |
+
if not template_dir.exists():
|
| 171 |
+
return {}
|
| 172 |
+
templates = {}
|
| 173 |
+
for f in template_dir.glob("*.npy"):
|
| 174 |
+
char = f.stem # filename = character label
|
| 175 |
+
templates[char] = np.load(f)
|
| 176 |
+
return templates
|
sensing/face_mesh.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
L1 β Facial affect detection via MediaPipe 2D Face Mesh.
|
| 3 |
+
|
| 4 |
+
Extracts 4 geometric features from 478 landmarks at ~10 fps:
|
| 5 |
+
MAR β Mouth Aspect Ratio (surprise / speech attempt)
|
| 6 |
+
EAR β Eye Aspect Ratio (frustration / blink)
|
| 7 |
+
BRI β Brow Raise Index (surprise / questioning)
|
| 8 |
+
LCP β Lip Corner Pull (smile vs frown)
|
| 9 |
+
|
| 10 |
+
These form the affect vector fed into MobileNetV3-Small affect classifier,
|
| 11 |
+
which maps to one of 4 actionable states: HAPPY | FRUSTRATED | NEUTRAL | SURPRISED.
|
| 12 |
+
|
| 13 |
+
EMA smoothing (Ξ±=0.3) prevents transient expressions (sneezes, blinks)
|
| 14 |
+
from destabilising the detected state across turns.
|
| 15 |
+
"""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import time
|
| 19 |
+
from dataclasses import dataclass, field
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
|
| 23 |
+
from config.settings import settings
|
| 24 |
+
from pipeline.state import AffectState, AffectVector
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
import mediapipe as mp
|
| 28 |
+
_MP_AVAILABLE = True
|
| 29 |
+
except ImportError:
|
| 30 |
+
_MP_AVAILABLE = False
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
import cv2
|
| 34 |
+
_CV2_AVAILABLE = True
|
| 35 |
+
except ImportError:
|
| 36 |
+
_CV2_AVAILABLE = False
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# ββ MediaPipe landmark indices (from proposal Β§5.2) βββββββββββββββββββββββββββ
|
| 40 |
+
|
| 41 |
+
# MAR β mouth vertical / horizontal ratio
|
| 42 |
+
_MOUTH_TOP = 13
|
| 43 |
+
_MOUTH_BOTTOM = 14
|
| 44 |
+
_MOUTH_LEFT = 61
|
| 45 |
+
_MOUTH_RIGHT = 291
|
| 46 |
+
|
| 47 |
+
# EAR β eye vertical / horizontal ratio (right eye)
|
| 48 |
+
_EYE_TOP = 159
|
| 49 |
+
_EYE_BOTTOM = 145
|
| 50 |
+
_EYE_LEFT = 33
|
| 51 |
+
_EYE_RIGHT = 133
|
| 52 |
+
|
| 53 |
+
# BRI β brow vertical displacement relative to eye centre
|
| 54 |
+
_BROW_LEFT = 70
|
| 55 |
+
_BROW_RIGHT = 300
|
| 56 |
+
|
| 57 |
+
# LCP β mouth corner horizontal displacement from neutral baseline
|
| 58 |
+
_CORNER_LEFT = 61
|
| 59 |
+
_CORNER_RIGHT = 291
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# ββ Affect classes ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 63 |
+
|
| 64 |
+
AFFECT_CLASSES = ["HAPPY", "FRUSTRATED", "NEUTRAL", "SURPRISED"]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@dataclass
|
| 68 |
+
class AffectDetector:
|
| 69 |
+
"""
|
| 70 |
+
Stateful detector that maintains EMA-smoothed affect across frames.
|
| 71 |
+
Create one instance per session and call `process_frame` each frame.
|
| 72 |
+
"""
|
| 73 |
+
_smoothed: AffectVector = field(default_factory=lambda: AffectVector(MAR=0.0, EAR=0.3, BRI=0.0, LCP=0.0))
|
| 74 |
+
_neutral_lcp: float = 0.0 # calibrated at session start
|
| 75 |
+
_calibrated: bool = False
|
| 76 |
+
|
| 77 |
+
def __post_init__(self):
|
| 78 |
+
if not _MP_AVAILABLE:
|
| 79 |
+
raise ImportError("mediapipe is required: pip install mediapipe")
|
| 80 |
+
if not _CV2_AVAILABLE:
|
| 81 |
+
raise ImportError("opencv-python is required: pip install opencv-python")
|
| 82 |
+
|
| 83 |
+
self._face_mesh = mp.solutions.face_mesh.FaceMesh(
|
| 84 |
+
static_image_mode=False,
|
| 85 |
+
max_num_faces=1,
|
| 86 |
+
refine_landmarks=True, # enables iris landmarks (468-477)
|
| 87 |
+
min_detection_confidence=0.5,
|
| 88 |
+
min_tracking_confidence=0.5,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def process_frame(self, bgr_frame: np.ndarray) -> AffectState | None:
|
| 92 |
+
"""
|
| 93 |
+
Process one BGR frame from OpenCV and return the current AffectState,
|
| 94 |
+
or None if no face is detected.
|
| 95 |
+
"""
|
| 96 |
+
rgb = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
|
| 97 |
+
result = self._face_mesh.process(rgb)
|
| 98 |
+
|
| 99 |
+
if not result.multi_face_landmarks:
|
| 100 |
+
return None
|
| 101 |
+
|
| 102 |
+
lm = result.multi_face_landmarks[0].landmark
|
| 103 |
+
h, w = bgr_frame.shape[:2]
|
| 104 |
+
|
| 105 |
+
def pt(idx):
|
| 106 |
+
l = lm[idx]
|
| 107 |
+
return np.array([l.x * w, l.y * h])
|
| 108 |
+
|
| 109 |
+
raw = self._compute_features(pt)
|
| 110 |
+
|
| 111 |
+
if not self._calibrated:
|
| 112 |
+
self._neutral_lcp = raw["LCP"]
|
| 113 |
+
self._calibrated = True
|
| 114 |
+
|
| 115 |
+
raw["LCP"] = raw["LCP"] - self._neutral_lcp # relative to neutral baseline
|
| 116 |
+
|
| 117 |
+
alpha = settings.affect_ema_alpha
|
| 118 |
+
smoothed = AffectVector(
|
| 119 |
+
MAR=alpha * raw["MAR"] + (1 - alpha) * self._smoothed["MAR"],
|
| 120 |
+
EAR=alpha * raw["EAR"] + (1 - alpha) * self._smoothed["EAR"],
|
| 121 |
+
BRI=alpha * raw["BRI"] + (1 - alpha) * self._smoothed["BRI"],
|
| 122 |
+
LCP=alpha * raw["LCP"] + (1 - alpha) * self._smoothed["LCP"],
|
| 123 |
+
)
|
| 124 |
+
self._smoothed = smoothed
|
| 125 |
+
|
| 126 |
+
emotion = self._classify(smoothed)
|
| 127 |
+
return AffectState(emotion=emotion, vector=raw, smoothed=smoothed)
|
| 128 |
+
|
| 129 |
+
def _compute_features(self, pt) -> dict:
|
| 130 |
+
# MAR
|
| 131 |
+
mouth_v = np.linalg.norm(pt(_MOUTH_TOP) - pt(_MOUTH_BOTTOM))
|
| 132 |
+
mouth_h = np.linalg.norm(pt(_MOUTH_LEFT) - pt(_MOUTH_RIGHT))
|
| 133 |
+
MAR = mouth_v / (mouth_h + 1e-6)
|
| 134 |
+
|
| 135 |
+
# EAR
|
| 136 |
+
eye_v = np.linalg.norm(pt(_EYE_TOP) - pt(_EYE_BOTTOM))
|
| 137 |
+
eye_h = np.linalg.norm(pt(_EYE_LEFT) - pt(_EYE_RIGHT))
|
| 138 |
+
EAR = eye_v / (eye_h + 1e-6)
|
| 139 |
+
|
| 140 |
+
# BRI β average brow displacement relative to eye centre
|
| 141 |
+
eye_center = (pt(_EYE_LEFT) + pt(_EYE_RIGHT)) / 2
|
| 142 |
+
inter_ocular = eye_h
|
| 143 |
+
brow_mid = (pt(_BROW_LEFT) + pt(_BROW_RIGHT)) / 2
|
| 144 |
+
BRI = (eye_center[1] - brow_mid[1]) / (inter_ocular + 1e-6)
|
| 145 |
+
|
| 146 |
+
# LCP β average horizontal mouth corner displacement
|
| 147 |
+
LCP = float((pt(_CORNER_LEFT)[0] + pt(_CORNER_RIGHT)[0]) / 2)
|
| 148 |
+
|
| 149 |
+
return {"MAR": float(MAR), "EAR": float(EAR), "BRI": float(BRI), "LCP": float(LCP)}
|
| 150 |
+
|
| 151 |
+
@staticmethod
|
| 152 |
+
def _classify(v: AffectVector) -> str:
|
| 153 |
+
"""
|
| 154 |
+
Rule-based classifier over the 4 geometric features.
|
| 155 |
+
Replace with MobileNetV3-Small for final evaluation.
|
| 156 |
+
"""
|
| 157 |
+
if v["BRI"] > 0.25 and v["MAR"] > 0.3:
|
| 158 |
+
return "SURPRISED"
|
| 159 |
+
if v["EAR"] < 0.15 and v["LCP"] < -5:
|
| 160 |
+
return "FRUSTRATED"
|
| 161 |
+
if v["LCP"] > 5:
|
| 162 |
+
return "HAPPY"
|
| 163 |
+
return "NEUTRAL"
|
| 164 |
+
|
| 165 |
+
def release(self):
|
| 166 |
+
self._face_mesh.close()
|
sensing/gaze.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
L1 β Gaze-based retrieval activation (Bonus feature, proposal Β§5.2).
|
| 3 |
+
|
| 4 |
+
Uses MediaPipe iris landmarks (468-472) to estimate gaze direction as
|
| 5 |
+
a 2D screen-coordinate vector. Sustained fixation (> 1.5 s dwell time)
|
| 6 |
+
on a defined UI region pre-biases the retrieval layer toward the
|
| 7 |
+
corresponding memory bucket.
|
| 8 |
+
|
| 9 |
+
UI region β bucket mapping:
|
| 10 |
+
top-left quadrant β family
|
| 11 |
+
top-right quadrant β medical
|
| 12 |
+
bottom-left quadrant β hobbies
|
| 13 |
+
bottom-right quadrant β daily_routine
|
| 14 |
+
centre strip β social
|
| 15 |
+
"""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import time
|
| 19 |
+
from dataclasses import dataclass, field
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
|
| 23 |
+
from config.settings import settings
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
import mediapipe as mp
|
| 27 |
+
_MP_AVAILABLE = True
|
| 28 |
+
except ImportError:
|
| 29 |
+
_MP_AVAILABLE = False
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ββ Iris landmark indices ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 33 |
+
# MediaPipe refine_landmarks=True adds iris landmarks 468-477
|
| 34 |
+
_LEFT_IRIS_CENTER = 468
|
| 35 |
+
_RIGHT_IRIS_CENTER = 473
|
| 36 |
+
|
| 37 |
+
# ββ Screen region β bucket map βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
+
# Defined as (x_min, y_min, x_max, y_max) in normalised [0,1] coords
|
| 39 |
+
_REGION_BUCKET: list[tuple[tuple[float, float, float, float], str]] = [
|
| 40 |
+
((0.0, 0.0, 0.5, 0.5), "family"),
|
| 41 |
+
((0.5, 0.0, 1.0, 0.5), "medical"),
|
| 42 |
+
((0.0, 0.5, 0.5, 1.0), "hobbies"),
|
| 43 |
+
((0.5, 0.5, 1.0, 1.0), "daily_routine"),
|
| 44 |
+
((0.3, 0.3, 0.7, 0.7), "social"), # centre strip (checked last β lowest priority)
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class GazeTracker:
|
| 50 |
+
"""
|
| 51 |
+
Stateful gaze tracker. Call `process_frame` each frame.
|
| 52 |
+
Returns the bucket name when dwell threshold is exceeded, else None.
|
| 53 |
+
"""
|
| 54 |
+
_dwell_start: float = field(default=0.0)
|
| 55 |
+
_current_region: str | None = field(default=None)
|
| 56 |
+
|
| 57 |
+
def __post_init__(self):
|
| 58 |
+
if not _MP_AVAILABLE:
|
| 59 |
+
raise ImportError("mediapipe is required: pip install mediapipe")
|
| 60 |
+
self._face_mesh = mp.solutions.face_mesh.FaceMesh(
|
| 61 |
+
static_image_mode=False,
|
| 62 |
+
max_num_faces=1,
|
| 63 |
+
refine_landmarks=True,
|
| 64 |
+
min_detection_confidence=0.5,
|
| 65 |
+
min_tracking_confidence=0.5,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def process_frame(self, bgr_frame) -> str | None:
|
| 69 |
+
"""
|
| 70 |
+
Returns the hinted bucket name once dwell threshold is exceeded,
|
| 71 |
+
then resets the dwell timer. Returns None otherwise.
|
| 72 |
+
"""
|
| 73 |
+
import cv2
|
| 74 |
+
rgb = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
|
| 75 |
+
result = self._face_mesh.process(rgb)
|
| 76 |
+
|
| 77 |
+
if not result.multi_face_landmarks:
|
| 78 |
+
self._reset()
|
| 79 |
+
return None
|
| 80 |
+
|
| 81 |
+
lm = result.multi_face_landmarks[0].landmark
|
| 82 |
+
|
| 83 |
+
# Average left + right iris centres for gaze estimate
|
| 84 |
+
gaze_x = (lm[_LEFT_IRIS_CENTER].x + lm[_RIGHT_IRIS_CENTER].x) / 2
|
| 85 |
+
gaze_y = (lm[_LEFT_IRIS_CENTER].y + lm[_RIGHT_IRIS_CENTER].y) / 2
|
| 86 |
+
|
| 87 |
+
bucket = self._region_for(gaze_x, gaze_y)
|
| 88 |
+
|
| 89 |
+
if bucket != self._current_region:
|
| 90 |
+
self._current_region = bucket
|
| 91 |
+
self._dwell_start = time.time()
|
| 92 |
+
return None
|
| 93 |
+
|
| 94 |
+
dwell = time.time() - self._dwell_start
|
| 95 |
+
if dwell >= settings.gaze_dwell_threshold_s and bucket is not None:
|
| 96 |
+
self._reset()
|
| 97 |
+
return bucket
|
| 98 |
+
|
| 99 |
+
return None
|
| 100 |
+
|
| 101 |
+
@staticmethod
|
| 102 |
+
def _region_for(x: float, y: float) -> str | None:
|
| 103 |
+
for (x0, y0, x1, y1), bucket in _REGION_BUCKET:
|
| 104 |
+
if x0 <= x <= x1 and y0 <= y <= y1:
|
| 105 |
+
return bucket
|
| 106 |
+
return None
|
| 107 |
+
|
| 108 |
+
def _reset(self):
|
| 109 |
+
self._dwell_start = 0.0
|
| 110 |
+
self._current_region = None
|
| 111 |
+
|
| 112 |
+
def release(self):
|
| 113 |
+
self._face_mesh.close()
|
sensing/gesture.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
L1 β Hand gesture recognition via MediaPipe Hands.
|
| 3 |
+
|
| 4 |
+
Recognises 4 gestures from 21 3D hand landmarks at ~15 fps using
|
| 5 |
+
normalised joint-angle rules (no ML model needed at this stage):
|
| 6 |
+
|
| 7 |
+
THUMBS_UP β [TONE:AFFIRMATIVE]
|
| 8 |
+
THUMBS_DOWN β [TONE:NEGATIVE]
|
| 9 |
+
POINTING β [INTENT:REFERENTIAL]
|
| 10 |
+
WAVING β [INTENT:GREETING]
|
| 11 |
+
|
| 12 |
+
Each detected gesture is mapped to a stylistic constraint tag that is
|
| 13 |
+
injected into the generation prompt by the planner node.
|
| 14 |
+
"""
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
import mediapipe as mp
|
| 21 |
+
_MP_AVAILABLE = True
|
| 22 |
+
except ImportError:
|
| 23 |
+
_MP_AVAILABLE = False
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# Gesture β prompt constraint tag mapping
|
| 27 |
+
GESTURE_TO_TAG: dict[str, str] = {
|
| 28 |
+
"THUMBS_UP": "[GESTURE:THUMBS_UP][TONE:AFFIRMATIVE]",
|
| 29 |
+
"THUMBS_DOWN": "[GESTURE:THUMBS_DOWN][TONE:NEGATIVE]",
|
| 30 |
+
"POINTING": "[GESTURE:POINTING][INTENT:REFERENTIAL]",
|
| 31 |
+
"WAVING": "[GESTURE:WAVING][INTENT:GREETING]",
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class GestureClassifier:
|
| 36 |
+
"""
|
| 37 |
+
Stateful classifier β create one instance per session.
|
| 38 |
+
Feed MediaPipe hand landmark results each frame.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self):
|
| 42 |
+
if not _MP_AVAILABLE:
|
| 43 |
+
raise ImportError("mediapipe is required: pip install mediapipe")
|
| 44 |
+
self._hands = mp.solutions.hands.Hands(
|
| 45 |
+
static_image_mode=False,
|
| 46 |
+
max_num_hands=1,
|
| 47 |
+
min_detection_confidence=0.6,
|
| 48 |
+
min_tracking_confidence=0.5,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
def process_frame(self, bgr_frame) -> str | None:
|
| 52 |
+
"""
|
| 53 |
+
Returns a gesture label string or None if no clear gesture is detected.
|
| 54 |
+
"""
|
| 55 |
+
import cv2
|
| 56 |
+
rgb = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
|
| 57 |
+
result = self._hands.process(rgb)
|
| 58 |
+
|
| 59 |
+
if not result.multi_hand_landmarks:
|
| 60 |
+
return None
|
| 61 |
+
|
| 62 |
+
lm = result.multi_hand_landmarks[0].landmark
|
| 63 |
+
pts = np.array([[l.x, l.y, l.z] for l in lm])
|
| 64 |
+
|
| 65 |
+
return self._classify(pts)
|
| 66 |
+
|
| 67 |
+
def gesture_tag(self, bgr_frame) -> str | None:
|
| 68 |
+
"""Convenience: returns the prompt tag directly, or None."""
|
| 69 |
+
gesture = self.process_frame(bgr_frame)
|
| 70 |
+
return GESTURE_TO_TAG.get(gesture) if gesture else None
|
| 71 |
+
|
| 72 |
+
@staticmethod
|
| 73 |
+
def _classify(pts: np.ndarray) -> str | None:
|
| 74 |
+
"""
|
| 75 |
+
Rule-based gesture classification over normalised joint positions.
|
| 76 |
+
|
| 77 |
+
MediaPipe hand landmark indices:
|
| 78 |
+
0=WRIST, 1-4=THUMB, 5-8=INDEX, 9-12=MIDDLE, 13-16=RING, 17-20=PINKY
|
| 79 |
+
"""
|
| 80 |
+
# Normalise: wrist at origin, scale by palm width
|
| 81 |
+
wrist = pts[0]
|
| 82 |
+
palm_width = np.linalg.norm(pts[5] - pts[17]) + 1e-6
|
| 83 |
+
p = (pts - wrist) / palm_width
|
| 84 |
+
|
| 85 |
+
thumb_tip = p[4]
|
| 86 |
+
index_tip = p[8]
|
| 87 |
+
middle_tip = p[12]
|
| 88 |
+
ring_tip = p[16]
|
| 89 |
+
pinky_tip = p[20]
|
| 90 |
+
index_mcp = p[5] # knuckle
|
| 91 |
+
|
| 92 |
+
# THUMBS_UP: thumb tip above wrist, other fingers curled
|
| 93 |
+
fingers_curled = all(
|
| 94 |
+
np.linalg.norm(tip) < np.linalg.norm(p[mcp])
|
| 95 |
+
for tip, mcp in [(index_tip, p[5]), (middle_tip, p[9]), (ring_tip, p[13])]
|
| 96 |
+
)
|
| 97 |
+
if thumb_tip[1] < -0.3 and fingers_curled:
|
| 98 |
+
return "THUMBS_UP"
|
| 99 |
+
|
| 100 |
+
# THUMBS_DOWN: thumb tip below wrist, other fingers curled
|
| 101 |
+
if thumb_tip[1] > 0.3 and fingers_curled:
|
| 102 |
+
return "THUMBS_DOWN"
|
| 103 |
+
|
| 104 |
+
# POINTING: index extended, others curled
|
| 105 |
+
index_extended = np.linalg.norm(index_tip) > np.linalg.norm(index_mcp) * 1.3
|
| 106 |
+
others_curled = all(
|
| 107 |
+
np.linalg.norm(tip) < 0.5
|
| 108 |
+
for tip in [middle_tip, ring_tip, pinky_tip]
|
| 109 |
+
)
|
| 110 |
+
if index_extended and others_curled:
|
| 111 |
+
return "POINTING"
|
| 112 |
+
|
| 113 |
+
# WAVING: all fingers extended, hand roughly vertical
|
| 114 |
+
all_extended = all(
|
| 115 |
+
np.linalg.norm(tip) > 0.5
|
| 116 |
+
for tip in [index_tip, middle_tip, ring_tip, pinky_tip]
|
| 117 |
+
)
|
| 118 |
+
if all_extended:
|
| 119 |
+
return "WAVING"
|
| 120 |
+
|
| 121 |
+
return None
|
| 122 |
+
|
| 123 |
+
def release(self):
|
| 124 |
+
self._hands.close()
|
ui/app.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Streamlit frontend β webcam + chat + live metrics dashboard.
|
| 3 |
+
|
| 4 |
+
Panels:
|
| 5 |
+
Left sidebar β persona selector, session controls, live affect display
|
| 6 |
+
Centre β chat interface with streaming response
|
| 7 |
+
Right sidebar β latency breakdown, bucket priors bar chart
|
| 8 |
+
|
| 9 |
+
Run: streamlit run ui/app.py
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import json
|
| 14 |
+
import time
|
| 15 |
+
|
| 16 |
+
import requests
|
| 17 |
+
import streamlit as st
|
| 18 |
+
|
| 19 |
+
# ββ Config βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 20 |
+
API_BASE = "http://localhost:8000"
|
| 21 |
+
|
| 22 |
+
st.set_page_config(
|
| 23 |
+
page_title="AAC Chatbot",
|
| 24 |
+
layout="wide",
|
| 25 |
+
initial_sidebar_state="expanded",
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# ββ Session state init βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 30 |
+
if "user_id" not in st.session_state:
|
| 31 |
+
st.session_state.user_id = None
|
| 32 |
+
if "messages" not in st.session_state:
|
| 33 |
+
st.session_state.messages = []
|
| 34 |
+
if "last_latency" not in st.session_state:
|
| 35 |
+
st.session_state.last_latency = {}
|
| 36 |
+
if "last_affect" not in st.session_state:
|
| 37 |
+
st.session_state.last_affect = "NEUTRAL"
|
| 38 |
+
if "affect_override" not in st.session_state:
|
| 39 |
+
st.session_state.affect_override = None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# ββ Sidebar ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 43 |
+
with st.sidebar:
|
| 44 |
+
st.title("AAC Chatbot")
|
| 45 |
+
|
| 46 |
+
# Persona selection
|
| 47 |
+
try:
|
| 48 |
+
users_resp = requests.get(f"{API_BASE}/users", timeout=3)
|
| 49 |
+
users = users_resp.json().get("users", [])
|
| 50 |
+
except Exception:
|
| 51 |
+
users = []
|
| 52 |
+
st.error("API not reachable β start the FastAPI server first.")
|
| 53 |
+
|
| 54 |
+
user_options = {u["id"]: f"{u['name']} ({u['condition']})" for u in users}
|
| 55 |
+
selected = st.selectbox("Select persona", options=list(user_options.keys()),
|
| 56 |
+
format_func=lambda k: user_options.get(k, k))
|
| 57 |
+
|
| 58 |
+
if selected != st.session_state.user_id:
|
| 59 |
+
st.session_state.user_id = selected
|
| 60 |
+
st.session_state.messages = []
|
| 61 |
+
try:
|
| 62 |
+
requests.post(f"{API_BASE}/session/reset", params={"user_id": selected})
|
| 63 |
+
except Exception:
|
| 64 |
+
pass
|
| 65 |
+
|
| 66 |
+
st.divider()
|
| 67 |
+
|
| 68 |
+
# Affect override (for demo / testing without webcam)
|
| 69 |
+
st.subheader("Affect Override")
|
| 70 |
+
st.caption("Simulates webcam affect detection")
|
| 71 |
+
affect_choice = st.radio(
|
| 72 |
+
"Current affect",
|
| 73 |
+
["Auto (webcam)", "HAPPY", "FRUSTRATED", "NEUTRAL", "SURPRISED"],
|
| 74 |
+
index=0,
|
| 75 |
+
)
|
| 76 |
+
st.session_state.affect_override = None if affect_choice == "Auto (webcam)" else affect_choice
|
| 77 |
+
|
| 78 |
+
st.divider()
|
| 79 |
+
|
| 80 |
+
# Live affect indicator
|
| 81 |
+
st.subheader("Detected Affect")
|
| 82 |
+
affect_emoji = {
|
| 83 |
+
"HAPPY": "π", "FRUSTRATED": "π€",
|
| 84 |
+
"NEUTRAL": "π", "SURPRISED": "π²",
|
| 85 |
+
}
|
| 86 |
+
af = st.session_state.last_affect
|
| 87 |
+
st.markdown(f"### {affect_emoji.get(af, 'β')} {af}")
|
| 88 |
+
|
| 89 |
+
# Webcam placeholder
|
| 90 |
+
st.divider()
|
| 91 |
+
st.subheader("Webcam Feed")
|
| 92 |
+
st.info("Live webcam sensing runs in the sensing client.\nAffect is sent to the API automatically.")
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# ββ Main chat area βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 96 |
+
st.header(f"Talking as: {user_options.get(st.session_state.user_id, 'β')}")
|
| 97 |
+
|
| 98 |
+
chat_col, metrics_col = st.columns([3, 1])
|
| 99 |
+
|
| 100 |
+
with chat_col:
|
| 101 |
+
for msg in st.session_state.messages:
|
| 102 |
+
role_label = "Partner" if msg["role"] == "partner" else "AAC User"
|
| 103 |
+
with st.chat_message("user" if msg["role"] == "partner" else "assistant"):
|
| 104 |
+
st.markdown(f"**{role_label}:** {msg['content']}")
|
| 105 |
+
|
| 106 |
+
query = st.chat_input("Type as the communication partnerβ¦")
|
| 107 |
+
|
| 108 |
+
if query and st.session_state.user_id:
|
| 109 |
+
st.session_state.messages.append({"role": "partner", "content": query})
|
| 110 |
+
with st.chat_message("user"):
|
| 111 |
+
st.markdown(f"**Partner:** {query}")
|
| 112 |
+
|
| 113 |
+
with st.chat_message("assistant"):
|
| 114 |
+
with st.spinner("Generating responseβ¦"):
|
| 115 |
+
try:
|
| 116 |
+
payload = {
|
| 117 |
+
"user_id": st.session_state.user_id,
|
| 118 |
+
"query": query,
|
| 119 |
+
"affect_override": st.session_state.affect_override,
|
| 120 |
+
}
|
| 121 |
+
resp = requests.post(f"{API_BASE}/chat", json=payload, timeout=15)
|
| 122 |
+
data = resp.json()
|
| 123 |
+
|
| 124 |
+
response_text = data.get("response", "I don't know.")
|
| 125 |
+
st.markdown(f"**AAC User:** {response_text}")
|
| 126 |
+
|
| 127 |
+
st.session_state.messages.append({"role": "aac_user", "content": response_text})
|
| 128 |
+
st.session_state.last_affect = data.get("affect", "NEUTRAL")
|
| 129 |
+
st.session_state.last_latency = data.get("latency", {})
|
| 130 |
+
|
| 131 |
+
if not data.get("guardrail_passed", True):
|
| 132 |
+
st.warning("β Guardrail triggered β response was sanitised.")
|
| 133 |
+
|
| 134 |
+
except requests.exceptions.Timeout:
|
| 135 |
+
st.error("Request timed out. Is the server running?")
|
| 136 |
+
except Exception as e:
|
| 137 |
+
st.error(f"Error: {e}")
|
| 138 |
+
|
| 139 |
+
with metrics_col:
|
| 140 |
+
st.subheader("Turn Latency (s)")
|
| 141 |
+
lat = st.session_state.last_latency
|
| 142 |
+
if lat:
|
| 143 |
+
for key, label in [
|
| 144 |
+
("t_sensing", "Sensing"),
|
| 145 |
+
("t_intent", "Intent"),
|
| 146 |
+
("t_retrieval", "Retrieval"),
|
| 147 |
+
("t_generation", "Generation"),
|
| 148 |
+
("t_total", "**Total**"),
|
| 149 |
+
]:
|
| 150 |
+
val = lat.get(key, 0.0)
|
| 151 |
+
st.metric(label=label, value=f"{val:.3f}s")
|
| 152 |
+
else:
|
| 153 |
+
st.caption("No turn yet.")
|