Yatsuiii's picture
Upload folder using huggingface_hub
16d6869 verified
"""
Hyperparameter optimization using Optuna.
Provides automated search over model, training, and data hyperparameters.
Integrates with the existing training pipeline via argparse.
"""
from __future__ import annotations
import argparse
import json
import logging
from pathlib import Path
from typing import Any
import optuna
from optuna.pruners import MedianPruner
from optuna.samplers import TPESampler
from optuna.trial import Trial
from brain_gcn.main import train_from_args, validate_args
log = logging.getLogger(__name__)
class HPOConfig:
"""Hyperparameter optimization configuration."""
def __init__(
self,
study_name: str = "brain_gcn_hpo",
n_trials: int = 20,
timeout: int | None = None,
direction: str = "maximize",
objective_metric: str = "test_auc",
storage: str | None = None,
seed: int = 42,
):
self.study_name = study_name
self.n_trials = n_trials
self.timeout = timeout
self.direction = direction
self.objective_metric = objective_metric
self.storage = storage
self.seed = seed
class HPOSearchSpace:
"""Define hyperparameter search space for Optuna."""
@staticmethod
def suggest_params(trial: Trial, base_args: argparse.Namespace) -> argparse.Namespace:
"""Suggest hyperparameters for a single trial.
Parameters
----------
trial : optuna.trial.Trial
Current trial object.
base_args : argparse.Namespace
Base arguments; suggested values override these.
Returns
-------
argparse.Namespace
Arguments with suggested hyperparameters.
"""
args = argparse.Namespace(**vars(base_args))
# Model architecture
args.hidden_dim = trial.suggest_categorical(
"hidden_dim", [32, 64, 128, 256]
)
args.dropout = trial.suggest_float("dropout", 0.0, 0.5, step=0.1)
# Training
args.lr = trial.suggest_loguniform("lr", 1e-5, 1e-2)
args.weight_decay = trial.suggest_loguniform("weight_decay", 1e-6, 1e-3)
args.batch_size = trial.suggest_categorical(
"batch_size", [8, 16, 32, 64]
)
# DropEdge regularization
args.drop_edge_p = trial.suggest_float("drop_edge_p", 0.0, 0.3, step=0.1)
# BOLD noise augmentation
args.bold_noise_std = trial.suggest_float(
"bold_noise_std", 0.0, 0.05, step=0.01
)
# Cosine annealing
args.cosine_t0 = trial.suggest_categorical(
"cosine_t0", [30, 50, 100]
)
args.cosine_t_mult = trial.suggest_categorical(
"cosine_t_mult", [1, 2, 3]
)
args.cosine_eta_min = trial.suggest_loguniform(
"cosine_eta_min", 1e-6, 1e-4
)
return args
def objective(
trial: Trial,
base_args: argparse.Namespace,
hpo_config: HPOConfig,
) -> float:
"""Objective function for Optuna optimization.
Parameters
----------
trial : optuna.trial.Trial
Current trial.
base_args : argparse.Namespace
Base arguments template.
hpo_config : HPOConfig
HPO configuration.
Returns
-------
float
Objective value (test set metric).
"""
try:
# Suggest hyperparameters
args = HPOSearchSpace.suggest_params(trial, base_args)
validate_args(args)
# Train model
trainer, _, _ = train_from_args(args)
# Extract objective metric
metric_value = trainer.callback_metrics.get(
hpo_config.objective_metric,
None
)
if metric_value is None:
log.warning(
f"Metric {hpo_config.objective_metric} not found. "
"Available: %s", list(trainer.callback_metrics.keys())
)
return float("-inf")
return float(metric_value.detach().cpu())
except Exception as e:
log.error(f"Trial failed: {e}")
return float("-inf")
class HPOStudy:
"""Wrapper for Optuna study with convenience methods."""
def __init__(self, config: HPOConfig):
self.config = config
self.study: optuna.Study | None = None
def create_study(self) -> optuna.Study:
"""Create or load Optuna study."""
sampler = TPESampler(seed=self.config.seed)
pruner = MedianPruner()
storage_url = None
if self.config.storage:
storage_url = f"sqlite:///{self.config.storage}"
self.study = optuna.create_study(
study_name=self.config.study_name,
direction=self.config.direction,
sampler=sampler,
pruner=pruner,
storage=storage_url,
load_if_exists=True,
)
return self.study
def optimize(
self,
base_args: argparse.Namespace,
) -> optuna.Study:
"""Run hyperparameter optimization.
Parameters
----------
base_args : argparse.Namespace
Base arguments template.
Returns
-------
optuna.Study
Completed study object.
"""
if self.study is None:
self.create_study()
self.study.optimize(
lambda trial: objective(trial, base_args, self.config),
n_trials=self.config.n_trials,
timeout=self.config.timeout,
show_progress_bar=True,
)
return self.study
def best_params(self) -> dict[str, Any]:
"""Get best hyperparameters found."""
if self.study is None:
raise RuntimeError("Study not created. Call optimize() first.")
return self.study.best_params
def best_value(self) -> float:
"""Get best objective value."""
if self.study is None:
raise RuntimeError("Study not created. Call optimize() first.")
return self.study.best_value
def save_summary(self, output_path: str | Path) -> None:
"""Save HPO summary to JSON."""
if self.study is None:
raise RuntimeError("Study not created. Call optimize() first.")
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
summary = {
"study_name": self.config.study_name,
"n_trials": len(self.study.trials),
"best_value": self.study.best_value,
"best_params": self.study.best_params,
"direction": self.config.direction,
"objective_metric": self.config.objective_metric,
}
with open(output_path, "w") as f:
json.dump(summary, f, indent=2)
log.info(f"HPO summary saved to {output_path}")
def hpo_from_args(args: argparse.Namespace) -> HPOStudy:
"""Create HPO study from command-line arguments."""
hpo_config = HPOConfig(
study_name=args.hpo_study_name,
n_trials=args.hpo_n_trials,
timeout=args.hpo_timeout,
objective_metric=args.hpo_objective,
storage=args.hpo_storage,
seed=args.seed,
)
return HPOStudy(hpo_config)
def add_hpo_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Add HPO-specific arguments to parser."""
parser.add_argument(
"--hpo_study_name",
type=str,
default="brain_gcn_hpo",
help="Optuna study name.",
)
parser.add_argument(
"--hpo_n_trials",
type=int,
default=20,
help="Number of trials.",
)
parser.add_argument(
"--hpo_timeout",
type=int,
default=None,
help="Timeout in seconds.",
)
parser.add_argument(
"--hpo_objective",
type=str,
default="test_auc",
help="Metric to optimize (e.g., test_auc, test_acc).",
)
parser.add_argument(
"--hpo_storage",
type=str,
default="hpo_studies.db",
help="SQLite storage path for persistent studies.",
)
return parser