BrainConnect-ASD / brain_gcn /ablation.py
Yatsuiii's picture
Upload folder using huggingface_hub
16d6869 verified
"""
Ablation study framework.
Systematically removes or disables components to measure their contribution.
Examples:
- Disable DropEdge (set drop_edge_p=0)
- Disable BOLD augmentation (set bold_noise_std=0)
- Use GCN baseline vs full graph-temporal
- Population adj vs per-subject adjacency
"""
from __future__ import annotations
import argparse
import json
import logging
from copy import deepcopy
from dataclasses import dataclass
from pathlib import Path
from typing import Callable
import pytorch_lightning as pl
import torch
from brain_gcn.main import train_from_args, validate_args
log = logging.getLogger(__name__)
@dataclass
class AblationComponent:
"""Single component to ablate."""
name: str
description: str
modify_fn: Callable[[argparse.Namespace], argparse.Namespace]
enabled: bool = True
class AblationStudy:
"""Framework for systematic ablation studies."""
# Predefined components
COMPONENTS = {
"drop_edge": AblationComponent(
name="drop_edge",
description="DropEdge regularization in graph convolution",
modify_fn=lambda args: (setattr(args, "drop_edge_p", 0.0), args)[1],
),
"bold_noise": AblationComponent(
name="bold_noise",
description="BOLD signal augmentation during training",
modify_fn=lambda args: (setattr(args, "bold_noise_std", 0.0), args)[1],
),
"graph": AblationComponent(
name="graph",
description="Graph structure (use GRU-only baseline)",
modify_fn=lambda args: (setattr(args, "model_name", "gru"), args)[1],
),
"population_adj": AblationComponent(
name="population_adj",
description="Population adjacency matrix",
modify_fn=lambda args: (setattr(args, "use_population_adj", False), args)[1],
),
"layer_norm": AblationComponent(
name="layer_norm",
description="Layer normalization in graph convolutions",
modify_fn=lambda args: (setattr(args, "use_layer_norm", False), args)[1],
),
}
def __init__(
self,
base_args: argparse.Namespace,
components: list[str] | None = None,
output_dir: str | Path | None = None,
):
"""Initialize ablation study.
Parameters
----------
base_args : argparse.Namespace
Base training arguments (full model).
components : list[str], optional
List of component names to ablate. If None, ablates all.
output_dir : str or Path, optional
Directory to save results.
"""
self.base_args = deepcopy(base_args)
self.output_dir = Path(output_dir) if output_dir else Path("ablations")
self.output_dir.mkdir(parents=True, exist_ok=True)
if components is None:
self.component_names = list(self.COMPONENTS.keys())
else:
self.component_names = components
self.components = [
self.COMPONENTS[name] for name in self.component_names
if name in self.COMPONENTS
]
self.results: dict[str, dict] = {}
def run(self) -> dict[str, dict]:
"""Run full ablation study.
Returns
-------
dict[str, dict]
Results keyed by component name.
"""
# Train full model first
log.info("Training full model (baseline)")
pl.seed_everything(self.base_args.seed, workers=True)
try:
trainer, _, _ = train_from_args(self.base_args)
baseline_metrics = {
key: value.item() if isinstance(value, torch.Tensor) else value
for key, value in trainer.callback_metrics.items()
if key.startswith(("test_",))
}
except Exception as e:
log.error(f"Baseline training failed: {e}")
baseline_metrics = {}
self.results["baseline"] = baseline_metrics
# Ablate each component
for component in self.components:
log.info(f"Ablating: {component.name} ({component.description})")
ablated_args = deepcopy(self.base_args)
ablated_args = component.modify_fn(ablated_args)
try:
validate_args(ablated_args)
except ValueError as e:
log.warning(f"Ablation {component.name} skipped: {e}")
continue
pl.seed_everything(self.base_args.seed, workers=True)
try:
trainer, _, _ = train_from_args(ablated_args)
ablated_metrics = {
key: value.item() if isinstance(value, torch.Tensor) else value
for key, value in trainer.callback_metrics.items()
if key.startswith(("test_",))
}
except Exception as e:
log.error(f"Ablation {component.name} failed: {e}")
ablated_metrics = {}
self.results[component.name] = ablated_metrics
# Compute deltas
self._compute_deltas(baseline_metrics)
return self.results
def _compute_deltas(self, baseline: dict) -> None:
"""Compute metric changes from baseline."""
deltas = {}
for component_name, ablated_metrics in self.results.items():
if component_name == "baseline":
deltas[component_name] = {}
continue
delta = {}
for key, ablated_val in ablated_metrics.items():
baseline_val = baseline.get(key, None)
if baseline_val is not None and isinstance(ablated_val, (int, float)):
delta[key] = ablated_val - baseline_val
else:
delta[key] = None
deltas[component_name] = delta
self.deltas = deltas
def save_results(self) -> None:
"""Save results to JSON."""
results_file = self.output_dir / "ablation_results.json"
# Convert torch tensors to serializable format
serializable = {}
for key, metrics in self.results.items():
serializable[key] = {
k: float(v) if isinstance(v, (int, float)) else str(v)
for k, v in metrics.items()
}
deltas_serializable = {}
for key, deltas in self.deltas.items():
deltas_serializable[key] = {
k: float(v) if v is None or isinstance(v, (int, float)) else str(v)
for k, v in deltas.items()
}
output = {
"results": serializable,
"deltas": deltas_serializable,
"components": [c.name for c in self.components],
}
with open(results_file, "w") as f:
json.dump(output, f, indent=2)
log.info(f"Ablation results saved to {results_file}")
def summary(self) -> str:
"""Pretty-print summary."""
lines = ["=" * 70]
lines.append("ABLATION STUDY SUMMARY")
lines.append("=" * 70)
# Baseline
if "baseline" in self.results:
lines.append("\nBaseline (Full Model):")
for key, val in sorted(self.results["baseline"].items()):
if isinstance(val, float):
lines.append(f" {key}: {val:.4f}")
else:
lines.append(f" {key}: {val}")
# Ablations
lines.append("\nAblation Impact (Δ from Baseline):")
lines.append("-" * 70)
for component_name in self.component_names:
if component_name in self.deltas:
delta = self.deltas[component_name]
lines.append(f"\n{component_name}:")
for key, val in sorted(delta.items()):
if isinstance(val, float):
sign = "+" if val >= 0 else "-"
lines.append(f" {key}: {sign}{abs(val):.4f}")
lines.append("\n" + "=" * 70)
return "\n".join(lines)
def add_ablation_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Add ablation-specific arguments."""
parser.add_argument(
"--ablation_components",
nargs="+",
choices=list(AblationStudy.COMPONENTS.keys()),
help="Components to ablate. If not specified, ablates all.",
)
parser.add_argument(
"--ablation_output_dir",
type=str,
default="results/ablations",
help="Output directory for ablation results.",
)
return parser