File size: 12,888 Bytes
e1624f5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 | """
Corrective RAG Node — Graded retrieval with query rewriting.
Design pattern: Corrective RAG (CRAG) from Yan et al. 2024
1. Retrieve top-K documents from ChromaDB
2. Grade each document for relevance (binary: RELEVANT / IRRELEVANT)
3. If insufficient relevant docs → rewrite query and re-retrieve
4. If still insufficient after max retries → route to fallback
Also implements parallelised evidence gathering:
- ChromaDB (clinical guidelines)
- CIViC API (genomic evidence)
- ClinicalTrials.gov (active trials)
"""
import logging
import re
from typing import Dict, Any, List, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
from .state import AgentState
from .tools import call_tier_model
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Lazy-loaded retriever singleton
# ---------------------------------------------------------------------------
_retriever_instance = None
def _get_retriever():
"""Return a cached OncoRAGRetriever instance (lazy init)."""
global _retriever_instance
if _retriever_instance is None:
try:
from rag_engine.retriever import OncoRAGRetriever
_retriever_instance = OncoRAGRetriever()
logger.info("OncoRAGRetriever initialised successfully.")
except Exception as exc:
logger.error("Failed to initialise OncoRAGRetriever: %s", exc)
raise
return _retriever_instance
# ---------------------------------------------------------------------------
# Document Grading (CRAG core)
# ---------------------------------------------------------------------------
def _grade_document(document_text: str, query: str, tier: int = 1) -> bool:
"""Grade a single retrieved document for relevance.
Uses the Tier 1 (fast) model for binary classification.
Args:
document_text: The document text to evaluate.
query: The clinical query.
tier: Model tier to use for grading (default: 1 for speed).
Returns:
True if the document is RELEVANT, False otherwise.
"""
# Rule-based shortcut: if the query contains a core cancer type and the document mentions it,
# we favor relevance to avoid model hallucination/rejection.
query_lower = query.lower()
doc_lower = document_text.lower()
# Simple semantic overlap check: ignore generic terms
ignore_terms = {"treatment", "recommendation", "guidelines", "triage", "management", "clinical", "oncology"}
core_terms = [t for t in query_lower.split() if len(t) > 4 and t not in ignore_terms]
term_match = any(term in doc_lower for term in core_terms)
system_prompt = (
"You are an expert Oncology Clinical Analyst. "
"Your task is to evaluate if a retrieved medical document is RELEVANT to a patient's clinical query. "
"RELEVANCE CRITERIA:\n"
"1. The document discusses the specific cancer type or its precursors.\n"
"2. The document mentions treatment protocols, staging, or diagnostic criteria relevant to the case.\n"
"3. Synonyms are allowed (e.g., 'Uterine Cancer' is relevant to 'Endometrial Adenocarcinoma').\n\n"
"Output ONLY 'RELEVANT' or 'IRRELEVANT'. No explanation."
)
user_prompt = (
f"Patient Query/Context: {query}\n\n"
f"Document Snippet:\n--- START ---\n{document_text[:2000]}\n--- END ---\n\n"
f"Is this document relevant? (RELEVANT/IRRELEVANT):"
)
try:
response = call_tier_model(
tier=tier,
system_prompt=system_prompt,
user_prompt=user_prompt,
max_tokens=10,
temperature=0.0,
)
is_relevant = "RELEVANT" in response.upper()
logger.info("Doc Grade: %s (Term Match: %s) -> Query: %s...", "RELEVANT" if is_relevant else "IRRELEVANT", term_match, query[:30])
# Boost: if model says IRRELEVANT but there is a strong term match, we might want to override.
# This ensures we don't drop guidelines that explicitly mention the cancer type.
if not is_relevant and term_match:
logger.debug("Model rejected doc but keyword match found. Overriding to RELEVANT for recall.")
return True
return is_relevant
except Exception as exc:
logger.warning("Document grading failed: %s — defaulting to RELEVANT.", exc)
return True # Fail open: include document if grading fails
def _rewrite_query(
original_query: str,
entities: Dict[str, Any],
attempt: int,
) -> str:
"""Broaden the query for a retry attempt.
Uses deterministic broadening rather than LLM-based rewriting
for speed and predictability.
Args:
original_query: The query that yielded insufficient results.
entities: Extracted clinical entities.
attempt: The retry attempt number (1-indexed).
Returns:
A broadened query string.
"""
cancer = entities.get("cancer_type", "Unknown")
stage = entities.get("stage", "Unknown")
mutations = entities.get("mutations", [])
if attempt == 1:
# Broadening strategy: remove stage specificity, keep cancer + mutations
parts = [cancer]
if mutations:
parts.append(f"mutations {' '.join(mutations)}")
parts.append("treatment guidelines evidence-based recommendations")
rewritten = " ".join(parts)
logger.info("Query rewrite attempt %d: %s → %s", attempt, original_query, rewritten)
return rewritten
# Attempt 2+: maximally broad
rewritten = f"{cancer} oncology clinical guidelines management"
logger.info("Query rewrite attempt %d (maximal broadening): %s", attempt, rewritten)
return rewritten
# ---------------------------------------------------------------------------
# Parallel Evidence Gathering
# ---------------------------------------------------------------------------
def _fetch_api_evidence(entities: Dict[str, Any]) -> Dict[str, List[str]]:
"""Fetch genomic and clinical trial evidence in parallel.
Calls CIViC API and ClinicalTrials.gov concurrently for MI300X
throughput optimisation.
Args:
entities: Extracted clinical entities.
Returns:
Dict with "genomic_evidence" and "clinical_trials" lists.
"""
results: Dict[str, List[str]] = {
"genomic_evidence": [],
"clinical_trials": [],
}
mutations = entities.get("mutations", [])
cancer = entities.get("cancer_type", "Unknown")
def fetch_civic():
"""Fetch genomic evidence from CIViC."""
try:
from rag_engine.api_clients import CivicAPIClient
client = CivicAPIClient()
evidence = []
for mutation in mutations:
civic_results = client.search_variant_evidence(mutation, cancer)
for r in civic_results:
evidence.append(
f"[CIViC] {mutation}: {r.get('summary', 'No summary available')}"
)
return evidence
except Exception as exc:
logger.warning("CIViC API failed: %s", exc)
return []
def fetch_trials():
"""Fetch active clinical trials."""
try:
from rag_engine.api_clients import ClinicalTrialsClient
client = ClinicalTrialsClient()
trial_results = client.search_trials(cancer, mutations)
return [
f"[ClinicalTrials.gov] {t.get('title', 'Unknown')}: {t.get('status', '?')}"
for t in trial_results
]
except Exception as exc:
logger.warning("ClinicalTrials.gov API failed: %s", exc)
return []
with ThreadPoolExecutor(max_workers=2) as executor:
futures = {
executor.submit(fetch_civic): "genomic_evidence",
executor.submit(fetch_trials): "clinical_trials",
}
for future in as_completed(futures):
key = futures[future]
try:
results[key] = future.result()
except Exception as exc:
logger.error("Parallel fetch error (%s): %s", key, exc)
return results
# ---------------------------------------------------------------------------
# Corrective RAG Node
# ---------------------------------------------------------------------------
# Minimum relevant documents required to proceed
_MIN_RELEVANT_DOCS = 2
# Maximum query rewrite attempts
_MAX_REWRITES = 1
def corrective_rag_node(state: AgentState) -> Dict[str, Any]:
"""Execute the Corrective RAG pipeline.
Pipeline:
1. Build structured query from extracted entities.
2. Retrieve top-K candidates from ChromaDB.
3. Grade each document for relevance.
4. If insufficient relevant docs → rewrite query and retry.
5. Fetch API evidence in parallel (CIViC + ClinicalTrials).
6. Compute confidence metrics.
Args:
state: Current LangGraph state.
Returns:
State update with rag_context, sources, confidence, and metrics.
"""
entities: Dict[str, Any] = state.get("extracted_entities", {})
clinical_text: str = state.get("clinical_text", "")
selected_tier: int = state.get("selected_tier", 1)
# --- Build initial query ---
cancer = entities.get("cancer_type", "Unknown")
stage = entities.get("stage", "Unknown")
mutations = ", ".join(entities.get("mutations", []))
query_parts = []
if cancer != "Unknown":
query_parts.append(cancer)
else:
# Fallback: use first 100 chars of clinical text for vector search
query_parts.append(clinical_text[:100].replace("\n", " "))
if stage != "Unknown":
query_parts.append(stage)
if mutations:
query_parts.append(f"mutations: {mutations}")
query_parts.append("treatment recommendation guidelines triage")
query = " ".join(query_parts)
rewrite_count = 0
relevant_docs: List[Dict[str, Any]] = []
try:
retriever = _get_retriever()
# --- Retrieve + Grade loop ---
for attempt in range(1 + _MAX_REWRITES):
if attempt > 0:
query = _rewrite_query(query, entities, attempt)
rewrite_count += 1
# Retrieve candidates
raw_results = retriever.query(query, n_results=8)
# Grade documents in parallel for MI300X/API efficiency
from concurrent.futures import ThreadPoolExecutor
def _grade_doc_wrapper(r):
doc_text = r.get("text", "")
is_relevant = _grade_document(doc_text, query, tier=1)
return r if is_relevant else None
with ThreadPoolExecutor(max_workers=8) as executor:
results = list(executor.map(_grade_doc_wrapper, raw_results))
graded = [r for r in results if r is not None]
logger.info(
"CRAG attempt %d: %d/%d documents graded RELEVANT (Parallel).",
attempt + 1, len(graded), len(raw_results),
)
if len(graded) >= _MIN_RELEVANT_DOCS:
relevant_docs = graded
break
# --- Format results ---
context_strings = []
source_strings = []
for r in relevant_docs:
context_strings.append(
f"[Source: {r['source']}, Page: {r.get('page', '?')}, "
f"Section: {r.get('header', 'Unknown')}]\n{r['text']}"
)
source_strings.append(
f"- **{r['source']}** (Page {r.get('page', '?')}): "
f"{r.get('header', 'Unknown')}"
)
# --- Confidence metrics ---
ce_scores = [
r["cross_encoder_score"]
for r in relevant_docs
if "cross_encoder_score" in r
]
mean_confidence = sum(ce_scores) / len(ce_scores) if ce_scores else 0.0
except Exception as exc:
logger.error("RAG retrieval failed: %s", exc)
context_strings = []
source_strings = []
relevant_docs = []
mean_confidence = 0.0
rewrite_count = 0
# --- Parallel API evidence ---
api_results = _fetch_api_evidence(entities)
return {
"rag_context": context_strings,
"rag_sources": source_strings,
"graph_rag_context": [], # Future: knowledge graph integration
"api_evidence_context": (
api_results.get("genomic_evidence", [])
+ api_results.get("clinical_trials", [])
),
"rag_confidence": round(mean_confidence, 4),
"rag_retrieval_count": len(context_strings),
"rag_grading_pass_count": len(relevant_docs),
"rag_query_rewrites": rewrite_count,
}
|