OncoAgent / agents /corrective_rag.py
MaximoLopezChenlo's picture
Upload folder using huggingface_hub
e1624f5 verified
"""
Corrective RAG Node — Graded retrieval with query rewriting.
Design pattern: Corrective RAG (CRAG) from Yan et al. 2024
1. Retrieve top-K documents from ChromaDB
2. Grade each document for relevance (binary: RELEVANT / IRRELEVANT)
3. If insufficient relevant docs → rewrite query and re-retrieve
4. If still insufficient after max retries → route to fallback
Also implements parallelised evidence gathering:
- ChromaDB (clinical guidelines)
- CIViC API (genomic evidence)
- ClinicalTrials.gov (active trials)
"""
import logging
import re
from typing import Dict, Any, List, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
from .state import AgentState
from .tools import call_tier_model
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Lazy-loaded retriever singleton
# ---------------------------------------------------------------------------
_retriever_instance = None
def _get_retriever():
"""Return a cached OncoRAGRetriever instance (lazy init)."""
global _retriever_instance
if _retriever_instance is None:
try:
from rag_engine.retriever import OncoRAGRetriever
_retriever_instance = OncoRAGRetriever()
logger.info("OncoRAGRetriever initialised successfully.")
except Exception as exc:
logger.error("Failed to initialise OncoRAGRetriever: %s", exc)
raise
return _retriever_instance
# ---------------------------------------------------------------------------
# Document Grading (CRAG core)
# ---------------------------------------------------------------------------
def _grade_document(document_text: str, query: str, tier: int = 1) -> bool:
"""Grade a single retrieved document for relevance.
Uses the Tier 1 (fast) model for binary classification.
Args:
document_text: The document text to evaluate.
query: The clinical query.
tier: Model tier to use for grading (default: 1 for speed).
Returns:
True if the document is RELEVANT, False otherwise.
"""
# Rule-based shortcut: if the query contains a core cancer type and the document mentions it,
# we favor relevance to avoid model hallucination/rejection.
query_lower = query.lower()
doc_lower = document_text.lower()
# Simple semantic overlap check: ignore generic terms
ignore_terms = {"treatment", "recommendation", "guidelines", "triage", "management", "clinical", "oncology"}
core_terms = [t for t in query_lower.split() if len(t) > 4 and t not in ignore_terms]
term_match = any(term in doc_lower for term in core_terms)
system_prompt = (
"You are an expert Oncology Clinical Analyst. "
"Your task is to evaluate if a retrieved medical document is RELEVANT to a patient's clinical query. "
"RELEVANCE CRITERIA:\n"
"1. The document discusses the specific cancer type or its precursors.\n"
"2. The document mentions treatment protocols, staging, or diagnostic criteria relevant to the case.\n"
"3. Synonyms are allowed (e.g., 'Uterine Cancer' is relevant to 'Endometrial Adenocarcinoma').\n\n"
"Output ONLY 'RELEVANT' or 'IRRELEVANT'. No explanation."
)
user_prompt = (
f"Patient Query/Context: {query}\n\n"
f"Document Snippet:\n--- START ---\n{document_text[:2000]}\n--- END ---\n\n"
f"Is this document relevant? (RELEVANT/IRRELEVANT):"
)
try:
response = call_tier_model(
tier=tier,
system_prompt=system_prompt,
user_prompt=user_prompt,
max_tokens=10,
temperature=0.0,
)
is_relevant = "RELEVANT" in response.upper()
logger.info("Doc Grade: %s (Term Match: %s) -> Query: %s...", "RELEVANT" if is_relevant else "IRRELEVANT", term_match, query[:30])
# Boost: if model says IRRELEVANT but there is a strong term match, we might want to override.
# This ensures we don't drop guidelines that explicitly mention the cancer type.
if not is_relevant and term_match:
logger.debug("Model rejected doc but keyword match found. Overriding to RELEVANT for recall.")
return True
return is_relevant
except Exception as exc:
logger.warning("Document grading failed: %s — defaulting to RELEVANT.", exc)
return True # Fail open: include document if grading fails
def _rewrite_query(
original_query: str,
entities: Dict[str, Any],
attempt: int,
) -> str:
"""Broaden the query for a retry attempt.
Uses deterministic broadening rather than LLM-based rewriting
for speed and predictability.
Args:
original_query: The query that yielded insufficient results.
entities: Extracted clinical entities.
attempt: The retry attempt number (1-indexed).
Returns:
A broadened query string.
"""
cancer = entities.get("cancer_type", "Unknown")
stage = entities.get("stage", "Unknown")
mutations = entities.get("mutations", [])
if attempt == 1:
# Broadening strategy: remove stage specificity, keep cancer + mutations
parts = [cancer]
if mutations:
parts.append(f"mutations {' '.join(mutations)}")
parts.append("treatment guidelines evidence-based recommendations")
rewritten = " ".join(parts)
logger.info("Query rewrite attempt %d: %s → %s", attempt, original_query, rewritten)
return rewritten
# Attempt 2+: maximally broad
rewritten = f"{cancer} oncology clinical guidelines management"
logger.info("Query rewrite attempt %d (maximal broadening): %s", attempt, rewritten)
return rewritten
# ---------------------------------------------------------------------------
# Parallel Evidence Gathering
# ---------------------------------------------------------------------------
def _fetch_api_evidence(entities: Dict[str, Any]) -> Dict[str, List[str]]:
"""Fetch genomic and clinical trial evidence in parallel.
Calls CIViC API and ClinicalTrials.gov concurrently for MI300X
throughput optimisation.
Args:
entities: Extracted clinical entities.
Returns:
Dict with "genomic_evidence" and "clinical_trials" lists.
"""
results: Dict[str, List[str]] = {
"genomic_evidence": [],
"clinical_trials": [],
}
mutations = entities.get("mutations", [])
cancer = entities.get("cancer_type", "Unknown")
def fetch_civic():
"""Fetch genomic evidence from CIViC."""
try:
from rag_engine.api_clients import CivicAPIClient
client = CivicAPIClient()
evidence = []
for mutation in mutations:
civic_results = client.search_variant_evidence(mutation, cancer)
for r in civic_results:
evidence.append(
f"[CIViC] {mutation}: {r.get('summary', 'No summary available')}"
)
return evidence
except Exception as exc:
logger.warning("CIViC API failed: %s", exc)
return []
def fetch_trials():
"""Fetch active clinical trials."""
try:
from rag_engine.api_clients import ClinicalTrialsClient
client = ClinicalTrialsClient()
trial_results = client.search_trials(cancer, mutations)
return [
f"[ClinicalTrials.gov] {t.get('title', 'Unknown')}: {t.get('status', '?')}"
for t in trial_results
]
except Exception as exc:
logger.warning("ClinicalTrials.gov API failed: %s", exc)
return []
with ThreadPoolExecutor(max_workers=2) as executor:
futures = {
executor.submit(fetch_civic): "genomic_evidence",
executor.submit(fetch_trials): "clinical_trials",
}
for future in as_completed(futures):
key = futures[future]
try:
results[key] = future.result()
except Exception as exc:
logger.error("Parallel fetch error (%s): %s", key, exc)
return results
# ---------------------------------------------------------------------------
# Corrective RAG Node
# ---------------------------------------------------------------------------
# Minimum relevant documents required to proceed
_MIN_RELEVANT_DOCS = 2
# Maximum query rewrite attempts
_MAX_REWRITES = 1
def corrective_rag_node(state: AgentState) -> Dict[str, Any]:
"""Execute the Corrective RAG pipeline.
Pipeline:
1. Build structured query from extracted entities.
2. Retrieve top-K candidates from ChromaDB.
3. Grade each document for relevance.
4. If insufficient relevant docs → rewrite query and retry.
5. Fetch API evidence in parallel (CIViC + ClinicalTrials).
6. Compute confidence metrics.
Args:
state: Current LangGraph state.
Returns:
State update with rag_context, sources, confidence, and metrics.
"""
entities: Dict[str, Any] = state.get("extracted_entities", {})
clinical_text: str = state.get("clinical_text", "")
selected_tier: int = state.get("selected_tier", 1)
# --- Build initial query ---
cancer = entities.get("cancer_type", "Unknown")
stage = entities.get("stage", "Unknown")
mutations = ", ".join(entities.get("mutations", []))
query_parts = []
if cancer != "Unknown":
query_parts.append(cancer)
else:
# Fallback: use first 100 chars of clinical text for vector search
query_parts.append(clinical_text[:100].replace("\n", " "))
if stage != "Unknown":
query_parts.append(stage)
if mutations:
query_parts.append(f"mutations: {mutations}")
query_parts.append("treatment recommendation guidelines triage")
query = " ".join(query_parts)
rewrite_count = 0
relevant_docs: List[Dict[str, Any]] = []
try:
retriever = _get_retriever()
# --- Retrieve + Grade loop ---
for attempt in range(1 + _MAX_REWRITES):
if attempt > 0:
query = _rewrite_query(query, entities, attempt)
rewrite_count += 1
# Retrieve candidates
raw_results = retriever.query(query, n_results=8)
# Grade documents in parallel for MI300X/API efficiency
from concurrent.futures import ThreadPoolExecutor
def _grade_doc_wrapper(r):
doc_text = r.get("text", "")
is_relevant = _grade_document(doc_text, query, tier=1)
return r if is_relevant else None
with ThreadPoolExecutor(max_workers=8) as executor:
results = list(executor.map(_grade_doc_wrapper, raw_results))
graded = [r for r in results if r is not None]
logger.info(
"CRAG attempt %d: %d/%d documents graded RELEVANT (Parallel).",
attempt + 1, len(graded), len(raw_results),
)
if len(graded) >= _MIN_RELEVANT_DOCS:
relevant_docs = graded
break
# --- Format results ---
context_strings = []
source_strings = []
for r in relevant_docs:
context_strings.append(
f"[Source: {r['source']}, Page: {r.get('page', '?')}, "
f"Section: {r.get('header', 'Unknown')}]\n{r['text']}"
)
source_strings.append(
f"- **{r['source']}** (Page {r.get('page', '?')}): "
f"{r.get('header', 'Unknown')}"
)
# --- Confidence metrics ---
ce_scores = [
r["cross_encoder_score"]
for r in relevant_docs
if "cross_encoder_score" in r
]
mean_confidence = sum(ce_scores) / len(ce_scores) if ce_scores else 0.0
except Exception as exc:
logger.error("RAG retrieval failed: %s", exc)
context_strings = []
source_strings = []
relevant_docs = []
mean_confidence = 0.0
rewrite_count = 0
# --- Parallel API evidence ---
api_results = _fetch_api_evidence(entities)
return {
"rag_context": context_strings,
"rag_sources": source_strings,
"graph_rag_context": [], # Future: knowledge graph integration
"api_evidence_context": (
api_results.get("genomic_evidence", [])
+ api_results.get("clinical_trials", [])
),
"rag_confidence": round(mean_confidence, 4),
"rag_retrieval_count": len(context_strings),
"rag_grading_pass_count": len(relevant_docs),
"rag_query_rewrites": rewrite_count,
}