narcolepticchicken commited on
Commit
04b4ddf
·
verified ·
1 Parent(s): 2e8ac72

Upload aco/execution_feedback.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. aco/execution_feedback.py +147 -0
aco/execution_feedback.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Execution-Feedback Router: Use cheap model output to decide escalation.
2
+
3
+ Implements the RouteNLP / CP-Router pattern:
4
+ 1. Send request to cheap model first
5
+ 2. Compute token-level uncertainty from output
6
+ 3. If uncertainty > calibrated threshold, escalate to stronger model
7
+ 4. Conformal calibration for distribution-free quality guarantees
8
+
9
+ This is the highest-impact routing improvement per the literature:
10
+ post-hoc quality estimates dramatically outperform ex-ante routing.
11
+ """
12
+ import math, json
13
+ from typing import Dict, List, Optional, Tuple
14
+ from dataclasses import dataclass
15
+
16
+ @dataclass
17
+ class FeedbackSignal:
18
+ entropy: float
19
+ binary_entropy: float
20
+ mean_logprob: float
21
+ min_logprob: float
22
+ token_count: int
23
+ low_confidence_tokens: int
24
+ low_confidence_ratio: float
25
+ should_escalate: bool
26
+ escalation_reason: str
27
+
28
+ @dataclass
29
+ class CascadeResult:
30
+ initial_tier: int
31
+ final_tier: int
32
+ initial_response: str
33
+ final_response: str
34
+ feedback_signal: FeedbackSignal
35
+ escalated: bool
36
+ total_cost: float
37
+ confidence: float
38
+
39
+ class ExecutionFeedbackRouter:
40
+ def __init__(self, safety_threshold: float = 0.30,
41
+ entropy_threshold: float = 2.5,
42
+ binary_entropy_threshold: float = 0.85,
43
+ min_logprob_threshold: float = -5.0,
44
+ low_conf_ratio_threshold: float = 0.15,
45
+ conformal_alpha: float = 0.05,
46
+ tier_costs: Dict[int,float] = None,
47
+ task_floors: Dict[str,int] = None):
48
+ self.safety_threshold = safety_threshold
49
+ self.entropy_threshold = entropy_threshold
50
+ self.binary_entropy_threshold = binary_entropy_threshold
51
+ self.min_logprob_threshold = min_logprob_threshold
52
+ self.low_conf_ratio_threshold = low_conf_ratio_threshold
53
+ self.conformal_alpha = conformal_alpha
54
+ self.tier_costs = tier_costs or {1:0.05,2:0.15,3:0.75,4:1.0,5:1.5}
55
+ self.task_floors = task_floors or {}
56
+ self._calibration_scores: List[float] = []
57
+ self._calibration_labels: List[int] = []
58
+ self._conformal_threshold: Optional[float] = None
59
+ self.escalation_stats = {"escalated":0,"stayed":0,"saved_cost":0.0,"wasted_cost":0.0}
60
+
61
+ def compute_entropy(self, logprobs: List[float]) -> float:
62
+ if not logprobs: return 0.0
63
+ n = len(logprobs)
64
+ return -sum(lp * math.exp(lp) for lp in logprobs) / n if n > 0 else 0.0
65
+
66
+ def compute_binary_entropy(self, prob: float) -> float:
67
+ p = max(min(prob, 0.9999), 0.0001)
68
+ return -(p*math.log2(p) + (1-p)*math.log2(1-p))
69
+
70
+ def analyze_output(self, token_logprobs: List[float],
71
+ token_probs: Optional[List[List[float]]] = None,
72
+ response_text: str = "",
73
+ task_type: str = "unknown",
74
+ current_tier: int = 2) -> FeedbackSignal:
75
+ if not token_logprobs:
76
+ return FeedbackSignal(0.0,0.0,0.0,-99.0,0,0,0.0,False,"no_data")
77
+ mean_lp = sum(token_logprobs)/len(token_logprobs)
78
+ min_lp = min(token_logprobs)
79
+ low_conf_count = sum(1 for lp in token_logprobs if lp < self.min_logprob_threshold)
80
+ low_conf_ratio = low_conf_count / len(token_logprobs)
81
+ entropy = self.compute_entropy(token_logprobs) if token_probs else abs(mean_lp)
82
+ binary_entropy = self.compute_binary_entropy(math.exp(mean_lp))
83
+ should_escalate = False
84
+ reasons = []
85
+ if entropy > self.entropy_threshold:
86
+ should_escalate = True
87
+ reasons.append(f"high_entropy({entropy:.2f}>{self.entropy_threshold})")
88
+ if low_conf_ratio > self.low_conf_ratio_threshold:
89
+ should_escalate = True
90
+ reasons.append(f"low_conf_ratio({low_conf_ratio:.2f}>{self.low_conf_ratio_threshold})")
91
+ if self._conformal_threshold is not None and entropy > self._conformal_threshold:
92
+ should_escalate = True
93
+ reasons.append(f"conformal_violation({entropy:.2f}>{self._conformal_threshold:.2f})")
94
+ return FeedbackSignal(
95
+ entropy=entropy, binary_entropy=binary_entropy,
96
+ mean_logprob=mean_lp, min_logprob=min_lp,
97
+ token_count=len(token_logprobs),
98
+ low_confidence_tokens=low_conf_count,
99
+ low_confidence_ratio=low_conf_ratio,
100
+ should_escalate=should_escalate,
101
+ escalation_reason="; ".join(reasons) if reasons else "ok",
102
+ )
103
+
104
+ def calibrate_conformal(self, scores: List[float], labels: List[int]):
105
+ self._calibration_scores = scores
106
+ self._calibration_labels = labels
107
+ n = len(scores)
108
+ if n < 10: return
109
+ alpha = self.conformal_alpha
110
+ k = min(int(math.ceil((1-alpha)*(n+1))), n)
111
+ sorted_scores = sorted(scores)
112
+ self._conformal_threshold = sorted_scores[min(k-1, n-1)]
113
+
114
+ def cascade(self, request: str, initial_tier: int,
115
+ cheap_logprobs: List[float],
116
+ cheap_response: str,
117
+ strong_response: str = "",
118
+ task_type: str = "unknown",
119
+ task_floor: int = 1) -> CascadeResult:
120
+ signal = self.analyze_output(cheap_logprobs, response_text=cheap_response,
121
+ task_type=task_type, current_tier=initial_tier)
122
+ escalated = signal.should_escalate
123
+ final_tier = initial_tier
124
+ final_response = cheap_response
125
+ if escalated and initial_tier < 5:
126
+ floor = max(task_floor, self.task_floors.get(task_type, 1))
127
+ final_tier = min(initial_tier + 1, 5)
128
+ final_tier = max(final_tier, floor)
129
+ final_response = strong_response if strong_response else cheap_response
130
+ self.escalation_stats["escalated"] += 1
131
+ self.escalation_stats["wasted_cost"] += self.tier_costs.get(final_tier, 1.0)
132
+ else:
133
+ self.escalation_stats["stayed"] += 1
134
+ self.escalation_stats["saved_cost"] += self.tier_costs.get(initial_tier+1, 1.0) - self.tier_costs.get(initial_tier, 0.15)
135
+ total_cost = self.tier_costs.get(initial_tier, 0.15)
136
+ if escalated:
137
+ total_cost += self.tier_costs.get(final_tier, 1.0)
138
+ return CascadeResult(
139
+ initial_tier=initial_tier,
140
+ final_tier=final_tier,
141
+ initial_response=cheap_response,
142
+ final_response=final_response,
143
+ feedback_signal=signal,
144
+ escalated=escalated,
145
+ total_cost=total_cost,
146
+ confidence=math.exp(signal.mean_logprob) if signal.mean_logprob > -99 else 0.0,
147
+ )