mekosotto commited on
Commit
8eff23e
·
1 Parent(s): 6b2c154

feat(rag): TF-IDF clinical retrieval with Turkish/English query expansion

Browse files
src/rag/clinical/retrieve.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TF-IDF retrieval over the clinical-paper corpus, with Turkish→English query expansion."""
2
+ from __future__ import annotations
3
+
4
+ import re
5
+ from textwrap import shorten
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+
10
+ from src.core.logger import get_logger
11
+ from src.rag.clinical.types import ClinicalEvidence, ClinicalRetrievalResult
12
+
13
+ logger = get_logger(__name__)
14
+
15
+ # Mirrors the table in /Users/mertgungor/Downloads/rag/rag.py so the same
16
+ # Turkish keyword set produces the same expansion in both pipelines.
17
+ _QUERY_EXPANSIONS: dict[str, str] = {
18
+ "alzheimer": "alzheimer dementia cognitive impairment mild cognitive impairment mci memory",
19
+ "demans": "dementia alzheimer cognitive impairment memory cognition",
20
+ "unutkanlik": "memory impairment cognitive decline dementia alzheimer",
21
+ "parkinson": "parkinson disease movement disorder tremor motor symptoms non motor symptoms",
22
+ "titreme": "tremor parkinson motor symptoms movement disorder",
23
+ "egzersiz": "exercise physical activity training aerobic resistance cognition",
24
+ "beslenme": "nutrition diet lifestyle metabolic risk factors",
25
+ "risk": "risk factors lifestyle metabolic nutrition prevention",
26
+ "tani": "diagnosis diagnostic criteria assessment screening",
27
+ "tedavi": "treatment management therapy intervention",
28
+ }
29
+
30
+
31
+ def _expand_query(query: str) -> str:
32
+ normalized = query.casefold()
33
+ extras = [exp for key, exp in _QUERY_EXPANSIONS.items() if key in normalized]
34
+ return f"{query} {' '.join(extras)}" if extras else query
35
+
36
+
37
+ def _split_sentences(text: str) -> list[str]:
38
+ sentences = re.split(r"(?<=[.!?])\s+", text)
39
+ return [s.strip() for s in sentences if len(s.split()) >= 6]
40
+
41
+
42
+ def _query_terms(expanded: str) -> set[str]:
43
+ return {t for t in re.findall(r"[A-Za-z0-9]+", expanded.lower()) if len(t) >= 4}
44
+
45
+
46
+ def retrieve_clinical(
47
+ payload: dict[str, Any],
48
+ query: str,
49
+ top_k: int = 5,
50
+ evidence_limit: int = 5,
51
+ ) -> ClinicalRetrievalResult:
52
+ """Run TF-IDF search over the clinical corpus, return evidence + a feedback summary."""
53
+ if not query.strip():
54
+ return ClinicalRetrievalResult(query=query, evidence=[], summary_text="")
55
+
56
+ vectorizer = payload["vectorizer"]
57
+ matrix = payload["matrix"]
58
+ chunks = payload["chunks"]
59
+
60
+ expanded = _expand_query(query)
61
+ qv = vectorizer.transform([expanded])
62
+ scores = (matrix @ qv.T).toarray().ravel()
63
+ if not np.any(scores):
64
+ return ClinicalRetrievalResult(query=query, evidence=[], summary_text="")
65
+
66
+ top_indices = np.argsort(scores)[::-1][:top_k]
67
+ top_chunks = [(chunks[int(i)], float(scores[int(i)])) for i in top_indices if scores[int(i)] > 0]
68
+
69
+ terms = _query_terms(expanded)
70
+ candidates: list[tuple[float, str, Any, float]] = []
71
+ for chunk, chunk_score in top_chunks:
72
+ for sentence in _split_sentences(chunk.text):
73
+ sent_terms = set(re.findall(r"[A-Za-z0-9]+", sentence.lower()))
74
+ overlap = len(terms & sent_terms)
75
+ if overlap == 0:
76
+ continue
77
+ candidates.append((overlap + chunk_score, sentence, chunk, chunk_score))
78
+
79
+ candidates.sort(key=lambda item: item[0], reverse=True)
80
+ seen: set[str] = set()
81
+ evidence: list[ClinicalEvidence] = []
82
+ for _, sent, chunk, sc in candidates:
83
+ fp = sent[:120].lower()
84
+ if fp in seen:
85
+ continue
86
+ seen.add(fp)
87
+ evidence.append(ClinicalEvidence(
88
+ sentence=shorten(sent, width=420, placeholder="..."),
89
+ source=chunk.source,
90
+ page_start=chunk.page_start,
91
+ page_end=chunk.page_end,
92
+ score=sc,
93
+ ))
94
+ if len(evidence) >= evidence_limit:
95
+ break
96
+
97
+ if not evidence:
98
+ for chunk, sc in top_chunks[:evidence_limit]:
99
+ evidence.append(ClinicalEvidence(
100
+ sentence=shorten(chunk.text, width=420, placeholder="..."),
101
+ source=chunk.source,
102
+ page_start=chunk.page_start,
103
+ page_end=chunk.page_end,
104
+ score=sc,
105
+ ))
106
+
107
+ lines = ["Clinical RAG evidence (not a medical diagnosis):"]
108
+ for ev in evidence:
109
+ page = (
110
+ f"p.{ev.page_start}" if ev.page_start == ev.page_end
111
+ else f"pp.{ev.page_start}-{ev.page_end}"
112
+ )
113
+ lines.append(f"- {ev.sentence} [{ev.source}, {page} | score={ev.score:.3f}]")
114
+ summary = "\n".join(lines)
115
+
116
+ return ClinicalRetrievalResult(query=query, evidence=evidence, summary_text=summary)
tests/rag/test_clinical_retrieve.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for src.rag.clinical.retrieve."""
2
+ from __future__ import annotations
3
+
4
+ from pathlib import Path
5
+
6
+ import pytest
7
+
8
+ from src.rag.clinical.retrieve import retrieve_clinical
9
+ from src.rag.clinical.loader import load_index
10
+ from tests.fixtures.build_tiny_clinical_index import build as build_tiny
11
+
12
+
13
+ class TestRetrieve:
14
+ def test_alzheimer_query_picks_alzheimer_chunks(self, tmp_path: Path) -> None:
15
+ payload = load_index(build_tiny(tmp_path / "tiny.pkl"))
16
+ result = retrieve_clinical(payload, query="exercise and Alzheimer's", top_k=2)
17
+ sources = {ev.source for ev in result.evidence}
18
+ assert any("alzheimers" in s for s in sources)
19
+
20
+ def test_parkinson_query_picks_parkinson_chunks(self, tmp_path: Path) -> None:
21
+ payload = load_index(build_tiny(tmp_path / "tiny.pkl"))
22
+ result = retrieve_clinical(payload, query="Parkinson levodopa", top_k=2)
23
+ sources = {ev.source for ev in result.evidence}
24
+ assert any("parkinsons" in s for s in sources)
25
+
26
+ def test_turkish_keyword_routes_via_expansion(self, tmp_path: Path) -> None:
27
+ payload = load_index(build_tiny(tmp_path / "tiny.pkl"))
28
+ result = retrieve_clinical(payload, query="egzersiz Alzheimer", top_k=2)
29
+ assert any("alzheimers_lifestyle" in ev.source for ev in result.evidence)
30
+
31
+ def test_summary_text_contains_citations(self, tmp_path: Path) -> None:
32
+ payload = load_index(build_tiny(tmp_path / "tiny.pkl"))
33
+ result = retrieve_clinical(payload, query="diet and Parkinson", top_k=2)
34
+ assert any(ev.source in result.summary_text for ev in result.evidence)
35
+
36
+ def test_empty_query_returns_empty_evidence(self, tmp_path: Path) -> None:
37
+ payload = load_index(build_tiny(tmp_path / "tiny.pkl"))
38
+ result = retrieve_clinical(payload, query="", top_k=2)
39
+ assert result.evidence == []