fepegar commited on
Commit
a345d50
·
1 Parent(s): e55f364

Add support to pass config to model and processor

Browse files
src/colipri/model/multimodal.py CHANGED
@@ -8,6 +8,8 @@ from accelerate import init_empty_weights
8
  from accelerate import load_checkpoint_and_dispatch
9
  from einops import rearrange
10
  from hydra.utils import instantiate
 
 
11
  from safetensors.torch import load_model
12
  from safetensors.torch import save_model
13
  from torch import nn
@@ -38,17 +40,42 @@ from .text import TextEncoder
38
  def get_model(
39
  checkpoint_path: TypePath | None = None,
40
  *,
 
41
  pretrained: bool = True,
42
  image_only: bool = False,
43
  **kwargs,
44
  ) -> Model:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  if pretrained and checkpoint_path is None:
46
  checkpoint_path = download_weights()
47
 
48
- overrides = []
49
- for key, value in kwargs.items():
50
- overrides.append(f"{key}={value}")
51
- config = load_model_config(overrides=overrides)
 
 
 
 
 
 
 
52
 
53
  if image_only:
54
  config.text_encoder = None
 
8
  from accelerate import load_checkpoint_and_dispatch
9
  from einops import rearrange
10
  from hydra.utils import instantiate
11
+ from omegaconf import DictConfig
12
+ from omegaconf import OmegaConf
13
  from safetensors.torch import load_model
14
  from safetensors.torch import save_model
15
  from torch import nn
 
40
  def get_model(
41
  checkpoint_path: TypePath | None = None,
42
  *,
43
+ config: TypePath | DictConfig | None = None,
44
  pretrained: bool = True,
45
  image_only: bool = False,
46
  **kwargs,
47
  ) -> Model:
48
+ """Create a :class:`Model` instance.
49
+
50
+ Args:
51
+ checkpoint_path: Path to a ``.safetensors`` checkpoint. When ``None``
52
+ and ``pretrained`` is ``True``, the pretrained weights are
53
+ downloaded from the Hugging Face Hub.
54
+ config: Optional custom model config. Can be a path to a YAML file or
55
+ a ``DictConfig``. When ``None``, the built-in default config is
56
+ used.
57
+ pretrained: If ``True`` and ``checkpoint_path`` is ``None``, download
58
+ the pretrained weights.
59
+ image_only: If ``True``, the text encoder is disabled.
60
+ **kwargs: Hydra-style dot-list overrides applied on top of the config.
61
+
62
+ Returns:
63
+ A configured :class:`Model` instance.
64
+ """
65
  if pretrained and checkpoint_path is None:
66
  checkpoint_path = download_weights()
67
 
68
+ overrides = [f"{key}={value}" for key, value in kwargs.items()]
69
+
70
+ if config is None:
71
+ config = load_model_config(overrides=overrides or None)
72
+ else:
73
+ if isinstance(config, (str, Path)):
74
+ config = OmegaConf.load(config)
75
+ assert isinstance(config, DictConfig)
76
+ if overrides:
77
+ config = OmegaConf.merge(config, OmegaConf.from_dotlist(overrides))
78
+ assert isinstance(config, DictConfig)
79
 
80
  if image_only:
81
  config.text_encoder = None
src/colipri/processor.py CHANGED
@@ -1,8 +1,12 @@
1
  from __future__ import annotations
2
 
 
 
3
  import torch
4
  import torchio as tio
5
  from hydra.utils import instantiate
 
 
6
  from transformers import BertTokenizer
7
  from transformers.tokenization_utils_base import BatchEncoding
8
 
@@ -15,12 +19,50 @@ from .types import TypeStringOrStrings
15
  from .types import TypeTextAttentionMask
16
  from .types import TypeTextTokenIds
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- def get_processor(*, image_only: bool = False, **kwargs) -> Processor:
20
- overrides = []
21
- for key, value in kwargs.items():
22
- overrides.append(f"{key}={value}")
23
- config = load_processor_config(overrides=overrides)
24
  if image_only:
25
  config.tokenizer = None
26
  return instantiate(config)
 
1
  from __future__ import annotations
2
 
3
+ from pathlib import Path
4
+
5
  import torch
6
  import torchio as tio
7
  from hydra.utils import instantiate
8
+ from omegaconf import DictConfig
9
+ from omegaconf import OmegaConf
10
  from transformers import BertTokenizer
11
  from transformers.tokenization_utils_base import BatchEncoding
12
 
 
19
  from .types import TypeTextAttentionMask
20
  from .types import TypeTextTokenIds
21
 
22
+ PROCESSOR_TARGET = "colipri.processor.Processor"
23
+
24
+
25
+ def get_processor(
26
+ *,
27
+ config: TypePath | DictConfig | None = None,
28
+ image_only: bool = False,
29
+ **kwargs,
30
+ ) -> Processor:
31
+ """Create a :class:`Processor` instance.
32
+
33
+ Args:
34
+ config: Optional custom config. Can be a path to a YAML file or a
35
+ ``DictConfig``. If the config is a transform-only config (e.g., as
36
+ exported by estereo), it is automatically wrapped into a full
37
+ processor config using the default tokenizer. When ``None``, the
38
+ built-in default config is used.
39
+ image_only: If ``True``, the tokenizer is disabled.
40
+ **kwargs: Hydra-style dot-list overrides applied on top of the config.
41
+
42
+ Returns:
43
+ A configured :class:`Processor` instance.
44
+ """
45
+ overrides = [f"{key}={value}" for key, value in kwargs.items()]
46
+
47
+ if config is None:
48
+ config = load_processor_config(overrides=overrides or None)
49
+ else:
50
+ if isinstance(config, (str, Path)):
51
+ config = OmegaConf.load(config)
52
+ assert isinstance(config, DictConfig)
53
+
54
+ # If the config is a transform (not a full Processor config), wrap it
55
+ # into a complete processor config using the default tokenizer.
56
+ is_full_processor_config = config.get("_target_") == PROCESSOR_TARGET
57
+ if not is_full_processor_config:
58
+ default_config = load_processor_config()
59
+ default_config.image_transform = config
60
+ config = default_config
61
+
62
+ if overrides:
63
+ config = OmegaConf.merge(config, OmegaConf.from_dotlist(overrides))
64
+ assert isinstance(config, DictConfig)
65
 
 
 
 
 
 
66
  if image_only:
67
  config.tokenizer = None
68
  return instantiate(config)
tests/test_custom_config.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for custom config support in get_model and get_processor."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+
7
+ import pytest
8
+ import torchio as tio
9
+ from omegaconf import DictConfig
10
+ from omegaconf import OmegaConf
11
+
12
+ from colipri.checkpoint import load_model_config
13
+ from colipri.checkpoint import load_processor_config
14
+ from colipri.processor import Processor
15
+ from colipri.processor import get_processor
16
+
17
+ TRANSFORM_YAML_CONTENT = """\
18
+ _target_: torchio.transforms.augmentation.composition.Compose
19
+ transforms:
20
+ - _target_: torchio.transforms.preprocessing.intensity.clamp.Clamp
21
+ out_min: -500
22
+ out_max: 500
23
+ """
24
+
25
+
26
+ @pytest.fixture
27
+ def transform_yaml(tmp_path: Path) -> Path:
28
+ """A self-contained transform YAML with no interpolation variables."""
29
+ path = tmp_path / "transform.yaml"
30
+ path.write_text(TRANSFORM_YAML_CONTENT)
31
+ return path
32
+
33
+
34
+ @pytest.fixture
35
+ def resolved_processor_config() -> DictConfig:
36
+ """Default processor config with all interpolations resolved."""
37
+ config = load_processor_config()
38
+ resolved = OmegaConf.to_container(config, resolve=True)
39
+ assert isinstance(resolved, dict)
40
+ return OmegaConf.create(resolved)
41
+
42
+
43
+ @pytest.fixture
44
+ def resolved_model_config() -> DictConfig:
45
+ """Default model config with all interpolations resolved."""
46
+ config = load_model_config()
47
+ resolved = OmegaConf.to_container(config, resolve=True)
48
+ assert isinstance(resolved, dict)
49
+ return OmegaConf.create(resolved)
50
+
51
+
52
+ class TestGetProcessorCustomConfig:
53
+ def test_with_transform_yaml_path(self, transform_yaml: Path) -> None:
54
+ """Transform YAML path → Processor with custom transform."""
55
+ processor = get_processor(config=transform_yaml, image_only=True)
56
+ assert isinstance(processor, Processor)
57
+ transform = processor._image_transform
58
+ assert isinstance(transform, tio.Compose)
59
+ assert len(transform.transforms) == 1
60
+ clamp = transform.transforms[0]
61
+ assert clamp.out_min == -500
62
+ assert clamp.out_max == 500
63
+
64
+ def test_with_transform_dictconfig(self) -> None:
65
+ """Transform DictConfig object → Processor with custom transform."""
66
+ config = OmegaConf.create(TRANSFORM_YAML_CONTENT)
67
+ processor = get_processor(config=config, image_only=True)
68
+ assert isinstance(processor, Processor)
69
+ transform = processor._image_transform
70
+ assert isinstance(transform, tio.Compose)
71
+ assert transform.transforms[0].out_min == -500
72
+
73
+ def test_with_full_processor_config(
74
+ self,
75
+ resolved_processor_config: DictConfig,
76
+ ) -> None:
77
+ """Full processor DictConfig → Processor matching that config."""
78
+ # Remove all but the first transform to distinguish from default (5 transforms)
79
+ resolved_processor_config.image_transform.transforms = (
80
+ resolved_processor_config.image_transform.transforms[:1]
81
+ )
82
+ processor = get_processor(
83
+ config=resolved_processor_config,
84
+ image_only=True,
85
+ )
86
+ assert isinstance(processor, Processor)
87
+ assert isinstance(processor._image_transform, tio.Compose)
88
+ assert len(processor._image_transform.transforms) == 1
89
+
90
+ def test_transform_yaml_wraps_with_default_tokenizer(
91
+ self,
92
+ transform_yaml: Path,
93
+ ) -> None:
94
+ """Transform-only config is wrapped with default tokenizer config."""
95
+ # Without image_only, the transform YAML should be wrapped into a full
96
+ # processor config that includes the default tokenizer.
97
+ processor = get_processor(config=transform_yaml)
98
+ assert isinstance(processor, Processor)
99
+ # Should have both custom transform and default tokenizer
100
+ assert isinstance(processor._image_transform, tio.Compose)
101
+ assert processor._text_tokenizer is not None
102
+
103
+ def test_default_unchanged(self) -> None:
104
+ """get_processor() without config still works (backward compat)."""
105
+ processor = get_processor(image_only=True)
106
+ assert isinstance(processor, Processor)
107
+ assert isinstance(processor._image_transform, tio.Compose)
108
+
109
+
110
+ class TestGetModelCustomConfig:
111
+ def test_with_config(self, resolved_model_config: DictConfig) -> None:
112
+ """Pass a model DictConfig → Model with that config."""
113
+ from colipri.model.multimodal import Model
114
+ from colipri.model.multimodal import get_model
115
+
116
+ model = get_model(pretrained=False, config=resolved_model_config)
117
+ assert isinstance(model, Model)
118
+
119
+ def test_default_unchanged(self) -> None:
120
+ """get_model(pretrained=False) without config still works."""
121
+ from colipri.model.multimodal import Model
122
+ from colipri.model.multimodal import get_model
123
+
124
+ model = get_model(pretrained=False)
125
+ assert isinstance(model, Model)