narcolepticchicken commited on
Commit
91c6791
·
verified ·
1 Parent(s): 911104d

Upload aco/conformal.py

Browse files
Files changed (1) hide show
  1. aco/conformal.py +112 -0
aco/conformal.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Conformal calibration for escalation thresholds.
2
+
3
+ Implements RouteNLP-style conformal risk control:
4
+ P(failure AND no escalation) <= alpha
5
+
6
+ Method:
7
+ 1. Compute nonconformity scores from calibrated P(success)
8
+ 2. Find conformal quantile threshold
9
+ 3. Guarantee coverage under exchangeability
10
+ """
11
+ import numpy as np
12
+ from typing import Dict, List, Optional, Tuple
13
+
14
+
15
+ class ConformalEscalationCalibrator:
16
+ """Calibrate escalation thresholds with distribution-free coverage guarantees.
17
+
18
+ Based on RouteNLP (arxiv 2604.23577) and Conformal Risk Control
19
+ (Angelopoulos et al., arxiv 2208.02814).
20
+
21
+ Guarantee: P(y=fail AND no_escalation) <= alpha
22
+ """
23
+
24
+ def __init__(self, alpha: float = 0.05):
25
+ self.alpha = alpha
26
+ self.thresholds: Dict[int, float] = {}
27
+ self.calibrated = False
28
+
29
+ def calibrate(
30
+ self,
31
+ psuccess: Dict[int, np.ndarray],
32
+ outcomes: Dict[int, np.ndarray],
33
+ ) -> Dict[int, float]:
34
+ """Calibrate per-tier escalation thresholds.
35
+
36
+ Args:
37
+ psuccess: {tier: array of calibrated P(success)}
38
+ outcomes: {tier: array of binary outcomes (1=success, 0=fail)}
39
+
40
+ Returns:
41
+ {tier: conformal_threshold}
42
+ """
43
+ for tier in sorted(psuccess.keys()):
44
+ p = psuccess[tier]
45
+ y = outcomes[tier]
46
+ n = len(y)
47
+
48
+ # Nonconformity: 1 - P(success) for failed examples
49
+ # These are the scores we want to bound
50
+ failed_mask = y == 0
51
+ if failed_mask.sum() == 0:
52
+ self.thresholds[tier] = 1.0
53
+ continue
54
+
55
+ # Conformal risk control: find threshold lam such that
56
+ # R_hat(lam) <= alpha, where R_hat = (1/n) * sum 1[p >= lam AND y=0]
57
+ # This means: fraction of examples with P(success) >= lam that actually failed <= alpha
58
+
59
+ # Sort P(success) values
60
+ sorted_p = np.sort(p[failed_mask])
61
+
62
+ # Conformal quantile: ceiling of (1-alpha)*(n+1)/n
63
+ q = int(np.ceil((1 - self.alpha) * (n + 1) / n))
64
+ q = min(q, len(sorted_p))
65
+
66
+ # Threshold: if P(success) < this, escalate
67
+ # We want the (1-alpha) quantile of failure nonconformity scores
68
+ threshold = sorted_p[q - 1] if q > 0 else 0.0
69
+ self.thresholds[tier] = float(threshold)
70
+
71
+ self.calibrated = True
72
+ return self.thresholds
73
+
74
+ def should_escalate(self, tier: int, psuccess: float) -> bool:
75
+ """Decide whether to escalate from this tier.
76
+
77
+ Returns True if P(success) is below conformal threshold,
78
+ meaning we can't guarantee success at this tier with 1-alpha coverage.
79
+ """
80
+ if not self.calibrated:
81
+ return psuccess < 0.65 # fallback to heuristic
82
+ threshold = self.thresholds.get(tier, 0.65)
83
+ return psuccess < threshold
84
+
85
+ def coverage_check(
86
+ self,
87
+ psuccess: Dict[int, np.ndarray],
88
+ outcomes: Dict[int, np.ndarray],
89
+ ) -> Dict[int, Dict[str, float]]:
90
+ """Verify conformal coverage on test data."""
91
+ results = {}
92
+ for tier in sorted(psuccess.keys()):
93
+ p = psuccess[tier]
94
+ y = outcomes[tier]
95
+ threshold = self.thresholds.get(tier, 0.65)
96
+
97
+ no_escalate = p >= threshold
98
+ failed_no_escalate = (y == 0) & no_escalate
99
+ n_no_escalate = no_escalate.sum()
100
+
101
+ violation_rate = failed_no_escalate.sum() / max(n_no_escalate, 1)
102
+ escalation_rate = 1 - no_escalate.mean()
103
+
104
+ results[tier] = {
105
+ "violation_rate": float(violation_rate),
106
+ "escalation_rate": float(escalation_rate),
107
+ "threshold": float(threshold),
108
+ "n_no_escalate": int(n_no_escalate),
109
+ "n_violations": int(failed_no_escalate.sum()),
110
+ "covered": violation_rate <= self.alpha,
111
+ }
112
+ return results