| """ |
| Hyperparameter optimization using Optuna. |
| |
| Provides automated search over model, training, and data hyperparameters. |
| Integrates with the existing training pipeline via argparse. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import logging |
| from pathlib import Path |
| from typing import Any |
|
|
| import optuna |
| from optuna.pruners import MedianPruner |
| from optuna.samplers import TPESampler |
| from optuna.trial import Trial |
|
|
| from brain_gcn.main import train_from_args, validate_args |
|
|
| log = logging.getLogger(__name__) |
|
|
|
|
| class HPOConfig: |
| """Hyperparameter optimization configuration.""" |
|
|
| def __init__( |
| self, |
| study_name: str = "brain_gcn_hpo", |
| n_trials: int = 20, |
| timeout: int | None = None, |
| direction: str = "maximize", |
| objective_metric: str = "test_auc", |
| storage: str | None = None, |
| seed: int = 42, |
| ): |
| self.study_name = study_name |
| self.n_trials = n_trials |
| self.timeout = timeout |
| self.direction = direction |
| self.objective_metric = objective_metric |
| self.storage = storage |
| self.seed = seed |
|
|
|
|
| class HPOSearchSpace: |
| """Define hyperparameter search space for Optuna.""" |
|
|
| @staticmethod |
| def suggest_params(trial: Trial, base_args: argparse.Namespace) -> argparse.Namespace: |
| """Suggest hyperparameters for a single trial. |
| |
| Parameters |
| ---------- |
| trial : optuna.trial.Trial |
| Current trial object. |
| base_args : argparse.Namespace |
| Base arguments; suggested values override these. |
| |
| Returns |
| ------- |
| argparse.Namespace |
| Arguments with suggested hyperparameters. |
| """ |
| args = argparse.Namespace(**vars(base_args)) |
|
|
| |
| args.hidden_dim = trial.suggest_categorical( |
| "hidden_dim", [32, 64, 128, 256] |
| ) |
| args.dropout = trial.suggest_float("dropout", 0.0, 0.5, step=0.1) |
|
|
| |
| args.lr = trial.suggest_loguniform("lr", 1e-5, 1e-2) |
| args.weight_decay = trial.suggest_loguniform("weight_decay", 1e-6, 1e-3) |
| args.batch_size = trial.suggest_categorical( |
| "batch_size", [8, 16, 32, 64] |
| ) |
|
|
| |
| args.drop_edge_p = trial.suggest_float("drop_edge_p", 0.0, 0.3, step=0.1) |
|
|
| |
| args.bold_noise_std = trial.suggest_float( |
| "bold_noise_std", 0.0, 0.05, step=0.01 |
| ) |
|
|
| |
| args.cosine_t0 = trial.suggest_categorical( |
| "cosine_t0", [30, 50, 100] |
| ) |
| args.cosine_t_mult = trial.suggest_categorical( |
| "cosine_t_mult", [1, 2, 3] |
| ) |
| args.cosine_eta_min = trial.suggest_loguniform( |
| "cosine_eta_min", 1e-6, 1e-4 |
| ) |
|
|
| return args |
|
|
|
|
| def objective( |
| trial: Trial, |
| base_args: argparse.Namespace, |
| hpo_config: HPOConfig, |
| ) -> float: |
| """Objective function for Optuna optimization. |
| |
| Parameters |
| ---------- |
| trial : optuna.trial.Trial |
| Current trial. |
| base_args : argparse.Namespace |
| Base arguments template. |
| hpo_config : HPOConfig |
| HPO configuration. |
| |
| Returns |
| ------- |
| float |
| Objective value (test set metric). |
| """ |
| try: |
| |
| args = HPOSearchSpace.suggest_params(trial, base_args) |
| validate_args(args) |
|
|
| |
| trainer, _, _ = train_from_args(args) |
|
|
| |
| metric_value = trainer.callback_metrics.get( |
| hpo_config.objective_metric, |
| None |
| ) |
| if metric_value is None: |
| log.warning( |
| f"Metric {hpo_config.objective_metric} not found. " |
| "Available: %s", list(trainer.callback_metrics.keys()) |
| ) |
| return float("-inf") |
|
|
| return float(metric_value.detach().cpu()) |
|
|
| except Exception as e: |
| log.error(f"Trial failed: {e}") |
| return float("-inf") |
|
|
|
|
| class HPOStudy: |
| """Wrapper for Optuna study with convenience methods.""" |
|
|
| def __init__(self, config: HPOConfig): |
| self.config = config |
| self.study: optuna.Study | None = None |
|
|
| def create_study(self) -> optuna.Study: |
| """Create or load Optuna study.""" |
| sampler = TPESampler(seed=self.config.seed) |
| pruner = MedianPruner() |
|
|
| storage_url = None |
| if self.config.storage: |
| storage_url = f"sqlite:///{self.config.storage}" |
|
|
| self.study = optuna.create_study( |
| study_name=self.config.study_name, |
| direction=self.config.direction, |
| sampler=sampler, |
| pruner=pruner, |
| storage=storage_url, |
| load_if_exists=True, |
| ) |
| return self.study |
|
|
| def optimize( |
| self, |
| base_args: argparse.Namespace, |
| ) -> optuna.Study: |
| """Run hyperparameter optimization. |
| |
| Parameters |
| ---------- |
| base_args : argparse.Namespace |
| Base arguments template. |
| |
| Returns |
| ------- |
| optuna.Study |
| Completed study object. |
| """ |
| if self.study is None: |
| self.create_study() |
|
|
| self.study.optimize( |
| lambda trial: objective(trial, base_args, self.config), |
| n_trials=self.config.n_trials, |
| timeout=self.config.timeout, |
| show_progress_bar=True, |
| ) |
| return self.study |
|
|
| def best_params(self) -> dict[str, Any]: |
| """Get best hyperparameters found.""" |
| if self.study is None: |
| raise RuntimeError("Study not created. Call optimize() first.") |
| return self.study.best_params |
|
|
| def best_value(self) -> float: |
| """Get best objective value.""" |
| if self.study is None: |
| raise RuntimeError("Study not created. Call optimize() first.") |
| return self.study.best_value |
|
|
| def save_summary(self, output_path: str | Path) -> None: |
| """Save HPO summary to JSON.""" |
| if self.study is None: |
| raise RuntimeError("Study not created. Call optimize() first.") |
|
|
| output_path = Path(output_path) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
| summary = { |
| "study_name": self.config.study_name, |
| "n_trials": len(self.study.trials), |
| "best_value": self.study.best_value, |
| "best_params": self.study.best_params, |
| "direction": self.config.direction, |
| "objective_metric": self.config.objective_metric, |
| } |
|
|
| with open(output_path, "w") as f: |
| json.dump(summary, f, indent=2) |
|
|
| log.info(f"HPO summary saved to {output_path}") |
|
|
|
|
| def hpo_from_args(args: argparse.Namespace) -> HPOStudy: |
| """Create HPO study from command-line arguments.""" |
| hpo_config = HPOConfig( |
| study_name=args.hpo_study_name, |
| n_trials=args.hpo_n_trials, |
| timeout=args.hpo_timeout, |
| objective_metric=args.hpo_objective, |
| storage=args.hpo_storage, |
| seed=args.seed, |
| ) |
| return HPOStudy(hpo_config) |
|
|
|
|
| def add_hpo_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: |
| """Add HPO-specific arguments to parser.""" |
| parser.add_argument( |
| "--hpo_study_name", |
| type=str, |
| default="brain_gcn_hpo", |
| help="Optuna study name.", |
| ) |
| parser.add_argument( |
| "--hpo_n_trials", |
| type=int, |
| default=20, |
| help="Number of trials.", |
| ) |
| parser.add_argument( |
| "--hpo_timeout", |
| type=int, |
| default=None, |
| help="Timeout in seconds.", |
| ) |
| parser.add_argument( |
| "--hpo_objective", |
| type=str, |
| default="test_auc", |
| help="Metric to optimize (e.g., test_auc, test_acc).", |
| ) |
| parser.add_argument( |
| "--hpo_storage", |
| type=str, |
| default="hpo_studies.db", |
| help="SQLite storage path for persistent studies.", |
| ) |
| return parser |
|
|