BrainConnect-ASD / brain_gcn /utils /evaluation.py
Yatsuiii's picture
Upload folder using huggingface_hub
16d6869 verified
"""
Extended evaluation metrics and analysis tools.
Provides:
- Per-class metrics (sensitivity, specificity, precision, F1)
- ROC/AUC analysis
- Confusion matrix
- Calibration curves
- Statistical significance testing
"""
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
import torch
from sklearn.metrics import (
auc,
confusion_matrix,
roc_curve,
precision_recall_curve,
roc_auc_score,
matthews_corrcoef,
cohen_kappa_score,
)
from scipy import stats
@dataclass
class BinaryClassificationMetrics:
"""Container for binary classification metrics."""
accuracy: float
sensitivity: float # ASD recall (TP / (TP + FN))
specificity: float # TD recall (TN / (TN + FP))
precision: float # ASD precision (TP / (TP + FP))
f1: float # ASD F1
auc: float # ROC AUC
mcc: float # Matthews Correlation Coefficient
kappa: float # Cohen's Kappa
def to_dict(self) -> dict:
"""Convert to dictionary."""
return {
"accuracy": self.accuracy,
"sensitivity": self.sensitivity,
"specificity": self.specificity,
"precision": self.precision,
"f1": self.f1,
"auc": self.auc,
"mcc": self.mcc,
"kappa": self.kappa,
}
@dataclass
class ConfusionMatrixMetrics:
"""Container for confusion matrix analysis."""
true_negatives: int
false_positives: int
false_negatives: int
true_positives: int
def to_dict(self) -> dict:
"""Convert to dictionary."""
return {
"tn": self.true_negatives,
"fp": self.false_positives,
"fn": self.false_negatives,
"tp": self.true_positives,
}
def compute_metrics(
probs: torch.Tensor | np.ndarray,
labels: torch.Tensor | np.ndarray,
threshold: float = 0.5,
) -> BinaryClassificationMetrics:
"""Compute comprehensive binary classification metrics.
Parameters
----------
probs : (N,) or (N, 2) tensor/array
Predicted probabilities. If (N, 2), uses class 1; if (N,), assumes
probability of positive class.
labels : (N,) tensor/array
Ground truth binary labels (0 or 1).
threshold : float
Decision threshold for classification.
Returns
-------
BinaryClassificationMetrics
Computed metrics.
"""
# Convert to numpy
if isinstance(probs, torch.Tensor):
probs = probs.detach().cpu().numpy()
if isinstance(labels, torch.Tensor):
labels = labels.detach().cpu().numpy()
# Extract probability of positive class
if probs.ndim == 2:
probs_pos = probs[:, 1]
else:
probs_pos = probs
# Hard predictions
preds = (probs_pos >= threshold).astype(int)
# Basic metrics
accuracy = np.mean(preds == labels)
cm = confusion_matrix(labels, preds, labels=[0, 1])
tn, fp, fn, tp = cm.ravel()
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
f1 = 2 * precision * sensitivity / (precision + sensitivity) \
if (precision + sensitivity) > 0 else 0.0
# AUC
try:
auc_score = roc_auc_score(labels, probs_pos)
except ValueError:
auc_score = 0.0
# Matthews correlation coefficient
mcc = matthews_corrcoef(labels, preds)
# Cohen's Kappa
kappa = cohen_kappa_score(labels, preds)
return BinaryClassificationMetrics(
accuracy=accuracy,
sensitivity=sensitivity,
specificity=specificity,
precision=precision,
f1=f1,
auc=auc_score,
mcc=mcc,
kappa=kappa,
)
def compute_confusion_matrix(
probs: torch.Tensor | np.ndarray,
labels: torch.Tensor | np.ndarray,
threshold: float = 0.5,
) -> ConfusionMatrixMetrics:
"""Compute confusion matrix components.
Parameters
----------
probs : (N,) or (N, 2) tensor/array
Predicted probabilities.
labels : (N,) tensor/array
Ground truth labels.
threshold : float
Decision threshold.
Returns
-------
ConfusionMatrixMetrics
Confusion matrix components.
"""
if isinstance(probs, torch.Tensor):
probs = probs.detach().cpu().numpy()
if isinstance(labels, torch.Tensor):
labels = labels.detach().cpu().numpy()
if probs.ndim == 2:
probs_pos = probs[:, 1]
else:
probs_pos = probs
preds = (probs_pos >= threshold).astype(int)
cm = confusion_matrix(labels, preds, labels=[0, 1])
tn, fp, fn, tp = cm.ravel()
return ConfusionMatrixMetrics(
true_negatives=int(tn),
false_positives=int(fp),
false_negatives=int(fn),
true_positives=int(tp),
)
def compute_roc_curve(
probs: torch.Tensor | np.ndarray,
labels: torch.Tensor | np.ndarray,
) -> dict:
"""Compute ROC curve.
Parameters
----------
probs : (N,) or (N, 2) tensor/array
Predicted probabilities.
labels : (N,) tensor/array
Ground truth labels.
Returns
-------
dict
FPR, TPR, thresholds, and AUC.
"""
if isinstance(probs, torch.Tensor):
probs = probs.detach().cpu().numpy()
if isinstance(labels, torch.Tensor):
labels = labels.detach().cpu().numpy()
if probs.ndim == 2:
probs_pos = probs[:, 1]
else:
probs_pos = probs
fpr, tpr, thresholds = roc_curve(labels, probs_pos)
auc_score = auc(fpr, tpr)
return {
"fpr": fpr,
"tpr": tpr,
"thresholds": thresholds,
"auc": auc_score,
}
def compute_pr_curve(
probs: torch.Tensor | np.ndarray,
labels: torch.Tensor | np.ndarray,
) -> dict:
"""Compute Precision-Recall curve.
Parameters
----------
probs : (N,) or (N, 2) tensor/array
Predicted probabilities.
labels : (N,) tensor/array
Ground truth labels.
Returns
-------
dict
Precision, recall, thresholds, and AP.
"""
if isinstance(probs, torch.Tensor):
probs = probs.detach().cpu().numpy()
if isinstance(labels, torch.Tensor):
labels = labels.detach().cpu().numpy()
if probs.ndim == 2:
probs_pos = probs[:, 1]
else:
probs_pos = probs
precision, recall, thresholds = precision_recall_curve(labels, probs_pos)
ap = auc(recall, precision)
return {
"precision": precision,
"recall": recall,
"thresholds": thresholds,
"ap": ap,
}
class StatisticalTester:
"""Statistical significance testing utilities."""
@staticmethod
def bootstrap_ci(
metric_fn,
probs: np.ndarray,
labels: np.ndarray,
n_bootstrap: int = 1000,
ci: float = 0.95,
) -> tuple[float, float, float]:
"""Compute confidence interval via bootstrap.
Parameters
----------
metric_fn : callable
Function that computes metric from (probs, labels).
probs : (N,) array
Predicted probabilities.
labels : (N,) array
Ground truth labels.
n_bootstrap : int
Number of bootstrap samples.
ci : float
Confidence interval (0.95 = 95%).
Returns
-------
tuple[float, float, float]
(lower, estimate, upper) bounds.
"""
n = len(labels)
bootstrap_vals = []
for _ in range(n_bootstrap):
idx = np.random.choice(n, size=n, replace=True)
val = metric_fn(probs[idx], labels[idx])
bootstrap_vals.append(val)
bootstrap_vals = np.array(bootstrap_vals)
lower = np.percentile(bootstrap_vals, (1 - ci) / 2 * 100)
upper = np.percentile(bootstrap_vals, (1 + ci) / 2 * 100)
estimate = np.mean(bootstrap_vals)
return lower, estimate, upper
@staticmethod
def compare_auc(
probs1: np.ndarray,
probs2: np.ndarray,
labels: np.ndarray,
) -> dict:
"""Compare AUC of two models (DeLong test).
Parameters
----------
probs1, probs2 : (N,) array
Predicted probabilities from two models.
labels : (N,) array
Ground truth labels.
Returns
-------
dict
AUC1, AUC2, z-statistic, p-value.
"""
auc1 = roc_auc_score(labels, probs1)
auc2 = roc_auc_score(labels, probs2)
# Simplified comparison (two-sample t-test on AUC)
# For proper DeLong test, see sklearn-labs or hand implementation
t_stat, p_val = stats.ttest_ind(
roc_curve(labels, probs1)[1],
roc_curve(labels, probs2)[1],
)
return {
"auc1": auc1,
"auc2": auc2,
"difference": auc1 - auc2,
"t_statistic": t_stat,
"p_value": p_val,
"significant": p_val < 0.05,
}