mekosotto commited on
Commit
91dde0d
·
1 Parent(s): 8eff23e

feat(agents): retrieve_context corpus dispatch (reference vs clinical)

Browse files
src/agents/schemas.py CHANGED
@@ -6,7 +6,7 @@ names lowercase + snake_case so prompts and JSON outputs align.
6
  """
7
  from __future__ import annotations
8
 
9
- from typing import Any
10
 
11
  from pydantic import BaseModel, Field
12
 
@@ -38,6 +38,14 @@ class RetrieveContextInput(BaseModel):
38
  """Input for `retrieve_context` — natural-language query into the KB."""
39
  query: str = Field(..., min_length=2, description="Search query for the knowledge base")
40
  k: int = Field(4, ge=1, le=10, description="Number of chunks to return")
 
 
 
 
 
 
 
 
41
 
42
 
43
  # --- Pipeline tool outputs --------------------------------------------------
 
6
  """
7
  from __future__ import annotations
8
 
9
+ from typing import Any, Literal
10
 
11
  from pydantic import BaseModel, Field
12
 
 
38
  """Input for `retrieve_context` — natural-language query into the KB."""
39
  query: str = Field(..., min_length=2, description="Search query for the knowledge base")
40
  k: int = Field(4, ge=1, le=10, description="Number of chunks to return")
41
+ corpus: Literal["reference", "clinical"] = Field(
42
+ "reference",
43
+ description=(
44
+ "Which corpus to query. 'reference' = curated FAISS index (default). "
45
+ "'clinical' = TF-IDF index over peer-reviewed Alzheimer's/Parkinson's "
46
+ "papers with Turkish+English query expansion."
47
+ ),
48
+ )
49
 
50
 
51
  # --- Pipeline tool outputs --------------------------------------------------
src/agents/tools.py CHANGED
@@ -157,11 +157,41 @@ def _make_mri_executor(processed_dir: Path) -> Callable[[MRIPipelineInput], MRIP
157
  return execute
158
 
159
 
160
- def _make_retrieve_executor(rag_index_dir: Path | None) -> Callable[[RetrieveContextInput], RetrieveContextOutput]:
161
- """Closure: capture the index dir; lazy-load the retriever on first call."""
162
- state: dict[str, Any] = {"retriever": None}
 
 
 
163
 
164
  def execute(inp: RetrieveContextInput) -> RetrieveContextOutput:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  if rag_index_dir is None or not (rag_index_dir / "index.bin").exists():
166
  return RetrieveContextOutput(query=inp.query, chunks=[])
167
  if state["retriever"] is None:
@@ -176,6 +206,7 @@ def _make_retrieve_executor(rag_index_dir: Path | None) -> Callable[[RetrieveCon
176
  def build_default_tools(
177
  rag_index_dir: Path | None,
178
  processed_dir: Path = Path("data/processed"),
 
179
  ) -> list[Tool]:
180
  """Return the 5 tools the orchestrator gets by default."""
181
  return [
@@ -217,15 +248,16 @@ def build_default_tools(
217
  Tool(
218
  name="retrieve_context",
219
  description=(
220
- "Retrieve up to k passages from the curated reference knowledge "
221
- "base. Use AFTER a pipeline tool returns, to ground your final "
222
- "synthesis in cited literature. Formulate a focused query "
223
- "based on the pipeline output (e.g., 'BBB permeability of "
224
- "small lipophilic molecules' or 'ComBat site harmonization')."
 
225
  ),
226
  input_model=RetrieveContextInput,
227
  output_model=RetrieveContextOutput,
228
- execute=_make_retrieve_executor(rag_index_dir),
229
  ),
230
  Tool(
231
  name="run_fusion",
 
157
  return execute
158
 
159
 
160
+ def _make_retrieve_executor(
161
+ rag_index_dir: Path | None,
162
+ clinical_rag_index_path: Path | None = None,
163
+ ) -> Callable[[RetrieveContextInput], RetrieveContextOutput]:
164
+ """Closure: capture both index sources; lazy-load each on first use."""
165
+ state: dict[str, Any] = {"retriever": None, "clinical_payload": None}
166
 
167
  def execute(inp: RetrieveContextInput) -> RetrieveContextOutput:
168
+ if inp.corpus == "clinical":
169
+ if clinical_rag_index_path is None or not Path(clinical_rag_index_path).exists():
170
+ logger.warning(
171
+ "retrieve_context corpus=clinical but no index path configured (path=%s)",
172
+ clinical_rag_index_path,
173
+ )
174
+ return RetrieveContextOutput(query=inp.query, chunks=[])
175
+ if state["clinical_payload"] is None:
176
+ from src.rag.clinical.loader import load_index
177
+ state["clinical_payload"] = load_index(Path(clinical_rag_index_path))
178
+ from src.rag.clinical.retrieve import retrieve_clinical
179
+ result = retrieve_clinical(state["clinical_payload"], inp.query, top_k=inp.k)
180
+ return RetrieveContextOutput(
181
+ query=inp.query,
182
+ chunks=[
183
+ {
184
+ "source": ev.source,
185
+ "page_start": ev.page_start,
186
+ "page_end": ev.page_end,
187
+ "text": ev.sentence,
188
+ "score": ev.score,
189
+ }
190
+ for ev in result.evidence
191
+ ],
192
+ )
193
+
194
+ # corpus == "reference" — existing FAISS path.
195
  if rag_index_dir is None or not (rag_index_dir / "index.bin").exists():
196
  return RetrieveContextOutput(query=inp.query, chunks=[])
197
  if state["retriever"] is None:
 
206
  def build_default_tools(
207
  rag_index_dir: Path | None,
208
  processed_dir: Path = Path("data/processed"),
209
+ clinical_rag_index_path: Path | None = None,
210
  ) -> list[Tool]:
211
  """Return the 5 tools the orchestrator gets by default."""
212
  return [
 
248
  Tool(
249
  name="retrieve_context",
250
  description=(
251
+ "Retrieve up to k passages from a knowledge base. corpus='clinical' "
252
+ "queries the peer-reviewed Alzheimer's/Parkinson's papers (TF-IDF, "
253
+ "supports Turkish keywords like 'egzersiz', 'beslenme', 'unutkanlik'); "
254
+ "default corpus='reference' queries the curated FAISS index. Use "
255
+ "AFTER a pipeline tool returns, to ground your final synthesis in "
256
+ "cited literature."
257
  ),
258
  input_model=RetrieveContextInput,
259
  output_model=RetrieveContextOutput,
260
+ execute=_make_retrieve_executor(rag_index_dir, clinical_rag_index_path),
261
  ),
262
  Tool(
263
  name="run_fusion",
src/api/routes.py CHANGED
@@ -616,7 +616,14 @@ def _build_orchestrator():
616
  timeout=30.0,
617
  )
618
  rag_dir = _DEFAULT_RAG_INDEX_DIR if _DEFAULT_RAG_INDEX_DIR.exists() else None
619
- tools = build_default_tools(rag_index_dir=rag_dir)
 
 
 
 
 
 
 
620
  model = os.environ.get(_AGENT_MODEL_ENV, _AGENT_DEFAULT_MODEL)
621
  return Orchestrator(
622
  llm_client=client,
 
616
  timeout=30.0,
617
  )
618
  rag_dir = _DEFAULT_RAG_INDEX_DIR if _DEFAULT_RAG_INDEX_DIR.exists() else None
619
+ clinical_idx = Path(os.environ.get(
620
+ "CLINICAL_RAG_INDEX_PATH",
621
+ "data/external_rag/index/rag_index.pkl",
622
+ ))
623
+ tools = build_default_tools(
624
+ rag_index_dir=rag_dir,
625
+ clinical_rag_index_path=clinical_idx if clinical_idx.exists() else None,
626
+ )
627
  model = os.environ.get(_AGENT_MODEL_ENV, _AGENT_DEFAULT_MODEL)
628
  return Orchestrator(
629
  llm_client=client,
tests/agents/test_tools_clinical_corpus.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests: retrieve_context tool dispatches by `corpus`."""
2
+ from __future__ import annotations
3
+
4
+ from pathlib import Path
5
+
6
+ from src.agents.tools import build_default_tools
7
+ from tests.fixtures.build_tiny_clinical_index import build as build_tiny
8
+
9
+
10
+ class TestClinicalCorpus:
11
+ def test_default_corpus_is_reference(self, tmp_path: Path) -> None:
12
+ clinical_idx = build_tiny(tmp_path / "tiny.pkl")
13
+ tools = {t.name: t for t in build_default_tools(
14
+ rag_index_dir=None,
15
+ clinical_rag_index_path=clinical_idx,
16
+ )}
17
+ tool = tools["retrieve_context"]
18
+ out = tool.execute(tool.input_model.model_validate({"query": "test query"}))
19
+ assert hasattr(out, "chunks")
20
+ # rag_index_dir=None means reference returns empty.
21
+ assert out.chunks == []
22
+
23
+ def test_clinical_corpus_returns_evidence(self, tmp_path: Path) -> None:
24
+ clinical_idx = build_tiny(tmp_path / "tiny.pkl")
25
+ tools = {t.name: t for t in build_default_tools(
26
+ rag_index_dir=None,
27
+ clinical_rag_index_path=clinical_idx,
28
+ )}
29
+ tool = tools["retrieve_context"]
30
+ out = tool.execute(tool.input_model.model_validate({
31
+ "query": "exercise and Alzheimer",
32
+ "corpus": "clinical",
33
+ }))
34
+ assert len(out.chunks) > 0
35
+ for c in out.chunks:
36
+ assert "source" in c and "text" in c
37
+
38
+ def test_clinical_corpus_without_index_returns_empty(self, tmp_path: Path) -> None:
39
+ # No clinical index path configured.
40
+ tools = {t.name: t for t in build_default_tools(
41
+ rag_index_dir=None,
42
+ clinical_rag_index_path=None,
43
+ )}
44
+ tool = tools["retrieve_context"]
45
+ out = tool.execute(tool.input_model.model_validate({
46
+ "query": "egzersiz Alzheimer",
47
+ "corpus": "clinical",
48
+ }))
49
+ assert out.chunks == []