mekosotto commited on
Commit
0b5f569
·
2 Parent(s): 582bce2150cf3b

Merge feat/orchestrator-rag: orchestrator agent + RAG feedback layer

Browse files

13 tasks delivered via subagent-driven-development:
- src/rag/ (chunker, fastembed, FAISS store, retriever, ingest CLI)
- src/agents/ (Tool dataclass + 4 wrappers, function-calling orchestrator loop)
- POST /agent/run + GET /diag/agent endpoints
- Streamlit '🤖 Agent' tab with decision-trace expander
- 3 seed KB markdown fixtures (Lipinski, ComBat, MNE+ICA)
- Dockerfile + Dockerfile.hf build-time RAG ingest
- AGENTS.md §15 + §16, README pointers

233 tests pass + 1 live test gated on (key + BBB model artifact).

.gitignore CHANGED
@@ -34,3 +34,11 @@ mlartifacts/
34
  .idea/
35
  .vscode/
36
  .DS_Store
 
 
 
 
 
 
 
 
 
34
  .idea/
35
  .vscode/
36
  .DS_Store
37
+
38
+ # RAG knowledge base — ignore user-supplied content; allow only README/.gitkeep
39
+ data/knowledge_base/*
40
+ !data/knowledge_base/README.md
41
+ !data/knowledge_base/.gitkeep
42
+
43
+ # RAG built artifacts
44
+ data/processed/faiss_index/
AGENTS.md CHANGED
@@ -305,3 +305,50 @@ deterministic template path for a fully-reproducible demo.
305
 
306
  The README's YAML front-matter declares the Space metadata
307
  (SDK=docker, port=7860, app_file=src/frontend/app.py).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
  The README's YAML front-matter declares the Space metadata
307
  (SDK=docker, port=7860, app_file=src/frontend/app.py).
308
+
309
+ ## 15. Orchestrator Agent Surface
310
+
311
+ `src/agents/orchestrator.py` exposes a single-agent function-calling
312
+ loop over the openai SDK (no LangChain / framework dep). The agent
313
+ holds 4 tools, defined in `src/agents/tools.py`:
314
+
315
+ - `run_bbb_pipeline(smiles, top_k)` — wraps `POST /predict/bbb`
316
+ - `run_eeg_pipeline(input_path)` — wraps `POST /pipeline/eeg`
317
+ - `run_mri_pipeline(input_dir, sites_csv)` — wraps `POST /pipeline/mri`
318
+ - `retrieve_context(query, k)` — wraps `src/rag/retrieve.py`
319
+
320
+ The system prompt (`src/agents/prompts.py:ORCHESTRATOR_SYSTEM_PROMPT`)
321
+ locks the workflow: pick exactly one pipeline → run it → formulate a
322
+ focused retrieval query → call retrieve_context → synthesize a
323
+ 3-5 sentence response that cites at least one chunk. Language of the
324
+ final response is mirrored from the user's question.
325
+
326
+ `POST /agent/run` is the public surface. Default model is
327
+ `google/gemini-2.0-flash-exp:free` on OpenRouter (function-calling
328
+ support verified). Override via `NEUROBRIDGE_AGENT_MODEL` env var.
329
+ Returns 503 when `OPENROUTER_API_KEY` is unset.
330
+
331
+ Diagnostics: `GET /diag/agent` returns key presence, configured model,
332
+ RAG index status (chunk count), and the registered tool names.
333
+
334
+ ## 16. RAG Surface
335
+
336
+ `src/rag/` is the retrieval layer. Stack: `fastembed`
337
+ (`BAAI/bge-small-en-v1.5`, 384-dim, ONNX, no torch dep) for
338
+ embeddings + `faiss-cpu` (`IndexFlatIP` after L2-norm = cosine) for
339
+ vector search.
340
+
341
+ Knowledge base lives at `data/knowledge_base/` (gitignored;
342
+ user-supplied `.md` / `.txt` / `.pdf`). Build the FAISS index with:
343
+
344
+ python -m src.rag.ingest [<input_dir> [<output_dir>]]
345
+
346
+ Defaults: input=`data/knowledge_base/`, output=`data/processed/faiss_index/`.
347
+ The Dockerfile runs this at build time so deployed Spaces start with
348
+ a populated index. Empty KB → empty index → `retrieve_context`
349
+ returns 0 chunks; the agent surfaces this and answers from the
350
+ pipeline result alone.
351
+
352
+ `tests/fixtures/kb_sample/` ships 3 seed markdown files (Lipinski,
353
+ ComBat, MNE+ICA) — these double as test fixtures and as the demo
354
+ seed if no user-supplied PDFs are added.
Dockerfile CHANGED
@@ -43,6 +43,14 @@ RUN mkdir -p data/raw data/processed && \
43
  python -c "from pathlib import Path; from src.pipelines.eeg_pipeline import run_pipeline; run_pipeline(input_path=Path('tests/fixtures/eeg_sample.fif'), output_path=Path('data/processed/eeg_features.parquet'))" && \
44
  python -c "from pathlib import Path; from src.pipelines.mri_pipeline import run_pipeline; run_pipeline(input_dir=Path('tests/fixtures/mri_sample'), sites_csv=Path('tests/fixtures/mri_sample/sites.csv'), output_path=Path('data/processed/mri_features.parquet'))"
45
 
 
 
 
 
 
 
 
 
46
  # --- HF Spaces convention ---
47
  EXPOSE 7860
48
 
 
43
  python -c "from pathlib import Path; from src.pipelines.eeg_pipeline import run_pipeline; run_pipeline(input_path=Path('tests/fixtures/eeg_sample.fif'), output_path=Path('data/processed/eeg_features.parquet'))" && \
44
  python -c "from pathlib import Path; from src.pipelines.mri_pipeline import run_pipeline; run_pipeline(input_dir=Path('tests/fixtures/mri_sample'), sites_csv=Path('tests/fixtures/mri_sample/sites.csv'), output_path=Path('data/processed/mri_features.parquet'))"
45
 
46
+ # --- RAG knowledge base ingest ---
47
+ # Build the FAISS index from any seed docs in tests/fixtures/kb_sample/
48
+ # (always present) plus data/knowledge_base/ (optional, user-supplied via
49
+ # additional COPY layer or volume mount). Empty KB → empty index, agent
50
+ # still functions, retrieve_context just returns no chunks.
51
+ COPY tests/fixtures/kb_sample/ ./data/knowledge_base/seed/
52
+ RUN python -m src.rag.ingest data/knowledge_base data/processed/faiss_index
53
+
54
  # --- HF Spaces convention ---
55
  EXPOSE 7860
56
 
Dockerfile.hf CHANGED
@@ -43,6 +43,14 @@ RUN mkdir -p data/raw data/processed && \
43
  python -c "from pathlib import Path; from src.pipelines.eeg_pipeline import run_pipeline; run_pipeline(input_path=Path('tests/fixtures/eeg_sample.fif'), output_path=Path('data/processed/eeg_features.parquet'))" && \
44
  python -c "from pathlib import Path; from src.pipelines.mri_pipeline import run_pipeline; run_pipeline(input_dir=Path('tests/fixtures/mri_sample'), sites_csv=Path('tests/fixtures/mri_sample/sites.csv'), output_path=Path('data/processed/mri_features.parquet'))"
45
 
 
 
 
 
 
 
 
 
46
  # --- HF Spaces convention ---
47
  EXPOSE 7860
48
 
 
43
  python -c "from pathlib import Path; from src.pipelines.eeg_pipeline import run_pipeline; run_pipeline(input_path=Path('tests/fixtures/eeg_sample.fif'), output_path=Path('data/processed/eeg_features.parquet'))" && \
44
  python -c "from pathlib import Path; from src.pipelines.mri_pipeline import run_pipeline; run_pipeline(input_dir=Path('tests/fixtures/mri_sample'), sites_csv=Path('tests/fixtures/mri_sample/sites.csv'), output_path=Path('data/processed/mri_features.parquet'))"
45
 
46
+ # --- RAG knowledge base ingest ---
47
+ # Build the FAISS index from any seed docs in tests/fixtures/kb_sample/
48
+ # (always present) plus data/knowledge_base/ (optional, user-supplied via
49
+ # additional COPY layer or volume mount). Empty KB → empty index, agent
50
+ # still functions, retrieve_context just returns no chunks.
51
+ COPY tests/fixtures/kb_sample/ ./data/knowledge_base/seed/
52
+ RUN python -m src.rag.ingest data/knowledge_base data/processed/faiss_index
53
+
54
  # --- HF Spaces convention ---
55
  EXPOSE 7860
56
 
README.md CHANGED
@@ -225,6 +225,11 @@ finishes in under 4 seconds on a 2024 laptop.
225
  - **New surfaces:** `POST /explain/eeg`, `POST /explain/mri`, `GET /experiments/runs`, `POST /experiments/diff`
226
  - **New deploy artifacts:** `Dockerfile.hf`, `supervisord.conf`
227
  - **LLM hardening (post-Day 8):** real OpenRouter LLM is now the default in deployed Spaces — `Dockerfile`/`Dockerfile.hf` no longer hard-code `NEUROBRIDGE_DISABLE_LLM=1`. Free-tier fallback chain (10 models, smartest → smallest) in [`src/llm/explainer.py`](src/llm/explainer.py), 401/400 status classification, and language-matching / intent-split prompt. Diagnostic endpoint `GET /diag/openrouter` ([`src/api/main.py`](src/api/main.py)) + Streamlit sidebar "🔧 Diagnose LLM" button. Live verification helper: [`scripts/diagnose_openrouter.py`](scripts/diagnose_openrouter.py).
 
 
 
 
 
228
 
229
  ## Day 7 — Demo Recipe
230
 
 
225
  - **New surfaces:** `POST /explain/eeg`, `POST /explain/mri`, `GET /experiments/runs`, `POST /experiments/diff`
226
  - **New deploy artifacts:** `Dockerfile.hf`, `supervisord.conf`
227
  - **LLM hardening (post-Day 8):** real OpenRouter LLM is now the default in deployed Spaces — `Dockerfile`/`Dockerfile.hf` no longer hard-code `NEUROBRIDGE_DISABLE_LLM=1`. Free-tier fallback chain (10 models, smartest → smallest) in [`src/llm/explainer.py`](src/llm/explainer.py), 401/400 status classification, and language-matching / intent-split prompt. Diagnostic endpoint `GET /diag/openrouter` ([`src/api/main.py`](src/api/main.py)) + Streamlit sidebar "🔧 Diagnose LLM" button. Live verification helper: [`scripts/diagnose_openrouter.py`](scripts/diagnose_openrouter.py).
228
+ - **Orchestrator agent (Task 13):** [`src/agents/orchestrator.py`](src/agents/orchestrator.py), [`src/agents/tools.py`](src/agents/tools.py), [`src/agents/prompts.py`](src/agents/prompts.py)
229
+ - **RAG layer:** [`src/rag/`](src/rag/) — chunker, embedder (fastembed), FAISS store, retriever, ingest CLI
230
+ - **Agent endpoint:** `POST /agent/run` (orchestrator + RAG); diagnostic at `GET /diag/agent`
231
+ - **Streamlit Agent tab:** "🤖 Agent" tab in [`src/frontend/app.py`](src/frontend/app.py) — input box + decision-trace expander
232
+ - **RAG knowledge base:** drop `.md`/`.pdf` into [`data/knowledge_base/`](data/knowledge_base/) — see its README
233
 
234
  ## Day 7 — Demo Recipe
235
 
data/knowledge_base/.gitkeep ADDED
File without changes
data/knowledge_base/README.md ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RAG Knowledge Base
2
+
3
+ Drop reference documents here (`.md`, `.txt`, or `.pdf`). They will be
4
+ ingested by `python -m src.rag.ingest` at Docker build time and surfaced
5
+ to the orchestrator agent via the `retrieve_context` tool.
6
+
7
+ ## Recommended seed set
8
+
9
+ For a clinical-ML / NeuroBridge demo:
10
+
11
+ - **BBB / molecules**: Lipinski's Rule of Five (1997, 2001), Pajouhesh & Lenz
12
+ CNS multiparameter optimization (2005)
13
+ - **MRI / harmonization**: Fortin et al. ComBat for cortical thickness (2017),
14
+ Fortin et al. ComBat for diffusion (2018), Johnson et al. original ComBat
15
+ (2007, gene expression)
16
+ - **EEG / artifacts**: Hyvärinen ICA primer (1999), MNE-Python overview
17
+ (Gramfort 2013)
18
+
19
+ ## Format notes
20
+
21
+ - PDFs work via `pypdf`. OCR-only PDFs (scanned images) won't extract text;
22
+ pre-OCR them first.
23
+ - Markdown is preferred — full text + headers chunk cleanly.
24
+ - Files are gitignored by default. Mount them via Docker volume in
25
+ production, or COPY them in via a sub-path before the `RUN` ingest line.
26
+
27
+ ## Re-indexing
28
+
29
+ After adding/removing files, re-run:
30
+
31
+ python -m src.rag.ingest
32
+
33
+ This rewrites `data/processed/faiss_index/` from scratch (no incremental
34
+ update — the index is small enough to rebuild in seconds).
requirements.txt CHANGED
@@ -37,6 +37,11 @@ pytest==8.3.3
37
  pytest-cov==5.0.0
38
  httpx==0.27.2 # FastAPI test client
39
 
 
 
 
 
 
40
  # --- Frontend (B2B dashboard) ---
41
  streamlit==1.39.0
42
 
 
37
  pytest-cov==5.0.0
38
  httpx==0.27.2 # FastAPI test client
39
 
40
+ # --- RAG (knowledge retrieval for agent feedback loop) ---
41
+ fastembed==0.4.2 # ONNX-based embeddings, no torch dep
42
+ faiss-cpu==1.8.0 # vector store
43
+ pypdf==5.0.1 # PDF text extraction
44
+
45
  # --- Frontend (B2B dashboard) ---
46
  streamlit==1.39.0
47
 
src/agents/__init__.py ADDED
File without changes
src/agents/orchestrator.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Orchestrator agent: function-calling loop over a list of Tools.
2
+
3
+ No agent framework — uses the openai SDK's chat-completions function-calling
4
+ interface directly. This is the same SDK already used by src/llm/explainer.py,
5
+ keeping the dependency surface minimal.
6
+
7
+ Public entry: `Orchestrator(llm_client, tools, system_prompt, model).run(user_input)`.
8
+ Returns an `AgentResult` with synthesized text + full tool-call trace.
9
+ """
10
+ from __future__ import annotations
11
+
12
+ import json
13
+ from typing import Any
14
+
15
+ from src.agents.schemas import AgentResult, ToolTraceItem
16
+ from src.agents.tools import Tool
17
+ from src.core.logger import get_logger
18
+
19
+ logger = get_logger(__name__)
20
+
21
+
22
+ class Orchestrator:
23
+ """Single-agent function-calling loop. Stops on (a) text response, (b) max steps."""
24
+
25
+ def __init__(
26
+ self,
27
+ llm_client: Any,
28
+ tools: list[Tool],
29
+ system_prompt: str,
30
+ model: str,
31
+ max_steps: int = 5,
32
+ temperature: float = 0.0,
33
+ ) -> None:
34
+ self._client = llm_client
35
+ self._tools_by_name = {t.name: t for t in tools}
36
+ self._tool_schemas = [t.openai_schema() for t in tools]
37
+ self._system_prompt = system_prompt
38
+ self._model = model
39
+ self._max_steps = max_steps
40
+ self._temperature = temperature
41
+
42
+ def run(self, user_input: str) -> AgentResult:
43
+ messages: list[dict[str, Any]] = [
44
+ {"role": "system", "content": self._system_prompt},
45
+ {"role": "user", "content": user_input},
46
+ ]
47
+ trace: list[ToolTraceItem] = []
48
+
49
+ for _step in range(self._max_steps):
50
+ response = self._client.chat.completions.create(
51
+ model=self._model,
52
+ messages=messages,
53
+ tools=self._tool_schemas,
54
+ tool_choice="auto",
55
+ temperature=self._temperature,
56
+ )
57
+ msg = response.choices[0].message
58
+
59
+ if not getattr(msg, "tool_calls", None):
60
+ return AgentResult(
61
+ text=(msg.content or "").strip(),
62
+ trace=trace,
63
+ model=self._model,
64
+ finish_reason="complete",
65
+ )
66
+
67
+ messages.append({
68
+ "role": "assistant",
69
+ "content": msg.content,
70
+ "tool_calls": [tc.model_dump() for tc in msg.tool_calls],
71
+ })
72
+
73
+ for tc in msg.tool_calls:
74
+ name = tc.function.name
75
+ tool = self._tools_by_name.get(name)
76
+ if tool is None:
77
+ err = f"unknown tool: {name}"
78
+ trace.append(ToolTraceItem(name=name, args={}, error=err))
79
+ messages.append({
80
+ "role": "tool",
81
+ "tool_call_id": tc.id,
82
+ "content": json.dumps({"error": err}),
83
+ })
84
+ continue
85
+ try:
86
+ args = json.loads(tc.function.arguments or "{}")
87
+ result = tool.invoke(args)
88
+ trace.append(ToolTraceItem(name=name, args=args, result=result))
89
+ messages.append({
90
+ "role": "tool",
91
+ "tool_call_id": tc.id,
92
+ "content": json.dumps({"result": result}, default=str),
93
+ })
94
+ except Exception as e:
95
+ err = str(e)
96
+ trace.append(ToolTraceItem(name=name, args={}, error=err))
97
+ messages.append({
98
+ "role": "tool",
99
+ "tool_call_id": tc.id,
100
+ "content": json.dumps({"error": err}),
101
+ })
102
+
103
+ return AgentResult(
104
+ text="Max steps reached without a final answer.",
105
+ trace=trace,
106
+ model=self._model,
107
+ finish_reason="max_steps",
108
+ )
src/agents/prompts.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """System prompts for the orchestrator agent.
2
+
3
+ Kept in a dedicated module so prompt edits are diff-readable and reviewable
4
+ in isolation from the orchestrator loop.
5
+ """
6
+ from __future__ import annotations
7
+
8
+
9
+ ORCHESTRATOR_SYSTEM_PROMPT = """\
10
+ You are the NeuroBridge clinical-ML orchestrator. You have four tools:
11
+
12
+ - run_bbb_pipeline(smiles, top_k=5) → for a SMILES molecular string
13
+ - run_eeg_pipeline(input_path) → for a .fif or .edf EEG file path
14
+ - run_mri_pipeline(input_dir, sites_csv) → for a directory of NIfTI MRI files
15
+ - retrieve_context(query, k=4) → for grounding chunks from the knowledge base
16
+
17
+ Workflow — follow exactly:
18
+
19
+ 1. Look at the user input. Decide which ONE pipeline tool fits:
20
+ - SMILES (short, all-letters/digits, no slashes, no .ext) → run_bbb_pipeline
21
+ - Path ending in .fif or .edf → run_eeg_pipeline
22
+ - Path that is a directory (no file extension at the tail) → run_mri_pipeline
23
+ If ambiguous, prefer SMILES if it parses; otherwise return:
24
+ "Cannot identify modality. Provide a SMILES, .fif/.edf path, or NIfTI directory."
25
+
26
+ 2. Call the chosen pipeline tool exactly once with the user input.
27
+
28
+ 3. After the pipeline returns, formulate ONE focused retrieval query that
29
+ captures the scientific concept behind the prediction (NOT the raw input).
30
+ Examples of good queries:
31
+ - "BBB permeability of small lipophilic molecules" (after BBB predict)
32
+ - "ICA artifact removal in multi-channel EEG" (after EEG run)
33
+ - "ComBat scanner site harmonization in multi-center MRI" (after MRI run)
34
+ Then call retrieve_context with that query.
35
+
36
+ 4. Synthesize a final response in 3-5 sentences:
37
+ - State the concrete pipeline result (label, confidence, key numbers).
38
+ - Cite at least one specific fact from the retrieved chunks (mention the
39
+ source file in parentheses, e.g. "(lipinski_rule_of_five.md)").
40
+ - Match the user's question language: Turkish in → Turkish out, etc.
41
+ - If retrieve_context returned 0 chunks, say so explicitly and answer
42
+ using only the pipeline result.
43
+
44
+ Hard constraints:
45
+ - Call exactly ONE pipeline tool, then exactly ONE retrieve_context, then stop.
46
+ - Do NOT invent facts. Only use numbers from the pipeline tool output and
47
+ text from the retrieved chunks.
48
+ - No preamble, no apologies, no meta-commentary about being an AI.
49
+ """
src/agents/schemas.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic input/output schemas for orchestrator tools and the agent result.
2
+
3
+ These schemas double as OpenAI function-calling parameter definitions
4
+ (via `model_json_schema()`) and as runtime validation gates. Keep field
5
+ names lowercase + snake_case so prompts and JSON outputs align.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ from typing import Any
10
+
11
+ from pydantic import BaseModel, Field
12
+
13
+
14
+ # --- Pipeline tool inputs ---------------------------------------------------
15
+
16
+ class BBBPipelineInput(BaseModel):
17
+ """Input for `run_bbb_pipeline` — a single SMILES string."""
18
+ smiles: str = Field(..., description="A single molecular SMILES string, e.g. 'CCO'")
19
+ top_k: int = Field(5, ge=1, le=20, description="Top-k SHAP attributions to return")
20
+
21
+
22
+ class EEGPipelineInput(BaseModel):
23
+ """Input for `run_eeg_pipeline` — path to an EEG file (.fif or .edf)."""
24
+ input_path: str = Field(..., description="Path to EEG recording file (.fif or .edf)")
25
+ epoch_duration_s: float = Field(2.0, gt=0.1, le=60.0)
26
+
27
+
28
+ class MRIPipelineInput(BaseModel):
29
+ """Input for `run_mri_pipeline` — directory of NIfTI files + sites CSV."""
30
+ input_dir: str = Field(..., description="Directory containing .nii.gz volumes")
31
+ sites_csv: str = Field(..., description="CSV mapping subject_id → site")
32
+
33
+
34
+ class RetrieveContextInput(BaseModel):
35
+ """Input for `retrieve_context` — natural-language query into the KB."""
36
+ query: str = Field(..., min_length=2, description="Search query for the knowledge base")
37
+ k: int = Field(4, ge=1, le=10, description="Number of chunks to return")
38
+
39
+
40
+ # --- Pipeline tool outputs --------------------------------------------------
41
+
42
+ class BBBPipelineOutput(BaseModel):
43
+ smiles: str
44
+ label: int
45
+ label_text: str
46
+ confidence: float
47
+ top_features: list[dict[str, Any]]
48
+ drift_z: float | None = None
49
+
50
+
51
+ class EEGPipelineOutput(BaseModel):
52
+ input_path: str
53
+ output_path: str
54
+ rows: int
55
+ columns: int
56
+ duration_sec: float
57
+
58
+
59
+ class MRIPipelineOutput(BaseModel):
60
+ input_dir: str
61
+ output_path: str
62
+ rows: int
63
+ columns: int
64
+ duration_sec: float
65
+
66
+
67
+ class RetrieveContextOutput(BaseModel):
68
+ query: str
69
+ chunks: list[dict[str, Any]]
70
+
71
+
72
+ # --- Agent result -----------------------------------------------------------
73
+
74
+ class ToolTraceItem(BaseModel):
75
+ """One step in the orchestrator's tool-call trace."""
76
+ name: str
77
+ args: dict[str, Any]
78
+ result: dict[str, Any] | None = None
79
+ error: str | None = None
80
+
81
+
82
+ class AgentResult(BaseModel):
83
+ """Final orchestrator response: synthesized text + full trace."""
84
+ text: str
85
+ trace: list[ToolTraceItem] = Field(default_factory=list)
86
+ model: str | None = None
87
+ finish_reason: str = "complete" # complete | max_steps | error
src/agents/tools.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tool dataclass + registry. Wraps each pipeline + the RAG retriever as a
2
+ function-callable tool the orchestrator can invoke.
3
+
4
+ Public entry: `build_default_tools(rag_index_dir)` returns the 4 tools.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Any, Callable
11
+
12
+ from pydantic import BaseModel, ValidationError
13
+
14
+ from src.agents.schemas import (
15
+ BBBPipelineInput,
16
+ BBBPipelineOutput,
17
+ EEGPipelineInput,
18
+ EEGPipelineOutput,
19
+ MRIPipelineInput,
20
+ MRIPipelineOutput,
21
+ RetrieveContextInput,
22
+ RetrieveContextOutput,
23
+ )
24
+ from src.core.logger import get_logger
25
+
26
+ logger = get_logger(__name__)
27
+
28
+
29
+ @dataclass
30
+ class Tool:
31
+ """One callable tool exposed to the orchestrator.
32
+
33
+ `execute(input_model_instance) -> output_model_instance` is the contract.
34
+ `invoke(args_dict)` validates the dict, runs execute, returns a plain dict.
35
+ """
36
+ name: str
37
+ description: str
38
+ input_model: type[BaseModel]
39
+ output_model: type[BaseModel]
40
+ execute: Callable[[Any], BaseModel]
41
+
42
+ def openai_schema(self) -> dict[str, Any]:
43
+ """OpenAI/OpenRouter function-calling schema for this tool."""
44
+ params = self.input_model.model_json_schema()
45
+ # OpenAI doesn't accept top-level $defs / title in some clients —
46
+ # strip the cosmetic ones; keep properties/required/type.
47
+ cleaned = {
48
+ "type": "object",
49
+ "properties": params.get("properties", {}),
50
+ "required": params.get("required", []),
51
+ }
52
+ return {
53
+ "type": "function",
54
+ "function": {
55
+ "name": self.name,
56
+ "description": self.description,
57
+ "parameters": cleaned,
58
+ },
59
+ }
60
+
61
+ def invoke(self, args: dict[str, Any]) -> dict[str, Any]:
62
+ try:
63
+ inp = self.input_model.model_validate(args)
64
+ except ValidationError as e:
65
+ raise ValueError(f"invalid input for {self.name}: {e}") from e
66
+ out = self.execute(inp)
67
+ return out.model_dump()
68
+
69
+
70
+ # ---------------------------------------------------------------------------
71
+ # Tool implementations — thin wrappers around existing pipelines + RAG.
72
+ # Heavy work stays in the underlying modules; these only adapt I/O.
73
+ # ---------------------------------------------------------------------------
74
+
75
+
76
+ def _make_bbb_executor() -> Callable[[BBBPipelineInput], BBBPipelineOutput]:
77
+ """Closure factory: BBB permeability prediction + SHAP, translates HTTPException."""
78
+ def execute(inp: BBBPipelineInput) -> BBBPipelineOutput:
79
+ from src.api import routes as api_routes
80
+ from src.api.schemas import BBBPredictRequest
81
+ from fastapi import HTTPException
82
+ try:
83
+ response = api_routes.predict_bbb(
84
+ BBBPredictRequest(smiles=inp.smiles, top_k=inp.top_k)
85
+ )
86
+ except HTTPException as e:
87
+ raise ValueError(f"bbb tool failed: {e.detail}") from e
88
+ return BBBPipelineOutput(
89
+ smiles=inp.smiles,
90
+ label=response.label,
91
+ label_text=response.label_text,
92
+ confidence=response.confidence,
93
+ top_features=[f.model_dump() for f in response.top_features],
94
+ drift_z=response.drift_z,
95
+ )
96
+ return execute
97
+
98
+
99
+ def _make_eeg_executor(processed_dir: Path) -> Callable[[EEGPipelineInput], EEGPipelineOutput]:
100
+ """Closure factory: EEG pipeline, writes output under processed_dir."""
101
+ def execute(inp: EEGPipelineInput) -> EEGPipelineOutput:
102
+ from src.api.schemas import EEGRequest
103
+ from src.api import routes as api_routes
104
+ from fastapi import HTTPException
105
+ out_path = processed_dir / "eeg_features.parquet"
106
+ try:
107
+ response = api_routes.run_eeg(
108
+ EEGRequest(
109
+ input_path=inp.input_path,
110
+ output_path=str(out_path),
111
+ epoch_duration_s=inp.epoch_duration_s,
112
+ )
113
+ )
114
+ except HTTPException as e:
115
+ raise ValueError(f"eeg tool failed: {e.detail}") from e
116
+ return EEGPipelineOutput(
117
+ input_path=inp.input_path,
118
+ output_path=response.output_path,
119
+ rows=response.rows,
120
+ columns=response.columns,
121
+ duration_sec=response.duration_sec,
122
+ )
123
+ return execute
124
+
125
+
126
+ def _make_mri_executor(processed_dir: Path) -> Callable[[MRIPipelineInput], MRIPipelineOutput]:
127
+ """Closure factory: MRI pipeline, writes output under processed_dir."""
128
+ def execute(inp: MRIPipelineInput) -> MRIPipelineOutput:
129
+ from src.api.schemas import MRIRequest
130
+ from src.api import routes as api_routes
131
+ from fastapi import HTTPException
132
+ out_path = processed_dir / "mri_features.parquet"
133
+ try:
134
+ response = api_routes.run_mri(
135
+ MRIRequest(
136
+ input_dir=inp.input_dir,
137
+ sites_csv=inp.sites_csv,
138
+ output_path=str(out_path),
139
+ )
140
+ )
141
+ except HTTPException as e:
142
+ raise ValueError(f"mri tool failed: {e.detail}") from e
143
+ return MRIPipelineOutput(
144
+ input_dir=inp.input_dir,
145
+ output_path=response.output_path,
146
+ rows=response.rows,
147
+ columns=response.columns,
148
+ duration_sec=response.duration_sec,
149
+ )
150
+ return execute
151
+
152
+
153
+ def _make_retrieve_executor(rag_index_dir: Path | None) -> Callable[[RetrieveContextInput], RetrieveContextOutput]:
154
+ """Closure: capture the index dir; lazy-load the retriever on first call."""
155
+ state: dict[str, Any] = {"retriever": None}
156
+
157
+ def execute(inp: RetrieveContextInput) -> RetrieveContextOutput:
158
+ if rag_index_dir is None or not (rag_index_dir / "index.bin").exists():
159
+ return RetrieveContextOutput(query=inp.query, chunks=[])
160
+ if state["retriever"] is None:
161
+ from src.rag.retrieve import RAGRetriever
162
+ state["retriever"] = RAGRetriever.load(rag_index_dir)
163
+ hits = state["retriever"].search(inp.query, k=inp.k)
164
+ return RetrieveContextOutput(query=inp.query, chunks=hits)
165
+
166
+ return execute
167
+
168
+
169
+ def build_default_tools(
170
+ rag_index_dir: Path | None,
171
+ processed_dir: Path = Path("data/processed"),
172
+ ) -> list[Tool]:
173
+ """Return the 4 tools the orchestrator gets by default."""
174
+ return [
175
+ Tool(
176
+ name="run_bbb_pipeline",
177
+ description=(
178
+ "Predict blood-brain-barrier permeability for a SINGLE SMILES "
179
+ "string. Use this when the user input looks like a molecule "
180
+ "(short alphanumeric string with no file extension, e.g. 'CCO', "
181
+ "'c1ccccc1'). Returns label, confidence, top SHAP features, drift."
182
+ ),
183
+ input_model=BBBPipelineInput,
184
+ output_model=BBBPipelineOutput,
185
+ execute=_make_bbb_executor(),
186
+ ),
187
+ Tool(
188
+ name="run_eeg_pipeline",
189
+ description=(
190
+ "Run the EEG signal-processing pipeline (bandpass + ICA + "
191
+ "epoching + feature extraction) on an EEG recording file. Use "
192
+ "when input_path ends in .fif or .edf. Returns row/column "
193
+ "counts + duration."
194
+ ),
195
+ input_model=EEGPipelineInput,
196
+ output_model=EEGPipelineOutput,
197
+ execute=_make_eeg_executor(processed_dir),
198
+ ),
199
+ Tool(
200
+ name="run_mri_pipeline",
201
+ description=(
202
+ "Run the multi-site MRI ComBat-harmonization pipeline. Use "
203
+ "when input is a directory containing .nii.gz volumes paired "
204
+ "with a sites.csv. Returns row/column counts + duration."
205
+ ),
206
+ input_model=MRIPipelineInput,
207
+ output_model=MRIPipelineOutput,
208
+ execute=_make_mri_executor(processed_dir),
209
+ ),
210
+ Tool(
211
+ name="retrieve_context",
212
+ description=(
213
+ "Retrieve up to k passages from the curated reference knowledge "
214
+ "base. Use AFTER a pipeline tool returns, to ground your final "
215
+ "synthesis in cited literature. Formulate a focused query "
216
+ "based on the pipeline output (e.g., 'BBB permeability of "
217
+ "small lipophilic molecules' or 'ComBat site harmonization')."
218
+ ),
219
+ input_model=RetrieveContextInput,
220
+ output_model=RetrieveContextOutput,
221
+ execute=_make_retrieve_executor(rag_index_dir),
222
+ ),
223
+ ]
src/api/main.py CHANGED
@@ -11,6 +11,7 @@ from src.api.routes import (
11
  predict_router,
12
  explain_router,
13
  experiments_router,
 
14
  )
15
  from src.api.schemas import HealthResponse
16
 
@@ -24,6 +25,7 @@ app.include_router(pipeline_router)
24
  app.include_router(predict_router)
25
  app.include_router(explain_router)
26
  app.include_router(experiments_router)
 
27
 
28
 
29
  @app.get("/health", response_model=HealthResponse)
@@ -100,3 +102,40 @@ def diag_openrouter() -> dict:
100
  out["probe"] = {"status": "ERR", "exception": type(e).__name__, "message": str(e)[:200]}
101
 
102
  return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  predict_router,
12
  explain_router,
13
  experiments_router,
14
+ agent_router,
15
  )
16
  from src.api.schemas import HealthResponse
17
 
 
25
  app.include_router(predict_router)
26
  app.include_router(explain_router)
27
  app.include_router(experiments_router)
28
+ app.include_router(agent_router)
29
 
30
 
31
  @app.get("/health", response_model=HealthResponse)
 
102
  out["probe"] = {"status": "ERR", "exception": type(e).__name__, "message": str(e)[:200]}
103
 
104
  return out
105
+
106
+
107
+ @app.get("/diag/agent")
108
+ def diag_agent() -> dict:
109
+ """Reachability probe for the orchestrator agent surface.
110
+
111
+ Reports key presence (length + 12-char prefix only — never the full
112
+ secret), the configured agent model, knowledge-base index status,
113
+ and the registered tool names.
114
+ """
115
+ import os as _os
116
+ from pathlib import Path as _Path
117
+
118
+ from src.agents.tools import build_default_tools
119
+
120
+ key = _os.environ.get("OPENROUTER_API_KEY") or ""
121
+ model = _os.environ.get("NEUROBRIDGE_AGENT_MODEL", "google/gemini-2.0-flash-exp:free")
122
+
123
+ rag_dir = _Path("data/processed/faiss_index")
124
+ rag_status: dict = {"index_dir": str(rag_dir), "exists": False, "chunk_count": 0}
125
+ if (rag_dir / "index.bin").exists() and (rag_dir / "chunks.json").exists():
126
+ rag_status["exists"] = True
127
+ try:
128
+ import json as _json
129
+ rag_status["chunk_count"] = len(_json.loads((rag_dir / "chunks.json").read_text()))
130
+ except Exception as e:
131
+ rag_status["error"] = f"chunks.json unreadable: {e}"
132
+
133
+ tools = build_default_tools(rag_index_dir=rag_dir if rag_status["exists"] else None)
134
+ return {
135
+ "has_key": bool(key),
136
+ "key_len": len(key),
137
+ "key_prefix": key[:12] if key else None,
138
+ "agent_model": model,
139
+ "rag": rag_status,
140
+ "tool_names": [t.name for t in tools],
141
+ }
src/api/routes.py CHANGED
@@ -18,6 +18,9 @@ import pandas as pd
18
  from fastapi import APIRouter, HTTPException
19
 
20
  from src.api.schemas import (
 
 
 
21
  BBBExplainRequest,
22
  BBBExplainResponse,
23
  BBBPredictRequest,
@@ -500,3 +503,63 @@ def diff_runs(req: RunDiffRequest) -> RunDiffResponse:
500
  )
501
  )
502
  return RunDiffResponse(rows=rows)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  from fastapi import APIRouter, HTTPException
19
 
20
  from src.api.schemas import (
21
+ AgentRunRequest,
22
+ AgentRunResponse,
23
+ AgentToolTraceItem,
24
  BBBExplainRequest,
25
  BBBExplainResponse,
26
  BBBPredictRequest,
 
503
  )
504
  )
505
  return RunDiffResponse(rows=rows)
506
+
507
+
508
+ # --- Agent router ----------------------------------------------------------
509
+
510
+ agent_router = APIRouter(prefix="/agent")
511
+
512
+
513
+ _DEFAULT_RAG_INDEX_DIR = Path("data/processed/faiss_index")
514
+ _AGENT_MODEL_ENV = "NEUROBRIDGE_AGENT_MODEL"
515
+ _AGENT_DEFAULT_MODEL = "google/gemini-2.0-flash-exp:free"
516
+
517
+
518
+ def _build_orchestrator():
519
+ """Construct the default orchestrator. Patchable in tests."""
520
+ from openai import OpenAI
521
+
522
+ from src.agents.orchestrator import Orchestrator
523
+ from src.agents.prompts import ORCHESTRATOR_SYSTEM_PROMPT
524
+ from src.agents.tools import build_default_tools
525
+
526
+ api_key = os.environ.get("OPENROUTER_API_KEY")
527
+ if not api_key:
528
+ raise HTTPException(
529
+ status_code=503,
530
+ detail="OPENROUTER_API_KEY not set; agent surface unavailable.",
531
+ )
532
+ client = OpenAI(
533
+ base_url="https://openrouter.ai/api/v1",
534
+ api_key=api_key,
535
+ timeout=30.0,
536
+ )
537
+ rag_dir = _DEFAULT_RAG_INDEX_DIR if _DEFAULT_RAG_INDEX_DIR.exists() else None
538
+ tools = build_default_tools(rag_index_dir=rag_dir)
539
+ model = os.environ.get(_AGENT_MODEL_ENV, _AGENT_DEFAULT_MODEL)
540
+ return Orchestrator(
541
+ llm_client=client,
542
+ tools=tools,
543
+ system_prompt=ORCHESTRATOR_SYSTEM_PROMPT,
544
+ model=model,
545
+ max_steps=5,
546
+ )
547
+
548
+
549
+ @agent_router.post("/run", response_model=AgentRunResponse)
550
+ def run_agent(req: AgentRunRequest) -> AgentRunResponse:
551
+ """Run the orchestrator on `user_input`. Picks a pipeline + grounds via RAG."""
552
+ orch = _build_orchestrator()
553
+ user_text = req.user_input
554
+ if req.user_question:
555
+ user_text = f"{req.user_input}\n\nUser question: {req.user_question}"
556
+ result = orch.run(user_text)
557
+ return AgentRunResponse(
558
+ text=result.text,
559
+ trace=[
560
+ AgentToolTraceItem(name=t.name, args=t.args, result=t.result, error=t.error)
561
+ for t in result.trace
562
+ ],
563
+ model=result.model,
564
+ finish_reason=result.finish_reason,
565
+ )
src/api/schemas.py CHANGED
@@ -228,3 +228,27 @@ class RunDiffRow(BaseModel):
228
  class RunDiffResponse(BaseModel):
229
  """Response for POST /experiments/diff: side-by-side metric/param diff."""
230
  rows: list[RunDiffRow]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  class RunDiffResponse(BaseModel):
229
  """Response for POST /experiments/diff: side-by-side metric/param diff."""
230
  rows: list[RunDiffRow]
231
+
232
+
233
+ # --- Agent surface (orchestrator + RAG) ------------------------------------
234
+
235
+ class AgentRunRequest(BaseModel):
236
+ """User input to the orchestrator."""
237
+ user_input: str = Field(..., min_length=1, description="SMILES, file path, or directory path")
238
+ user_question: str | None = Field(
239
+ None, description="Optional natural-language question to language-match the response"
240
+ )
241
+
242
+
243
+ class AgentToolTraceItem(BaseModel):
244
+ name: str
245
+ args: dict = Field(default_factory=dict)
246
+ result: dict | None = None
247
+ error: str | None = None
248
+
249
+
250
+ class AgentRunResponse(BaseModel):
251
+ text: str
252
+ trace: list[AgentToolTraceItem] = Field(default_factory=list)
253
+ model: str | None = None
254
+ finish_reason: str = "complete"
src/frontend/app.py CHANGED
@@ -935,9 +935,9 @@ def _check_api_health() -> tuple[bool, str]:
935
  return False, type(e).__name__.lower()
936
 
937
 
938
- def _post(endpoint: str, payload: dict) -> dict:
939
  """POST to the FastAPI surface; let httpx raise on non-2xx."""
940
- resp = httpx.post(f"{_API_URL}{endpoint}", json=payload, timeout=120.0)
941
  resp.raise_for_status()
942
  return resp.json()
943
 
@@ -1752,12 +1752,13 @@ def main() -> None:
1752
  "Run `uvicorn src.api.main:app --port 8000` or `docker compose up`."
1753
  )
1754
 
1755
- bbb_tab, eeg_tab, mri_tab, assistant_tab, experiments_tab = st.tabs([
1756
  "Molecule",
1757
  "Signal",
1758
  "Image",
1759
  "AI Assistant",
1760
  "Experiments",
 
1761
  ])
1762
 
1763
  with bbb_tab:
@@ -1771,6 +1772,55 @@ def main() -> None:
1771
  with experiments_tab:
1772
  _render_experiments_tab()
1773
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1774
 
1775
  if __name__ == "__main__":
1776
  main()
 
935
  return False, type(e).__name__.lower()
936
 
937
 
938
+ def _post(endpoint: str, payload: dict, timeout: float = 120.0) -> dict:
939
  """POST to the FastAPI surface; let httpx raise on non-2xx."""
940
+ resp = httpx.post(f"{_API_URL}{endpoint}", json=payload, timeout=timeout)
941
  resp.raise_for_status()
942
  return resp.json()
943
 
 
1752
  "Run `uvicorn src.api.main:app --port 8000` or `docker compose up`."
1753
  )
1754
 
1755
+ bbb_tab, eeg_tab, mri_tab, assistant_tab, experiments_tab, agent_tab = st.tabs([
1756
  "Molecule",
1757
  "Signal",
1758
  "Image",
1759
  "AI Assistant",
1760
  "Experiments",
1761
+ "🤖 Agent",
1762
  ])
1763
 
1764
  with bbb_tab:
 
1772
  with experiments_tab:
1773
  _render_experiments_tab()
1774
 
1775
+ with agent_tab:
1776
+ st.markdown("### Orchestrator Agent")
1777
+ st.caption(
1778
+ "Pick the pipeline automatically, run it, then ground the response "
1779
+ "in curated reference docs (RAG)."
1780
+ )
1781
+
1782
+ with st.form("agent_form"):
1783
+ agent_input = st.text_input(
1784
+ "Input",
1785
+ value="CCO",
1786
+ help="SMILES (e.g., CCO), .fif/.edf path, or NIfTI directory path",
1787
+ )
1788
+ agent_question = st.text_input(
1789
+ "Question (optional)",
1790
+ value="",
1791
+ help="Ask in any language — the agent will mirror it in the response",
1792
+ )
1793
+ submitted = st.form_submit_button("Run agent")
1794
+
1795
+ if submitted and agent_input:
1796
+ with st.spinner("Agent is reasoning..."):
1797
+ try:
1798
+ payload: dict = {"user_input": agent_input}
1799
+ if agent_question:
1800
+ payload["user_question"] = agent_question
1801
+ response = _post("/agent/run", payload, timeout=120.0)
1802
+ except Exception as e:
1803
+ st.error(f"Agent run failed: {e}")
1804
+ else:
1805
+ st.markdown("#### Response")
1806
+ st.write(response.get("text", ""))
1807
+ st.caption(
1808
+ f"model: `{response.get('model', '?')}` · "
1809
+ f"finish: `{response.get('finish_reason', '?')}`"
1810
+ )
1811
+ trace = response.get("trace", [])
1812
+ expander_title = f"🧠 Decision trace ({len(trace)} step{'s' if len(trace) != 1 else ''})"
1813
+ with st.expander(expander_title, expanded=True):
1814
+ if not trace:
1815
+ st.write("_(no tool calls)_")
1816
+ for i, step in enumerate(trace, start=1):
1817
+ st.markdown(f"**{i}. `{step['name']}`**")
1818
+ if step.get("error"):
1819
+ st.error(step["error"])
1820
+ else:
1821
+ st.json(step.get("args", {}))
1822
+ st.json(step.get("result", {}))
1823
+
1824
 
1825
  if __name__ == "__main__":
1826
  main()
src/rag/__init__.py ADDED
File without changes
src/rag/chunker.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Paragraph-aware recursive character splitter for RAG ingestion.
2
+
3
+ Public entry: `chunk_text(text, max_chars, overlap)`. Splits on the first
4
+ of [paragraph break, sentence end, newline, space] that fits inside the
5
+ window. Empty / whitespace-only inputs return [].
6
+ """
7
+ from __future__ import annotations
8
+
9
+
10
+ _SEPARATORS: tuple[str, ...] = ("\n\n", ". ", "\n", " ")
11
+
12
+
13
+ def chunk_text(text: str, max_chars: int = 600, overlap: int = 80) -> list[str]:
14
+ """Split `text` into chunks of at most `max_chars`, with `overlap` carry-over."""
15
+ text = text.strip()
16
+ if not text:
17
+ return []
18
+ if len(text) <= max_chars:
19
+ return [text]
20
+
21
+ chunks: list[str] = []
22
+ start = 0
23
+ n = len(text)
24
+ while start < n:
25
+ end = min(start + max_chars, n)
26
+ if end < n:
27
+ # try to land on a clean boundary inside [start, end]
28
+ for sep in _SEPARATORS:
29
+ last = text.rfind(sep, start, end)
30
+ if last > start:
31
+ end = last + len(sep)
32
+ break
33
+ chunk = text[start:end].strip()
34
+ if chunk:
35
+ chunks.append(chunk)
36
+ if end >= n:
37
+ break
38
+ start = max(start + 1, end - overlap)
39
+ return chunks
src/rag/embed.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Fastembed wrapper — ONNX-based, CPU-only, no torch dep.
2
+
3
+ Public entry: `Embedder().encode(texts) -> np.ndarray[N, D]`. Model is
4
+ loaded lazily on first call. Output is float32 to match FAISS's expected
5
+ input dtype.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import numpy as np
10
+
11
+ from src.core.logger import get_logger
12
+
13
+ logger = get_logger(__name__)
14
+
15
+
16
+ # bge-small-en-v1.5: 384-dim, ~33MB ONNX, MTEB top-tier for size class.
17
+ _MODEL_NAME = "BAAI/bge-small-en-v1.5"
18
+ EMBEDDING_DIM = 384
19
+
20
+
21
+ class Embedder:
22
+ """Lazy-loaded fastembed wrapper. One instance per process is enough."""
23
+
24
+ def __init__(self, model_name: str = _MODEL_NAME) -> None:
25
+ self._model_name = model_name
26
+ self._model = None # lazy-loaded on first encode()
27
+
28
+ def _ensure_model(self) -> None:
29
+ if self._model is None:
30
+ from fastembed import TextEmbedding
31
+ logger.info("Loading fastembed model %s (one-time)", self._model_name)
32
+ self._model = TextEmbedding(model_name=self._model_name)
33
+
34
+ def encode(self, texts: list[str]) -> np.ndarray:
35
+ if not texts:
36
+ return np.zeros((0, EMBEDDING_DIM), dtype=np.float32)
37
+ self._ensure_model()
38
+ embeddings = list(self._model.embed(texts))
39
+ return np.array(embeddings, dtype=np.float32)
src/rag/ingest.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Walk a knowledge-base directory, chunk each file, embed, persist FAISS index.
2
+
3
+ CLI entry point: `python -m src.rag.ingest [<input_dir> [<output_dir>]]`.
4
+ Defaults: input=`data/knowledge_base/`, output=`data/processed/faiss_index/`.
5
+
6
+ Supported file types: `.md`, `.txt`, `.pdf`. Other extensions are ignored
7
+ with a logged WARNING.
8
+ """
9
+ from __future__ import annotations
10
+
11
+ import sys
12
+ from pathlib import Path
13
+
14
+ from src.core.logger import get_logger
15
+ from src.rag.chunker import chunk_text
16
+ from src.rag.embed import EMBEDDING_DIM, Embedder
17
+ from src.rag.store import FAISSStore
18
+
19
+ logger = get_logger(__name__)
20
+
21
+
22
+ _DEFAULT_INPUT = Path("data/knowledge_base")
23
+ _DEFAULT_OUTPUT = Path("data/processed/faiss_index")
24
+ _SUPPORTED = {".md", ".txt", ".pdf"}
25
+
26
+
27
+ def _read_pdf(path: Path) -> str:
28
+ from pypdf import PdfReader
29
+ reader = PdfReader(str(path))
30
+ return "\n\n".join(page.extract_text() or "" for page in reader.pages)
31
+
32
+
33
+ def _read_file(path: Path) -> str:
34
+ suffix = path.suffix.lower()
35
+ if suffix == ".pdf":
36
+ return _read_pdf(path)
37
+ return path.read_text(encoding="utf-8", errors="replace")
38
+
39
+
40
+ def ingest_directory(input_dir: Path, output_dir: Path) -> int:
41
+ """Ingest every supported file in `input_dir` into a FAISS index at `output_dir`.
42
+
43
+ Returns the total number of chunks indexed.
44
+ """
45
+ input_dir = Path(input_dir)
46
+ output_dir = Path(output_dir)
47
+
48
+ files = sorted(p for p in input_dir.rglob("*") if p.suffix.lower() in _SUPPORTED)
49
+ logger.info("Ingesting %d file(s) from %s", len(files), input_dir)
50
+
51
+ all_chunks: list[dict] = []
52
+ for path in files:
53
+ try:
54
+ text = _read_file(path)
55
+ except Exception as e:
56
+ logger.warning("Skipping %s (read failed: %s)", path, e)
57
+ continue
58
+ for i, ch in enumerate(chunk_text(text)):
59
+ all_chunks.append({
60
+ "text": ch,
61
+ "source": str(path.relative_to(input_dir)),
62
+ "chunk_index": i,
63
+ })
64
+
65
+ store = FAISSStore(dim=EMBEDDING_DIM)
66
+ if all_chunks:
67
+ embedder = Embedder()
68
+ vectors = embedder.encode([c["text"] for c in all_chunks])
69
+ store.add(vectors, all_chunks)
70
+
71
+ store.save(output_dir)
72
+ logger.info("Indexed %d chunk(s) → %s", len(all_chunks), output_dir)
73
+ return len(all_chunks)
74
+
75
+
76
+ def main() -> None:
77
+ args = sys.argv[1:]
78
+ inp = Path(args[0]) if len(args) >= 1 else _DEFAULT_INPUT
79
+ out = Path(args[1]) if len(args) >= 2 else _DEFAULT_OUTPUT
80
+ ingest_directory(inp, out)
81
+ # Per-call summary already logged at INFO inside ingest_directory; no print() in src/.
82
+
83
+
84
+ if __name__ == "__main__":
85
+ main()
src/rag/retrieve.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Query → top-k chunks. Encapsulates the embedder + store pair so callers
2
+ don't have to assemble both. Loads from disk lazily.
3
+ """
4
+ from __future__ import annotations
5
+
6
+ from pathlib import Path
7
+
8
+ from src.core.logger import get_logger
9
+ from src.rag.embed import EMBEDDING_DIM, Embedder
10
+ from src.rag.store import FAISSStore
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
15
+ class RAGRetriever:
16
+ """Bundle (embedder, store). Use `RAGRetriever.load(dir)` to construct."""
17
+
18
+ def __init__(self, store: FAISSStore, embedder: Embedder) -> None:
19
+ self._store = store
20
+ self._embedder = embedder
21
+
22
+ @classmethod
23
+ def load(cls, index_dir: Path) -> "RAGRetriever":
24
+ store = FAISSStore.load(Path(index_dir), dim=EMBEDDING_DIM)
25
+ return cls(store=store, embedder=Embedder())
26
+
27
+ def __len__(self) -> int:
28
+ return len(self._store)
29
+
30
+ def search(self, query: str, k: int = 5) -> list[dict]:
31
+ """Return up to `k` chunks most relevant to `query`, sorted by score desc.
32
+
33
+ Each chunk dict carries `text`, `source`, `chunk_index`, `score`.
34
+ Returns [] for empty query or empty store.
35
+ """
36
+ if not query.strip() or len(self._store) == 0:
37
+ return []
38
+ vec = self._embedder.encode([query])
39
+ hits = self._store.search(vec[0], k=k)
40
+ return [{**chunk, "score": score} for chunk, score in hits]
src/rag/store.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FAISS vector store with parallel chunk metadata.
2
+
3
+ Public entry: `FAISSStore(dim)`. Vectors are L2-normalized on add and
4
+ search so inner-product == cosine similarity. Chunks are arbitrary dicts;
5
+ `text` and `source` keys are recommended but not enforced.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ from pathlib import Path
11
+ from typing import Any
12
+
13
+ import faiss
14
+ import numpy as np
15
+
16
+
17
+ class FAISSStore:
18
+ """Inner-product (cosine after L2-norm) FAISS store with chunk metadata."""
19
+
20
+ def __init__(self, dim: int) -> None:
21
+ self.dim = dim
22
+ self._index: faiss.Index = faiss.IndexFlatIP(dim)
23
+ self._chunks: list[dict[str, Any]] = []
24
+
25
+ def __len__(self) -> int:
26
+ return len(self._chunks)
27
+
28
+ def add(self, vectors: np.ndarray, chunks: list[dict[str, Any]]) -> None:
29
+ if vectors.shape[0] != len(chunks):
30
+ raise ValueError(
31
+ f"size mismatch: {vectors.shape[0]} vectors vs {len(chunks)} chunks"
32
+ )
33
+ if vectors.shape[0] == 0:
34
+ return
35
+ v = np.array(vectors, dtype=np.float32, copy=True)
36
+ faiss.normalize_L2(v)
37
+ self._index.add(v)
38
+ self._chunks.extend(chunks)
39
+
40
+ def search(self, query: np.ndarray, k: int = 5) -> list[tuple[dict[str, Any], float]]:
41
+ if len(self._chunks) == 0:
42
+ return []
43
+ q = np.array(query, dtype=np.float32, copy=True)
44
+ if q.ndim == 1:
45
+ q = q[np.newaxis, :]
46
+ faiss.normalize_L2(q)
47
+ k = min(k, len(self._chunks))
48
+ scores, idx = self._index.search(q, k)
49
+ out: list[tuple[dict[str, Any], float]] = []
50
+ for i, s in zip(idx[0], scores[0]):
51
+ if i == -1:
52
+ continue
53
+ out.append((self._chunks[int(i)], float(s)))
54
+ return out
55
+
56
+ def save(self, dir_path: Path) -> None:
57
+ dir_path.mkdir(parents=True, exist_ok=True)
58
+ faiss.write_index(self._index, str(dir_path / "index.bin"))
59
+ (dir_path / "chunks.json").write_text(json.dumps(self._chunks, indent=2))
60
+
61
+ @classmethod
62
+ def load(cls, dir_path: Path, dim: int) -> "FAISSStore":
63
+ store = cls(dim=dim)
64
+ store._index = faiss.read_index(str(dir_path / "index.bin"))
65
+ store._chunks = json.loads((dir_path / "chunks.json").read_text())
66
+ return store
tests/agents/__init__.py ADDED
File without changes
tests/agents/test_agent_route.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for POST /agent/run — uses a stub orchestrator factory."""
2
+ from __future__ import annotations
3
+
4
+ from typing import Any
5
+ from unittest.mock import patch
6
+
7
+ import pytest
8
+ from fastapi.testclient import TestClient
9
+
10
+ from src.agents.schemas import AgentResult, ToolTraceItem
11
+ from src.api.main import app
12
+
13
+
14
+ client = TestClient(app)
15
+
16
+
17
+ class _FakeOrchestrator:
18
+ """Returns a canned AgentResult; ignores input."""
19
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
20
+ pass
21
+
22
+ def run(self, user_input: str) -> AgentResult:
23
+ return AgentResult(
24
+ text=f"Synthesized answer for: {user_input}",
25
+ trace=[
26
+ ToolTraceItem(name="run_bbb_pipeline", args={"smiles": user_input},
27
+ result={"label": 1, "label_text": "permeable"}),
28
+ ToolTraceItem(name="retrieve_context", args={"query": "BBB"},
29
+ result={"chunks": []}),
30
+ ],
31
+ model="stub-model",
32
+ finish_reason="complete",
33
+ )
34
+
35
+
36
+ class TestAgentRoute:
37
+ def test_post_returns_synthesized_text_and_trace(self) -> None:
38
+ with patch("src.api.routes._build_orchestrator", return_value=_FakeOrchestrator()):
39
+ r = client.post("/agent/run", json={"user_input": "CCO"})
40
+ assert r.status_code == 200
41
+ body = r.json()
42
+ assert "Synthesized answer for: CCO" in body["text"]
43
+ assert len(body["trace"]) == 2
44
+ assert body["trace"][0]["name"] == "run_bbb_pipeline"
45
+ assert body["model"] == "stub-model"
46
+ assert body["finish_reason"] == "complete"
47
+
48
+ def test_empty_user_input_422(self) -> None:
49
+ r = client.post("/agent/run", json={"user_input": ""})
50
+ assert r.status_code == 422
51
+
52
+ def test_missing_user_input_422(self) -> None:
53
+ r = client.post("/agent/run", json={})
54
+ assert r.status_code == 422
tests/agents/test_orchestrator.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for src.agents.orchestrator — agent loop with stubbed LLM client.
2
+
3
+ We do NOT hit OpenRouter here. We construct a fake client that returns
4
+ scripted tool-call responses, then verify the orchestrator dispatches
5
+ tools and assembles the trace correctly.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ from typing import Any
11
+ from unittest.mock import MagicMock
12
+
13
+ import pytest
14
+ from pydantic import BaseModel
15
+
16
+ from src.agents.orchestrator import Orchestrator
17
+ from src.agents.tools import Tool
18
+
19
+
20
+ # --- Helpers ----------------------------------------------------------------
21
+
22
+
23
+ def _fake_choice_with_tool_call(name: str, args: dict[str, Any], call_id: str = "c1") -> Any:
24
+ msg = MagicMock()
25
+ msg.content = None
26
+ tc = MagicMock()
27
+ tc.id = call_id
28
+ tc.function.name = name
29
+ tc.function.arguments = json.dumps(args)
30
+ tc.model_dump = MagicMock(return_value={"id": call_id, "type": "function",
31
+ "function": {"name": name,
32
+ "arguments": json.dumps(args)}})
33
+ msg.tool_calls = [tc]
34
+ choice = MagicMock()
35
+ choice.message = msg
36
+ response = MagicMock()
37
+ response.choices = [choice]
38
+ return response
39
+
40
+
41
+ def _fake_choice_with_text(text: str) -> Any:
42
+ msg = MagicMock()
43
+ msg.content = text
44
+ msg.tool_calls = None
45
+ choice = MagicMock()
46
+ choice.message = msg
47
+ response = MagicMock()
48
+ response.choices = [choice]
49
+ return response
50
+
51
+
52
+ class _PingInput(BaseModel):
53
+ msg: str
54
+
55
+
56
+ class _PingOutput(BaseModel):
57
+ echo: str
58
+
59
+
60
+ def _make_ping_tool() -> Tool:
61
+ return Tool(
62
+ name="ping",
63
+ description="Echo a string back.",
64
+ input_model=_PingInput,
65
+ output_model=_PingOutput,
66
+ execute=lambda inp: _PingOutput(echo=f"pong:{inp.msg}"),
67
+ )
68
+
69
+
70
+ # --- Tests ------------------------------------------------------------------
71
+
72
+
73
+ class TestOrchestrator:
74
+ def test_single_tool_then_text_response(self) -> None:
75
+ client = MagicMock()
76
+ client.chat.completions.create.side_effect = [
77
+ _fake_choice_with_tool_call("ping", {"msg": "hello"}),
78
+ _fake_choice_with_text("All done."),
79
+ ]
80
+ orch = Orchestrator(
81
+ llm_client=client,
82
+ tools=[_make_ping_tool()],
83
+ system_prompt="sys",
84
+ model="stub-model",
85
+ max_steps=4,
86
+ )
87
+ result = orch.run("test input")
88
+ assert result.text == "All done."
89
+ assert result.finish_reason == "complete"
90
+ assert len(result.trace) == 1
91
+ assert result.trace[0].name == "ping"
92
+ assert result.trace[0].args == {"msg": "hello"}
93
+ assert result.trace[0].result == {"echo": "pong:hello"}
94
+
95
+ def test_unknown_tool_recorded_as_error(self) -> None:
96
+ client = MagicMock()
97
+ client.chat.completions.create.side_effect = [
98
+ _fake_choice_with_tool_call("nonexistent_tool", {"x": 1}),
99
+ _fake_choice_with_text("Done."),
100
+ ]
101
+ orch = Orchestrator(
102
+ llm_client=client,
103
+ tools=[_make_ping_tool()],
104
+ system_prompt="sys",
105
+ model="stub-model",
106
+ max_steps=4,
107
+ )
108
+ result = orch.run("test")
109
+ assert result.trace[0].error is not None
110
+ assert "unknown tool" in result.trace[0].error
111
+ assert result.text == "Done."
112
+
113
+ def test_invalid_tool_args_recorded_as_error(self) -> None:
114
+ client = MagicMock()
115
+ client.chat.completions.create.side_effect = [
116
+ _fake_choice_with_tool_call("ping", {"wrong_field": "x"}),
117
+ _fake_choice_with_text("Recovered."),
118
+ ]
119
+ orch = Orchestrator(
120
+ llm_client=client,
121
+ tools=[_make_ping_tool()],
122
+ system_prompt="sys",
123
+ model="stub-model",
124
+ max_steps=4,
125
+ )
126
+ result = orch.run("test")
127
+ assert result.trace[0].error is not None
128
+ assert result.text == "Recovered."
129
+
130
+ def test_max_steps_exhausted_returns_finish_reason(self) -> None:
131
+ client = MagicMock()
132
+ # Always return another tool call — never terminates with text
133
+ client.chat.completions.create.side_effect = [
134
+ _fake_choice_with_tool_call("ping", {"msg": f"{i}"}, call_id=f"c{i}")
135
+ for i in range(10)
136
+ ]
137
+ orch = Orchestrator(
138
+ llm_client=client,
139
+ tools=[_make_ping_tool()],
140
+ system_prompt="sys",
141
+ model="stub-model",
142
+ max_steps=3,
143
+ )
144
+ result = orch.run("test")
145
+ assert result.finish_reason == "max_steps"
146
+ assert len(result.trace) == 3
147
+
148
+ def test_first_response_is_text_no_tools(self) -> None:
149
+ client = MagicMock()
150
+ client.chat.completions.create.side_effect = [
151
+ _fake_choice_with_text("Direct answer."),
152
+ ]
153
+ orch = Orchestrator(
154
+ llm_client=client,
155
+ tools=[_make_ping_tool()],
156
+ system_prompt="sys",
157
+ model="stub-model",
158
+ )
159
+ result = orch.run("trivial input")
160
+ assert result.text == "Direct answer."
161
+ assert result.trace == []
tests/agents/test_orchestrator_live.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Live integration test — hits real OpenRouter, picks pipeline, retrieves chunks.
2
+
3
+ Skipped unless BOTH OPENROUTER_API_KEY is set AND the BBB model artifact
4
+ is built (the `run_bbb_pipeline` tool can't run without it). Marked `slow`
5
+ (network round-trips).
6
+
7
+ The dual gate matters because src/llm/explainer.py auto-loads .env at
8
+ import time; without the model-artifact gate, this test would attempt a
9
+ real OpenRouter call in CI/dev and then fail because the BBB tool can't
10
+ execute. In the deployed Docker image both conditions are satisfied
11
+ (secret + build-time training).
12
+ """
13
+ from __future__ import annotations
14
+
15
+ import os
16
+ from pathlib import Path
17
+
18
+ import pytest
19
+ from openai import OpenAI
20
+
21
+ from src.agents.orchestrator import Orchestrator
22
+ from src.agents.prompts import ORCHESTRATOR_SYSTEM_PROMPT
23
+ from src.agents.tools import build_default_tools
24
+ from src.rag.ingest import ingest_directory
25
+
26
+
27
+ _FIXTURE_KB = Path(__file__).parent.parent / "fixtures" / "kb_sample"
28
+ _DEFAULT_MODEL = "google/gemini-2.0-flash-exp:free"
29
+ _FALLBACK_MODEL = "anthropic/claude-haiku-4-5"
30
+ _BBB_MODEL_PATH = Path(
31
+ os.environ.get("BBB_MODEL_PATH", "data/processed/bbb_model.joblib")
32
+ )
33
+
34
+
35
+ @pytest.mark.slow
36
+ @pytest.mark.skipif(
37
+ not os.environ.get("OPENROUTER_API_KEY"),
38
+ reason="OPENROUTER_API_KEY not set",
39
+ )
40
+ @pytest.mark.skipif(
41
+ not _BBB_MODEL_PATH.exists(),
42
+ reason=f"BBB model artifact missing at {_BBB_MODEL_PATH} — run python -m src.models.bbb_model",
43
+ )
44
+ class TestOrchestratorLive:
45
+ @pytest.fixture(scope="class")
46
+ def rag_dir(self, tmp_path_factory: pytest.TempPathFactory) -> Path:
47
+ d = tmp_path_factory.mktemp("rag_live")
48
+ ingest_directory(_FIXTURE_KB, d)
49
+ return d
50
+
51
+ @pytest.fixture(scope="class")
52
+ def client(self) -> OpenAI:
53
+ return OpenAI(
54
+ base_url="https://openrouter.ai/api/v1",
55
+ api_key=os.environ["OPENROUTER_API_KEY"],
56
+ timeout=30.0,
57
+ )
58
+
59
+ def test_smiles_input_picks_bbb_then_retrieves(self, client: OpenAI, rag_dir: Path) -> None:
60
+ tools = build_default_tools(rag_index_dir=rag_dir)
61
+ orch = Orchestrator(
62
+ llm_client=client,
63
+ tools=tools,
64
+ system_prompt=ORCHESTRATOR_SYSTEM_PROMPT,
65
+ model=os.environ.get("NEUROBRIDGE_AGENT_MODEL", _DEFAULT_MODEL),
66
+ max_steps=5,
67
+ )
68
+ result = orch.run("CCO")
69
+ # Soft assertions — model behavior varies but the workflow shape is fixed.
70
+ assert result.finish_reason == "complete", f"got {result.finish_reason}, trace={result.trace}"
71
+ tool_names = [t.name for t in result.trace]
72
+ assert "run_bbb_pipeline" in tool_names, f"BBB pipeline not called; trace={tool_names}"
73
+ assert "retrieve_context" in tool_names, f"RAG not called; trace={tool_names}"
74
+ assert result.text, "empty final text"
tests/agents/test_tools.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for src.agents.tools — Tool dataclass + registry + 4 tool wrappers."""
2
+ from __future__ import annotations
3
+
4
+ from pathlib import Path
5
+
6
+ import pytest
7
+ from pydantic import BaseModel
8
+
9
+ from src.agents.tools import (
10
+ Tool,
11
+ build_default_tools,
12
+ BBBPipelineInput,
13
+ EEGPipelineInput,
14
+ MRIPipelineInput,
15
+ RetrieveContextInput,
16
+ )
17
+
18
+
19
+ class _DummyInput(BaseModel):
20
+ x: int
21
+ y: str = "default"
22
+
23
+
24
+ class _DummyOutput(BaseModel):
25
+ result: int
26
+
27
+
28
+ class TestTool:
29
+ def test_openai_schema_shape(self) -> None:
30
+ tool = Tool(
31
+ name="dummy",
32
+ description="A dummy tool",
33
+ input_model=_DummyInput,
34
+ output_model=_DummyOutput,
35
+ execute=lambda inp: _DummyOutput(result=inp.x * 2),
36
+ )
37
+ schema = tool.openai_schema()
38
+ assert schema["type"] == "function"
39
+ assert schema["function"]["name"] == "dummy"
40
+ assert schema["function"]["description"] == "A dummy tool"
41
+ params = schema["function"]["parameters"]
42
+ assert params["type"] == "object"
43
+ assert "x" in params["properties"]
44
+ assert "x" in params["required"]
45
+ assert "y" not in params["required"] # has default
46
+
47
+ def test_invoke_validates_and_returns_dict(self) -> None:
48
+ tool = Tool(
49
+ name="dummy",
50
+ description="d",
51
+ input_model=_DummyInput,
52
+ output_model=_DummyOutput,
53
+ execute=lambda inp: _DummyOutput(result=inp.x * 2),
54
+ )
55
+ out = tool.invoke({"x": 5})
56
+ assert out == {"result": 10}
57
+
58
+ def test_invoke_invalid_input_raises(self) -> None:
59
+ tool = Tool(
60
+ name="dummy",
61
+ description="d",
62
+ input_model=_DummyInput,
63
+ output_model=_DummyOutput,
64
+ execute=lambda inp: _DummyOutput(result=inp.x * 2),
65
+ )
66
+ with pytest.raises(ValueError, match="invalid input"):
67
+ tool.invoke({"y": "missing-x"})
68
+
69
+
70
+ class TestBuildDefaultTools:
71
+ def test_default_set_has_four_tools(self, tmp_path: Path) -> None:
72
+ # build with placeholder paths; tools won't be invoked here
73
+ tools = build_default_tools(rag_index_dir=None)
74
+ names = {t.name for t in tools}
75
+ assert names == {
76
+ "run_bbb_pipeline",
77
+ "run_eeg_pipeline",
78
+ "run_mri_pipeline",
79
+ "retrieve_context",
80
+ }
81
+
82
+ def test_each_tool_has_pydantic_input_model(self) -> None:
83
+ tools = build_default_tools(rag_index_dir=None)
84
+ for t in tools:
85
+ assert issubclass(t.input_model, BaseModel)
86
+ assert issubclass(t.output_model, BaseModel)
87
+
88
+ def test_input_models_have_smiles_paths(self) -> None:
89
+ # verify the field names downstream system prompt depends on
90
+ assert "smiles" in BBBPipelineInput.model_fields
91
+ assert "input_path" in EEGPipelineInput.model_fields
92
+ assert "input_dir" in MRIPipelineInput.model_fields
93
+ assert "sites_csv" in MRIPipelineInput.model_fields
94
+ assert "query" in RetrieveContextInput.model_fields
95
+ assert "k" in RetrieveContextInput.model_fields
96
+
97
+ def test_retrieve_context_short_circuits_when_no_index(self) -> None:
98
+ tools = build_default_tools(rag_index_dir=None)
99
+ retrieve = next(t for t in tools if t.name == "retrieve_context")
100
+ out = retrieve.invoke({"query": "anything", "k": 3})
101
+ assert out == {"query": "anything", "chunks": []}
102
+
103
+ def test_processed_dir_parameter_threads_to_executors(self, tmp_path: Path) -> None:
104
+ # build_default_tools should accept processed_dir; executors should
105
+ # eventually write under it (we don't invoke the pipelines here, just
106
+ # verify the parameter is accepted and tools are built).
107
+ tools = build_default_tools(rag_index_dir=None, processed_dir=tmp_path)
108
+ names = {t.name for t in tools}
109
+ assert "run_eeg_pipeline" in names
110
+ assert "run_mri_pipeline" in names
111
+
112
+ def test_default_processed_dir_when_omitted(self) -> None:
113
+ # backwards-compat: omitting processed_dir keeps existing behavior
114
+ tools = build_default_tools(rag_index_dir=None)
115
+ # just ensure no exception and 4 tools returned
116
+ assert len(tools) == 4
117
+
118
+ def test_bbb_executor_translates_httpexception_to_valueerror(self) -> None:
119
+ from unittest.mock import patch
120
+ from fastapi import HTTPException
121
+
122
+ tools = build_default_tools(rag_index_dir=None)
123
+ bbb = next(t for t in tools if t.name == "run_bbb_pipeline")
124
+
125
+ with patch("src.api.routes.predict_bbb",
126
+ side_effect=HTTPException(status_code=503, detail="model missing")):
127
+ with pytest.raises(ValueError, match="bbb tool failed"):
128
+ bbb.invoke({"smiles": "CCO"})
tests/fixtures/kb_sample/combat_harmonization_primer.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ComBat Harmonization for Multi-Site Neuroimaging
2
+
3
+ ComBat (Johnson et al. 2007, adapted to MRI by Fortin et al. 2017, 2018)
4
+ is the de-facto standard for removing scanner / acquisition-site bias
5
+ from multi-center neuroimaging studies.
6
+
7
+ ## How it works
8
+
9
+ ComBat models per-site location (mean) and scale (variance) parameters
10
+ using an empirical-Bayes hierarchical framework. It estimates these
11
+ parameters jointly across all sites and shrinks them toward a global
12
+ prior — small-N sites are pulled toward the global mean, preventing
13
+ overfitting.
14
+
15
+ ## Site-gap reduction
16
+
17
+ A typical demonstration: the per-site mean of a hippocampus volume
18
+ feature can vary by 5+ standard deviations across hospitals. ComBat
19
+ typically collapses this gap to <0.005 — a 1000x+ reduction — while
20
+ preserving within-site biological variance (age, sex, diagnosis).
21
+
22
+ ## When it fails
23
+
24
+ ComBat requires at least 2 sites with overlapping covariate
25
+ distributions. Single-site data, or sites with completely disjoint
26
+ populations (e.g., one site only-pediatric, another only-elderly),
27
+ produce unreliable harmonization.
tests/fixtures/kb_sample/lipinski_rule_of_five.md ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Lipinski's Rule of Five — BBB Permeability Heuristic
2
+
3
+ Lipinski's Rule of Five (Lipinski 1997, 2001) is the foundational
4
+ medicinal-chemistry rule for predicting whether a small molecule will
5
+ cross the blood-brain barrier (BBB) by passive diffusion.
6
+
7
+ ## The four criteria
8
+
9
+ A molecule is likely BBB-permeable if it satisfies all four:
10
+
11
+ 1. Molecular weight (MW) <= 500 Daltons
12
+ 2. Octanol-water partition coefficient (logP) <= 5
13
+ 3. Hydrogen-bond donors <= 5
14
+ 4. Hydrogen-bond acceptors <= 10
15
+
16
+ Molecules violating two or more criteria are typically poorly absorbed
17
+ or impermeant.
18
+
19
+ ## Why ethanol crosses
20
+
21
+ Ethanol (CCO) has MW=46 Da, logP=-0.31, 1 H-bond donor, 1 H-bond
22
+ acceptor — well within all four thresholds. This explains its rapid
23
+ CNS penetration despite hydrophilicity.
24
+
25
+ ## SHAP attribution interpretation
26
+
27
+ When a Random Forest BBB classifier flags Morgan fingerprint bits with
28
+ positive SHAP values toward a "permeable" label, the bit usually
29
+ corresponds to a small lipophilic substructure (CH3-, -OCH3-, aromatic
30
+ ring) consistent with Lipinski compliance.
tests/fixtures/kb_sample/mne_ica_basics.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MNE-Python ICA for EEG Artifact Removal
2
+
3
+ Independent Component Analysis (ICA, Hyvärinen 1999) decomposes a
4
+ multi-channel EEG recording into statistically independent source
5
+ components. It is the de-facto method for removing eye-blink and
6
+ heartbeat artifacts before downstream analysis.
7
+
8
+ ## Why ICA, not PCA
9
+
10
+ PCA decomposes signals into orthogonal components — but neural sources
11
+ are not orthogonal in scalp space, they are statistically independent.
12
+ ICA's independence assumption matches the physics: the eye, the heart,
13
+ and cortical sources fire on uncorrelated schedules.
14
+
15
+ ## The standard workflow
16
+
17
+ 1. Bandpass the raw recording at 0.5-40 Hz to remove DC drift and line
18
+ noise (50/60 Hz).
19
+ 2. Fit ICA with N components (typically 15-30, less than channel count).
20
+ 3. Identify artifact components by correlating each ICA source with the
21
+ EOG (eye) channel; reject components with |correlation| > 0.5.
22
+ 4. Reconstruct the cleaned signal by zeroing out the rejected
23
+ components and inverse-transforming.
24
+
25
+ ## Quality check
26
+
27
+ Post-ICA, the EOG channel should show minimal residual correlation
28
+ with frontal channels (Fp1/Fp2). If it doesn't, the ICA fit was likely
29
+ unstable — re-run with a different random seed or more components.
tests/rag/__init__.py ADDED
File without changes
tests/rag/test_chunker.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for src.rag.chunker — paragraph-aware character splitter."""
2
+ from __future__ import annotations
3
+
4
+ import pytest
5
+
6
+ from src.rag.chunker import chunk_text
7
+
8
+
9
+ class TestChunkText:
10
+ def test_short_text_returns_single_chunk(self) -> None:
11
+ out = chunk_text("hello world", max_chars=100, overlap=10)
12
+ assert out == ["hello world"]
13
+
14
+ def test_empty_text_returns_empty_list(self) -> None:
15
+ assert chunk_text("", max_chars=100, overlap=10) == []
16
+ assert chunk_text(" \n\n ", max_chars=100, overlap=10) == []
17
+
18
+ def test_long_text_splits_into_multiple_chunks(self) -> None:
19
+ text = "a" * 250
20
+ out = chunk_text(text, max_chars=100, overlap=10)
21
+ assert len(out) >= 3
22
+ # every chunk respects max_chars
23
+ for c in out:
24
+ assert len(c) <= 100
25
+
26
+ def test_overlap_between_chunks(self) -> None:
27
+ text = "abcdefghij" * 30 # 300 chars, no natural break
28
+ out = chunk_text(text, max_chars=100, overlap=20)
29
+ # consecutive chunks share at least some characters
30
+ for i in range(len(out) - 1):
31
+ assert out[i][-10:] in out[i + 1] or out[i + 1][:10] in out[i]
32
+
33
+ def test_paragraph_boundary_preferred(self) -> None:
34
+ # First paragraph fits, second doesn't — split at \n\n
35
+ para_a = "First paragraph content."
36
+ para_b = "Second paragraph content " * 10
37
+ text = f"{para_a}\n\n{para_b}"
38
+ out = chunk_text(text, max_chars=100, overlap=10)
39
+ # first chunk should end at the paragraph boundary, not mid-word
40
+ assert para_a in out[0]
tests/rag/test_embed.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for src.rag.embed — fastembed wrapper."""
2
+ from __future__ import annotations
3
+
4
+ import numpy as np
5
+ import pytest
6
+
7
+ from src.rag.embed import Embedder, EMBEDDING_DIM
8
+
9
+
10
+ class TestEmbedder:
11
+ @pytest.fixture(scope="class")
12
+ def embedder(self) -> Embedder:
13
+ return Embedder()
14
+
15
+ def test_dim_constant_matches_model(self, embedder: Embedder) -> None:
16
+ out = embedder.encode(["hello"])
17
+ assert out.shape == (1, EMBEDDING_DIM)
18
+
19
+ def test_batch_encoding(self, embedder: Embedder) -> None:
20
+ out = embedder.encode(["hello", "world", "blood-brain barrier"])
21
+ assert out.shape == (3, EMBEDDING_DIM)
22
+ assert out.dtype == np.float32
23
+
24
+ def test_empty_list_returns_empty_array(self, embedder: Embedder) -> None:
25
+ out = embedder.encode([])
26
+ assert out.shape == (0, EMBEDDING_DIM)
27
+
28
+ def test_similar_strings_have_higher_similarity_than_dissimilar(
29
+ self, embedder: Embedder
30
+ ) -> None:
31
+ vecs = embedder.encode([
32
+ "blood-brain barrier permeability",
33
+ "BBB drug penetration",
34
+ "MRI multi-site harmonization",
35
+ ])
36
+ # cosine similarity (vectors should be normalized for stable comparison)
37
+ from numpy.linalg import norm
38
+ def cos(a, b):
39
+ return float(np.dot(a, b) / (norm(a) * norm(b)))
40
+ sim_ab = cos(vecs[0], vecs[1])
41
+ sim_ac = cos(vecs[0], vecs[2])
42
+ assert sim_ab > sim_ac, f"Expected BBB-related strings closer; got {sim_ab=} vs {sim_ac=}"
tests/rag/test_ingest.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for src.rag.ingest — walk a directory, chunk, embed, persist."""
2
+ from __future__ import annotations
3
+
4
+ import shutil
5
+ from pathlib import Path
6
+
7
+ import pytest
8
+
9
+ from src.rag.ingest import ingest_directory
10
+ from src.rag.store import FAISSStore
11
+
12
+
13
+ _FIXTURE_KB = Path(__file__).parent.parent / "fixtures" / "kb_sample"
14
+
15
+
16
+ class TestIngestDirectory:
17
+ def test_ingests_markdown_files(self, tmp_path: Path) -> None:
18
+ out_dir = tmp_path / "idx"
19
+ n = ingest_directory(_FIXTURE_KB, out_dir)
20
+ assert n > 0 # at least one chunk per fixture file
21
+ assert (out_dir / "index.bin").exists()
22
+ assert (out_dir / "chunks.json").exists()
23
+
24
+ def test_loaded_store_is_searchable(self, tmp_path: Path) -> None:
25
+ out_dir = tmp_path / "idx"
26
+ ingest_directory(_FIXTURE_KB, out_dir)
27
+ from src.rag.embed import EMBEDDING_DIM
28
+ store = FAISSStore.load(out_dir, dim=EMBEDDING_DIM)
29
+ assert len(store) > 0
30
+ # chunks have source metadata
31
+ assert all("source" in c for c in store._chunks)
32
+ assert all("text" in c for c in store._chunks)
33
+
34
+ def test_empty_directory_creates_empty_index(self, tmp_path: Path) -> None:
35
+ empty = tmp_path / "empty_kb"
36
+ empty.mkdir()
37
+ out_dir = tmp_path / "idx"
38
+ n = ingest_directory(empty, out_dir)
39
+ assert n == 0
40
+ assert (out_dir / "index.bin").exists()
tests/rag/test_retrieve.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for src.rag.retrieve — query → top-k chunks."""
2
+ from __future__ import annotations
3
+
4
+ from pathlib import Path
5
+
6
+ import pytest
7
+
8
+ from src.rag.ingest import ingest_directory
9
+ from src.rag.retrieve import RAGRetriever
10
+
11
+
12
+ _FIXTURE_KB = Path(__file__).parent.parent / "fixtures" / "kb_sample"
13
+
14
+
15
+ class TestRAGRetriever:
16
+ @pytest.fixture(scope="class")
17
+ def retriever(self, tmp_path_factory: pytest.TempPathFactory) -> RAGRetriever:
18
+ idx_dir = tmp_path_factory.mktemp("rag_idx")
19
+ ingest_directory(_FIXTURE_KB, idx_dir)
20
+ return RAGRetriever.load(idx_dir)
21
+
22
+ def test_bbb_query_returns_lipinski_chunk(self, retriever: RAGRetriever) -> None:
23
+ hits = retriever.search("Why does ethanol cross the blood-brain barrier?", k=3)
24
+ assert len(hits) == 3
25
+ sources = [h["source"] for h in hits]
26
+ assert "lipinski_rule_of_five.md" in sources
27
+ # top hit should be from lipinski
28
+ assert hits[0]["source"] == "lipinski_rule_of_five.md"
29
+
30
+ def test_combat_query_returns_combat_chunk(self, retriever: RAGRetriever) -> None:
31
+ hits = retriever.search("How does ComBat remove scanner bias from MRI data?", k=2)
32
+ assert hits[0]["source"] == "combat_harmonization_primer.md"
33
+
34
+ def test_eeg_query_returns_ica_chunk(self, retriever: RAGRetriever) -> None:
35
+ hits = retriever.search("How do you remove eye blink artifacts from EEG?", k=2)
36
+ assert hits[0]["source"] == "mne_ica_basics.md"
37
+
38
+ def test_search_includes_score_and_text(self, retriever: RAGRetriever) -> None:
39
+ hits = retriever.search("BBB permeability", k=1)
40
+ h = hits[0]
41
+ assert "text" in h
42
+ assert "source" in h
43
+ assert "score" in h
44
+ assert isinstance(h["score"], float)
45
+ assert 0.0 <= h["score"] <= 1.0
tests/rag/test_store.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for src.rag.store — FAISS vector store with metadata."""
2
+ from __future__ import annotations
3
+
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import pytest
8
+
9
+ from src.rag.store import FAISSStore
10
+
11
+
12
+ def _rand_vecs(n: int, d: int = 4, seed: int = 0) -> np.ndarray:
13
+ rng = np.random.default_rng(seed)
14
+ return rng.standard_normal((n, d), dtype=np.float32)
15
+
16
+
17
+ class TestFAISSStore:
18
+ def test_add_then_search(self) -> None:
19
+ store = FAISSStore(dim=4)
20
+ vecs = _rand_vecs(3)
21
+ chunks = [{"text": f"chunk-{i}", "source": "test.md"} for i in range(3)]
22
+ store.add(vecs, chunks)
23
+ results = store.search(vecs[0], k=2)
24
+ assert len(results) == 2
25
+ # the closest hit is the chunk we used as the query (cosine ~1.0)
26
+ top_chunk, top_score = results[0]
27
+ assert top_chunk["text"] == "chunk-0"
28
+ assert top_score > 0.99
29
+
30
+ def test_add_size_mismatch_raises(self) -> None:
31
+ store = FAISSStore(dim=4)
32
+ with pytest.raises(ValueError, match="size mismatch"):
33
+ store.add(_rand_vecs(3), [{"text": "only-one"}])
34
+
35
+ def test_search_k_larger_than_corpus(self) -> None:
36
+ store = FAISSStore(dim=4)
37
+ store.add(_rand_vecs(2), [{"text": f"c{i}"} for i in range(2)])
38
+ results = store.search(_rand_vecs(1)[0], k=10)
39
+ assert len(results) == 2
40
+
41
+ def test_save_load_roundtrip(self, tmp_path: Path) -> None:
42
+ store = FAISSStore(dim=4)
43
+ vecs = _rand_vecs(3)
44
+ chunks = [{"text": f"chunk-{i}", "source": "test.md"} for i in range(3)]
45
+ store.add(vecs, chunks)
46
+ store.save(tmp_path / "idx")
47
+
48
+ restored = FAISSStore.load(tmp_path / "idx", dim=4)
49
+ results = restored.search(vecs[0], k=1)
50
+ assert results[0][0]["text"] == "chunk-0"
51
+
52
+ def test_search_on_empty_store_returns_empty(self) -> None:
53
+ store = FAISSStore(dim=4)
54
+ assert store.search(_rand_vecs(1)[0], k=5) == []
55
+
56
+ def test_add_does_not_mutate_caller_vectors(self) -> None:
57
+ store = FAISSStore(dim=4)
58
+ vecs = _rand_vecs(3)
59
+ original = vecs.copy()
60
+ store.add(vecs, [{"text": f"c{i}"} for i in range(3)])
61
+ # Caller's array must be unchanged after add() (faiss.normalize_L2 is in-place)
62
+ assert np.allclose(vecs, original), "store.add() mutated caller's vectors"
63
+
64
+ def test_search_does_not_mutate_caller_query(self) -> None:
65
+ store = FAISSStore(dim=4)
66
+ store.add(_rand_vecs(3), [{"text": f"c{i}"} for i in range(3)])
67
+ query = _rand_vecs(1)[0]
68
+ original_query = query.copy()
69
+ store.search(query, k=2)
70
+ assert np.allclose(query, original_query), "store.search() mutated caller's query"