| """ |
| Experiment tracking and logging infrastructure. |
| |
| Tracks: |
| - Run metadata (config, environment, hardware) |
| - Training/validation/test metrics |
| - Checkpoint locations |
| - Results summaries |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import logging |
| from dataclasses import asdict, dataclass, field |
| from datetime import datetime |
| from pathlib import Path |
| from typing import Any |
|
|
| import platform |
| import torch |
|
|
| log = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class ExperimentMetadata: |
| """Metadata for an experiment run.""" |
|
|
| run_id: str |
| timestamp: str |
| model_name: str |
| dataset: str = "ABIDE" |
| split_strategy: str = "site_holdout" |
| notes: str = "" |
|
|
| |
| python_version: str = "" |
| pytorch_version: str = "" |
| device: str = "" |
| num_gpus: int = 0 |
|
|
| |
| hyperparameters: dict[str, Any] = field(default_factory=dict) |
|
|
| |
| test_metrics: dict[str, float] = field(default_factory=dict) |
| checkpoint_path: str = "" |
|
|
| def to_dict(self) -> dict: |
| """Convert to dictionary.""" |
| return asdict(self) |
|
|
| def to_json(self) -> str: |
| """Convert to JSON string.""" |
| return json.dumps(self.to_dict(), indent=2) |
|
|
| @classmethod |
| def from_args( |
| cls, |
| run_id: str, |
| args, |
| notes: str = "", |
| ) -> ExperimentMetadata: |
| """Create metadata from training arguments. |
| |
| Parameters |
| ---------- |
| run_id : str |
| Unique run identifier. |
| args : argparse.Namespace |
| Training arguments. |
| notes : str, optional |
| Additional notes. |
| |
| Returns |
| ------- |
| ExperimentMetadata |
| Metadata object. |
| """ |
| hyperparams = { |
| "hidden_dim": getattr(args, "hidden_dim", None), |
| "dropout": getattr(args, "dropout", None), |
| "lr": getattr(args, "lr", None), |
| "weight_decay": getattr(args, "weight_decay", None), |
| "batch_size": getattr(args, "batch_size", None), |
| "max_epochs": getattr(args, "max_epochs", None), |
| "drop_edge_p": getattr(args, "drop_edge_p", None), |
| "bold_noise_std": getattr(args, "bold_noise_std", None), |
| } |
|
|
| return cls( |
| run_id=run_id, |
| timestamp=datetime.now().isoformat(), |
| model_name=getattr(args, "model_name", "unknown"), |
| split_strategy=getattr(args, "split_strategy", "site_holdout"), |
| notes=notes, |
| python_version=platform.python_version(), |
| pytorch_version=torch.__version__, |
| device=str(torch.device("cuda" if torch.cuda.is_available() else "cpu")), |
| num_gpus=torch.cuda.device_count(), |
| hyperparameters=hyperparams, |
| ) |
|
|
|
|
| class ExperimentTracker: |
| """Tracks and logs experiment runs.""" |
|
|
| def __init__(self, output_dir: str | Path = "experiments"): |
| """Initialize tracker. |
| |
| Parameters |
| ---------- |
| output_dir : str or Path |
| Directory to save experiment logs. |
| """ |
| self.output_dir = Path(output_dir) |
| self.output_dir.mkdir(parents=True, exist_ok=True) |
| self.metadata_list: list[ExperimentMetadata] = [] |
|
|
| def add_run( |
| self, |
| metadata: ExperimentMetadata, |
| ) -> None: |
| """Record a completed run. |
| |
| Parameters |
| ---------- |
| metadata : ExperimentMetadata |
| Run metadata. |
| """ |
| self.metadata_list.append(metadata) |
| self._save_run(metadata) |
|
|
| def _save_run(self, metadata: ExperimentMetadata) -> None: |
| """Save individual run to JSON.""" |
| run_dir = self.output_dir / metadata.run_id |
| run_dir.mkdir(parents=True, exist_ok=True) |
|
|
| meta_file = run_dir / "metadata.json" |
| with open(meta_file, "w") as f: |
| f.write(metadata.to_json()) |
|
|
| log.info(f"Experiment metadata saved to {meta_file}") |
|
|
| def save_summary(self) -> None: |
| """Save summary of all runs.""" |
| summary_file = self.output_dir / "summary.json" |
|
|
| summary = { |
| "total_runs": len(self.metadata_list), |
| "runs": [m.to_dict() for m in self.metadata_list], |
| } |
|
|
| with open(summary_file, "w") as f: |
| json.dump(summary, f, indent=2) |
|
|
| log.info(f"Experiment summary saved to {summary_file}") |
|
|
| def load_summary(self) -> dict: |
| """Load summary from disk.""" |
| summary_file = self.output_dir / "summary.json" |
| if not summary_file.exists(): |
| return {"total_runs": 0, "runs": []} |
|
|
| with open(summary_file) as f: |
| return json.load(f) |
|
|
|
|
| class RunLogger: |
| """Context manager for logging a single run.""" |
|
|
| def __init__( |
| self, |
| run_id: str, |
| args, |
| tracker: ExperimentTracker, |
| notes: str = "", |
| ): |
| """Initialize run logger. |
| |
| Parameters |
| ---------- |
| run_id : str |
| Unique run ID. |
| args : argparse.Namespace |
| Training arguments. |
| tracker : ExperimentTracker |
| Parent tracker. |
| notes : str, optional |
| Notes about the run. |
| """ |
| self.run_id = run_id |
| self.args = args |
| self.tracker = tracker |
| self.notes = notes |
| self.metadata = ExperimentMetadata.from_args(run_id, args, notes) |
|
|
| def __enter__(self): |
| """Enter context.""" |
| log.info(f"Starting run: {self.run_id}") |
| return self.metadata |
|
|
| def __exit__(self, exc_type, exc_val, exc_tb): |
| """Exit context and log results.""" |
| if exc_type is not None: |
| log.error(f"Run {self.run_id} failed: {exc_val}") |
| return |
|
|
| self.tracker.add_run(self.metadata) |
| log.info(f"Run {self.run_id} completed and logged") |
|
|
| def update_metrics(self, metrics: dict) -> None: |
| """Update test metrics.""" |
| self.metadata.test_metrics.update(metrics) |
|
|
| def set_checkpoint_path(self, path: str | Path) -> None: |
| """Record checkpoint location.""" |
| self.metadata.checkpoint_path = str(path) |
|
|