File size: 5,369 Bytes
e1624f5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | """
Router Node — Complexity classification and model tier selection.
Design pattern: Supervisor Routing (LangGraph SOTA)
Inspired by:
- Claude Code: deterministic routing via structured logic, not free text
- Hermes Agent: structured JSON output for decisions
The router classifies each clinical case into one of three categories:
- ``simple``: Well-known cancer + standard staging → Tier 1 (9B)
- ``complex``: Rare cancer / multi-mutation / ambiguous staging → Tier 2 (27B)
- ``insufficient``: Input too short or unintelligible → direct fallback
Supports manual tier override from the UI (user can force Tier 1 or 2).
"""
import logging
import json
from typing import Dict, Any, Optional
from .state import AgentState
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Complexity heuristics
# ---------------------------------------------------------------------------
# Cancer types considered "well-documented" with standard NCCN guidelines
_COMMON_CANCERS = frozenset({
"breast cancer", "lung cancer", "colon cancer", "colorectal cancer",
"prostate cancer", "melanoma", "bladder cancer", "thyroid cancer",
"cervical cancer", "ovarian cancer", "gastric cancer",
})
# Cancer types considered rare or requiring deeper reasoning
_RARE_CANCERS = frozenset({
"pancreatic cancer", "hepatocellular carcinoma", "sarcoma",
"glioma", "glioblastoma", "multiple myeloma", "renal cell carcinoma",
"esophageal cancer", "cholangiocarcinoma", "mesothelioma",
"neuroendocrine tumor", "adrenocortical carcinoma",
})
# Mutations that indicate multi-pathway complexity
_COMPLEX_MUTATIONS = frozenset({
"EGFR", "ALK", "KRAS", "NTRK", "RET", "MET", "ROS1",
"PIK3CA", "MSI-H", "DMMR", "BRAF V600E",
})
# Minimum character count for a clinically meaningful input
_MIN_INPUT_LENGTH = 30
def _classify_complexity(
clinical_text: str,
entities: Dict[str, Any],
) -> tuple[str, float, int]:
"""Classify case complexity using rule-based heuristics.
Args:
clinical_text: Raw clinical text.
entities: Extracted entities from the ingestion node.
Returns:
Tuple of (routing_decision, complexity_score, recommended_tier).
"""
# Gate: insufficient input
if len(clinical_text.strip()) < _MIN_INPUT_LENGTH:
logger.info("Input too short (%d chars) — routing to insufficient.", len(clinical_text))
return "insufficient", 0.0, 1
score = 0.0
cancer_type = entities.get("cancer_type", "Unknown").lower()
stage = entities.get("stage", "Unknown")
mutations = entities.get("mutations", [])
# --- Cancer type scoring ---
if cancer_type in _RARE_CANCERS:
score += 0.4
elif cancer_type == "unknown":
score += 0.3 # Unidentified cancer is inherently complex
# Common cancers add no complexity
# --- Stage scoring ---
if "IV" in stage.upper():
score += 0.25
elif "III" in stage.upper():
score += 0.15
# --- Mutation complexity ---
complex_muts = [m for m in mutations if m.upper() in _COMPLEX_MUTATIONS]
if len(complex_muts) >= 2:
score += 0.3 # Multi-mutation = high complexity
elif len(complex_muts) == 1:
score += 0.15
# --- Prior treatment mentions (heuristic) ---
prior_treatment_keywords = [
"prior treatment", "previously treated", "relapsed",
"refractory", "second-line", "third-line", "progression",
"resistance", "failed", "recurrent",
]
text_lower = clinical_text.lower()
for kw in prior_treatment_keywords:
if kw in text_lower:
score += 0.1
break
# Clamp to [0, 1]
score = min(score, 1.0)
# Decision boundary
if score >= 0.5:
return "complex", score, 2
else:
return "simple", score, 1
# ---------------------------------------------------------------------------
# Router Node
# ---------------------------------------------------------------------------
def router_node(state: AgentState) -> Dict[str, Any]:
"""Classify case complexity and select the appropriate model tier.
If the user has set ``user_tier_override`` in the state, that
takes precedence over the automatic classification.
Args:
state: Current LangGraph state.
Returns:
State update with routing_decision, complexity_score, selected_tier.
"""
clinical_text: str = state.get("clinical_text", "")
entities: Dict[str, Any] = state.get("extracted_entities", {})
user_override: Optional[int] = state.get("user_tier_override")
# Run automatic classification
decision, score, auto_tier = _classify_complexity(clinical_text, entities)
# Apply manual override if present
if user_override in (1, 2):
selected_tier = user_override
logger.info(
"Manual tier override applied: Tier %d (auto would be Tier %d, score=%.2f)",
user_override, auto_tier, score,
)
else:
selected_tier = auto_tier
logger.info(
"Auto-routing: decision=%s, score=%.2f → Tier %d",
decision, score, selected_tier,
)
return {
"routing_decision": decision,
"complexity_score": round(score, 4),
"selected_tier": selected_tier,
}
|