"""Tests for config loading, including GlobalHydra conflict regression.""" from omegaconf import OmegaConf from colipri.checkpoint import _load_config from colipri.checkpoint import load_model_config from colipri.checkpoint import load_processor_config def test_load_config(): config = _load_config() assert config.input_size == 192 assert config.spacing == 2 assert config.model._target_ == "colipri.model.multimodal.Model" assert config.processor._target_ == "colipri.processor.Processor" def test_load_model_config(): config = load_model_config() assert config._target_ == "colipri.model.multimodal.Model" assert "image_encoder" in config assert "text_encoder" in config def test_load_processor_config(): config = load_processor_config() assert config._target_ == "colipri.processor.Processor" assert "image_transform" in config assert "tokenizer" in config def test_overrides(): config = _load_config(overrides=["input_size=256", "spacing=3"]) assert config.input_size == 256 assert config.spacing == 3 def test_interpolation(): config = _load_config() resolved = OmegaConf.to_container(config, resolve=True) assert isinstance(resolved, dict) backbone = resolved["model"]["image_encoder"]["backbone"] assert backbone["embed_dim"] == 864 # ${image_embed_dim} assert backbone["input_shape"] == [192, 192, 192] # ${input_size} def test_config_loading_with_hydra_preinitialized(): """Regression test: COLIPRI must work when GlobalHydra is already initialized. See https://huggingface.co/microsoft/colipri/discussions/3 """ from hydra import initialize from hydra.core.global_hydra import GlobalHydra with initialize(config_path=None, version_base=None): assert GlobalHydra.instance().is_initialized() model_cfg = load_model_config() proc_cfg = load_processor_config() assert model_cfg._target_ == "colipri.model.multimodal.Model" assert proc_cfg._target_ == "colipri.processor.Processor" # Overrides must also work config = _load_config(overrides=["input_size=256"]) assert config.input_size == 256 # GlobalHydra state must not be corrupted assert not GlobalHydra.instance().is_initialized()