| """Execution-Feedback Router: Use cheap model output to decide escalation. |
| |
| Implements the RouteNLP / CP-Router pattern: |
| 1. Send request to cheap model first |
| 2. Compute token-level uncertainty from output |
| 3. If uncertainty > calibrated threshold, escalate to stronger model |
| 4. Conformal calibration for distribution-free quality guarantees |
| |
| This is the highest-impact routing improvement per the literature: |
| post-hoc quality estimates dramatically outperform ex-ante routing. |
| """ |
| import math, json |
| from typing import Dict, List, Optional, Tuple |
| from dataclasses import dataclass |
|
|
| @dataclass |
| class FeedbackSignal: |
| entropy: float |
| binary_entropy: float |
| mean_logprob: float |
| min_logprob: float |
| token_count: int |
| low_confidence_tokens: int |
| low_confidence_ratio: float |
| should_escalate: bool |
| escalation_reason: str |
|
|
| @dataclass |
| class CascadeResult: |
| initial_tier: int |
| final_tier: int |
| initial_response: str |
| final_response: str |
| feedback_signal: FeedbackSignal |
| escalated: bool |
| total_cost: float |
| confidence: float |
|
|
| class ExecutionFeedbackRouter: |
| def __init__(self, safety_threshold: float = 0.30, |
| entropy_threshold: float = 2.5, |
| binary_entropy_threshold: float = 0.85, |
| min_logprob_threshold: float = -5.0, |
| low_conf_ratio_threshold: float = 0.15, |
| conformal_alpha: float = 0.05, |
| tier_costs: Dict[int,float] = None, |
| task_floors: Dict[str,int] = None): |
| self.safety_threshold = safety_threshold |
| self.entropy_threshold = entropy_threshold |
| self.binary_entropy_threshold = binary_entropy_threshold |
| self.min_logprob_threshold = min_logprob_threshold |
| self.low_conf_ratio_threshold = low_conf_ratio_threshold |
| self.conformal_alpha = conformal_alpha |
| self.tier_costs = tier_costs or {1:0.05,2:0.15,3:0.75,4:1.0,5:1.5} |
| self.task_floors = task_floors or {} |
| self._calibration_scores: List[float] = [] |
| self._calibration_labels: List[int] = [] |
| self._conformal_threshold: Optional[float] = None |
| self.escalation_stats = {"escalated":0,"stayed":0,"saved_cost":0.0,"wasted_cost":0.0} |
|
|
| def compute_entropy(self, logprobs: List[float]) -> float: |
| if not logprobs: return 0.0 |
| n = len(logprobs) |
| return -sum(lp * math.exp(lp) for lp in logprobs) / n if n > 0 else 0.0 |
|
|
| def compute_binary_entropy(self, prob: float) -> float: |
| p = max(min(prob, 0.9999), 0.0001) |
| return -(p*math.log2(p) + (1-p)*math.log2(1-p)) |
|
|
| def analyze_output(self, token_logprobs: List[float], |
| token_probs: Optional[List[List[float]]] = None, |
| response_text: str = "", |
| task_type: str = "unknown", |
| current_tier: int = 2) -> FeedbackSignal: |
| if not token_logprobs: |
| return FeedbackSignal(0.0,0.0,0.0,-99.0,0,0,0.0,False,"no_data") |
| mean_lp = sum(token_logprobs)/len(token_logprobs) |
| min_lp = min(token_logprobs) |
| low_conf_count = sum(1 for lp in token_logprobs if lp < self.min_logprob_threshold) |
| low_conf_ratio = low_conf_count / len(token_logprobs) |
| entropy = self.compute_entropy(token_logprobs) if token_probs else abs(mean_lp) |
| binary_entropy = self.compute_binary_entropy(math.exp(mean_lp)) |
| should_escalate = False |
| reasons = [] |
| if entropy > self.entropy_threshold: |
| should_escalate = True |
| reasons.append(f"high_entropy({entropy:.2f}>{self.entropy_threshold})") |
| if low_conf_ratio > self.low_conf_ratio_threshold: |
| should_escalate = True |
| reasons.append(f"low_conf_ratio({low_conf_ratio:.2f}>{self.low_conf_ratio_threshold})") |
| if self._conformal_threshold is not None and entropy > self._conformal_threshold: |
| should_escalate = True |
| reasons.append(f"conformal_violation({entropy:.2f}>{self._conformal_threshold:.2f})") |
| return FeedbackSignal( |
| entropy=entropy, binary_entropy=binary_entropy, |
| mean_logprob=mean_lp, min_logprob=min_lp, |
| token_count=len(token_logprobs), |
| low_confidence_tokens=low_conf_count, |
| low_confidence_ratio=low_conf_ratio, |
| should_escalate=should_escalate, |
| escalation_reason="; ".join(reasons) if reasons else "ok", |
| ) |
|
|
| def calibrate_conformal(self, scores: List[float], labels: List[int]): |
| self._calibration_scores = scores |
| self._calibration_labels = labels |
| n = len(scores) |
| if n < 10: return |
| alpha = self.conformal_alpha |
| k = min(int(math.ceil((1-alpha)*(n+1))), n) |
| sorted_scores = sorted(scores) |
| self._conformal_threshold = sorted_scores[min(k-1, n-1)] |
|
|
| def cascade(self, request: str, initial_tier: int, |
| cheap_logprobs: List[float], |
| cheap_response: str, |
| strong_response: str = "", |
| task_type: str = "unknown", |
| task_floor: int = 1) -> CascadeResult: |
| signal = self.analyze_output(cheap_logprobs, response_text=cheap_response, |
| task_type=task_type, current_tier=initial_tier) |
| escalated = signal.should_escalate |
| final_tier = initial_tier |
| final_response = cheap_response |
| if escalated and initial_tier < 5: |
| floor = max(task_floor, self.task_floors.get(task_type, 1)) |
| final_tier = min(initial_tier + 1, 5) |
| final_tier = max(final_tier, floor) |
| final_response = strong_response if strong_response else cheap_response |
| self.escalation_stats["escalated"] += 1 |
| self.escalation_stats["wasted_cost"] += self.tier_costs.get(final_tier, 1.0) |
| else: |
| self.escalation_stats["stayed"] += 1 |
| self.escalation_stats["saved_cost"] += self.tier_costs.get(initial_tier+1, 1.0) - self.tier_costs.get(initial_tier, 0.15) |
| total_cost = self.tier_costs.get(initial_tier, 0.15) |
| if escalated: |
| total_cost += self.tier_costs.get(final_tier, 1.0) |
| return CascadeResult( |
| initial_tier=initial_tier, |
| final_tier=final_tier, |
| initial_response=cheap_response, |
| final_response=final_response, |
| feedback_signal=signal, |
| escalated=escalated, |
| total_cost=total_cost, |
| confidence=math.exp(signal.mean_logprob) if signal.mean_logprob > -99 else 0.0, |
| ) |
|
|