BrainConnect-ASD / brain_gcn /eval_cli.py
Yatsuiii's picture
Upload folder using huggingface_hub
16d6869 verified
raw
history blame
6.25 kB
"""
Evaluation entry point for extended metrics analysis.
Computes extended evaluation metrics, ROC curves, and statistical tests.
Usage:
python -m brain_gcn.eval_cli --checkpoint <path> --test_metrics
"""
from __future__ import annotations
import argparse
import json
import logging
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.metrics import auc
from brain_gcn.main import build_datamodule
from brain_gcn.tasks import ClassificationTask
from brain_gcn.utils.evaluation import (
compute_metrics,
compute_roc_curve,
compute_pr_curve,
compute_confusion_matrix,
StatisticalTester,
)
logging.basicConfig(level=logging.INFO)
log = logging.getLogger(__name__)
def add_eval_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Add evaluation-specific arguments."""
parser.add_argument(
"--eval_checkpoint",
type=str,
required=True,
help="Path to model checkpoint.",
)
parser.add_argument(
"--eval_output_dir",
type=str,
default="results/evaluation",
help="Output directory for evaluation results.",
)
parser.add_argument(
"--eval_plot_roc",
action="store_true",
help="Save ROC curve plot.",
)
parser.add_argument(
"--eval_plot_pr",
action="store_true",
help="Save Precision-Recall curve plot.",
)
parser.add_argument(
"--eval_bootstrap_ci",
action="store_true",
help="Compute bootstrap confidence intervals.",
)
parser.add_argument(
"--eval_ci_n_bootstrap",
type=int,
default=1000,
help="Number of bootstrap samples.",
)
return parser
def load_checkpoint(
ckpt_path: str | Path,
device: str = "cpu",
) -> ClassificationTask:
"""Load trained model from checkpoint."""
return ClassificationTask.load_from_checkpoint(ckpt_path, map_location=device)
def get_predictions(
model: ClassificationTask,
dm,
device: str = "cpu",
) -> tuple[np.ndarray, np.ndarray]:
"""Get predictions on test set."""
model.eval()
model.to(device)
all_probs = []
all_labels = []
with torch.no_grad():
for bold_windows, adj, labels in dm.test_dataloader():
logits = model(bold_windows.to(device), adj.to(device))
probs = torch.softmax(logits, dim=-1)[:, 1]
all_probs.append(probs.cpu().numpy())
all_labels.append(labels.numpy())
return np.concatenate(all_probs), np.concatenate(all_labels)
def plot_roc(
probs: np.ndarray,
labels: np.ndarray,
output_path: str | Path,
) -> None:
"""Plot and save ROC curve."""
roc_data = compute_roc_curve(probs, labels)
fpr = roc_data["fpr"]
tpr = roc_data["tpr"]
auc_score = roc_data["auc"]
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, label=f"ROC (AUC={auc_score:.4f})", linewidth=2)
plt.plot([0, 1], [0, 1], "k--", label="Random", linewidth=1)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(output_path, dpi=150)
plt.close()
log.info(f"ROC curve saved to {output_path}")
def plot_pr(
probs: np.ndarray,
labels: np.ndarray,
output_path: str | Path,
) -> None:
"""Plot and save Precision-Recall curve."""
pr_data = compute_pr_curve(probs, labels)
precision = pr_data["precision"]
recall = pr_data["recall"]
ap = pr_data["ap"]
plt.figure(figsize=(8, 6))
plt.plot(recall, precision, label=f"PR (AP={ap:.4f})", linewidth=2)
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision-Recall Curve")
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(output_path, dpi=150)
plt.close()
log.info(f"PR curve saved to {output_path}")
def main():
from brain_gcn.main import build_parser
parser = build_parser()
parser = add_eval_arguments(parser)
args = parser.parse_args()
output_dir = Path(args.eval_output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Load model and data
log.info(f"Loading checkpoint: {args.eval_checkpoint}")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = load_checkpoint(args.eval_checkpoint, device=device)
log.info("Building datamodule")
dm = build_datamodule(args)
dm.prepare_data()
dm.setup()
# Get predictions
log.info("Generating predictions on test set")
probs, labels = get_predictions(model, dm, device=device)
# Compute metrics
log.info("Computing metrics")
metrics = compute_metrics(probs, labels)
cm = compute_confusion_matrix(probs, labels)
# Print metrics
log.info("\n" + "=" * 70)
log.info("CLASSIFICATION METRICS")
log.info("=" * 70)
for key, value in metrics.to_dict().items():
log.info(f"{key:20s}: {value:.4f}")
log.info("\nConfusion Matrix:")
log.info(f" TP={cm.true_positives}, FP={cm.false_positives}")
log.info(f" FN={cm.false_negatives}, TN={cm.true_negatives}")
# Compute confidence intervals if requested
if args.eval_bootstrap_ci:
log.info(f"\nComputing {args.eval_ci_n_bootstrap} bootstrap samples")
ci_auc = StatisticalTester.bootstrap_ci(
lambda p, l: compute_metrics(p, l).auc,
probs,
labels,
n_bootstrap=args.eval_ci_n_bootstrap,
)
log.info(f"AUC 95% CI: [{ci_auc[0]:.4f}, {ci_auc[2]:.4f}]")
# Save results
results = {
"metrics": metrics.to_dict(),
"confusion_matrix": cm.to_dict(),
}
results_file = output_dir / "metrics.json"
with open(results_file, "w") as f:
json.dump(results, f, indent=2)
log.info(f"\nResults saved to {results_file}")
# Plot ROC and PR curves if requested
if args.eval_plot_roc:
roc_path = output_dir / "roc_curve.png"
plot_roc(probs, labels, roc_path)
if args.eval_plot_pr:
pr_path = output_dir / "pr_curve.png"
plot_pr(probs, labels, pr_path)
if __name__ == "__main__":
main()