| """ |
| Ablation study framework. |
| |
| Systematically removes or disables components to measure their contribution. |
| |
| Examples: |
| - Disable DropEdge (set drop_edge_p=0) |
| - Disable BOLD augmentation (set bold_noise_std=0) |
| - Use GCN baseline vs full graph-temporal |
| - Population adj vs per-subject adjacency |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import logging |
| from copy import deepcopy |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Callable |
|
|
| import pytorch_lightning as pl |
| import torch |
|
|
| from brain_gcn.main import train_from_args, validate_args |
|
|
| log = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class AblationComponent: |
| """Single component to ablate.""" |
|
|
| name: str |
| description: str |
| modify_fn: Callable[[argparse.Namespace], argparse.Namespace] |
| enabled: bool = True |
|
|
|
|
| class AblationStudy: |
| """Framework for systematic ablation studies.""" |
|
|
| |
| COMPONENTS = { |
| "drop_edge": AblationComponent( |
| name="drop_edge", |
| description="DropEdge regularization in graph convolution", |
| modify_fn=lambda args: (setattr(args, "drop_edge_p", 0.0), args)[1], |
| ), |
| "bold_noise": AblationComponent( |
| name="bold_noise", |
| description="BOLD signal augmentation during training", |
| modify_fn=lambda args: (setattr(args, "bold_noise_std", 0.0), args)[1], |
| ), |
| "graph": AblationComponent( |
| name="graph", |
| description="Graph structure (use GRU-only baseline)", |
| modify_fn=lambda args: (setattr(args, "model_name", "gru"), args)[1], |
| ), |
| "population_adj": AblationComponent( |
| name="population_adj", |
| description="Population adjacency matrix", |
| modify_fn=lambda args: (setattr(args, "use_population_adj", False), args)[1], |
| ), |
| "layer_norm": AblationComponent( |
| name="layer_norm", |
| description="Layer normalization in graph convolutions", |
| modify_fn=lambda args: (setattr(args, "use_layer_norm", False), args)[1], |
| ), |
| } |
|
|
| def __init__( |
| self, |
| base_args: argparse.Namespace, |
| components: list[str] | None = None, |
| output_dir: str | Path | None = None, |
| ): |
| """Initialize ablation study. |
| |
| Parameters |
| ---------- |
| base_args : argparse.Namespace |
| Base training arguments (full model). |
| components : list[str], optional |
| List of component names to ablate. If None, ablates all. |
| output_dir : str or Path, optional |
| Directory to save results. |
| """ |
| self.base_args = deepcopy(base_args) |
| self.output_dir = Path(output_dir) if output_dir else Path("ablations") |
| self.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| if components is None: |
| self.component_names = list(self.COMPONENTS.keys()) |
| else: |
| self.component_names = components |
|
|
| self.components = [ |
| self.COMPONENTS[name] for name in self.component_names |
| if name in self.COMPONENTS |
| ] |
|
|
| self.results: dict[str, dict] = {} |
|
|
| def run(self) -> dict[str, dict]: |
| """Run full ablation study. |
| |
| Returns |
| ------- |
| dict[str, dict] |
| Results keyed by component name. |
| """ |
| |
| log.info("Training full model (baseline)") |
| pl.seed_everything(self.base_args.seed, workers=True) |
| try: |
| trainer, _, _ = train_from_args(self.base_args) |
| baseline_metrics = { |
| key: value.item() if isinstance(value, torch.Tensor) else value |
| for key, value in trainer.callback_metrics.items() |
| if key.startswith(("test_",)) |
| } |
| except Exception as e: |
| log.error(f"Baseline training failed: {e}") |
| baseline_metrics = {} |
|
|
| self.results["baseline"] = baseline_metrics |
|
|
| |
| for component in self.components: |
| log.info(f"Ablating: {component.name} ({component.description})") |
|
|
| ablated_args = deepcopy(self.base_args) |
| ablated_args = component.modify_fn(ablated_args) |
|
|
| try: |
| validate_args(ablated_args) |
| except ValueError as e: |
| log.warning(f"Ablation {component.name} skipped: {e}") |
| continue |
|
|
| pl.seed_everything(self.base_args.seed, workers=True) |
| try: |
| trainer, _, _ = train_from_args(ablated_args) |
| ablated_metrics = { |
| key: value.item() if isinstance(value, torch.Tensor) else value |
| for key, value in trainer.callback_metrics.items() |
| if key.startswith(("test_",)) |
| } |
| except Exception as e: |
| log.error(f"Ablation {component.name} failed: {e}") |
| ablated_metrics = {} |
|
|
| self.results[component.name] = ablated_metrics |
|
|
| |
| self._compute_deltas(baseline_metrics) |
|
|
| return self.results |
|
|
| def _compute_deltas(self, baseline: dict) -> None: |
| """Compute metric changes from baseline.""" |
| deltas = {} |
|
|
| for component_name, ablated_metrics in self.results.items(): |
| if component_name == "baseline": |
| deltas[component_name] = {} |
| continue |
|
|
| delta = {} |
| for key, ablated_val in ablated_metrics.items(): |
| baseline_val = baseline.get(key, None) |
| if baseline_val is not None and isinstance(ablated_val, (int, float)): |
| delta[key] = ablated_val - baseline_val |
| else: |
| delta[key] = None |
|
|
| deltas[component_name] = delta |
|
|
| self.deltas = deltas |
|
|
| def save_results(self) -> None: |
| """Save results to JSON.""" |
| results_file = self.output_dir / "ablation_results.json" |
|
|
| |
| serializable = {} |
| for key, metrics in self.results.items(): |
| serializable[key] = { |
| k: float(v) if isinstance(v, (int, float)) else str(v) |
| for k, v in metrics.items() |
| } |
|
|
| deltas_serializable = {} |
| for key, deltas in self.deltas.items(): |
| deltas_serializable[key] = { |
| k: float(v) if v is None or isinstance(v, (int, float)) else str(v) |
| for k, v in deltas.items() |
| } |
|
|
| output = { |
| "results": serializable, |
| "deltas": deltas_serializable, |
| "components": [c.name for c in self.components], |
| } |
|
|
| with open(results_file, "w") as f: |
| json.dump(output, f, indent=2) |
|
|
| log.info(f"Ablation results saved to {results_file}") |
|
|
| def summary(self) -> str: |
| """Pretty-print summary.""" |
| lines = ["=" * 70] |
| lines.append("ABLATION STUDY SUMMARY") |
| lines.append("=" * 70) |
|
|
| |
| if "baseline" in self.results: |
| lines.append("\nBaseline (Full Model):") |
| for key, val in sorted(self.results["baseline"].items()): |
| if isinstance(val, float): |
| lines.append(f" {key}: {val:.4f}") |
| else: |
| lines.append(f" {key}: {val}") |
|
|
| |
| lines.append("\nAblation Impact (Δ from Baseline):") |
| lines.append("-" * 70) |
|
|
| for component_name in self.component_names: |
| if component_name in self.deltas: |
| delta = self.deltas[component_name] |
| lines.append(f"\n{component_name}:") |
| for key, val in sorted(delta.items()): |
| if isinstance(val, float): |
| sign = "+" if val >= 0 else "-" |
| lines.append(f" {key}: {sign}{abs(val):.4f}") |
|
|
| lines.append("\n" + "=" * 70) |
| return "\n".join(lines) |
|
|
|
|
| def add_ablation_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: |
| """Add ablation-specific arguments.""" |
| parser.add_argument( |
| "--ablation_components", |
| nargs="+", |
| choices=list(AblationStudy.COMPONENTS.keys()), |
| help="Components to ablate. If not specified, ablates all.", |
| ) |
| parser.add_argument( |
| "--ablation_output_dir", |
| type=str, |
| default="results/ablations", |
| help="Output directory for ablation results.", |
| ) |
| return parser |
|
|