agent-cost-optimizer / aco /execution_feedback.py
narcolepticchicken's picture
Upload aco/execution_feedback.py with huggingface_hub
04b4ddf verified
"""Execution-Feedback Router: Use cheap model output to decide escalation.
Implements the RouteNLP / CP-Router pattern:
1. Send request to cheap model first
2. Compute token-level uncertainty from output
3. If uncertainty > calibrated threshold, escalate to stronger model
4. Conformal calibration for distribution-free quality guarantees
This is the highest-impact routing improvement per the literature:
post-hoc quality estimates dramatically outperform ex-ante routing.
"""
import math, json
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
@dataclass
class FeedbackSignal:
entropy: float
binary_entropy: float
mean_logprob: float
min_logprob: float
token_count: int
low_confidence_tokens: int
low_confidence_ratio: float
should_escalate: bool
escalation_reason: str
@dataclass
class CascadeResult:
initial_tier: int
final_tier: int
initial_response: str
final_response: str
feedback_signal: FeedbackSignal
escalated: bool
total_cost: float
confidence: float
class ExecutionFeedbackRouter:
def __init__(self, safety_threshold: float = 0.30,
entropy_threshold: float = 2.5,
binary_entropy_threshold: float = 0.85,
min_logprob_threshold: float = -5.0,
low_conf_ratio_threshold: float = 0.15,
conformal_alpha: float = 0.05,
tier_costs: Dict[int,float] = None,
task_floors: Dict[str,int] = None):
self.safety_threshold = safety_threshold
self.entropy_threshold = entropy_threshold
self.binary_entropy_threshold = binary_entropy_threshold
self.min_logprob_threshold = min_logprob_threshold
self.low_conf_ratio_threshold = low_conf_ratio_threshold
self.conformal_alpha = conformal_alpha
self.tier_costs = tier_costs or {1:0.05,2:0.15,3:0.75,4:1.0,5:1.5}
self.task_floors = task_floors or {}
self._calibration_scores: List[float] = []
self._calibration_labels: List[int] = []
self._conformal_threshold: Optional[float] = None
self.escalation_stats = {"escalated":0,"stayed":0,"saved_cost":0.0,"wasted_cost":0.0}
def compute_entropy(self, logprobs: List[float]) -> float:
if not logprobs: return 0.0
n = len(logprobs)
return -sum(lp * math.exp(lp) for lp in logprobs) / n if n > 0 else 0.0
def compute_binary_entropy(self, prob: float) -> float:
p = max(min(prob, 0.9999), 0.0001)
return -(p*math.log2(p) + (1-p)*math.log2(1-p))
def analyze_output(self, token_logprobs: List[float],
token_probs: Optional[List[List[float]]] = None,
response_text: str = "",
task_type: str = "unknown",
current_tier: int = 2) -> FeedbackSignal:
if not token_logprobs:
return FeedbackSignal(0.0,0.0,0.0,-99.0,0,0,0.0,False,"no_data")
mean_lp = sum(token_logprobs)/len(token_logprobs)
min_lp = min(token_logprobs)
low_conf_count = sum(1 for lp in token_logprobs if lp < self.min_logprob_threshold)
low_conf_ratio = low_conf_count / len(token_logprobs)
entropy = self.compute_entropy(token_logprobs) if token_probs else abs(mean_lp)
binary_entropy = self.compute_binary_entropy(math.exp(mean_lp))
should_escalate = False
reasons = []
if entropy > self.entropy_threshold:
should_escalate = True
reasons.append(f"high_entropy({entropy:.2f}>{self.entropy_threshold})")
if low_conf_ratio > self.low_conf_ratio_threshold:
should_escalate = True
reasons.append(f"low_conf_ratio({low_conf_ratio:.2f}>{self.low_conf_ratio_threshold})")
if self._conformal_threshold is not None and entropy > self._conformal_threshold:
should_escalate = True
reasons.append(f"conformal_violation({entropy:.2f}>{self._conformal_threshold:.2f})")
return FeedbackSignal(
entropy=entropy, binary_entropy=binary_entropy,
mean_logprob=mean_lp, min_logprob=min_lp,
token_count=len(token_logprobs),
low_confidence_tokens=low_conf_count,
low_confidence_ratio=low_conf_ratio,
should_escalate=should_escalate,
escalation_reason="; ".join(reasons) if reasons else "ok",
)
def calibrate_conformal(self, scores: List[float], labels: List[int]):
self._calibration_scores = scores
self._calibration_labels = labels
n = len(scores)
if n < 10: return
alpha = self.conformal_alpha
k = min(int(math.ceil((1-alpha)*(n+1))), n)
sorted_scores = sorted(scores)
self._conformal_threshold = sorted_scores[min(k-1, n-1)]
def cascade(self, request: str, initial_tier: int,
cheap_logprobs: List[float],
cheap_response: str,
strong_response: str = "",
task_type: str = "unknown",
task_floor: int = 1) -> CascadeResult:
signal = self.analyze_output(cheap_logprobs, response_text=cheap_response,
task_type=task_type, current_tier=initial_tier)
escalated = signal.should_escalate
final_tier = initial_tier
final_response = cheap_response
if escalated and initial_tier < 5:
floor = max(task_floor, self.task_floors.get(task_type, 1))
final_tier = min(initial_tier + 1, 5)
final_tier = max(final_tier, floor)
final_response = strong_response if strong_response else cheap_response
self.escalation_stats["escalated"] += 1
self.escalation_stats["wasted_cost"] += self.tier_costs.get(final_tier, 1.0)
else:
self.escalation_stats["stayed"] += 1
self.escalation_stats["saved_cost"] += self.tier_costs.get(initial_tier+1, 1.0) - self.tier_costs.get(initial_tier, 0.15)
total_cost = self.tier_costs.get(initial_tier, 0.15)
if escalated:
total_cost += self.tier_costs.get(final_tier, 1.0)
return CascadeResult(
initial_tier=initial_tier,
final_tier=final_tier,
initial_response=cheap_response,
final_response=final_response,
feedback_signal=signal,
escalated=escalated,
total_cost=total_cost,
confidence=math.exp(signal.mean_logprob) if signal.mean_logprob > -99 else 0.0,
)