File size: 5,369 Bytes
e1624f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""
Router Node — Complexity classification and model tier selection.

Design pattern: Supervisor Routing (LangGraph SOTA)
Inspired by:
  - Claude Code: deterministic routing via structured logic, not free text
  - Hermes Agent: structured JSON output for decisions

The router classifies each clinical case into one of three categories:
  - ``simple``:  Well-known cancer + standard staging → Tier 1 (9B)
  - ``complex``: Rare cancer / multi-mutation / ambiguous staging → Tier 2 (27B)
  - ``insufficient``: Input too short or unintelligible → direct fallback

Supports manual tier override from the UI (user can force Tier 1 or 2).
"""

import logging
import json
from typing import Dict, Any, Optional

from .state import AgentState

logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Complexity heuristics
# ---------------------------------------------------------------------------

# Cancer types considered "well-documented" with standard NCCN guidelines
_COMMON_CANCERS = frozenset({
    "breast cancer", "lung cancer", "colon cancer", "colorectal cancer",
    "prostate cancer", "melanoma", "bladder cancer", "thyroid cancer",
    "cervical cancer", "ovarian cancer", "gastric cancer",
})

# Cancer types considered rare or requiring deeper reasoning
_RARE_CANCERS = frozenset({
    "pancreatic cancer", "hepatocellular carcinoma", "sarcoma",
    "glioma", "glioblastoma", "multiple myeloma", "renal cell carcinoma",
    "esophageal cancer", "cholangiocarcinoma", "mesothelioma",
    "neuroendocrine tumor", "adrenocortical carcinoma",
})

# Mutations that indicate multi-pathway complexity
_COMPLEX_MUTATIONS = frozenset({
    "EGFR", "ALK", "KRAS", "NTRK", "RET", "MET", "ROS1",
    "PIK3CA", "MSI-H", "DMMR", "BRAF V600E",
})

# Minimum character count for a clinically meaningful input
_MIN_INPUT_LENGTH = 30


def _classify_complexity(
    clinical_text: str,
    entities: Dict[str, Any],
) -> tuple[str, float, int]:
    """Classify case complexity using rule-based heuristics.

    Args:
        clinical_text: Raw clinical text.
        entities: Extracted entities from the ingestion node.

    Returns:
        Tuple of (routing_decision, complexity_score, recommended_tier).
    """
    # Gate: insufficient input
    if len(clinical_text.strip()) < _MIN_INPUT_LENGTH:
        logger.info("Input too short (%d chars) — routing to insufficient.", len(clinical_text))
        return "insufficient", 0.0, 1

    score = 0.0
    cancer_type = entities.get("cancer_type", "Unknown").lower()
    stage = entities.get("stage", "Unknown")
    mutations = entities.get("mutations", [])

    # --- Cancer type scoring ---
    if cancer_type in _RARE_CANCERS:
        score += 0.4
    elif cancer_type == "unknown":
        score += 0.3  # Unidentified cancer is inherently complex
    # Common cancers add no complexity

    # --- Stage scoring ---
    if "IV" in stage.upper():
        score += 0.25
    elif "III" in stage.upper():
        score += 0.15

    # --- Mutation complexity ---
    complex_muts = [m for m in mutations if m.upper() in _COMPLEX_MUTATIONS]
    if len(complex_muts) >= 2:
        score += 0.3  # Multi-mutation = high complexity
    elif len(complex_muts) == 1:
        score += 0.15

    # --- Prior treatment mentions (heuristic) ---
    prior_treatment_keywords = [
        "prior treatment", "previously treated", "relapsed",
        "refractory", "second-line", "third-line", "progression",
        "resistance", "failed", "recurrent",
    ]
    text_lower = clinical_text.lower()
    for kw in prior_treatment_keywords:
        if kw in text_lower:
            score += 0.1
            break

    # Clamp to [0, 1]
    score = min(score, 1.0)

    # Decision boundary
    if score >= 0.5:
        return "complex", score, 2
    else:
        return "simple", score, 1


# ---------------------------------------------------------------------------
# Router Node
# ---------------------------------------------------------------------------

def router_node(state: AgentState) -> Dict[str, Any]:
    """Classify case complexity and select the appropriate model tier.

    If the user has set ``user_tier_override`` in the state, that
    takes precedence over the automatic classification.

    Args:
        state: Current LangGraph state.

    Returns:
        State update with routing_decision, complexity_score, selected_tier.
    """
    clinical_text: str = state.get("clinical_text", "")
    entities: Dict[str, Any] = state.get("extracted_entities", {})
    user_override: Optional[int] = state.get("user_tier_override")

    # Run automatic classification
    decision, score, auto_tier = _classify_complexity(clinical_text, entities)

    # Apply manual override if present
    if user_override in (1, 2):
        selected_tier = user_override
        logger.info(
            "Manual tier override applied: Tier %d (auto would be Tier %d, score=%.2f)",
            user_override, auto_tier, score,
        )
    else:
        selected_tier = auto_tier
        logger.info(
            "Auto-routing: decision=%s, score=%.2f → Tier %d",
            decision, score, selected_tier,
        )

    return {
        "routing_decision": decision,
        "complexity_score": round(score, 4),
        "selected_tier": selected_tier,
    }