| """ |
| Corrective RAG Node — Graded retrieval with query rewriting. |
| |
| Design pattern: Corrective RAG (CRAG) from Yan et al. 2024 |
| 1. Retrieve top-K documents from ChromaDB |
| 2. Grade each document for relevance (binary: RELEVANT / IRRELEVANT) |
| 3. If insufficient relevant docs → rewrite query and re-retrieve |
| 4. If still insufficient after max retries → route to fallback |
| |
| Also implements parallelised evidence gathering: |
| - ChromaDB (clinical guidelines) |
| - CIViC API (genomic evidence) |
| - ClinicalTrials.gov (active trials) |
| """ |
|
|
| import logging |
| import re |
| from typing import Dict, Any, List, Optional |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
| from .state import AgentState |
| from .tools import call_tier_model |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
|
|
| _retriever_instance = None |
|
|
|
|
| def _get_retriever(): |
| """Return a cached OncoRAGRetriever instance (lazy init).""" |
| global _retriever_instance |
| if _retriever_instance is None: |
| try: |
| from rag_engine.retriever import OncoRAGRetriever |
| _retriever_instance = OncoRAGRetriever() |
| logger.info("OncoRAGRetriever initialised successfully.") |
| except Exception as exc: |
| logger.error("Failed to initialise OncoRAGRetriever: %s", exc) |
| raise |
| return _retriever_instance |
|
|
|
|
| |
| |
| |
|
|
| def _grade_document(document_text: str, query: str, tier: int = 1) -> bool: |
| """Grade a single retrieved document for relevance. |
| |
| Uses the Tier 1 (fast) model for binary classification. |
| |
| Args: |
| document_text: The document text to evaluate. |
| query: The clinical query. |
| tier: Model tier to use for grading (default: 1 for speed). |
| |
| Returns: |
| True if the document is RELEVANT, False otherwise. |
| """ |
| |
| |
| query_lower = query.lower() |
| doc_lower = document_text.lower() |
| |
| |
| ignore_terms = {"treatment", "recommendation", "guidelines", "triage", "management", "clinical", "oncology"} |
| core_terms = [t for t in query_lower.split() if len(t) > 4 and t not in ignore_terms] |
| term_match = any(term in doc_lower for term in core_terms) |
|
|
| system_prompt = ( |
| "You are an expert Oncology Clinical Analyst. " |
| "Your task is to evaluate if a retrieved medical document is RELEVANT to a patient's clinical query. " |
| "RELEVANCE CRITERIA:\n" |
| "1. The document discusses the specific cancer type or its precursors.\n" |
| "2. The document mentions treatment protocols, staging, or diagnostic criteria relevant to the case.\n" |
| "3. Synonyms are allowed (e.g., 'Uterine Cancer' is relevant to 'Endometrial Adenocarcinoma').\n\n" |
| "Output ONLY 'RELEVANT' or 'IRRELEVANT'. No explanation." |
| ) |
| user_prompt = ( |
| f"Patient Query/Context: {query}\n\n" |
| f"Document Snippet:\n--- START ---\n{document_text[:2000]}\n--- END ---\n\n" |
| f"Is this document relevant? (RELEVANT/IRRELEVANT):" |
| ) |
|
|
| try: |
| response = call_tier_model( |
| tier=tier, |
| system_prompt=system_prompt, |
| user_prompt=user_prompt, |
| max_tokens=10, |
| temperature=0.0, |
| ) |
| is_relevant = "RELEVANT" in response.upper() |
| logger.info("Doc Grade: %s (Term Match: %s) -> Query: %s...", "RELEVANT" if is_relevant else "IRRELEVANT", term_match, query[:30]) |
| |
| |
| |
| if not is_relevant and term_match: |
| logger.debug("Model rejected doc but keyword match found. Overriding to RELEVANT for recall.") |
| return True |
| |
| return is_relevant |
| except Exception as exc: |
| logger.warning("Document grading failed: %s — defaulting to RELEVANT.", exc) |
| return True |
|
|
|
|
|
|
| def _rewrite_query( |
| original_query: str, |
| entities: Dict[str, Any], |
| attempt: int, |
| ) -> str: |
| """Broaden the query for a retry attempt. |
| |
| Uses deterministic broadening rather than LLM-based rewriting |
| for speed and predictability. |
| |
| Args: |
| original_query: The query that yielded insufficient results. |
| entities: Extracted clinical entities. |
| attempt: The retry attempt number (1-indexed). |
| |
| Returns: |
| A broadened query string. |
| """ |
| cancer = entities.get("cancer_type", "Unknown") |
| stage = entities.get("stage", "Unknown") |
| mutations = entities.get("mutations", []) |
|
|
| if attempt == 1: |
| |
| parts = [cancer] |
| if mutations: |
| parts.append(f"mutations {' '.join(mutations)}") |
| parts.append("treatment guidelines evidence-based recommendations") |
| rewritten = " ".join(parts) |
| logger.info("Query rewrite attempt %d: %s → %s", attempt, original_query, rewritten) |
| return rewritten |
|
|
| |
| rewritten = f"{cancer} oncology clinical guidelines management" |
| logger.info("Query rewrite attempt %d (maximal broadening): %s", attempt, rewritten) |
| return rewritten |
|
|
|
|
| |
| |
| |
|
|
| def _fetch_api_evidence(entities: Dict[str, Any]) -> Dict[str, List[str]]: |
| """Fetch genomic and clinical trial evidence in parallel. |
| |
| Calls CIViC API and ClinicalTrials.gov concurrently for MI300X |
| throughput optimisation. |
| |
| Args: |
| entities: Extracted clinical entities. |
| |
| Returns: |
| Dict with "genomic_evidence" and "clinical_trials" lists. |
| """ |
| results: Dict[str, List[str]] = { |
| "genomic_evidence": [], |
| "clinical_trials": [], |
| } |
|
|
| mutations = entities.get("mutations", []) |
| cancer = entities.get("cancer_type", "Unknown") |
|
|
| def fetch_civic(): |
| """Fetch genomic evidence from CIViC.""" |
| try: |
| from rag_engine.api_clients import CivicAPIClient |
| client = CivicAPIClient() |
| evidence = [] |
| for mutation in mutations: |
| civic_results = client.search_variant_evidence(mutation, cancer) |
| for r in civic_results: |
| evidence.append( |
| f"[CIViC] {mutation}: {r.get('summary', 'No summary available')}" |
| ) |
| return evidence |
| except Exception as exc: |
| logger.warning("CIViC API failed: %s", exc) |
| return [] |
|
|
| def fetch_trials(): |
| """Fetch active clinical trials.""" |
| try: |
| from rag_engine.api_clients import ClinicalTrialsClient |
| client = ClinicalTrialsClient() |
| trial_results = client.search_trials(cancer, mutations) |
| return [ |
| f"[ClinicalTrials.gov] {t.get('title', 'Unknown')}: {t.get('status', '?')}" |
| for t in trial_results |
| ] |
| except Exception as exc: |
| logger.warning("ClinicalTrials.gov API failed: %s", exc) |
| return [] |
|
|
| with ThreadPoolExecutor(max_workers=2) as executor: |
| futures = { |
| executor.submit(fetch_civic): "genomic_evidence", |
| executor.submit(fetch_trials): "clinical_trials", |
| } |
| for future in as_completed(futures): |
| key = futures[future] |
| try: |
| results[key] = future.result() |
| except Exception as exc: |
| logger.error("Parallel fetch error (%s): %s", key, exc) |
|
|
| return results |
|
|
|
|
| |
| |
| |
|
|
| |
| _MIN_RELEVANT_DOCS = 2 |
| |
| _MAX_REWRITES = 1 |
|
|
|
|
| def corrective_rag_node(state: AgentState) -> Dict[str, Any]: |
| """Execute the Corrective RAG pipeline. |
| |
| Pipeline: |
| 1. Build structured query from extracted entities. |
| 2. Retrieve top-K candidates from ChromaDB. |
| 3. Grade each document for relevance. |
| 4. If insufficient relevant docs → rewrite query and retry. |
| 5. Fetch API evidence in parallel (CIViC + ClinicalTrials). |
| 6. Compute confidence metrics. |
| |
| Args: |
| state: Current LangGraph state. |
| |
| Returns: |
| State update with rag_context, sources, confidence, and metrics. |
| """ |
| entities: Dict[str, Any] = state.get("extracted_entities", {}) |
| clinical_text: str = state.get("clinical_text", "") |
| selected_tier: int = state.get("selected_tier", 1) |
|
|
| |
| cancer = entities.get("cancer_type", "Unknown") |
| stage = entities.get("stage", "Unknown") |
| mutations = ", ".join(entities.get("mutations", [])) |
|
|
| query_parts = [] |
| if cancer != "Unknown": |
| query_parts.append(cancer) |
| else: |
| |
| query_parts.append(clinical_text[:100].replace("\n", " ")) |
| |
| if stage != "Unknown": |
| query_parts.append(stage) |
| if mutations: |
| query_parts.append(f"mutations: {mutations}") |
| query_parts.append("treatment recommendation guidelines triage") |
| query = " ".join(query_parts) |
|
|
| rewrite_count = 0 |
| relevant_docs: List[Dict[str, Any]] = [] |
|
|
| try: |
| retriever = _get_retriever() |
|
|
| |
| for attempt in range(1 + _MAX_REWRITES): |
| if attempt > 0: |
| query = _rewrite_query(query, entities, attempt) |
| rewrite_count += 1 |
|
|
| |
| raw_results = retriever.query(query, n_results=8) |
|
|
| |
| from concurrent.futures import ThreadPoolExecutor |
| |
| def _grade_doc_wrapper(r): |
| doc_text = r.get("text", "") |
| is_relevant = _grade_document(doc_text, query, tier=1) |
| return r if is_relevant else None |
|
|
| with ThreadPoolExecutor(max_workers=8) as executor: |
| results = list(executor.map(_grade_doc_wrapper, raw_results)) |
| |
| graded = [r for r in results if r is not None] |
|
|
| logger.info( |
| "CRAG attempt %d: %d/%d documents graded RELEVANT (Parallel).", |
| attempt + 1, len(graded), len(raw_results), |
| ) |
|
|
| if len(graded) >= _MIN_RELEVANT_DOCS: |
| relevant_docs = graded |
| break |
|
|
| |
| context_strings = [] |
| source_strings = [] |
| for r in relevant_docs: |
| context_strings.append( |
| f"[Source: {r['source']}, Page: {r.get('page', '?')}, " |
| f"Section: {r.get('header', 'Unknown')}]\n{r['text']}" |
| ) |
| source_strings.append( |
| f"- **{r['source']}** (Page {r.get('page', '?')}): " |
| f"{r.get('header', 'Unknown')}" |
| ) |
|
|
| |
| ce_scores = [ |
| r["cross_encoder_score"] |
| for r in relevant_docs |
| if "cross_encoder_score" in r |
| ] |
| mean_confidence = sum(ce_scores) / len(ce_scores) if ce_scores else 0.0 |
|
|
| except Exception as exc: |
| logger.error("RAG retrieval failed: %s", exc) |
| context_strings = [] |
| source_strings = [] |
| relevant_docs = [] |
| mean_confidence = 0.0 |
| rewrite_count = 0 |
|
|
| |
| api_results = _fetch_api_evidence(entities) |
|
|
| return { |
| "rag_context": context_strings, |
| "rag_sources": source_strings, |
| "graph_rag_context": [], |
| "api_evidence_context": ( |
| api_results.get("genomic_evidence", []) |
| + api_results.get("clinical_trials", []) |
| ), |
| "rag_confidence": round(mean_confidence, 4), |
| "rag_retrieval_count": len(context_strings), |
| "rag_grading_pass_count": len(relevant_docs), |
| "rag_query_rewrites": rewrite_count, |
| } |
|
|