File size: 6,997 Bytes
dc71cad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""
uncertainty/temperature_scaling.py
────────────────────────────────────
Temperature scaling for DeBERTa classifier logits.

After fine-tuning, DeBERTa's raw logits are often overconfident.
Temperature scaling is the simplest, most effective calibration method
(Guo et al., 2017 β€” "On Calibration of Modern Neural Networks").

Method:
    calibrated_prob = softmax(logits / T)
    T is learned by minimising NLL on a held-out calibration set.

For our use case, T is fit on the SWE-bench validation split:
    - True positives: (issue, gold_file) pairs β†’ label=1
    - True negatives: (issue, non-gold_file) pairs β†’ label=0
    - T is scalar, so only one parameter to fit (no overfitting risk)

After calibration:
    - ECE (Expected Calibration Error) < 0.05 target
    - Reliability diagram should be close to diagonal

Integration:
    DeBERTa ranker outputs raw logits β†’ temperature_scale() β†’ calibrated prob
    Calibrated prob replaces raw relevance_score in RankedFile
"""
from __future__ import annotations

import json
import logging
from pathlib import Path
from typing import Optional

import numpy as np

logger = logging.getLogger(__name__)


class TemperatureScaler:
    """
    Learns a single temperature parameter T by minimising NLL on validation data.

    T > 1: softer probabilities (reduces overconfidence)
    T < 1: harder probabilities (makes model more confident)
    T = 1: uncalibrated (no change)
    """

    def __init__(self, T: float = 1.0):
        self.T = T
        self._fitted = False

    def scale(self, logits: np.ndarray) -> np.ndarray:
        """
        Apply temperature scaling and return calibrated probabilities.

        Args:
            logits: shape (n, 2) β€” binary classification logits
        Returns:
            probs: shape (n, 2) β€” calibrated probabilities
        """
        scaled = logits / self.T
        # Numerically stable softmax
        shifted = scaled - scaled.max(axis=1, keepdims=True)
        exp = np.exp(shifted)
        return exp / exp.sum(axis=1, keepdims=True)

    def scale_score(self, logit_positive: float) -> float:
        """Scale a single logit for the positive class β†’ calibrated probability."""
        # Convert single value to 2-class logit pair
        logits = np.array([[0.0, logit_positive]])
        probs = self.scale(logits)
        return float(probs[0, 1])

    def fit(
        self,
        logits: np.ndarray,   # shape (n, 2)
        labels: np.ndarray,   # shape (n,) β€” 0 or 1
        n_iter: int = 100,
        lr: float = 0.01,
        tol: float = 1e-6,
    ) -> dict:
        """
        Fit temperature by minimising NLL using gradient descent.

        Returns:
            stats dict: {T_before, T_after, nll_before, nll_after, ece_before, ece_after}
        """
        T_init = self.T
        nll_before = self._nll(logits, labels, T_init)
        ece_before = self._ece(logits, labels, T_init)

        # Simple gradient descent over scalar T
        T = float(T_init)
        for i in range(n_iter):
            grad = self._nll_gradient(logits, labels, T)
            T_new = T - lr * grad
            T_new = max(T_new, 0.01)  # T must be positive
            if abs(T_new - T) < tol:
                logger.debug("Temperature scaling converged at iteration %d", i)
                break
            T = T_new

        self.T = T
        self._fitted = True

        nll_after = self._nll(logits, labels, T)
        ece_after = self._ece(logits, labels, T)

        logger.info(
            "Temperature scaling: T=%.3f→%.3f | NLL: %.4f→%.4f | ECE: %.4f→%.4f",
            T_init, T, nll_before, nll_after, ece_before, ece_after
        )
        return {
            "T_before": T_init, "T_after": T,
            "nll_before": nll_before, "nll_after": nll_after,
            "ece_before": ece_before, "ece_after": ece_after,
            "fitted": True,
        }

    def _nll(self, logits: np.ndarray, labels: np.ndarray, T: float) -> float:
        """Negative log-likelihood at temperature T."""
        probs = self._softmax(logits / T)
        eps = 1e-8
        correct_probs = probs[np.arange(len(labels)), labels.astype(int)]
        return float(-np.mean(np.log(correct_probs + eps)))

    def _nll_gradient(self, logits: np.ndarray, labels: np.ndarray, T: float) -> float:
        """Numerical gradient of NLL w.r.t. T."""
        eps = 1e-4
        return (self._nll(logits, labels, T + eps) - self._nll(logits, labels, T - eps)) / (2 * eps)

    def _ece(self, logits: np.ndarray, labels: np.ndarray, T: float, n_bins: int = 10) -> float:
        """Expected Calibration Error (ECE)."""
        probs = self._softmax(logits / T)
        max_probs = probs.max(axis=1)
        predictions = probs.argmax(axis=1)
        correct = (predictions == labels.astype(int))

        bins = np.linspace(0, 1, n_bins + 1)
        ece = 0.0
        for i in range(n_bins):
            mask = (max_probs > bins[i]) & (max_probs <= bins[i + 1])
            if mask.sum() == 0:
                continue
            acc = correct[mask].mean()
            conf = max_probs[mask].mean()
            ece += mask.mean() * abs(acc - conf)
        return float(ece)

    @staticmethod
    def _softmax(logits: np.ndarray) -> np.ndarray:
        shifted = logits - logits.max(axis=1, keepdims=True)
        exp = np.exp(shifted)
        return exp / exp.sum(axis=1, keepdims=True)

    def save(self, path: Path) -> None:
        Path(path).parent.mkdir(parents=True, exist_ok=True)
        Path(path).write_text(json.dumps({"T": self.T, "fitted": self._fitted}))
        logger.info("Temperature scaler saved: T=%.4f β†’ %s", self.T, path)

    @classmethod
    def load(cls, path: Path) -> "TemperatureScaler":
        data = json.loads(Path(path).read_text())
        ts = cls(T=data["T"])
        ts._fitted = data.get("fitted", False)
        logger.info("Temperature scaler loaded: T=%.4f from %s", ts.T, path)
        return ts


# ── ECE visualisation helper ──────────────────────────────────────────────────

def reliability_diagram_data(
    probs: np.ndarray,        # shape (n,) β€” predicted positive probabilities
    labels: np.ndarray,       # shape (n,) β€” true binary labels
    n_bins: int = 10,
) -> list[dict]:
    """
    Compute data for a reliability diagram.

    Returns list of bins:
        [{"confidence": 0.15, "accuracy": 0.12, "count": 45}, ...]
    """
    bins = np.linspace(0, 1, n_bins + 1)
    result = []
    for i in range(n_bins):
        mask = (probs >= bins[i]) & (probs < bins[i + 1])
        if mask.sum() == 0:
            continue
        result.append({
            "confidence": float((bins[i] + bins[i + 1]) / 2),
            "accuracy": float(labels[mask].mean()),
            "count": int(mask.sum()),
        })
    return result