File size: 6,607 Bytes
04b4ddf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | """Execution-Feedback Router: Use cheap model output to decide escalation.
Implements the RouteNLP / CP-Router pattern:
1. Send request to cheap model first
2. Compute token-level uncertainty from output
3. If uncertainty > calibrated threshold, escalate to stronger model
4. Conformal calibration for distribution-free quality guarantees
This is the highest-impact routing improvement per the literature:
post-hoc quality estimates dramatically outperform ex-ante routing.
"""
import math, json
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
@dataclass
class FeedbackSignal:
entropy: float
binary_entropy: float
mean_logprob: float
min_logprob: float
token_count: int
low_confidence_tokens: int
low_confidence_ratio: float
should_escalate: bool
escalation_reason: str
@dataclass
class CascadeResult:
initial_tier: int
final_tier: int
initial_response: str
final_response: str
feedback_signal: FeedbackSignal
escalated: bool
total_cost: float
confidence: float
class ExecutionFeedbackRouter:
def __init__(self, safety_threshold: float = 0.30,
entropy_threshold: float = 2.5,
binary_entropy_threshold: float = 0.85,
min_logprob_threshold: float = -5.0,
low_conf_ratio_threshold: float = 0.15,
conformal_alpha: float = 0.05,
tier_costs: Dict[int,float] = None,
task_floors: Dict[str,int] = None):
self.safety_threshold = safety_threshold
self.entropy_threshold = entropy_threshold
self.binary_entropy_threshold = binary_entropy_threshold
self.min_logprob_threshold = min_logprob_threshold
self.low_conf_ratio_threshold = low_conf_ratio_threshold
self.conformal_alpha = conformal_alpha
self.tier_costs = tier_costs or {1:0.05,2:0.15,3:0.75,4:1.0,5:1.5}
self.task_floors = task_floors or {}
self._calibration_scores: List[float] = []
self._calibration_labels: List[int] = []
self._conformal_threshold: Optional[float] = None
self.escalation_stats = {"escalated":0,"stayed":0,"saved_cost":0.0,"wasted_cost":0.0}
def compute_entropy(self, logprobs: List[float]) -> float:
if not logprobs: return 0.0
n = len(logprobs)
return -sum(lp * math.exp(lp) for lp in logprobs) / n if n > 0 else 0.0
def compute_binary_entropy(self, prob: float) -> float:
p = max(min(prob, 0.9999), 0.0001)
return -(p*math.log2(p) + (1-p)*math.log2(1-p))
def analyze_output(self, token_logprobs: List[float],
token_probs: Optional[List[List[float]]] = None,
response_text: str = "",
task_type: str = "unknown",
current_tier: int = 2) -> FeedbackSignal:
if not token_logprobs:
return FeedbackSignal(0.0,0.0,0.0,-99.0,0,0,0.0,False,"no_data")
mean_lp = sum(token_logprobs)/len(token_logprobs)
min_lp = min(token_logprobs)
low_conf_count = sum(1 for lp in token_logprobs if lp < self.min_logprob_threshold)
low_conf_ratio = low_conf_count / len(token_logprobs)
entropy = self.compute_entropy(token_logprobs) if token_probs else abs(mean_lp)
binary_entropy = self.compute_binary_entropy(math.exp(mean_lp))
should_escalate = False
reasons = []
if entropy > self.entropy_threshold:
should_escalate = True
reasons.append(f"high_entropy({entropy:.2f}>{self.entropy_threshold})")
if low_conf_ratio > self.low_conf_ratio_threshold:
should_escalate = True
reasons.append(f"low_conf_ratio({low_conf_ratio:.2f}>{self.low_conf_ratio_threshold})")
if self._conformal_threshold is not None and entropy > self._conformal_threshold:
should_escalate = True
reasons.append(f"conformal_violation({entropy:.2f}>{self._conformal_threshold:.2f})")
return FeedbackSignal(
entropy=entropy, binary_entropy=binary_entropy,
mean_logprob=mean_lp, min_logprob=min_lp,
token_count=len(token_logprobs),
low_confidence_tokens=low_conf_count,
low_confidence_ratio=low_conf_ratio,
should_escalate=should_escalate,
escalation_reason="; ".join(reasons) if reasons else "ok",
)
def calibrate_conformal(self, scores: List[float], labels: List[int]):
self._calibration_scores = scores
self._calibration_labels = labels
n = len(scores)
if n < 10: return
alpha = self.conformal_alpha
k = min(int(math.ceil((1-alpha)*(n+1))), n)
sorted_scores = sorted(scores)
self._conformal_threshold = sorted_scores[min(k-1, n-1)]
def cascade(self, request: str, initial_tier: int,
cheap_logprobs: List[float],
cheap_response: str,
strong_response: str = "",
task_type: str = "unknown",
task_floor: int = 1) -> CascadeResult:
signal = self.analyze_output(cheap_logprobs, response_text=cheap_response,
task_type=task_type, current_tier=initial_tier)
escalated = signal.should_escalate
final_tier = initial_tier
final_response = cheap_response
if escalated and initial_tier < 5:
floor = max(task_floor, self.task_floors.get(task_type, 1))
final_tier = min(initial_tier + 1, 5)
final_tier = max(final_tier, floor)
final_response = strong_response if strong_response else cheap_response
self.escalation_stats["escalated"] += 1
self.escalation_stats["wasted_cost"] += self.tier_costs.get(final_tier, 1.0)
else:
self.escalation_stats["stayed"] += 1
self.escalation_stats["saved_cost"] += self.tier_costs.get(initial_tier+1, 1.0) - self.tier_costs.get(initial_tier, 0.15)
total_cost = self.tier_costs.get(initial_tier, 0.15)
if escalated:
total_cost += self.tier_costs.get(final_tier, 1.0)
return CascadeResult(
initial_tier=initial_tier,
final_tier=final_tier,
initial_response=cheap_response,
final_response=final_response,
feedback_signal=signal,
escalated=escalated,
total_cost=total_cost,
confidence=math.exp(signal.mean_logprob) if signal.mean_logprob > -99 else 0.0,
)
|