BaseChange / models /__init__.py
Vedant Jigarbhai Mehta
Initial scaffolding for military base change detection project
b25c087
"""Model factory for change detection models.
Provides a unified interface to instantiate any supported model by name.
"""
from typing import Any, Dict
import torch.nn as nn
from .changeformer import ChangeFormer
from .siamese_cnn import SiameseCNN
from .unet_pp import UNetPPChangeDetection
_MODEL_REGISTRY: Dict[str, type] = {
"siamese_cnn": SiameseCNN,
"unet_pp": UNetPPChangeDetection,
"changeformer": ChangeFormer,
}
def get_model(model_name: str, config: Dict[str, Any]) -> nn.Module:
"""Instantiate a change detection model by name.
Args:
model_name: One of 'siamese_cnn', 'unet_pp', 'changeformer'.
config: Full config dict; model-specific section is extracted internally.
Returns:
Initialized model (nn.Module).
Raises:
ValueError: If model_name is not recognized.
"""
if model_name not in _MODEL_REGISTRY:
raise ValueError(f"Unknown model '{model_name}'. Choose from: {list(_MODEL_REGISTRY.keys())}")
model_cls = _MODEL_REGISTRY[model_name]
model_config = config.get(model_name, {})
return model_cls(**model_config)