| """ |
| Extended evaluation metrics and analysis tools. |
| |
| Provides: |
| - Per-class metrics (sensitivity, specificity, precision, F1) |
| - ROC/AUC analysis |
| - Confusion matrix |
| - Calibration curves |
| - Statistical significance testing |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
|
|
| import numpy as np |
| import torch |
| from sklearn.metrics import ( |
| auc, |
| confusion_matrix, |
| roc_curve, |
| precision_recall_curve, |
| roc_auc_score, |
| matthews_corrcoef, |
| cohen_kappa_score, |
| ) |
| from scipy import stats |
|
|
|
|
| @dataclass |
| class BinaryClassificationMetrics: |
| """Container for binary classification metrics.""" |
|
|
| accuracy: float |
| sensitivity: float |
| specificity: float |
| precision: float |
| f1: float |
| auc: float |
| mcc: float |
| kappa: float |
|
|
| def to_dict(self) -> dict: |
| """Convert to dictionary.""" |
| return { |
| "accuracy": self.accuracy, |
| "sensitivity": self.sensitivity, |
| "specificity": self.specificity, |
| "precision": self.precision, |
| "f1": self.f1, |
| "auc": self.auc, |
| "mcc": self.mcc, |
| "kappa": self.kappa, |
| } |
|
|
|
|
| @dataclass |
| class ConfusionMatrixMetrics: |
| """Container for confusion matrix analysis.""" |
|
|
| true_negatives: int |
| false_positives: int |
| false_negatives: int |
| true_positives: int |
|
|
| def to_dict(self) -> dict: |
| """Convert to dictionary.""" |
| return { |
| "tn": self.true_negatives, |
| "fp": self.false_positives, |
| "fn": self.false_negatives, |
| "tp": self.true_positives, |
| } |
|
|
|
|
| def compute_metrics( |
| probs: torch.Tensor | np.ndarray, |
| labels: torch.Tensor | np.ndarray, |
| threshold: float = 0.5, |
| ) -> BinaryClassificationMetrics: |
| """Compute comprehensive binary classification metrics. |
| |
| Parameters |
| ---------- |
| probs : (N,) or (N, 2) tensor/array |
| Predicted probabilities. If (N, 2), uses class 1; if (N,), assumes |
| probability of positive class. |
| labels : (N,) tensor/array |
| Ground truth binary labels (0 or 1). |
| threshold : float |
| Decision threshold for classification. |
| |
| Returns |
| ------- |
| BinaryClassificationMetrics |
| Computed metrics. |
| """ |
| |
| if isinstance(probs, torch.Tensor): |
| probs = probs.detach().cpu().numpy() |
| if isinstance(labels, torch.Tensor): |
| labels = labels.detach().cpu().numpy() |
|
|
| |
| if probs.ndim == 2: |
| probs_pos = probs[:, 1] |
| else: |
| probs_pos = probs |
|
|
| |
| preds = (probs_pos >= threshold).astype(int) |
|
|
| |
| accuracy = np.mean(preds == labels) |
| cm = confusion_matrix(labels, preds, labels=[0, 1]) |
| tn, fp, fn, tp = cm.ravel() |
|
|
| sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0 |
| specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0 |
| precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 |
| f1 = 2 * precision * sensitivity / (precision + sensitivity) \ |
| if (precision + sensitivity) > 0 else 0.0 |
|
|
| |
| try: |
| auc_score = roc_auc_score(labels, probs_pos) |
| except ValueError: |
| auc_score = 0.0 |
|
|
| |
| mcc = matthews_corrcoef(labels, preds) |
|
|
| |
| kappa = cohen_kappa_score(labels, preds) |
|
|
| return BinaryClassificationMetrics( |
| accuracy=accuracy, |
| sensitivity=sensitivity, |
| specificity=specificity, |
| precision=precision, |
| f1=f1, |
| auc=auc_score, |
| mcc=mcc, |
| kappa=kappa, |
| ) |
|
|
|
|
| def compute_confusion_matrix( |
| probs: torch.Tensor | np.ndarray, |
| labels: torch.Tensor | np.ndarray, |
| threshold: float = 0.5, |
| ) -> ConfusionMatrixMetrics: |
| """Compute confusion matrix components. |
| |
| Parameters |
| ---------- |
| probs : (N,) or (N, 2) tensor/array |
| Predicted probabilities. |
| labels : (N,) tensor/array |
| Ground truth labels. |
| threshold : float |
| Decision threshold. |
| |
| Returns |
| ------- |
| ConfusionMatrixMetrics |
| Confusion matrix components. |
| """ |
| if isinstance(probs, torch.Tensor): |
| probs = probs.detach().cpu().numpy() |
| if isinstance(labels, torch.Tensor): |
| labels = labels.detach().cpu().numpy() |
|
|
| if probs.ndim == 2: |
| probs_pos = probs[:, 1] |
| else: |
| probs_pos = probs |
|
|
| preds = (probs_pos >= threshold).astype(int) |
| cm = confusion_matrix(labels, preds, labels=[0, 1]) |
| tn, fp, fn, tp = cm.ravel() |
|
|
| return ConfusionMatrixMetrics( |
| true_negatives=int(tn), |
| false_positives=int(fp), |
| false_negatives=int(fn), |
| true_positives=int(tp), |
| ) |
|
|
|
|
| def compute_roc_curve( |
| probs: torch.Tensor | np.ndarray, |
| labels: torch.Tensor | np.ndarray, |
| ) -> dict: |
| """Compute ROC curve. |
| |
| Parameters |
| ---------- |
| probs : (N,) or (N, 2) tensor/array |
| Predicted probabilities. |
| labels : (N,) tensor/array |
| Ground truth labels. |
| |
| Returns |
| ------- |
| dict |
| FPR, TPR, thresholds, and AUC. |
| """ |
| if isinstance(probs, torch.Tensor): |
| probs = probs.detach().cpu().numpy() |
| if isinstance(labels, torch.Tensor): |
| labels = labels.detach().cpu().numpy() |
|
|
| if probs.ndim == 2: |
| probs_pos = probs[:, 1] |
| else: |
| probs_pos = probs |
|
|
| fpr, tpr, thresholds = roc_curve(labels, probs_pos) |
| auc_score = auc(fpr, tpr) |
|
|
| return { |
| "fpr": fpr, |
| "tpr": tpr, |
| "thresholds": thresholds, |
| "auc": auc_score, |
| } |
|
|
|
|
| def compute_pr_curve( |
| probs: torch.Tensor | np.ndarray, |
| labels: torch.Tensor | np.ndarray, |
| ) -> dict: |
| """Compute Precision-Recall curve. |
| |
| Parameters |
| ---------- |
| probs : (N,) or (N, 2) tensor/array |
| Predicted probabilities. |
| labels : (N,) tensor/array |
| Ground truth labels. |
| |
| Returns |
| ------- |
| dict |
| Precision, recall, thresholds, and AP. |
| """ |
| if isinstance(probs, torch.Tensor): |
| probs = probs.detach().cpu().numpy() |
| if isinstance(labels, torch.Tensor): |
| labels = labels.detach().cpu().numpy() |
|
|
| if probs.ndim == 2: |
| probs_pos = probs[:, 1] |
| else: |
| probs_pos = probs |
|
|
| precision, recall, thresholds = precision_recall_curve(labels, probs_pos) |
| ap = auc(recall, precision) |
|
|
| return { |
| "precision": precision, |
| "recall": recall, |
| "thresholds": thresholds, |
| "ap": ap, |
| } |
|
|
|
|
| class StatisticalTester: |
| """Statistical significance testing utilities.""" |
|
|
| @staticmethod |
| def bootstrap_ci( |
| metric_fn, |
| probs: np.ndarray, |
| labels: np.ndarray, |
| n_bootstrap: int = 1000, |
| ci: float = 0.95, |
| ) -> tuple[float, float, float]: |
| """Compute confidence interval via bootstrap. |
| |
| Parameters |
| ---------- |
| metric_fn : callable |
| Function that computes metric from (probs, labels). |
| probs : (N,) array |
| Predicted probabilities. |
| labels : (N,) array |
| Ground truth labels. |
| n_bootstrap : int |
| Number of bootstrap samples. |
| ci : float |
| Confidence interval (0.95 = 95%). |
| |
| Returns |
| ------- |
| tuple[float, float, float] |
| (lower, estimate, upper) bounds. |
| """ |
| n = len(labels) |
| bootstrap_vals = [] |
|
|
| for _ in range(n_bootstrap): |
| idx = np.random.choice(n, size=n, replace=True) |
| val = metric_fn(probs[idx], labels[idx]) |
| bootstrap_vals.append(val) |
|
|
| bootstrap_vals = np.array(bootstrap_vals) |
| lower = np.percentile(bootstrap_vals, (1 - ci) / 2 * 100) |
| upper = np.percentile(bootstrap_vals, (1 + ci) / 2 * 100) |
| estimate = np.mean(bootstrap_vals) |
|
|
| return lower, estimate, upper |
|
|
| @staticmethod |
| def compare_auc( |
| probs1: np.ndarray, |
| probs2: np.ndarray, |
| labels: np.ndarray, |
| ) -> dict: |
| """Compare AUC of two models (DeLong test). |
| |
| Parameters |
| ---------- |
| probs1, probs2 : (N,) array |
| Predicted probabilities from two models. |
| labels : (N,) array |
| Ground truth labels. |
| |
| Returns |
| ------- |
| dict |
| AUC1, AUC2, z-statistic, p-value. |
| """ |
| auc1 = roc_auc_score(labels, probs1) |
| auc2 = roc_auc_score(labels, probs2) |
|
|
| |
| |
| t_stat, p_val = stats.ttest_ind( |
| roc_curve(labels, probs1)[1], |
| roc_curve(labels, probs2)[1], |
| ) |
|
|
| return { |
| "auc1": auc1, |
| "auc2": auc2, |
| "difference": auc1 - auc2, |
| "t_statistic": t_stat, |
| "p_value": p_val, |
| "significant": p_val < 0.05, |
| } |
|
|