File size: 6,607 Bytes
04b4ddf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""Execution-Feedback Router: Use cheap model output to decide escalation.

Implements the RouteNLP / CP-Router pattern:
1. Send request to cheap model first
2. Compute token-level uncertainty from output
3. If uncertainty > calibrated threshold, escalate to stronger model
4. Conformal calibration for distribution-free quality guarantees

This is the highest-impact routing improvement per the literature:
post-hoc quality estimates dramatically outperform ex-ante routing.
"""
import math, json
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass

@dataclass
class FeedbackSignal:
    entropy: float
    binary_entropy: float
    mean_logprob: float
    min_logprob: float
    token_count: int
    low_confidence_tokens: int
    low_confidence_ratio: float
    should_escalate: bool
    escalation_reason: str

@dataclass
class CascadeResult:
    initial_tier: int
    final_tier: int
    initial_response: str
    final_response: str
    feedback_signal: FeedbackSignal
    escalated: bool
    total_cost: float
    confidence: float

class ExecutionFeedbackRouter:
    def __init__(self, safety_threshold: float = 0.30,
                 entropy_threshold: float = 2.5,
                 binary_entropy_threshold: float = 0.85,
                 min_logprob_threshold: float = -5.0,
                 low_conf_ratio_threshold: float = 0.15,
                 conformal_alpha: float = 0.05,
                 tier_costs: Dict[int,float] = None,
                 task_floors: Dict[str,int] = None):
        self.safety_threshold = safety_threshold
        self.entropy_threshold = entropy_threshold
        self.binary_entropy_threshold = binary_entropy_threshold
        self.min_logprob_threshold = min_logprob_threshold
        self.low_conf_ratio_threshold = low_conf_ratio_threshold
        self.conformal_alpha = conformal_alpha
        self.tier_costs = tier_costs or {1:0.05,2:0.15,3:0.75,4:1.0,5:1.5}
        self.task_floors = task_floors or {}
        self._calibration_scores: List[float] = []
        self._calibration_labels: List[int] = []
        self._conformal_threshold: Optional[float] = None
        self.escalation_stats = {"escalated":0,"stayed":0,"saved_cost":0.0,"wasted_cost":0.0}

    def compute_entropy(self, logprobs: List[float]) -> float:
        if not logprobs: return 0.0
        n = len(logprobs)
        return -sum(lp * math.exp(lp) for lp in logprobs) / n if n > 0 else 0.0

    def compute_binary_entropy(self, prob: float) -> float:
        p = max(min(prob, 0.9999), 0.0001)
        return -(p*math.log2(p) + (1-p)*math.log2(1-p))

    def analyze_output(self, token_logprobs: List[float],
                       token_probs: Optional[List[List[float]]] = None,
                       response_text: str = "",
                       task_type: str = "unknown",
                       current_tier: int = 2) -> FeedbackSignal:
        if not token_logprobs:
            return FeedbackSignal(0.0,0.0,0.0,-99.0,0,0,0.0,False,"no_data")
        mean_lp = sum(token_logprobs)/len(token_logprobs)
        min_lp = min(token_logprobs)
        low_conf_count = sum(1 for lp in token_logprobs if lp < self.min_logprob_threshold)
        low_conf_ratio = low_conf_count / len(token_logprobs)
        entropy = self.compute_entropy(token_logprobs) if token_probs else abs(mean_lp)
        binary_entropy = self.compute_binary_entropy(math.exp(mean_lp))
        should_escalate = False
        reasons = []
        if entropy > self.entropy_threshold:
            should_escalate = True
            reasons.append(f"high_entropy({entropy:.2f}>{self.entropy_threshold})")
        if low_conf_ratio > self.low_conf_ratio_threshold:
            should_escalate = True
            reasons.append(f"low_conf_ratio({low_conf_ratio:.2f}>{self.low_conf_ratio_threshold})")
        if self._conformal_threshold is not None and entropy > self._conformal_threshold:
            should_escalate = True
            reasons.append(f"conformal_violation({entropy:.2f}>{self._conformal_threshold:.2f})")
        return FeedbackSignal(
            entropy=entropy, binary_entropy=binary_entropy,
            mean_logprob=mean_lp, min_logprob=min_lp,
            token_count=len(token_logprobs),
            low_confidence_tokens=low_conf_count,
            low_confidence_ratio=low_conf_ratio,
            should_escalate=should_escalate,
            escalation_reason="; ".join(reasons) if reasons else "ok",
        )

    def calibrate_conformal(self, scores: List[float], labels: List[int]):
        self._calibration_scores = scores
        self._calibration_labels = labels
        n = len(scores)
        if n < 10: return
        alpha = self.conformal_alpha
        k = min(int(math.ceil((1-alpha)*(n+1))), n)
        sorted_scores = sorted(scores)
        self._conformal_threshold = sorted_scores[min(k-1, n-1)]

    def cascade(self, request: str, initial_tier: int,
                cheap_logprobs: List[float],
                cheap_response: str,
                strong_response: str = "",
                task_type: str = "unknown",
                task_floor: int = 1) -> CascadeResult:
        signal = self.analyze_output(cheap_logprobs, response_text=cheap_response,
                                     task_type=task_type, current_tier=initial_tier)
        escalated = signal.should_escalate
        final_tier = initial_tier
        final_response = cheap_response
        if escalated and initial_tier < 5:
            floor = max(task_floor, self.task_floors.get(task_type, 1))
            final_tier = min(initial_tier + 1, 5)
            final_tier = max(final_tier, floor)
            final_response = strong_response if strong_response else cheap_response
            self.escalation_stats["escalated"] += 1
            self.escalation_stats["wasted_cost"] += self.tier_costs.get(final_tier, 1.0)
        else:
            self.escalation_stats["stayed"] += 1
            self.escalation_stats["saved_cost"] += self.tier_costs.get(initial_tier+1, 1.0) - self.tier_costs.get(initial_tier, 0.15)
        total_cost = self.tier_costs.get(initial_tier, 0.15)
        if escalated:
            total_cost += self.tier_costs.get(final_tier, 1.0)
        return CascadeResult(
            initial_tier=initial_tier,
            final_tier=final_tier,
            initial_response=cheap_response,
            final_response=final_response,
            feedback_signal=signal,
            escalated=escalated,
            total_cost=total_cost,
            confidence=math.exp(signal.mean_logprob) if signal.mean_logprob > -99 else 0.0,
        )