""" Model registry for centralized model access and configuration. Simplifies model loading, configuration, and comparison. """ from __future__ import annotations import argparse import logging from typing import Any, Callable import torch from torch import nn log = logging.getLogger(__name__) # Import all models def _lazy_import_models(): """Lazy import to avoid circular dependencies.""" from brain_gcn.models.brain_gcn import BrainGCNClassifier, GraphOnlyClassifier, TemporalGRUClassifier, ConnectivityMLPClassifier from brain_gcn.models.advanced_models import ( GATClassifier, TransformerClassifier, CNN3DClassifier, GraphSAGEClassifier, ) return { # Original models 'graph_temporal': BrainGCNClassifier, 'gcn': GraphOnlyClassifier, 'gru': TemporalGRUClassifier, 'fc_mlp': ConnectivityMLPClassifier, # New models 'gat': GATClassifier, 'transformer': TransformerClassifier, 'cnn3d': CNN3DClassifier, 'graphsage': GraphSAGEClassifier, } class ModelConfig: """Configuration for model instantiation.""" def __init__( self, model_name: str, hidden_dim: int = 64, dropout: float = 0.5, num_heads: int = 4, num_layers: int = 2, readout: str = "attention", drop_edge_p: float = 0.1, **kwargs ): """ Parameters ---------- model_name : str Model identifier (must be in registry) hidden_dim : int Hidden dimension size dropout : float Dropout probability num_heads : int Number of attention heads (for GAT, Transformer) num_layers : int Number of layers (for Transformer) readout : str Readout method ("attention" or "mean") drop_edge_p : float Edge dropout probability (for GCN-based models) **kwargs Additional arguments """ self.model_name = model_name self.hidden_dim = hidden_dim self.dropout = dropout self.num_heads = num_heads self.num_layers = num_layers self.readout = readout self.drop_edge_p = drop_edge_p self.kwargs = kwargs def to_dict(self) -> dict[str, Any]: """Export configuration as dictionary.""" return { 'model_name': self.model_name, 'hidden_dim': self.hidden_dim, 'dropout': self.dropout, 'num_heads': self.num_heads, 'num_layers': self.num_layers, 'readout': self.readout, 'drop_edge_p': self.drop_edge_p, **self.kwargs } @classmethod def from_dict(cls, config_dict: dict) -> ModelConfig: """Load configuration from dictionary.""" config_dict = dict(config_dict) # don't mutate caller's dict model_name = config_dict.pop('model_name') return cls(model_name, **config_dict) class ModelRegistry: """Central registry for all available models.""" _models = None _configs = { 'graph_temporal': { 'display_name': 'Graph-Temporal GCN', 'description': 'Graph projection per window + GRU temporal encoder', 'requires': ['bold_windows', 'adj'], 'parameters': ['hidden_dim', 'dropout', 'readout', 'drop_edge_p'], }, 'gcn': { 'display_name': 'Graph-Only (GCN)', 'description': 'GCN baseline over ROI average signals', 'requires': ['bold_windows', 'adj'], 'parameters': ['hidden_dim', 'dropout', 'drop_edge_p'], }, 'gru': { 'display_name': 'Temporal-Only (GRU)', 'description': 'GRU baseline without graph structure', 'requires': ['bold_windows'], 'parameters': ['hidden_dim', 'dropout'], }, 'fc_mlp': { 'display_name': 'Connectivity MLP', 'description': 'Static FC adjacency MLP (requires --no-use_population_adj)', 'requires': ['adj'], 'parameters': ['hidden_dim', 'dropout'], }, 'gat': { 'display_name': 'Graph Attention Network', 'description': 'Multi-head graph attention mechanism', 'requires': ['bold_windows', 'adj'], 'parameters': ['hidden_dim', 'dropout', 'num_heads'], }, 'transformer': { 'display_name': 'Transformer Encoder', 'description': 'Transformer-based temporal encoder', 'requires': ['bold_windows'], 'parameters': ['hidden_dim', 'dropout', 'num_heads', 'num_layers'], }, 'cnn3d': { 'display_name': '3D-CNN', 'description': '3D convolution for spatiotemporal features', 'requires': ['bold_windows', 'fc_windows'], 'parameters': ['hidden_dim', 'dropout'], }, 'graphsage': { 'display_name': 'GraphSAGE', 'description': 'Sampling and aggregating graph convolution', 'requires': ['bold_windows', 'adj'], 'parameters': ['hidden_dim', 'dropout'], }, } @classmethod def get_models(cls) -> dict[str, type]: """Get all available models.""" if cls._models is None: cls._models = _lazy_import_models() return cls._models @classmethod def get_model_class(cls, model_name: str) -> type: """Get model class by name.""" models = cls.get_models() if model_name not in models: available = ', '.join(models.keys()) raise ValueError( f"Unknown model: {model_name}\nAvailable: {available}" ) return models[model_name] @classmethod def build_model( cls, config: ModelConfig, **override_kwargs ) -> nn.Module: """Build model instance from config. Parameters ---------- config : ModelConfig Model configuration **override_kwargs Override config parameters Returns ------- nn.Module Instantiated model """ model_class = cls.get_model_class(config.model_name) # Prepare arguments kwargs = { 'hidden_dim': config.hidden_dim, 'dropout': config.dropout, } # Add model-specific parameters if config.model_name in ['graph_temporal', 'gcn', 'graphsage']: kwargs['drop_edge_p'] = config.drop_edge_p if config.model_name == 'graph_temporal': kwargs['readout'] = config.readout if config.model_name in ['gat', 'transformer']: kwargs['num_heads'] = config.num_heads if config.model_name == 'transformer': kwargs['num_layers'] = config.num_layers # Apply overrides kwargs.update(override_kwargs) # Remove unsupported kwargs model_class_init = model_class.__init__ import inspect sig = inspect.signature(model_class_init) valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters} log.info(f"Building {config.model_name} with {valid_kwargs}") return model_class(**valid_kwargs) @classmethod def list_models(cls) -> list[str]: """List all available models.""" return list(cls._configs.keys()) @classmethod def get_model_info(cls, model_name: str) -> dict: """Get information about a model. Parameters ---------- model_name : str Model name Returns ------- dict Model metadata """ if model_name not in cls._configs: raise ValueError(f"Unknown model: {model_name}") return cls._configs[model_name] @classmethod def print_registry(cls) -> None: """Print all models and their descriptions.""" print("\n" + "=" * 80) print("AVAILABLE MODELS") print("=" * 80) for model_name in cls.list_models(): info = cls.get_model_info(model_name) print(f"\n{model_name:15} | {info['display_name']}") print(f"{'':15} | {info['description']}") print(f"{'':15} | Requires: {', '.join(info['requires'])}") def add_model_choice_argument(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """Add model choice argument to parser. Parameters ---------- parser : argparse.ArgumentParser Argument parser Returns ------- argparse.ArgumentParser Updated parser """ available_models = ModelRegistry.list_models() parser.add_argument( '--model_name', type=str, choices=available_models, default='graph_temporal', help=f"Model architecture. Available: {', '.join(available_models)}", ) parser.add_argument( '--num_heads', type=int, default=4, help="Number of attention heads (for GAT, Transformer)", ) parser.add_argument( '--num_layers', type=int, default=2, help="Number of layers (for Transformer)", ) return parser if __name__ == "__main__": # Print all available models ModelRegistry.print_registry()