Upload aco/execution_feedback.py with huggingface_hub
Browse files- aco/execution_feedback.py +147 -0
aco/execution_feedback.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Execution-Feedback Router: Use cheap model output to decide escalation.
|
| 2 |
+
|
| 3 |
+
Implements the RouteNLP / CP-Router pattern:
|
| 4 |
+
1. Send request to cheap model first
|
| 5 |
+
2. Compute token-level uncertainty from output
|
| 6 |
+
3. If uncertainty > calibrated threshold, escalate to stronger model
|
| 7 |
+
4. Conformal calibration for distribution-free quality guarantees
|
| 8 |
+
|
| 9 |
+
This is the highest-impact routing improvement per the literature:
|
| 10 |
+
post-hoc quality estimates dramatically outperform ex-ante routing.
|
| 11 |
+
"""
|
| 12 |
+
import math, json
|
| 13 |
+
from typing import Dict, List, Optional, Tuple
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class FeedbackSignal:
|
| 18 |
+
entropy: float
|
| 19 |
+
binary_entropy: float
|
| 20 |
+
mean_logprob: float
|
| 21 |
+
min_logprob: float
|
| 22 |
+
token_count: int
|
| 23 |
+
low_confidence_tokens: int
|
| 24 |
+
low_confidence_ratio: float
|
| 25 |
+
should_escalate: bool
|
| 26 |
+
escalation_reason: str
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class CascadeResult:
|
| 30 |
+
initial_tier: int
|
| 31 |
+
final_tier: int
|
| 32 |
+
initial_response: str
|
| 33 |
+
final_response: str
|
| 34 |
+
feedback_signal: FeedbackSignal
|
| 35 |
+
escalated: bool
|
| 36 |
+
total_cost: float
|
| 37 |
+
confidence: float
|
| 38 |
+
|
| 39 |
+
class ExecutionFeedbackRouter:
|
| 40 |
+
def __init__(self, safety_threshold: float = 0.30,
|
| 41 |
+
entropy_threshold: float = 2.5,
|
| 42 |
+
binary_entropy_threshold: float = 0.85,
|
| 43 |
+
min_logprob_threshold: float = -5.0,
|
| 44 |
+
low_conf_ratio_threshold: float = 0.15,
|
| 45 |
+
conformal_alpha: float = 0.05,
|
| 46 |
+
tier_costs: Dict[int,float] = None,
|
| 47 |
+
task_floors: Dict[str,int] = None):
|
| 48 |
+
self.safety_threshold = safety_threshold
|
| 49 |
+
self.entropy_threshold = entropy_threshold
|
| 50 |
+
self.binary_entropy_threshold = binary_entropy_threshold
|
| 51 |
+
self.min_logprob_threshold = min_logprob_threshold
|
| 52 |
+
self.low_conf_ratio_threshold = low_conf_ratio_threshold
|
| 53 |
+
self.conformal_alpha = conformal_alpha
|
| 54 |
+
self.tier_costs = tier_costs or {1:0.05,2:0.15,3:0.75,4:1.0,5:1.5}
|
| 55 |
+
self.task_floors = task_floors or {}
|
| 56 |
+
self._calibration_scores: List[float] = []
|
| 57 |
+
self._calibration_labels: List[int] = []
|
| 58 |
+
self._conformal_threshold: Optional[float] = None
|
| 59 |
+
self.escalation_stats = {"escalated":0,"stayed":0,"saved_cost":0.0,"wasted_cost":0.0}
|
| 60 |
+
|
| 61 |
+
def compute_entropy(self, logprobs: List[float]) -> float:
|
| 62 |
+
if not logprobs: return 0.0
|
| 63 |
+
n = len(logprobs)
|
| 64 |
+
return -sum(lp * math.exp(lp) for lp in logprobs) / n if n > 0 else 0.0
|
| 65 |
+
|
| 66 |
+
def compute_binary_entropy(self, prob: float) -> float:
|
| 67 |
+
p = max(min(prob, 0.9999), 0.0001)
|
| 68 |
+
return -(p*math.log2(p) + (1-p)*math.log2(1-p))
|
| 69 |
+
|
| 70 |
+
def analyze_output(self, token_logprobs: List[float],
|
| 71 |
+
token_probs: Optional[List[List[float]]] = None,
|
| 72 |
+
response_text: str = "",
|
| 73 |
+
task_type: str = "unknown",
|
| 74 |
+
current_tier: int = 2) -> FeedbackSignal:
|
| 75 |
+
if not token_logprobs:
|
| 76 |
+
return FeedbackSignal(0.0,0.0,0.0,-99.0,0,0,0.0,False,"no_data")
|
| 77 |
+
mean_lp = sum(token_logprobs)/len(token_logprobs)
|
| 78 |
+
min_lp = min(token_logprobs)
|
| 79 |
+
low_conf_count = sum(1 for lp in token_logprobs if lp < self.min_logprob_threshold)
|
| 80 |
+
low_conf_ratio = low_conf_count / len(token_logprobs)
|
| 81 |
+
entropy = self.compute_entropy(token_logprobs) if token_probs else abs(mean_lp)
|
| 82 |
+
binary_entropy = self.compute_binary_entropy(math.exp(mean_lp))
|
| 83 |
+
should_escalate = False
|
| 84 |
+
reasons = []
|
| 85 |
+
if entropy > self.entropy_threshold:
|
| 86 |
+
should_escalate = True
|
| 87 |
+
reasons.append(f"high_entropy({entropy:.2f}>{self.entropy_threshold})")
|
| 88 |
+
if low_conf_ratio > self.low_conf_ratio_threshold:
|
| 89 |
+
should_escalate = True
|
| 90 |
+
reasons.append(f"low_conf_ratio({low_conf_ratio:.2f}>{self.low_conf_ratio_threshold})")
|
| 91 |
+
if self._conformal_threshold is not None and entropy > self._conformal_threshold:
|
| 92 |
+
should_escalate = True
|
| 93 |
+
reasons.append(f"conformal_violation({entropy:.2f}>{self._conformal_threshold:.2f})")
|
| 94 |
+
return FeedbackSignal(
|
| 95 |
+
entropy=entropy, binary_entropy=binary_entropy,
|
| 96 |
+
mean_logprob=mean_lp, min_logprob=min_lp,
|
| 97 |
+
token_count=len(token_logprobs),
|
| 98 |
+
low_confidence_tokens=low_conf_count,
|
| 99 |
+
low_confidence_ratio=low_conf_ratio,
|
| 100 |
+
should_escalate=should_escalate,
|
| 101 |
+
escalation_reason="; ".join(reasons) if reasons else "ok",
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
def calibrate_conformal(self, scores: List[float], labels: List[int]):
|
| 105 |
+
self._calibration_scores = scores
|
| 106 |
+
self._calibration_labels = labels
|
| 107 |
+
n = len(scores)
|
| 108 |
+
if n < 10: return
|
| 109 |
+
alpha = self.conformal_alpha
|
| 110 |
+
k = min(int(math.ceil((1-alpha)*(n+1))), n)
|
| 111 |
+
sorted_scores = sorted(scores)
|
| 112 |
+
self._conformal_threshold = sorted_scores[min(k-1, n-1)]
|
| 113 |
+
|
| 114 |
+
def cascade(self, request: str, initial_tier: int,
|
| 115 |
+
cheap_logprobs: List[float],
|
| 116 |
+
cheap_response: str,
|
| 117 |
+
strong_response: str = "",
|
| 118 |
+
task_type: str = "unknown",
|
| 119 |
+
task_floor: int = 1) -> CascadeResult:
|
| 120 |
+
signal = self.analyze_output(cheap_logprobs, response_text=cheap_response,
|
| 121 |
+
task_type=task_type, current_tier=initial_tier)
|
| 122 |
+
escalated = signal.should_escalate
|
| 123 |
+
final_tier = initial_tier
|
| 124 |
+
final_response = cheap_response
|
| 125 |
+
if escalated and initial_tier < 5:
|
| 126 |
+
floor = max(task_floor, self.task_floors.get(task_type, 1))
|
| 127 |
+
final_tier = min(initial_tier + 1, 5)
|
| 128 |
+
final_tier = max(final_tier, floor)
|
| 129 |
+
final_response = strong_response if strong_response else cheap_response
|
| 130 |
+
self.escalation_stats["escalated"] += 1
|
| 131 |
+
self.escalation_stats["wasted_cost"] += self.tier_costs.get(final_tier, 1.0)
|
| 132 |
+
else:
|
| 133 |
+
self.escalation_stats["stayed"] += 1
|
| 134 |
+
self.escalation_stats["saved_cost"] += self.tier_costs.get(initial_tier+1, 1.0) - self.tier_costs.get(initial_tier, 0.15)
|
| 135 |
+
total_cost = self.tier_costs.get(initial_tier, 0.15)
|
| 136 |
+
if escalated:
|
| 137 |
+
total_cost += self.tier_costs.get(final_tier, 1.0)
|
| 138 |
+
return CascadeResult(
|
| 139 |
+
initial_tier=initial_tier,
|
| 140 |
+
final_tier=final_tier,
|
| 141 |
+
initial_response=cheap_response,
|
| 142 |
+
final_response=final_response,
|
| 143 |
+
feedback_signal=signal,
|
| 144 |
+
escalated=escalated,
|
| 145 |
+
total_cost=total_cost,
|
| 146 |
+
confidence=math.exp(signal.mean_logprob) if signal.mean_logprob > -99 else 0.0,
|
| 147 |
+
)
|