"""Tests for custom config support in get_model and get_processor.""" from __future__ import annotations from pathlib import Path import pytest import torchio as tio from omegaconf import DictConfig from omegaconf import OmegaConf from colipri.checkpoint import load_model_config from colipri.checkpoint import load_processor_config from colipri.processor import Processor from colipri.processor import get_processor TRANSFORM_YAML_CONTENT = """\ _target_: torchio.transforms.augmentation.composition.Compose transforms: - _target_: torchio.transforms.preprocessing.intensity.clamp.Clamp out_min: -500 out_max: 500 """ @pytest.fixture def transform_yaml(tmp_path: Path) -> Path: """A self-contained transform YAML with no interpolation variables.""" path = tmp_path / "transform.yaml" path.write_text(TRANSFORM_YAML_CONTENT) return path @pytest.fixture def resolved_processor_config() -> DictConfig: """Default processor config with all interpolations resolved.""" config = load_processor_config() resolved = OmegaConf.to_container(config, resolve=True) assert isinstance(resolved, dict) return OmegaConf.create(resolved) @pytest.fixture def resolved_model_config() -> DictConfig: """Default model config with all interpolations resolved.""" config = load_model_config() resolved = OmegaConf.to_container(config, resolve=True) assert isinstance(resolved, dict) return OmegaConf.create(resolved) class TestGetProcessorCustomConfig: def test_with_transform_yaml_path(self, transform_yaml: Path) -> None: """Transform YAML path → Processor with custom transform.""" processor = get_processor(config=transform_yaml, image_only=True) assert isinstance(processor, Processor) transform = processor._image_transform assert isinstance(transform, tio.Compose) assert len(transform.transforms) == 1 clamp = transform.transforms[0] assert clamp.out_min == -500 assert clamp.out_max == 500 def test_with_transform_dictconfig(self) -> None: """Transform DictConfig object → Processor with custom transform.""" config = OmegaConf.create(TRANSFORM_YAML_CONTENT) processor = get_processor(config=config, image_only=True) assert isinstance(processor, Processor) transform = processor._image_transform assert isinstance(transform, tio.Compose) assert transform.transforms[0].out_min == -500 def test_with_full_processor_config( self, resolved_processor_config: DictConfig, ) -> None: """Full processor DictConfig → Processor matching that config.""" # Remove all but the first transform to distinguish from default (5 transforms) resolved_processor_config.image_transform.transforms = ( resolved_processor_config.image_transform.transforms[:1] ) processor = get_processor( config=resolved_processor_config, image_only=True, ) assert isinstance(processor, Processor) assert isinstance(processor._image_transform, tio.Compose) assert len(processor._image_transform.transforms) == 1 def test_transform_yaml_wraps_with_default_tokenizer( self, transform_yaml: Path, ) -> None: """Transform-only config is wrapped with default tokenizer config.""" # Without image_only, the transform YAML should be wrapped into a full # processor config that includes the default tokenizer. processor = get_processor(config=transform_yaml) assert isinstance(processor, Processor) # Should have both custom transform and default tokenizer assert isinstance(processor._image_transform, tio.Compose) assert processor._text_tokenizer is not None def test_default_unchanged(self) -> None: """get_processor() without config still works (backward compat).""" processor = get_processor(image_only=True) assert isinstance(processor, Processor) assert isinstance(processor._image_transform, tio.Compose) class TestGetModelCustomConfig: def test_with_config(self, resolved_model_config: DictConfig) -> None: """Pass a model DictConfig → Model with that config.""" from colipri.model.multimodal import Model from colipri.model.multimodal import get_model model = get_model(pretrained=False, config=resolved_model_config) assert isinstance(model, Model) def test_default_unchanged(self) -> None: """get_model(pretrained=False) without config still works.""" from colipri.model.multimodal import Model from colipri.model.multimodal import get_model model = get_model(pretrained=False) assert isinstance(model, Model)