File size: 12,888 Bytes
e1624f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
"""
Corrective RAG Node — Graded retrieval with query rewriting.

Design pattern: Corrective RAG (CRAG) from Yan et al. 2024
  1. Retrieve top-K documents from ChromaDB
  2. Grade each document for relevance (binary: RELEVANT / IRRELEVANT)
  3. If insufficient relevant docs → rewrite query and re-retrieve
  4. If still insufficient after max retries → route to fallback

Also implements parallelised evidence gathering:
  - ChromaDB (clinical guidelines)
  - CIViC API (genomic evidence)
  - ClinicalTrials.gov (active trials)
"""

import logging
import re
from typing import Dict, Any, List, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed

from .state import AgentState
from .tools import call_tier_model

logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Lazy-loaded retriever singleton
# ---------------------------------------------------------------------------

_retriever_instance = None


def _get_retriever():
    """Return a cached OncoRAGRetriever instance (lazy init)."""
    global _retriever_instance
    if _retriever_instance is None:
        try:
            from rag_engine.retriever import OncoRAGRetriever
            _retriever_instance = OncoRAGRetriever()
            logger.info("OncoRAGRetriever initialised successfully.")
        except Exception as exc:
            logger.error("Failed to initialise OncoRAGRetriever: %s", exc)
            raise
    return _retriever_instance


# ---------------------------------------------------------------------------
# Document Grading (CRAG core)
# ---------------------------------------------------------------------------

def _grade_document(document_text: str, query: str, tier: int = 1) -> bool:
    """Grade a single retrieved document for relevance.

    Uses the Tier 1 (fast) model for binary classification.

    Args:
        document_text: The document text to evaluate.
        query: The clinical query.
        tier: Model tier to use for grading (default: 1 for speed).

    Returns:
        True if the document is RELEVANT, False otherwise.
    """
    # Rule-based shortcut: if the query contains a core cancer type and the document mentions it,
    # we favor relevance to avoid model hallucination/rejection.
    query_lower = query.lower()
    doc_lower = document_text.lower()
    
    # Simple semantic overlap check: ignore generic terms
    ignore_terms = {"treatment", "recommendation", "guidelines", "triage", "management", "clinical", "oncology"}
    core_terms = [t for t in query_lower.split() if len(t) > 4 and t not in ignore_terms]
    term_match = any(term in doc_lower for term in core_terms)

    system_prompt = (
        "You are an expert Oncology Clinical Analyst. "
        "Your task is to evaluate if a retrieved medical document is RELEVANT to a patient's clinical query. "
        "RELEVANCE CRITERIA:\n"
        "1. The document discusses the specific cancer type or its precursors.\n"
        "2. The document mentions treatment protocols, staging, or diagnostic criteria relevant to the case.\n"
        "3. Synonyms are allowed (e.g., 'Uterine Cancer' is relevant to 'Endometrial Adenocarcinoma').\n\n"
        "Output ONLY 'RELEVANT' or 'IRRELEVANT'. No explanation."
    )
    user_prompt = (
        f"Patient Query/Context: {query}\n\n"
        f"Document Snippet:\n--- START ---\n{document_text[:2000]}\n--- END ---\n\n"
        f"Is this document relevant? (RELEVANT/IRRELEVANT):"
    )

    try:
        response = call_tier_model(
            tier=tier,
            system_prompt=system_prompt,
            user_prompt=user_prompt,
            max_tokens=10,
            temperature=0.0,
        )
        is_relevant = "RELEVANT" in response.upper()
        logger.info("Doc Grade: %s (Term Match: %s) -> Query: %s...", "RELEVANT" if is_relevant else "IRRELEVANT", term_match, query[:30])
        
        # Boost: if model says IRRELEVANT but there is a strong term match, we might want to override.
        # This ensures we don't drop guidelines that explicitly mention the cancer type.
        if not is_relevant and term_match:
            logger.debug("Model rejected doc but keyword match found. Overriding to RELEVANT for recall.")
            return True
            
        return is_relevant
    except Exception as exc:
        logger.warning("Document grading failed: %s — defaulting to RELEVANT.", exc)
        return True  # Fail open: include document if grading fails



def _rewrite_query(
    original_query: str,
    entities: Dict[str, Any],
    attempt: int,
) -> str:
    """Broaden the query for a retry attempt.

    Uses deterministic broadening rather than LLM-based rewriting
    for speed and predictability.

    Args:
        original_query: The query that yielded insufficient results.
        entities: Extracted clinical entities.
        attempt: The retry attempt number (1-indexed).

    Returns:
        A broadened query string.
    """
    cancer = entities.get("cancer_type", "Unknown")
    stage = entities.get("stage", "Unknown")
    mutations = entities.get("mutations", [])

    if attempt == 1:
        # Broadening strategy: remove stage specificity, keep cancer + mutations
        parts = [cancer]
        if mutations:
            parts.append(f"mutations {' '.join(mutations)}")
        parts.append("treatment guidelines evidence-based recommendations")
        rewritten = " ".join(parts)
        logger.info("Query rewrite attempt %d: %s → %s", attempt, original_query, rewritten)
        return rewritten

    # Attempt 2+: maximally broad
    rewritten = f"{cancer} oncology clinical guidelines management"
    logger.info("Query rewrite attempt %d (maximal broadening): %s", attempt, rewritten)
    return rewritten


# ---------------------------------------------------------------------------
# Parallel Evidence Gathering
# ---------------------------------------------------------------------------

def _fetch_api_evidence(entities: Dict[str, Any]) -> Dict[str, List[str]]:
    """Fetch genomic and clinical trial evidence in parallel.

    Calls CIViC API and ClinicalTrials.gov concurrently for MI300X
    throughput optimisation.

    Args:
        entities: Extracted clinical entities.

    Returns:
        Dict with "genomic_evidence" and "clinical_trials" lists.
    """
    results: Dict[str, List[str]] = {
        "genomic_evidence": [],
        "clinical_trials": [],
    }

    mutations = entities.get("mutations", [])
    cancer = entities.get("cancer_type", "Unknown")

    def fetch_civic():
        """Fetch genomic evidence from CIViC."""
        try:
            from rag_engine.api_clients import CivicAPIClient
            client = CivicAPIClient()
            evidence = []
            for mutation in mutations:
                civic_results = client.search_variant_evidence(mutation, cancer)
                for r in civic_results:
                    evidence.append(
                        f"[CIViC] {mutation}: {r.get('summary', 'No summary available')}"
                    )
            return evidence
        except Exception as exc:
            logger.warning("CIViC API failed: %s", exc)
            return []

    def fetch_trials():
        """Fetch active clinical trials."""
        try:
            from rag_engine.api_clients import ClinicalTrialsClient
            client = ClinicalTrialsClient()
            trial_results = client.search_trials(cancer, mutations)
            return [
                f"[ClinicalTrials.gov] {t.get('title', 'Unknown')}: {t.get('status', '?')}"
                for t in trial_results
            ]
        except Exception as exc:
            logger.warning("ClinicalTrials.gov API failed: %s", exc)
            return []

    with ThreadPoolExecutor(max_workers=2) as executor:
        futures = {
            executor.submit(fetch_civic): "genomic_evidence",
            executor.submit(fetch_trials): "clinical_trials",
        }
        for future in as_completed(futures):
            key = futures[future]
            try:
                results[key] = future.result()
            except Exception as exc:
                logger.error("Parallel fetch error (%s): %s", key, exc)

    return results


# ---------------------------------------------------------------------------
# Corrective RAG Node
# ---------------------------------------------------------------------------

# Minimum relevant documents required to proceed
_MIN_RELEVANT_DOCS = 2
# Maximum query rewrite attempts
_MAX_REWRITES = 1


def corrective_rag_node(state: AgentState) -> Dict[str, Any]:
    """Execute the Corrective RAG pipeline.

    Pipeline:
      1. Build structured query from extracted entities.
      2. Retrieve top-K candidates from ChromaDB.
      3. Grade each document for relevance.
      4. If insufficient relevant docs → rewrite query and retry.
      5. Fetch API evidence in parallel (CIViC + ClinicalTrials).
      6. Compute confidence metrics.

    Args:
        state: Current LangGraph state.

    Returns:
        State update with rag_context, sources, confidence, and metrics.
    """
    entities: Dict[str, Any] = state.get("extracted_entities", {})
    clinical_text: str = state.get("clinical_text", "")
    selected_tier: int = state.get("selected_tier", 1)

    # --- Build initial query ---
    cancer = entities.get("cancer_type", "Unknown")
    stage = entities.get("stage", "Unknown")
    mutations = ", ".join(entities.get("mutations", []))

    query_parts = []
    if cancer != "Unknown":
        query_parts.append(cancer)
    else:
        # Fallback: use first 100 chars of clinical text for vector search
        query_parts.append(clinical_text[:100].replace("\n", " "))
    
    if stage != "Unknown":
        query_parts.append(stage)
    if mutations:
        query_parts.append(f"mutations: {mutations}")
    query_parts.append("treatment recommendation guidelines triage")
    query = " ".join(query_parts)

    rewrite_count = 0
    relevant_docs: List[Dict[str, Any]] = []

    try:
        retriever = _get_retriever()

        # --- Retrieve + Grade loop ---
        for attempt in range(1 + _MAX_REWRITES):
            if attempt > 0:
                query = _rewrite_query(query, entities, attempt)
                rewrite_count += 1

            # Retrieve candidates
            raw_results = retriever.query(query, n_results=8)

            # Grade documents in parallel for MI300X/API efficiency
            from concurrent.futures import ThreadPoolExecutor
            
            def _grade_doc_wrapper(r):
                doc_text = r.get("text", "")
                is_relevant = _grade_document(doc_text, query, tier=1)
                return r if is_relevant else None

            with ThreadPoolExecutor(max_workers=8) as executor:
                results = list(executor.map(_grade_doc_wrapper, raw_results))
            
            graded = [r for r in results if r is not None]

            logger.info(
                "CRAG attempt %d: %d/%d documents graded RELEVANT (Parallel).",
                attempt + 1, len(graded), len(raw_results),
            )

            if len(graded) >= _MIN_RELEVANT_DOCS:
                relevant_docs = graded
                break

        # --- Format results ---
        context_strings = []
        source_strings = []
        for r in relevant_docs:
            context_strings.append(
                f"[Source: {r['source']}, Page: {r.get('page', '?')}, "
                f"Section: {r.get('header', 'Unknown')}]\n{r['text']}"
            )
            source_strings.append(
                f"- **{r['source']}** (Page {r.get('page', '?')}): "
                f"{r.get('header', 'Unknown')}"
            )

        # --- Confidence metrics ---
        ce_scores = [
            r["cross_encoder_score"]
            for r in relevant_docs
            if "cross_encoder_score" in r
        ]
        mean_confidence = sum(ce_scores) / len(ce_scores) if ce_scores else 0.0

    except Exception as exc:
        logger.error("RAG retrieval failed: %s", exc)
        context_strings = []
        source_strings = []
        relevant_docs = []
        mean_confidence = 0.0
        rewrite_count = 0

    # --- Parallel API evidence ---
    api_results = _fetch_api_evidence(entities)

    return {
        "rag_context": context_strings,
        "rag_sources": source_strings,
        "graph_rag_context": [],  # Future: knowledge graph integration
        "api_evidence_context": (
            api_results.get("genomic_evidence", [])
            + api_results.get("clinical_trials", [])
        ),
        "rag_confidence": round(mean_confidence, 4),
        "rag_retrieval_count": len(context_strings),
        "rag_grading_pass_count": len(relevant_docs),
        "rag_query_rewrites": rewrite_count,
    }