| """Tests for custom config support in get_model and get_processor.""" |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
|
|
| import pytest |
| import torchio as tio |
| from omegaconf import DictConfig |
| from omegaconf import OmegaConf |
|
|
| from colipri.checkpoint import load_model_config |
| from colipri.checkpoint import load_processor_config |
| from colipri.processor import Processor |
| from colipri.processor import get_processor |
|
|
| TRANSFORM_YAML_CONTENT = """\ |
| _target_: torchio.transforms.augmentation.composition.Compose |
| transforms: |
| - _target_: torchio.transforms.preprocessing.intensity.clamp.Clamp |
| out_min: -500 |
| out_max: 500 |
| """ |
|
|
|
|
| @pytest.fixture |
| def transform_yaml(tmp_path: Path) -> Path: |
| """A self-contained transform YAML with no interpolation variables.""" |
| path = tmp_path / "transform.yaml" |
| path.write_text(TRANSFORM_YAML_CONTENT) |
| return path |
|
|
|
|
| @pytest.fixture |
| def resolved_processor_config() -> DictConfig: |
| """Default processor config with all interpolations resolved.""" |
| config = load_processor_config() |
| resolved = OmegaConf.to_container(config, resolve=True) |
| assert isinstance(resolved, dict) |
| return OmegaConf.create(resolved) |
|
|
|
|
| @pytest.fixture |
| def resolved_model_config() -> DictConfig: |
| """Default model config with all interpolations resolved.""" |
| config = load_model_config() |
| resolved = OmegaConf.to_container(config, resolve=True) |
| assert isinstance(resolved, dict) |
| return OmegaConf.create(resolved) |
|
|
|
|
| class TestGetProcessorCustomConfig: |
| def test_with_transform_yaml_path(self, transform_yaml: Path) -> None: |
| """Transform YAML path → Processor with custom transform.""" |
| processor = get_processor(config=transform_yaml, image_only=True) |
| assert isinstance(processor, Processor) |
| transform = processor._image_transform |
| assert isinstance(transform, tio.Compose) |
| assert len(transform.transforms) == 1 |
| clamp = transform.transforms[0] |
| assert clamp.out_min == -500 |
| assert clamp.out_max == 500 |
|
|
| def test_with_transform_dictconfig(self) -> None: |
| """Transform DictConfig object → Processor with custom transform.""" |
| config = OmegaConf.create(TRANSFORM_YAML_CONTENT) |
| processor = get_processor(config=config, image_only=True) |
| assert isinstance(processor, Processor) |
| transform = processor._image_transform |
| assert isinstance(transform, tio.Compose) |
| assert transform.transforms[0].out_min == -500 |
|
|
| def test_with_full_processor_config( |
| self, |
| resolved_processor_config: DictConfig, |
| ) -> None: |
| """Full processor DictConfig → Processor matching that config.""" |
| |
| resolved_processor_config.image_transform.transforms = ( |
| resolved_processor_config.image_transform.transforms[:1] |
| ) |
| processor = get_processor( |
| config=resolved_processor_config, |
| image_only=True, |
| ) |
| assert isinstance(processor, Processor) |
| assert isinstance(processor._image_transform, tio.Compose) |
| assert len(processor._image_transform.transforms) == 1 |
|
|
| def test_transform_yaml_wraps_with_default_tokenizer( |
| self, |
| transform_yaml: Path, |
| ) -> None: |
| """Transform-only config is wrapped with default tokenizer config.""" |
| |
| |
| processor = get_processor(config=transform_yaml) |
| assert isinstance(processor, Processor) |
| |
| assert isinstance(processor._image_transform, tio.Compose) |
| assert processor._text_tokenizer is not None |
|
|
| def test_default_unchanged(self) -> None: |
| """get_processor() without config still works (backward compat).""" |
| processor = get_processor(image_only=True) |
| assert isinstance(processor, Processor) |
| assert isinstance(processor._image_transform, tio.Compose) |
|
|
|
|
| class TestGetModelCustomConfig: |
| def test_with_config(self, resolved_model_config: DictConfig) -> None: |
| """Pass a model DictConfig → Model with that config.""" |
| from colipri.model.multimodal import Model |
| from colipri.model.multimodal import get_model |
|
|
| model = get_model(pretrained=False, config=resolved_model_config) |
| assert isinstance(model, Model) |
|
|
| def test_default_unchanged(self) -> None: |
| """get_model(pretrained=False) without config still works.""" |
| from colipri.model.multimodal import Model |
| from colipri.model.multimodal import get_model |
|
|
| model = get_model(pretrained=False) |
| assert isinstance(model, Model) |
|
|