File size: 4,491 Bytes
8eff23e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""TF-IDF retrieval over the clinical-paper corpus, with Turkish→English query expansion."""
from __future__ import annotations

import re
from textwrap import shorten
from typing import Any

import numpy as np

from src.core.logger import get_logger
from src.rag.clinical.types import ClinicalEvidence, ClinicalRetrievalResult

logger = get_logger(__name__)

# Mirrors the table in /Users/mertgungor/Downloads/rag/rag.py so the same
# Turkish keyword set produces the same expansion in both pipelines.
_QUERY_EXPANSIONS: dict[str, str] = {
    "alzheimer": "alzheimer dementia cognitive impairment mild cognitive impairment mci memory",
    "demans": "dementia alzheimer cognitive impairment memory cognition",
    "unutkanlik": "memory impairment cognitive decline dementia alzheimer",
    "parkinson": "parkinson disease movement disorder tremor motor symptoms non motor symptoms",
    "titreme": "tremor parkinson motor symptoms movement disorder",
    "egzersiz": "exercise physical activity training aerobic resistance cognition",
    "beslenme": "nutrition diet lifestyle metabolic risk factors",
    "risk": "risk factors lifestyle metabolic nutrition prevention",
    "tani": "diagnosis diagnostic criteria assessment screening",
    "tedavi": "treatment management therapy intervention",
}


def _expand_query(query: str) -> str:
    normalized = query.casefold()
    extras = [exp for key, exp in _QUERY_EXPANSIONS.items() if key in normalized]
    return f"{query} {' '.join(extras)}" if extras else query


def _split_sentences(text: str) -> list[str]:
    sentences = re.split(r"(?<=[.!?])\s+", text)
    return [s.strip() for s in sentences if len(s.split()) >= 6]


def _query_terms(expanded: str) -> set[str]:
    return {t for t in re.findall(r"[A-Za-z0-9]+", expanded.lower()) if len(t) >= 4}


def retrieve_clinical(
    payload: dict[str, Any],
    query: str,
    top_k: int = 5,
    evidence_limit: int = 5,
) -> ClinicalRetrievalResult:
    """Run TF-IDF search over the clinical corpus, return evidence + a feedback summary."""
    if not query.strip():
        return ClinicalRetrievalResult(query=query, evidence=[], summary_text="")

    vectorizer = payload["vectorizer"]
    matrix = payload["matrix"]
    chunks = payload["chunks"]

    expanded = _expand_query(query)
    qv = vectorizer.transform([expanded])
    scores = (matrix @ qv.T).toarray().ravel()
    if not np.any(scores):
        return ClinicalRetrievalResult(query=query, evidence=[], summary_text="")

    top_indices = np.argsort(scores)[::-1][:top_k]
    top_chunks = [(chunks[int(i)], float(scores[int(i)])) for i in top_indices if scores[int(i)] > 0]

    terms = _query_terms(expanded)
    candidates: list[tuple[float, str, Any, float]] = []
    for chunk, chunk_score in top_chunks:
        for sentence in _split_sentences(chunk.text):
            sent_terms = set(re.findall(r"[A-Za-z0-9]+", sentence.lower()))
            overlap = len(terms & sent_terms)
            if overlap == 0:
                continue
            candidates.append((overlap + chunk_score, sentence, chunk, chunk_score))

    candidates.sort(key=lambda item: item[0], reverse=True)
    seen: set[str] = set()
    evidence: list[ClinicalEvidence] = []
    for _, sent, chunk, sc in candidates:
        fp = sent[:120].lower()
        if fp in seen:
            continue
        seen.add(fp)
        evidence.append(ClinicalEvidence(
            sentence=shorten(sent, width=420, placeholder="..."),
            source=chunk.source,
            page_start=chunk.page_start,
            page_end=chunk.page_end,
            score=sc,
        ))
        if len(evidence) >= evidence_limit:
            break

    if not evidence:
        for chunk, sc in top_chunks[:evidence_limit]:
            evidence.append(ClinicalEvidence(
                sentence=shorten(chunk.text, width=420, placeholder="..."),
                source=chunk.source,
                page_start=chunk.page_start,
                page_end=chunk.page_end,
                score=sc,
            ))

    lines = ["Clinical RAG evidence (not a medical diagnosis):"]
    for ev in evidence:
        page = (
            f"p.{ev.page_start}" if ev.page_start == ev.page_end
            else f"pp.{ev.page_start}-{ev.page_end}"
        )
        lines.append(f"- {ev.sentence} [{ev.source}, {page} | score={ev.score:.3f}]")
    summary = "\n".join(lines)

    return ClinicalRetrievalResult(query=query, evidence=evidence, summary_text=summary)