"""Execution-Feedback Router: Use cheap model output to decide escalation. Implements the RouteNLP / CP-Router pattern: 1. Send request to cheap model first 2. Compute token-level uncertainty from output 3. If uncertainty > calibrated threshold, escalate to stronger model 4. Conformal calibration for distribution-free quality guarantees This is the highest-impact routing improvement per the literature: post-hoc quality estimates dramatically outperform ex-ante routing. """ import math, json from typing import Dict, List, Optional, Tuple from dataclasses import dataclass @dataclass class FeedbackSignal: entropy: float binary_entropy: float mean_logprob: float min_logprob: float token_count: int low_confidence_tokens: int low_confidence_ratio: float should_escalate: bool escalation_reason: str @dataclass class CascadeResult: initial_tier: int final_tier: int initial_response: str final_response: str feedback_signal: FeedbackSignal escalated: bool total_cost: float confidence: float class ExecutionFeedbackRouter: def __init__(self, safety_threshold: float = 0.30, entropy_threshold: float = 2.5, binary_entropy_threshold: float = 0.85, min_logprob_threshold: float = -5.0, low_conf_ratio_threshold: float = 0.15, conformal_alpha: float = 0.05, tier_costs: Dict[int,float] = None, task_floors: Dict[str,int] = None): self.safety_threshold = safety_threshold self.entropy_threshold = entropy_threshold self.binary_entropy_threshold = binary_entropy_threshold self.min_logprob_threshold = min_logprob_threshold self.low_conf_ratio_threshold = low_conf_ratio_threshold self.conformal_alpha = conformal_alpha self.tier_costs = tier_costs or {1:0.05,2:0.15,3:0.75,4:1.0,5:1.5} self.task_floors = task_floors or {} self._calibration_scores: List[float] = [] self._calibration_labels: List[int] = [] self._conformal_threshold: Optional[float] = None self.escalation_stats = {"escalated":0,"stayed":0,"saved_cost":0.0,"wasted_cost":0.0} def compute_entropy(self, logprobs: List[float]) -> float: if not logprobs: return 0.0 n = len(logprobs) return -sum(lp * math.exp(lp) for lp in logprobs) / n if n > 0 else 0.0 def compute_binary_entropy(self, prob: float) -> float: p = max(min(prob, 0.9999), 0.0001) return -(p*math.log2(p) + (1-p)*math.log2(1-p)) def analyze_output(self, token_logprobs: List[float], token_probs: Optional[List[List[float]]] = None, response_text: str = "", task_type: str = "unknown", current_tier: int = 2) -> FeedbackSignal: if not token_logprobs: return FeedbackSignal(0.0,0.0,0.0,-99.0,0,0,0.0,False,"no_data") mean_lp = sum(token_logprobs)/len(token_logprobs) min_lp = min(token_logprobs) low_conf_count = sum(1 for lp in token_logprobs if lp < self.min_logprob_threshold) low_conf_ratio = low_conf_count / len(token_logprobs) entropy = self.compute_entropy(token_logprobs) if token_probs else abs(mean_lp) binary_entropy = self.compute_binary_entropy(math.exp(mean_lp)) should_escalate = False reasons = [] if entropy > self.entropy_threshold: should_escalate = True reasons.append(f"high_entropy({entropy:.2f}>{self.entropy_threshold})") if low_conf_ratio > self.low_conf_ratio_threshold: should_escalate = True reasons.append(f"low_conf_ratio({low_conf_ratio:.2f}>{self.low_conf_ratio_threshold})") if self._conformal_threshold is not None and entropy > self._conformal_threshold: should_escalate = True reasons.append(f"conformal_violation({entropy:.2f}>{self._conformal_threshold:.2f})") return FeedbackSignal( entropy=entropy, binary_entropy=binary_entropy, mean_logprob=mean_lp, min_logprob=min_lp, token_count=len(token_logprobs), low_confidence_tokens=low_conf_count, low_confidence_ratio=low_conf_ratio, should_escalate=should_escalate, escalation_reason="; ".join(reasons) if reasons else "ok", ) def calibrate_conformal(self, scores: List[float], labels: List[int]): self._calibration_scores = scores self._calibration_labels = labels n = len(scores) if n < 10: return alpha = self.conformal_alpha k = min(int(math.ceil((1-alpha)*(n+1))), n) sorted_scores = sorted(scores) self._conformal_threshold = sorted_scores[min(k-1, n-1)] def cascade(self, request: str, initial_tier: int, cheap_logprobs: List[float], cheap_response: str, strong_response: str = "", task_type: str = "unknown", task_floor: int = 1) -> CascadeResult: signal = self.analyze_output(cheap_logprobs, response_text=cheap_response, task_type=task_type, current_tier=initial_tier) escalated = signal.should_escalate final_tier = initial_tier final_response = cheap_response if escalated and initial_tier < 5: floor = max(task_floor, self.task_floors.get(task_type, 1)) final_tier = min(initial_tier + 1, 5) final_tier = max(final_tier, floor) final_response = strong_response if strong_response else cheap_response self.escalation_stats["escalated"] += 1 self.escalation_stats["wasted_cost"] += self.tier_costs.get(final_tier, 1.0) else: self.escalation_stats["stayed"] += 1 self.escalation_stats["saved_cost"] += self.tier_costs.get(initial_tier+1, 1.0) - self.tier_costs.get(initial_tier, 0.15) total_cost = self.tier_costs.get(initial_tier, 0.15) if escalated: total_cost += self.tier_costs.get(final_tier, 1.0) return CascadeResult( initial_tier=initial_tier, final_tier=final_tier, initial_response=cheap_response, final_response=final_response, feedback_signal=signal, escalated=escalated, total_cost=total_cost, confidence=math.exp(signal.mean_logprob) if signal.mean_logprob > -99 else 0.0, )