Visual Document Retrieval
ColPali
Safetensors
English
modernvbert
vidore-experimental
vidore
paultltc commited on
Commit
67929cf
·
verified ·
1 Parent(s): 894b7ab

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. README.md +2 -0
  2. config.json +81 -22
  3. configuration_modernvbert.py +53 -171
  4. model.safetensors +2 -2
  5. modeling_modernvbert.py +439 -347
README.md CHANGED
@@ -14,6 +14,8 @@ tags:
14
  pipeline_tag: visual-document-retrieval
15
  ---
16
 
 
 
17
  # ModernVBERT
18
 
19
  ![bg](https://cdn-uploads.huggingface.co/production/uploads/6720a87e392e9cea0187fde6/nRa7iE30dqCUHGblnK8GQ.png)
 
14
  pipeline_tag: visual-document-retrieval
15
  ---
16
 
17
+ TESTING INTEGRATION TO TRANSFORMERS. PLEASE USE ModernVBERT/modernvbert.
18
+
19
  # ModernVBERT
20
 
21
  ![bg](https://cdn-uploads.huggingface.co/production/uploads/6720a87e392e9cea0187fde6/nRa7iE30dqCUHGblnK8GQ.png)
config.json CHANGED
@@ -1,39 +1,98 @@
1
  {
2
- "additional_vocab_size": 40,
3
- "architectures": [
4
- "ModernVBertForMaskedLM"
5
- ],
6
- "auto_map": {
7
- "AutoConfig": "configuration_modernvbert.ModernVBertConfig",
8
- "AutoModel": "modeling_modernvbert.ModernVBertModel",
9
- "AutoModelForMaskedLM": "modeling_modernvbert.ModernVBertForMaskedLM"
10
- },
11
- "hidden_size": 768,
12
  "image_token_id": 50407,
13
  "initializer_range": 0.02,
14
- "max_position_embeddings": 8192,
15
  "model_type": "modernvbert",
16
- "output_attentions": false,
17
  "pixel_shuffle_factor": 4,
18
- "qk_layer_norms": false,
19
  "text_config": {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  "hidden_size": 768,
 
 
21
  "intermediate_size": 1152,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  "mlp_bias": false,
 
 
 
 
 
23
  "num_hidden_layers": 22,
24
- "text_model_name": "jhu-clsp/ettin-encoder-150m",
25
- "vocab_size": 50368
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  },
27
- "tie_word_embeddings": false,
28
- "torch_dtype": "float32",
29
- "transformers_version": null,
30
  "vision_config": {
31
- "embed_dim": 768,
 
 
32
  "image_size": 512,
33
  "intermediate_size": 3072,
 
 
 
 
34
  "num_hidden_layers": 12,
35
- "patch_size": 16,
36
- "vision_model_name": "google/siglip2-base-patch16-512"
37
  },
38
- "vocab_size": 50368
39
  }
 
1
  {
 
 
 
 
 
 
 
 
 
 
2
  "image_token_id": 50407,
3
  "initializer_range": 0.02,
 
4
  "model_type": "modernvbert",
 
5
  "pixel_shuffle_factor": 4,
 
6
  "text_config": {
7
+ "_name_or_path": "ettin-encoder-150m",
8
+ "architectures": [
9
+ "ModernBertForMaskedLM"
10
+ ],
11
+ "attention_bias": false,
12
+ "attention_dropout": 0.0,
13
+ "causal_mask": false,
14
+ "classifier_activation": "gelu",
15
+ "classifier_bias": false,
16
+ "classifier_dropout": 0.0,
17
+ "classifier_pooling": "mean",
18
+ "cls_token_id": 50281,
19
+ "decoder_bias": true,
20
+ "deterministic_flash_attn": false,
21
+ "dtype": "float32",
22
+ "embedding_dropout": 0.0,
23
+ "global_attn_every_n_layers": 3,
24
+ "global_rope_theta": 160000.0,
25
+ "gradient_checkpointing": false,
26
+ "hidden_activation": "gelu",
27
  "hidden_size": 768,
28
+ "initializer_cutoff_factor": 2.0,
29
+ "initializer_range": 0.02,
30
  "intermediate_size": 1152,
31
+ "is_causal": false,
32
+ "layer_norm_eps": 1e-05,
33
+ "layer_types": [
34
+ "full_attention",
35
+ "sliding_attention",
36
+ "sliding_attention",
37
+ "full_attention",
38
+ "sliding_attention",
39
+ "sliding_attention",
40
+ "full_attention",
41
+ "sliding_attention",
42
+ "sliding_attention",
43
+ "full_attention",
44
+ "sliding_attention",
45
+ "sliding_attention",
46
+ "full_attention",
47
+ "sliding_attention",
48
+ "sliding_attention",
49
+ "full_attention",
50
+ "sliding_attention",
51
+ "sliding_attention",
52
+ "full_attention",
53
+ "sliding_attention",
54
+ "sliding_attention",
55
+ "full_attention"
56
+ ],
57
+ "local_attention": 128,
58
+ "local_rope_theta": 160000.0,
59
+ "max_position_embeddings": 7999,
60
  "mlp_bias": false,
61
+ "mlp_dropout": 0.0,
62
+ "model_type": "modernbert",
63
+ "norm_bias": false,
64
+ "norm_eps": 1e-05,
65
+ "num_attention_heads": 12,
66
  "num_hidden_layers": 22,
67
+ "position_embedding_type": "sans_pos",
68
+ "repad_logits_with_grad": false,
69
+ "rope_parameters": {
70
+ "full_attention": {
71
+ "rope_theta": 160000.0,
72
+ "rope_type": "default"
73
+ },
74
+ "sliding_attention": {
75
+ "rope_theta": 160000.0,
76
+ "rope_type": "default"
77
+ }
78
+ },
79
+ "sparse_pred_ignore_index": -100,
80
+ "sparse_prediction": false,
81
+ "vocab_size": 50408
82
  },
83
+ "transformers_version": "5.0.0.dev0",
 
 
84
  "vision_config": {
85
+ "attention_dropout": 0.0,
86
+ "hidden_act": "gelu_pytorch_tanh",
87
+ "hidden_size": 768,
88
  "image_size": 512,
89
  "intermediate_size": 3072,
90
+ "layer_norm_eps": 1e-06,
91
+ "model_type": "siglip_vision_model",
92
+ "num_attention_heads": 12,
93
+ "num_channels": 3,
94
  "num_hidden_layers": 12,
95
+ "patch_size": 16
 
96
  },
97
+ "tie_word_embeddings": false
98
  }
configuration_modernvbert.py CHANGED
@@ -4,157 +4,49 @@
4
  # the file from the modular. If any change should be done, please apply the change to the
5
  # modular_modernvbert.py file directly. One of our CI enforces this.
6
  # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
- import os
8
- from typing import Any, Union
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  from ...configuration_utils import PretrainedConfig
11
- from ..modernbert import ModernBertConfig
12
- from ..siglip import SiglipConfig
13
-
14
-
15
- class ModernVBertTextConfig(PretrainedConfig):
16
- r"""
17
- This is the configuration class to store the configuration of a [`ModernBERT`]. It is used to instantiate an ModernBERT
18
- model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
19
- defaults will yield a similar configuration to that of the [jhu-clsp/ettin-encoder-150m](https://huggingface.co/jhu-clsp/ettin-encoder-150m) architecture.
20
-
21
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
22
- documentation from [`PretrainedConfig`] for more information.
23
- """
24
-
25
- model_type = "modernvbert_text"
26
-
27
- def __init__(
28
- self,
29
- text_model_name="jhu-clsp/ettin-encoder-150m",
30
- hidden_size=768,
31
- num_hidden_layers=22,
32
- intermediate_size=1152,
33
- mlp_bias=False,
34
- vocab_size=50368,
35
- **kwargs,
36
- ):
37
- super().__init__(
38
- text_model_name=text_model_name,
39
- hidden_size=hidden_size,
40
- num_hidden_layers=num_hidden_layers,
41
- intermediate_size=intermediate_size,
42
- mlp_bias=mlp_bias,
43
- vocab_size=vocab_size,
44
- **kwargs,
45
- )
46
-
47
- @classmethod
48
- def from_base_model(
49
- cls,
50
- text_model_name,
51
- **kwargs,
52
- ):
53
- text_config = ModernBertConfig.from_pretrained(text_model_name)
54
- if hasattr(text_config, "text_config"):
55
- text_config = text_config.text_config
56
-
57
- return cls(
58
- text_model_name=text_model_name,
59
- hidden_size=text_config.hidden_size,
60
- num_hidden_layers=text_config.num_hidden_layers,
61
- intermediate_size=text_config.intermediate_size,
62
- mlp_bias=text_config.mlp_bias,
63
- vocab_size=text_config.vocab_size,
64
- **kwargs,
65
- )
66
-
67
-
68
- class ModernVBertVisionConfig(PretrainedConfig):
69
- r"""
70
- This is the configuration class to store the configuration of a [`SigLIP`]. It is used to instantiate the vision encoder part of the ModernVBERT
71
- model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
72
- defaults will yield a similar configuration to that of the SigLIP.
73
-
74
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
75
- documentation from [`PretrainedConfig`] for more information.
76
- """
77
-
78
- model_type = "modernvbert_vision"
79
-
80
- attribute_map = {
81
- "hidden_size": "embed_dim",
82
- }
83
-
84
- def __init__(
85
- self,
86
- vision_model_name="google/siglip2-base-patch16-512",
87
- embed_dim=768,
88
- image_size=512,
89
- patch_size=16,
90
- num_hidden_layers=12,
91
- intermediate_size=3072,
92
- **kwargs,
93
- ):
94
- super().__init__(
95
- vision_model_name=vision_model_name,
96
- embed_dim=embed_dim,
97
- image_size=image_size,
98
- patch_size=patch_size,
99
- num_hidden_layers=num_hidden_layers,
100
- intermediate_size=intermediate_size,
101
- **kwargs,
102
- )
103
-
104
- @classmethod
105
- def from_base_model(
106
- cls,
107
- vision_model_name,
108
- **kwargs,
109
- ):
110
- vision_config = SiglipConfig.from_pretrained(vision_model_name)
111
- if hasattr(vision_config, "vision_config"):
112
- vision_config = vision_config.vision_config
113
-
114
- return cls(
115
- vision_model_name=vision_model_name,
116
- embed_dim=vision_config.hidden_size,
117
- image_size=vision_config.image_size,
118
- patch_size=vision_config.patch_size,
119
- num_hidden_layers=vision_config.num_hidden_layers,
120
- intermediate_size=vision_config.intermediate_size,
121
- **kwargs,
122
- )
123
 
124
 
125
  class ModernVBertConfig(PretrainedConfig):
126
  r"""
127
- This is the configuration class to store the configuration of a `ModernVBert` model. It is used to
128
  instantiate a ModernVBert model according to the specified arguments and defines the model architecture.
 
129
 
130
  Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
131
  See the documentation for [`PretrainedConfig`] for more details.
132
 
133
  Args:
134
- text_config (`PretrainedConfig` or `dict`, optional):
135
- Custom text config or a dict with a `text_model_name` key for the text encoder. If `None`, the
136
- default text backbone defined by `DEFAULT_TEXT_MODEL_NAME` is used.
137
- vision_config (`PretrainedConfig` or `dict`, optional):
138
- Custom vision config or a dict with a `vision_model_name` key for the vision encoder. If `None`, the
139
- default vision backbone defined by `DEFAULT_VISION_MODEL_NAME` is used.
140
- image_token_id (`int`, optional, defaults to 128257):
141
- Token id reserved for image tokens inserted into the text stream.
142
- vocab_size (`int`, optional, defaults to 128256):
143
- Vocabulary size used by the text embeddings.
144
- tie_word_embeddings (`bool`, optional, defaults to `False`):
145
- Whether to tie input token embeddings and output token embeddings.
146
- pixel_shuffle_factor (`int`, optional, defaults to 4):
147
- Scale factor used by any pixel-shuffle / upsampling operations in the vision head.
148
- additional_vocab_size (`int`, optional, defaults to 0):
149
- Number of extra tokens appended to the base vocabulary (useful for adapters / special tokens).
150
- pad_token_id (`int`, optional):
151
- Padding token id.
152
- initializer_range (`float`, optional, defaults to 0.02):
153
- Stddev used for weight initialization.
154
 
155
  Example:
156
  ```python
157
- >>> from modernvbert import ModernVBertConfig
158
 
159
  >>> # Initializing configuration
160
  >>> configuration = ModernVBertConfig()
@@ -162,7 +54,7 @@ class ModernVBertConfig(PretrainedConfig):
162
  >>> # Initializing a model from the configuration (model class is implemented in
163
  >>> # `modernvbert.modeling_modernvbert`)
164
 
165
- >>> from modernvbert import ModernVBertModel
166
  >>> model = ModernVBertModel(configuration)
167
 
168
  >>> # Accessing the model configuration
@@ -170,56 +62,46 @@ class ModernVBertConfig(PretrainedConfig):
170
  ```"""
171
 
172
  model_type = "modernvbert"
173
- sub_configs: dict[str, Any] = {"text_config": ModernVBertTextConfig, "vision_config": ModernVBertVisionConfig}
174
 
175
  def __init__(
176
  self,
177
  text_config=None,
178
  vision_config=None,
179
- image_token_id: int = 50407,
180
- initializer_range=0.02,
181
- vocab_size=50368,
182
- pad_token_id=None,
183
- pixel_shuffle_factor=4,
184
- additional_vocab_size=0,
 
185
  **kwargs,
186
  ):
187
- super().__init__(**kwargs)
 
 
 
188
 
189
  if text_config is None:
190
- text_config = self.sub_configs["text_config"].from_base_model("jhu-clsp/ettin-encoder-150m")
191
  elif isinstance(text_config, dict):
192
- text_config = self.sub_configs["text_config"].from_dict(text_config)
193
  self.text_config = text_config
194
 
195
  if vision_config is None:
196
- vision_config = self.sub_configs["vision_config"].from_base_model("google/siglip2-base-patch16-512")
197
  elif isinstance(vision_config, dict):
198
- vision_config = self.sub_configs["vision_config"].from_dict(vision_config)
199
  self.vision_config = vision_config
200
 
201
- self.initializer_range = initializer_range
202
- self.image_token_id = image_token_id
203
- self.pad_token_id = pad_token_id
204
  self.pixel_shuffle_factor = pixel_shuffle_factor
205
- self.vocab_size = vocab_size
206
- self.additional_vocab_size = additional_vocab_size
207
- self.hidden_size = kwargs.pop("hidden_size", self.text_config.hidden_size)
208
-
209
- @classmethod
210
- def from_pretrained_models(
211
- cls,
212
- text_model_name: Union[str, os.PathLike],
213
- vision_model_name: Union[str, os.PathLike],
214
- **kwargs,
215
- ) -> "PretrainedConfig":
216
- text_model_config = ModernVBertTextConfig.from_base_model(text_model_name)
217
- vision_model_config = ModernVBertVisionConfig.from_base_model(vision_model_name)
218
- return cls(
219
- text_config=text_model_config,
220
- vision_config=vision_model_config,
221
- **kwargs,
222
- )
223
 
224
 
225
- __all__ = ["ModernVBertConfig", "ModernVBertTextConfig", "ModernVBertVisionConfig"]
 
4
  # the file from the modular. If any change should be done, please apply the change to the
5
  # modular_modernvbert.py file directly. One of our CI enforces this.
6
  # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # Copyright 2026 Illuin Technology and contributors, and The HuggingFace Inc. team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ from typing import Any, Literal
22
 
23
  from ...configuration_utils import PretrainedConfig
24
+ from ..auto import CONFIG_MAPPING, AutoConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
 
27
  class ModernVBertConfig(PretrainedConfig):
28
  r"""
29
+ This is the configuration class to store the configuration of a [`ModernVBert`] model. It is used to
30
  instantiate a ModernVBert model according to the specified arguments and defines the model architecture.
31
+ e.g. [ModernVBERT/modernvbert](https://huggingface.co/ModernVBERT/modernvbert).
32
 
33
  Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
34
  See the documentation for [`PretrainedConfig`] for more details.
35
 
36
  Args:
37
+ text_config (`AutoConfig`, *optional*): Configuration for the text encoder.
38
+ vision_config (`ModernVBertVisionConfig`, *optional*): Configuration for the vision encoder.
39
+ image_token_id (`int | None`, *optional*, defaults to 50407): The token id reserved for image tokens inserted into the text stream.
40
+ pixel_shuffle_factor (`int | None`, *optional*, defaults to 4): Scale factor used by any pixel-shuffle / upsampling operations in the vision head.
41
+ initializer_range (`float | None`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
42
+ initializer_cutoff_factor (`float | None`, *optional*, defaults to 2.0): The cutoff factor for the truncated_normal_initializer for initializing all weight matrices.
43
+ classifier_pooling (`Literal["cls", "mean"]`, *optional*, defaults to `"cls"`): The pooling strategy to use for classification tasks.
44
+ classifier_dropout (`float | None`, *optional*, defaults to 0.0): The dropout probability for the classification head.
45
+ classifier_bias (`bool | None`, *optional*, defaults to `False`): Whether to add a bias term to the classification head.
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  Example:
48
  ```python
49
+ >>> from transformers import ModernVBertConfig
50
 
51
  >>> # Initializing configuration
52
  >>> configuration = ModernVBertConfig()
 
54
  >>> # Initializing a model from the configuration (model class is implemented in
55
  >>> # `modernvbert.modeling_modernvbert`)
56
 
57
+ >>> from transformers import ModernVBertModel
58
  >>> model = ModernVBertModel(configuration)
59
 
60
  >>> # Accessing the model configuration
 
62
  ```"""
63
 
64
  model_type = "modernvbert"
65
+ sub_configs: dict[str, Any] = {"text_config": AutoConfig, "vision_config": AutoConfig}
66
 
67
  def __init__(
68
  self,
69
  text_config=None,
70
  vision_config=None,
71
+ image_token_id: int | None = 50407,
72
+ pixel_shuffle_factor: int | None = 4,
73
+ initializer_range: float | None = 0.02,
74
+ initializer_cutoff_factor: float | None = 2.0,
75
+ classifier_pooling: Literal["cls", "mean"] = "cls",
76
+ classifier_dropout: float | None = 0.0,
77
+ classifier_bias: bool | None = False,
78
  **kwargs,
79
  ):
80
+ if classifier_pooling not in ["cls", "mean"]:
81
+ raise ValueError(
82
+ f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {classifier_pooling}.'
83
+ )
84
 
85
  if text_config is None:
86
+ text_config = CONFIG_MAPPING["modernbert"]()
87
  elif isinstance(text_config, dict):
88
+ text_config = CONFIG_MAPPING["modernbert"](**text_config)
89
  self.text_config = text_config
90
 
91
  if vision_config is None:
92
+ vision_config = CONFIG_MAPPING["siglip_vision_model"]()
93
  elif isinstance(vision_config, dict):
94
+ vision_config = CONFIG_MAPPING["siglip_vision_model"](**vision_config)
95
  self.vision_config = vision_config
96
 
 
 
 
97
  self.pixel_shuffle_factor = pixel_shuffle_factor
98
+ self.initializer_range = initializer_range
99
+ self.initializer_cutoff_factor = initializer_cutoff_factor
100
+ self.classifier_pooling = classifier_pooling
101
+ self.classifier_dropout = classifier_dropout
102
+ self.classifier_bias = classifier_bias
103
+
104
+ super().__init__(image_token_id=image_token_id, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
105
 
106
 
107
+ __all__ = ["ModernVBertConfig"]
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7258bef432f17b1010de6f93d9651d8298da83cea3430dac26b6caf43864162c
3
- size 1165468824
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d38dafdb2bc949c08f0fd320fd515479e2e93f2b849dd177f89cc0362571de7
3
+ size 1165471416
modeling_modernvbert.py CHANGED
@@ -4,115 +4,48 @@
4
  # the file from the modular. If any change should be done, please apply the change to the
5
  # modular_modernvbert.py file directly. One of our CI enforces this.
6
  # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from dataclasses import dataclass
8
- from typing import Optional, Union
9
 
10
  import torch
11
  import torch.nn as nn
12
- import torch.nn.functional as F
13
- from torch.nn import CrossEntropyLoss
14
-
15
- from ...modeling_flash_attention_utils import FlashAttentionKwargs
16
- from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPoolingAndCrossAttentions, MaskedLMOutput
 
 
 
 
 
 
17
  from ...modeling_utils import PreTrainedModel
18
  from ...processing_utils import Unpack
19
- from ...utils import auto_docstring, can_return_tuple
20
- from ..modernbert import ModernBertConfig, ModernBertForMaskedLM, ModernBertModel
21
- from ..siglip import SiglipVisionConfig, SiglipVisionModel
22
  from .configuration_modernvbert import ModernVBertConfig
23
 
24
 
25
- class DecoupledEmbedding(nn.Embedding):
26
- # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
27
- """
28
- Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings.
29
- In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, then it will create `num_additional_embeddings` additional parameters that are always trained.
30
- If `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
31
- """
32
-
33
- def __init__(
34
- self,
35
- num_embeddings,
36
- num_additional_embeddings,
37
- embedding_dim,
38
- partially_freeze=False,
39
- device=None,
40
- dtype=None,
41
- padding_idx=None,
42
- **kwargs,
43
- ) -> None:
44
- """
45
- num_additional_embeddings: int. Number of additional embeddings. Only useful when you `partially_freeze=True`.
46
- partially_freeze: bool. If True, the regular `weight` will be frozen. `additional_weight` is never frozen.
47
-
48
- Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`, `max_norm` or `norm_type`. We are not supporting these.
49
- """
50
- if padding_idx is not None and padding_idx > num_embeddings:
51
- raise ValueError(f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}")
52
-
53
- super().__init__(
54
- num_embeddings=num_embeddings,
55
- embedding_dim=embedding_dim,
56
- device=device,
57
- dtype=dtype,
58
- padding_idx=padding_idx,
59
- **kwargs,
60
- )
61
- self.num_embeddings = num_embeddings
62
- self.num_additional_embeddings = num_additional_embeddings
63
- self.partially_freeze = partially_freeze
64
-
65
- if partially_freeze:
66
- self.weight.requires_grad_(False)
67
-
68
- if self.num_additional_embeddings > 0:
69
- self.additional_embedding = nn.Embedding(
70
- num_embeddings=num_additional_embeddings,
71
- embedding_dim=embedding_dim,
72
- device=device,
73
- dtype=dtype,
74
- )
75
-
76
- def forward(self, input_ids):
77
- """
78
- we have 2 embeddings, with different indices - one pretrained self.weight and another
79
- self.additional_embedding.weight that is being trained.
80
-
81
- in order to make a lookup of the input ids, we:
82
- 1. find out the indices of the entries belonging to the 2nd embedding
83
- 2. extract those values while subtracting the size of the first embedding (num_embeddings),
84
- since the 2nd embedding starts from 0 and not num_embeddings
85
- 3. perform the 2nd embedding lookup
86
- 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
87
- 5. perform the 1st embedding lookup
88
- 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
89
-
90
- note: for the 1st embedding lookup we could have looked up only the low indices and not do
91
- the padding, but then we have to create a new tensor and populate it with 2 tensors that are
92
- spread out across various indices - i.e. not a simple concat - I haven't benchmarked the
93
- complex case if it's any faster, given that seqlens are usually relatively short it's
94
- probably not faster or if faster not by much - but might be a good idea to measure.
95
-
96
- """
97
- if self.num_additional_embeddings == 0:
98
- return super().forward(input_ids)
99
-
100
- input_ids = input_ids.clone()
101
- additional_vocab_indices = torch.where(input_ids >= self.num_embeddings)
102
- input_ids_additional_vocab = input_ids[additional_vocab_indices]
103
- additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings)
104
-
105
- # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
106
- input_ids[additional_vocab_indices] = 0
107
- full_vector = F.embedding(input_ids, self.weight)
108
- full_vector[additional_vocab_indices] = additional_embeddings # overwrite the records with high indices
109
- return full_vector
110
-
111
-
112
  @dataclass
113
  class ModernVBertBaseModelOutput(BaseModelOutput):
114
  """
115
- Base class for ModernVBERT model's outputs that may also contain a past key/values (to speed up sequential decoding).
116
  Args:
117
  last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
118
  Sequence of hidden-states at the output of the last layer of the model.
@@ -134,15 +67,15 @@ class ModernVBertBaseModelOutput(BaseModelOutput):
134
  """
135
 
136
  last_hidden_state: torch.FloatTensor = None
137
- hidden_states: Optional[tuple[torch.FloatTensor]] = None
138
- attentions: Optional[tuple[torch.FloatTensor]] = None
139
- image_hidden_states: Optional[tuple[torch.FloatTensor]] = None
140
 
141
 
142
  @dataclass
143
  class ModernVBertMaskedLMOutput(MaskedLMOutput):
144
  """
145
- Base class for ModernVBERT model's outputs that may also contain a past key/values (to speed up sequential decoding).
146
  Args:
147
  loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
148
  Masked language modeling (MLM) loss.
@@ -163,22 +96,11 @@ class ModernVBertMaskedLMOutput(MaskedLMOutput):
163
  image_hidden_states of the model produced by the vision encoder
164
  """
165
 
166
- loss: Optional[torch.FloatTensor] = None
167
  logits: torch.FloatTensor = None
168
- hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
169
- attentions: Optional[tuple[torch.FloatTensor, ...]] = None
170
- image_hidden_states: Optional[torch.FloatTensor] = None
171
-
172
-
173
- class ModernVBertSimpleMLP(nn.Module):
174
- """A simple linear projection layer to project the vision hidden states to the text hidden states."""
175
-
176
- def __init__(self, input_size, output_size):
177
- super().__init__()
178
- self.proj = nn.Linear(input_size, output_size, bias=False)
179
-
180
- def forward(self, x):
181
- return self.proj(x)
182
 
183
 
184
  class ModernVBertConnector(nn.Module):
@@ -190,148 +112,186 @@ class ModernVBertConnector(nn.Module):
190
  def __init__(self, config):
191
  super().__init__()
192
  self.pixel_shuffle_factor = config.pixel_shuffle_factor
193
- self.modality_projection = ModernVBertSimpleMLP(
194
- input_size=config.vision_config.hidden_size * (config.pixel_shuffle_factor**2),
195
- output_size=config.text_config.hidden_size,
 
196
  )
197
 
198
- def pixel_shuffle(self, x, pixel_shuffle_factor):
199
- bsz, seq, embed_dim = x.size()
200
- height = width = int(seq**0.5)
201
- x = x.view(bsz, height, width, embed_dim)
202
- x = x.view(bsz, height, int(width / pixel_shuffle_factor), embed_dim * pixel_shuffle_factor)
203
- x = x.permute(0, 2, 1, 3)
204
- x = x.reshape(
205
- bsz,
 
 
206
  int(width / pixel_shuffle_factor),
207
  int(height / pixel_shuffle_factor),
208
  embed_dim * (pixel_shuffle_factor**2),
209
  )
210
- x = x.permute(0, 2, 1, 3)
211
- return x.reshape(bsz, int(seq / (pixel_shuffle_factor**2)), embed_dim * (pixel_shuffle_factor**2))
 
 
212
 
213
  def forward(self, image_hidden_states):
214
  image_hidden_states = self.pixel_shuffle(image_hidden_states, self.pixel_shuffle_factor)
215
  return self.modality_projection(image_hidden_states)
216
 
217
 
 
218
  class ModernVBertPreTrainedModel(PreTrainedModel):
219
- config_class = ModernVBertConfig
220
  base_model_prefix = "model"
 
221
  supports_gradient_checkpointing = True
222
- _supports_flash_attn_2 = True
 
 
 
 
 
 
 
223
  _supports_sdpa = True
 
 
 
 
224
 
 
225
  def _init_weights(self, module):
226
- std = getattr(self.config, "initializer_range", 0.02)
227
- if isinstance(module, (nn.Linear, nn.Conv2d)):
228
- module.weight.data.normal_(mean=0.0, std=std)
229
- if module.bias is not None:
230
- module.bias.data.zero_()
231
- elif isinstance(module, nn.Embedding):
232
- module.weight.data.normal_(mean=0.0, std=std)
233
- if module.padding_idx is not None:
234
- module.weight.data[module.padding_idx].zero_()
235
-
 
236
 
237
- @auto_docstring
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  class ModernVBertModel(ModernVBertPreTrainedModel):
 
 
 
 
 
239
  def __init__(self, config: ModernVBertConfig):
240
  super().__init__(config)
 
 
 
241
 
242
  # init components
243
- self.vision_model = ModernVBertModel.init_vision_model(config)
244
  self.connector = ModernVBertConnector(config)
245
- self.text_model = ModernVBertModel.init_language_model(config)
246
-
247
- # set the correct dtype for vision and text models
248
- self.vision_model.to(self.dtype)
249
- self.text_model.to(self.dtype)
250
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
251
 
252
  self.image_seq_len = int(
253
  ((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
254
  / (config.pixel_shuffle_factor**2)
255
  )
 
256
 
257
  self.post_init()
258
 
259
- @staticmethod
260
- def init_vision_model(config: ModernVBertConfig):
261
- vision_model_config = SiglipVisionConfig.from_pretrained(
262
- config.vision_config.vision_model_name,
263
- _attn_implementation=config._attn_implementation,
264
- )
265
- vision_model = SiglipVisionModel(vision_model_config).vision_model
266
- return vision_model
267
-
268
- @staticmethod
269
- def init_language_model(config: ModernVBertConfig):
270
- text_model_config = ModernBertConfig.from_pretrained(
271
- config.text_config.text_model_name,
272
- _attn_implementation=config._attn_implementation,
273
- )
274
- text_model = ModernBertModel(text_model_config)
275
- embed_layer = DecoupledEmbedding(
276
- num_embeddings=text_model_config.vocab_size,
277
- num_additional_embeddings=config.additional_vocab_size,
278
- embedding_dim=config.hidden_size,
279
- partially_freeze=getattr(config, "freeze_config", {"freeze_text_layers": False})["freeze_text_layers"],
280
- padding_idx=config.pad_token_id,
281
- )
282
- text_model.set_input_embeddings(embed_layer)
283
- return text_model
284
-
285
- # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.enable_input_require_grads
286
- def enable_input_require_grads(self):
287
- """
288
- Enables the gradients for the input embeddings.
289
 
290
- This is useful for lora when using gradient checkpointing.
291
- c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032
292
 
293
- Override to set output.requires_grad = True for both the decoder's and vision model's embeddings.
 
 
294
  """
 
 
 
 
 
 
 
 
 
295
 
296
- def get_lowest_module(module):
297
- if len(list(module.children())) == 0:
298
- # If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.)
299
- return module
300
- else:
301
- # Recursively call the function on each child module
302
- return get_lowest_module(list(module.children())[0])
303
-
304
- def make_inputs_require_grads(module, input, output):
305
- output.requires_grad_(True)
306
 
307
- self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
308
- self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook(
309
- make_inputs_require_grads
 
310
  )
 
311
 
312
- # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.disable_input_require_grads
313
- def disable_input_require_grads(self):
314
- self._text_require_grads_hook.remove()
315
- self._vision_require_grads_hook.remove()
 
 
316
 
317
- def get_input_embeddings(self):
318
- return self.text_model.get_input_embeddings()
319
 
320
- def set_input_embeddings(self, value):
321
- self.text_model.set_input_embeddings(value)
322
 
 
 
 
 
323
  def get_image_features(
324
- self, pixel_values: torch.FloatTensor, pixel_attention_mask: Optional[torch.LongTensor] = None
325
- ):
326
- """
327
- Derived from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/smolvlm/modeling_smolvlm.py
328
- Encodes images into continuous embeddings that can be forwarded to the language model.
329
-
330
- Args:
331
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
332
- The tensors corresponding to the input images.
333
- pixel_attention_mask (`torch.LongTensor`, *optional*):
334
- The attention mask indicating padded regions in the image.
335
  """
336
  batch_size, num_images, num_channels, height, width = pixel_values.shape
337
  pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
@@ -341,8 +301,8 @@ class ModernVBertModel(ModernVBertPreTrainedModel):
341
  nb_values_per_image = pixel_values.shape[1:].numel()
342
  real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
343
 
344
- if not any(real_images_inds):
345
- real_images_inds[0] = True
346
 
347
  pixel_values = pixel_values[real_images_inds].contiguous()
348
  # Handle the vision attention mask
@@ -356,60 +316,24 @@ class ModernVBertModel(ModernVBertPreTrainedModel):
356
  # Remove padding images from the mask
357
  pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
358
  pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
359
-
360
  patch_size = self.config.vision_config.patch_size
361
  patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
362
  patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
363
  patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
364
 
365
  # Get sequence from the vision encoder
366
- image_hidden_states = self.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
367
- image_hidden_states = image_hidden_states.last_hidden_state
368
-
369
- return image_hidden_states
370
-
371
- def inputs_merger(self, input_ids, inputs_embeds, image_hidden_states):
372
- """Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/smolvlm/modeling_smolvlm.py
373
-
374
- This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
375
- The merging happens as follows:
376
- - The text token sequence is: `tok_1 tok_2 tok_3 <fake_token_around_image> <image> <image> ... <image> <fake_token_around_image> tok_4`.
377
- - We get the image hidden states for the image through the vision encoder and that hidden state, after a pixel shuffle operation, is then projected into the text embedding space.
378
- We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer.
379
- - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
380
- - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
381
- """
382
-
383
- _, patch_size, _ = image_hidden_states.shape
384
-
385
- if input_ids is None:
386
- image_mask = inputs_embeds == self.get_input_embeddings()(
387
- torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
388
- )
389
- image_mask = image_mask[..., 0] # slice off the hidden dim
390
- else:
391
- image_mask = input_ids == self.config.image_token_id
392
-
393
- # Assert that the input <image> tokens are valid (i.e. multiple of patch_size)
394
- num_image_tokens = image_mask.sum(dim=1)
395
- if not torch.all(num_image_tokens % patch_size == 0):
396
- raise ValueError("Number of <image> tokens not divisible by patch_size.")
397
-
398
- blocks_per_sample = num_image_tokens // patch_size
399
-
400
- offsets = torch.nn.functional.pad(blocks_per_sample.cumsum(dim=0), (1, 0), value=0)
401
- block_offset = offsets[:-1]
402
- row_cum = image_mask.cumsum(dim=-1)
403
- chunk_idx = (row_cum - 1) // patch_size
404
- local_idx = (row_cum - 1) % patch_size
405
- block_idx = block_offset.unsqueeze(1) + chunk_idx
406
 
407
- image_embeds = torch.zeros_like(inputs_embeds)
408
- image_embeds[image_mask] = image_hidden_states[block_idx[image_mask], local_idx[image_mask], :]
 
409
 
410
- return torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds)
411
 
412
- @can_return_tuple
413
  @auto_docstring(
414
  custom_intro="""
415
  Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
@@ -420,53 +344,38 @@ class ModernVBertModel(ModernVBertPreTrainedModel):
420
  discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
421
  image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
422
  """,
423
- checkpoint="modernvbert/ModernVBert",
424
  )
425
  def forward(
426
  self,
427
  input_ids: torch.LongTensor = None,
428
- attention_mask: Optional[torch.Tensor] = None,
429
- position_ids: Optional[torch.LongTensor] = None,
430
- inputs_embeds: Optional[torch.FloatTensor] = None,
431
- pixel_values: Optional[torch.FloatTensor] = None,
432
- pixel_attention_mask: Optional[torch.BoolTensor] = None,
433
- image_hidden_states: Optional[torch.FloatTensor] = None,
434
- output_attentions: Optional[bool] = None,
435
- output_hidden_states: Optional[bool] = None,
436
- return_dict: Optional[bool] = None,
437
- **kwargs: Unpack[FlashAttentionKwargs],
438
- ) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
439
  r"""
440
  pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
441
  Mask to avoid performing attention on padding pixel indices.
442
  image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
443
  The hidden states of the image encoder after modality projection.
444
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
445
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
446
- config.vocab_size]` or `model.image_token_id`. Tokens with indices set to `model.image_token_id` are
447
- ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
448
  """
449
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
450
- output_hidden_states = (
451
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
452
- )
453
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
454
 
455
  if inputs_embeds is None:
456
  inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device)
457
 
458
  # Images processing
459
  if pixel_values is not None:
460
- # Vision encoder pass
461
  image_hidden_states = self.get_image_features(
462
  pixel_values=pixel_values, pixel_attention_mask=pixel_attention_mask
463
- )
464
- # Modality projection & resampling
465
- image_hidden_states = self.connector(image_hidden_states)
466
 
467
  # Merge image and text embeddings
468
  if image_hidden_states is not None:
469
- image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=inputs_embeds.device)
470
  inputs_embeds = self.inputs_merger(
471
  input_ids=input_ids, inputs_embeds=inputs_embeds, image_hidden_states=image_hidden_states
472
  )
@@ -476,9 +385,6 @@ class ModernVBertModel(ModernVBertPreTrainedModel):
476
  inputs_embeds=inputs_embeds,
477
  attention_mask=attention_mask,
478
  position_ids=position_ids,
479
- output_attentions=output_attentions,
480
- output_hidden_states=output_hidden_states,
481
- return_dict=return_dict,
482
  **kwargs,
483
  )
484
 
@@ -490,40 +396,41 @@ class ModernVBertModel(ModernVBertPreTrainedModel):
490
  )
491
 
492
 
493
- class ModernVBertLMHead(nn.Module):
494
- def __init__(self, config):
495
  super().__init__()
496
- pretrained_config = ModernBertConfig.from_pretrained(config.text_config.text_model_name)
497
- pretrained_model = ModernBertForMaskedLM(pretrained_config)
498
- self.head = pretrained_model.head
499
- self.decoder = pretrained_model.decoder
500
 
501
- def forward(self, hidden_states):
502
- return self.decoder(self.head(hidden_states))
503
 
504
 
505
  @auto_docstring
506
  class ModernVBertForMaskedLM(ModernVBertPreTrainedModel):
507
- _tied_weights_keys = ["lm_head.decoder.weight", "model.text_model.embeddings.word_embeddings.weight"]
508
 
509
  def __init__(self, config):
510
  super().__init__(config)
511
- self.in_features = config.hidden_size
512
- self.out_additional_features = config.additional_vocab_size
513
- self.vocab_size = config.vocab_size
514
  self.model = ModernVBertModel(config)
515
- self.lm_head = ModernVBertLMHead(config)
516
- if self.out_additional_features > 0:
517
- self.additional_fc = nn.Linear(self.in_features, self.out_additional_features, bias=False)
518
- self.lm_head.to(self.dtype)
519
  self.post_init()
520
 
521
- # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.disable_input_require_grads
522
- def disable_input_require_grads(self):
523
- self._text_require_grads_hook.remove()
524
- self._vision_require_grads_hook.remove()
525
 
526
- @can_return_tuple
 
 
 
527
  @auto_docstring(
528
  custom_intro="""
529
  Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
@@ -534,23 +441,20 @@ class ModernVBertForMaskedLM(ModernVBertPreTrainedModel):
534
  discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
535
  image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
536
  """,
537
- checkpoint="modernvbert/ModernVBert",
538
  )
539
  def forward(
540
  self,
541
  input_ids: torch.LongTensor = None,
542
- attention_mask: Optional[torch.Tensor] = None,
543
- position_ids: Optional[torch.LongTensor] = None,
544
- inputs_embeds: Optional[torch.FloatTensor] = None,
545
- pixel_values: Optional[torch.FloatTensor] = None,
546
- pixel_attention_mask: Optional[torch.BoolTensor] = None,
547
- image_hidden_states: Optional[torch.FloatTensor] = None,
548
- output_attentions: Optional[bool] = None,
549
- output_hidden_states: Optional[bool] = None,
550
- return_dict: Optional[bool] = None,
551
- labels: Optional[torch.LongTensor] = None,
552
- **kwargs: Unpack[FlashAttentionKwargs],
553
- ) -> Union[tuple, ModernVBertMaskedLMOutput]:
554
  r"""
555
  pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
556
  Mask to avoid performing attention on padding pixel indices.
@@ -558,16 +462,92 @@ class ModernVBertForMaskedLM(ModernVBertPreTrainedModel):
558
  The hidden states of the image encoder after modality projection.
559
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
560
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
561
- config.vocab_size]` or `model.image_token_id`. Tokens with indices set to `model.image_token_id` are
562
- ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
563
  """
564
 
565
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
566
- output_hidden_states = (
567
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
 
568
  )
569
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
571
  outputs = self.model(
572
  input_ids=input_ids,
573
  attention_mask=attention_mask,
@@ -576,35 +556,147 @@ class ModernVBertForMaskedLM(ModernVBertPreTrainedModel):
576
  pixel_values=pixel_values,
577
  pixel_attention_mask=pixel_attention_mask,
578
  image_hidden_states=image_hidden_states,
579
- output_attentions=output_attentions,
580
- output_hidden_states=output_hidden_states,
581
- return_dict=return_dict,
582
  **kwargs,
583
  )
584
- hidden_states = outputs[0]
 
 
 
 
 
 
 
 
 
585
 
586
- logits = self.lm_head(hidden_states)
 
 
 
 
587
 
588
- if self.out_additional_features > 0:
589
- proj_states = self.lm_head.head(hidden_states)
590
- additional_features = self.additional_fc(proj_states)
591
- logits = torch.cat((logits, additional_features), -1)
592
 
593
  loss = None
594
  if labels is not None:
595
- loss = CrossEntropyLoss()(logits.view(-1, self.vocab_size + self.out_additional_features), labels.view(-1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
596
 
597
- if not return_dict:
598
- output = (logits,) + outputs[2:]
599
- return ((loss,) + output) if loss is not None else output
600
 
601
- return ModernVBertMaskedLMOutput(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
602
  loss=loss,
603
- logits=logits.float(),
604
  hidden_states=outputs.hidden_states,
605
  attentions=outputs.attentions,
606
- image_hidden_states=outputs.image_hidden_states,
607
  )
608
 
609
 
610
- __all__ = ["ModernVBertPreTrainedModel", "ModernVBertModel", "ModernVBertForMaskedLM"]
 
 
 
 
 
 
 
4
  # the file from the modular. If any change should be done, please apply the change to the
5
  # modular_modernvbert.py file directly. One of our CI enforces this.
6
  # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # Copyright 2026 Illuin Technology and contributors, and The HuggingFace Inc. team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ import math
22
  from dataclasses import dataclass
 
23
 
24
  import torch
25
  import torch.nn as nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+
28
+ from ... import initialization as init
29
+ from ...activations import ACT2FN
30
+ from ...modeling_outputs import (
31
+ BaseModelOutput,
32
+ BaseModelOutputWithPooling,
33
+ MaskedLMOutput,
34
+ SequenceClassifierOutput,
35
+ TokenClassifierOutput,
36
+ )
37
  from ...modeling_utils import PreTrainedModel
38
  from ...processing_utils import Unpack
39
+ from ...utils import TransformersKwargs, auto_docstring, torch_compilable_check
40
+ from ...utils.generic import can_return_tuple, check_model_inputs
41
+ from ..auto import AutoModel
42
  from .configuration_modernvbert import ModernVBertConfig
43
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  @dataclass
46
  class ModernVBertBaseModelOutput(BaseModelOutput):
47
  """
48
+ Base class for ModernVBERT model's outputs.
49
  Args:
50
  last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
51
  Sequence of hidden-states at the output of the last layer of the model.
 
67
  """
68
 
69
  last_hidden_state: torch.FloatTensor = None
70
+ hidden_states: tuple[torch.FloatTensor] | None = None
71
+ attentions: tuple[torch.FloatTensor] | None = None
72
+ image_hidden_states: tuple[torch.FloatTensor] | None = None
73
 
74
 
75
  @dataclass
76
  class ModernVBertMaskedLMOutput(MaskedLMOutput):
77
  """
78
+ Base class for ModernVBERT model's outputs with masked language modeling loss.
79
  Args:
80
  loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
81
  Masked language modeling (MLM) loss.
 
96
  image_hidden_states of the model produced by the vision encoder
97
  """
98
 
99
+ loss: torch.FloatTensor | None = None
100
  logits: torch.FloatTensor = None
101
+ hidden_states: tuple[torch.FloatTensor, ...] | None = None
102
+ attentions: tuple[torch.FloatTensor, ...] | None = None
103
+ image_hidden_states: torch.FloatTensor | None = None
 
 
 
 
 
 
 
 
 
 
 
104
 
105
 
106
  class ModernVBertConnector(nn.Module):
 
112
  def __init__(self, config):
113
  super().__init__()
114
  self.pixel_shuffle_factor = config.pixel_shuffle_factor
115
+ self.modality_projection = nn.Linear(
116
+ config.vision_config.hidden_size * (config.pixel_shuffle_factor**2),
117
+ config.text_config.hidden_size,
118
+ bias=False,
119
  )
120
 
121
+ def pixel_shuffle(self, image_hidden_states, pixel_shuffle_factor):
122
+ batch_size, seq_length, embed_dim = image_hidden_states.size()
123
+ height = width = int(seq_length**0.5)
124
+ image_hidden_states = image_hidden_states.view(batch_size, height, width, embed_dim)
125
+ image_hidden_states = image_hidden_states.view(
126
+ batch_size, height, int(width / pixel_shuffle_factor), embed_dim * pixel_shuffle_factor
127
+ )
128
+ image_hidden_states = image_hidden_states.permute(0, 2, 1, 3)
129
+ image_hidden_states = image_hidden_states.reshape(
130
+ batch_size,
131
  int(width / pixel_shuffle_factor),
132
  int(height / pixel_shuffle_factor),
133
  embed_dim * (pixel_shuffle_factor**2),
134
  )
135
+ image_hidden_states = image_hidden_states.permute(0, 2, 1, 3)
136
+ return image_hidden_states.reshape(
137
+ batch_size, int(seq_length / (pixel_shuffle_factor**2)), embed_dim * (pixel_shuffle_factor**2)
138
+ )
139
 
140
  def forward(self, image_hidden_states):
141
  image_hidden_states = self.pixel_shuffle(image_hidden_states, self.pixel_shuffle_factor)
142
  return self.modality_projection(image_hidden_states)
143
 
144
 
145
+ @auto_docstring
146
  class ModernVBertPreTrainedModel(PreTrainedModel):
147
+ config: ModernVBertConfig
148
  base_model_prefix = "model"
149
+ input_modalities = ("image", "text")
150
  supports_gradient_checkpointing = True
151
+ _no_split_modules = [
152
+ "ModernBertEmbeddings",
153
+ "ModernBertEncoderLayer",
154
+ "SiglipEncoderLayer",
155
+ "SiglipMultiheadAttentionPoolingHead",
156
+ ]
157
+ _skip_keys_device_placement = "past_key_values"
158
+ _supports_flash_attn = True
159
  _supports_sdpa = True
160
+ _supports_flex_attn = False
161
+ _supports_attention_backend = True
162
+ config_class = ModernVBertConfig
163
+ _can_record_outputs = {"image_hidden_states": ModernVBertConnector}
164
 
165
+ @torch.no_grad()
166
  def _init_weights(self, module):
167
+ super()._init_weights(module)
168
+
169
+ def init_weight(module: nn.Module, std: float):
170
+ cutoff_factor = getattr(self.config, "initializer_cutoff_factor", 2.0)
171
+ init.trunc_normal_(
172
+ module.weight,
173
+ mean=0.0,
174
+ std=std,
175
+ a=-cutoff_factor * std,
176
+ b=cutoff_factor * std,
177
+ )
178
 
179
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
180
+ if module.bias is not None:
181
+ init.zeros_(module.bias)
182
+
183
+ if isinstance(module, ModernVBertConnector):
184
+ out_std = self.config.initializer_range / math.sqrt(2.0 * self.config.text_config.num_hidden_layers)
185
+ init_weight(module.modality_projection, out_std)
186
+ elif isinstance(module, ModernVBertForMaskedLM):
187
+ out_std = self.config.initializer_range / math.sqrt(2.0 * self.config.text_config.num_hidden_layers)
188
+ init_weight(module.lm_head, out_std)
189
+ elif isinstance(
190
+ module,
191
+ (
192
+ ModernVBertForSequenceClassification,
193
+ ModernVBertForTokenClassification,
194
+ ),
195
+ ):
196
+ final_out_std = self.config.initializer_range / math.sqrt(self.config.text_config.hidden_size)
197
+ init_weight(module.classifier, final_out_std)
198
+
199
+
200
+ @auto_docstring(
201
+ custom_intro="""
202
+ ModernVBertModel is a model that combines a vision encoder (SigLIP) and a text encoder (ModernBert).
203
+
204
+ ModernVBert is the base model of the visual retriver ColModernVBert, and was introduced in the following paper:
205
+ [*ModernVBERT: Towards Smaller Visual Document Retrievers*](https://arxiv.org/abs/2510.01149).
206
+ """
207
+ )
208
  class ModernVBertModel(ModernVBertPreTrainedModel):
209
+ """
210
+ A subclass of Idefics3Model. We do *not* remove or block the call to inputs_merger
211
+ in forward. Instead, we override inputs_merger here with custom logic.
212
+ """
213
+
214
  def __init__(self, config: ModernVBertConfig):
215
  super().__init__(config)
216
+ self.padding_idx = self.config.text_config.pad_token_id
217
+ self.vocab_size = self.config.text_config.vocab_size
218
+ self.vision_model = AutoModel.from_config(config.vision_config)
219
 
220
  # init components
 
221
  self.connector = ModernVBertConnector(config)
222
+ self.text_model = AutoModel.from_config(config.text_config)
 
 
 
 
 
223
 
224
  self.image_seq_len = int(
225
  ((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
226
  / (config.pixel_shuffle_factor**2)
227
  )
228
+ self.image_token_id = self.config.image_token_id
229
 
230
  self.post_init()
231
 
232
+ def get_input_embeddings(self):
233
+ return self.text_model.get_input_embeddings()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
+ def set_input_embeddings(self, value):
236
+ self.text_model.set_input_embeddings(value)
237
 
238
+ def inputs_merger(
239
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.Tensor, image_hidden_states: torch.Tensor
240
+ ):
241
  """
242
+ This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
243
+ The merging happens as follows:
244
+ - The text token sequence is: `tok_1 tok_2 tok_3 <fake_token_around_image> <image> <image> ... <image> <fake_token_around_image> tok_4`.
245
+ - We get the image hidden states for the image through the vision encoder and that hidden state, after a pixel shuffle operation, is then projected into the text embedding space.
246
+ We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer.
247
+ - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
248
+ - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
249
+ """
250
+ _, patch_size, _ = image_hidden_states.shape
251
 
252
+ if input_ids is None:
253
+ image_mask = inputs_embeds == self.get_input_embeddings()(
254
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
255
+ )
256
+ image_mask = image_mask[..., 0] # slice off the hidden dim
257
+ else:
258
+ image_mask = input_ids == self.config.image_token_id
 
 
 
259
 
260
+ num_image_tokens = image_mask.sum(dim=1)
261
+ torch_compilable_check(
262
+ torch.all(num_image_tokens % patch_size == 0),
263
+ "At least one sample has <image> tokens not divisible by patch_size.",
264
  )
265
+ blocks_per_sample = num_image_tokens // patch_size
266
 
267
+ offsets = torch.nn.functional.pad(blocks_per_sample.cumsum(dim=0), (1, 0), value=0)
268
+ block_offset = offsets[:-1]
269
+ row_cum = image_mask.cumsum(dim=-1)
270
+ chunk_idx = (row_cum - 1) // patch_size
271
+ local_idx = (row_cum - 1) % patch_size
272
+ block_idx = block_offset.unsqueeze(1) + chunk_idx
273
 
274
+ image_embeds = torch.zeros_like(inputs_embeds)
275
+ image_embeds[image_mask] = image_hidden_states[block_idx[image_mask], local_idx[image_mask], :]
276
 
277
+ merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds)
278
+ return merged_embeds
279
 
280
+ @can_return_tuple
281
+ @auto_docstring(
282
+ custom_intro="Encodes images into continuous embeddings that can be forwarded to the language model."
283
+ )
284
  def get_image_features(
285
+ self,
286
+ pixel_values: torch.FloatTensor,
287
+ pixel_attention_mask: torch.LongTensor | None = None,
288
+ **kwargs: Unpack[TransformersKwargs],
289
+ ) -> tuple | BaseModelOutputWithPooling:
290
+ r"""
291
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
292
+ The tensors corresponding to the input images.
293
+ pixel_attention_mask (`torch.LongTensor`, *optional*):
294
+ The attention mask indicating padded regions in the image.
 
295
  """
296
  batch_size, num_images, num_channels, height, width = pixel_values.shape
297
  pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
 
301
  nb_values_per_image = pixel_values.shape[1:].numel()
302
  real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
303
 
304
+ # If no images, leave one empty image.
305
+ real_images_inds[0] |= ~torch.any(real_images_inds)
306
 
307
  pixel_values = pixel_values[real_images_inds].contiguous()
308
  # Handle the vision attention mask
 
316
  # Remove padding images from the mask
317
  pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
318
  pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
 
319
  patch_size = self.config.vision_config.patch_size
320
  patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
321
  patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
322
  patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
323
 
324
  # Get sequence from the vision encoder
325
+ image_outputs = self.vision_model(
326
+ pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, return_dict=True, **kwargs
327
+ )
328
+ image_hidden_states = image_outputs.last_hidden_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
+ # Modality projection & resampling
331
+ image_features = self.connector(image_hidden_states)
332
+ image_outputs.pooler_output = image_features
333
 
334
+ return image_outputs
335
 
336
+ @check_model_inputs
337
  @auto_docstring(
338
  custom_intro="""
339
  Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
 
344
  discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
345
  image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
346
  """,
347
+ checkpoint="ModernVBERT/modernvbert",
348
  )
349
  def forward(
350
  self,
351
  input_ids: torch.LongTensor = None,
352
+ attention_mask: torch.Tensor | None = None,
353
+ position_ids: torch.LongTensor | None = None,
354
+ inputs_embeds: torch.FloatTensor | None = None,
355
+ pixel_values: torch.FloatTensor | None = None,
356
+ pixel_attention_mask: torch.BoolTensor | None = None,
357
+ image_hidden_states: torch.FloatTensor | None = None,
358
+ **kwargs: Unpack[TransformersKwargs],
359
+ ) -> tuple | ModernVBertBaseModelOutput:
 
 
 
360
  r"""
361
  pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
362
  Mask to avoid performing attention on padding pixel indices.
363
  image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
364
  The hidden states of the image encoder after modality projection.
 
 
 
 
365
  """
 
 
 
 
 
366
 
367
  if inputs_embeds is None:
368
  inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device)
369
 
370
  # Images processing
371
  if pixel_values is not None:
 
372
  image_hidden_states = self.get_image_features(
373
  pixel_values=pixel_values, pixel_attention_mask=pixel_attention_mask
374
+ ).pooler_output
 
 
375
 
376
  # Merge image and text embeddings
377
  if image_hidden_states is not None:
378
+ image_hidden_states = image_hidden_states.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
379
  inputs_embeds = self.inputs_merger(
380
  input_ids=input_ids, inputs_embeds=inputs_embeds, image_hidden_states=image_hidden_states
381
  )
 
385
  inputs_embeds=inputs_embeds,
386
  attention_mask=attention_mask,
387
  position_ids=position_ids,
 
 
 
388
  **kwargs,
389
  )
390
 
 
396
  )
397
 
398
 
399
+ class ModernVBertPredictionHead(nn.Module):
400
+ def __init__(self, config: ModernVBertConfig):
401
  super().__init__()
402
+ self.config = config
403
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias)
404
+ self.act = ACT2FN[config.classifier_activation]
405
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
406
 
407
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
408
+ return self.norm(self.act(self.dense(hidden_states)))
409
 
410
 
411
  @auto_docstring
412
  class ModernVBertForMaskedLM(ModernVBertPreTrainedModel):
413
+ _tied_weights_keys = {"lm_head.weight": "model.text_model.embeddings.tok_embeddings.weight"}
414
 
415
  def __init__(self, config):
416
  super().__init__(config)
417
+
418
+ self.vocab_size = config.text_config.vocab_size
419
+
420
  self.model = ModernVBertModel(config)
421
+ self.projection_head = ModernVBertPredictionHead(config.text_config)
422
+ self.lm_head = nn.Linear(config.text_config.hidden_size, self.vocab_size, bias=config.text_config.decoder_bias)
423
+
424
+ # Initialize weights and apply final processing
425
  self.post_init()
426
 
427
+ def get_output_embeddings(self):
428
+ return self.lm_head
 
 
429
 
430
+ def set_output_embeddings(self, new_embeddings):
431
+ self.lm_head = new_embeddings
432
+
433
+ @check_model_inputs
434
  @auto_docstring(
435
  custom_intro="""
436
  Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
 
441
  discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
442
  image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
443
  """,
444
+ checkpoint="ModernVBERT/modernvbert",
445
  )
446
  def forward(
447
  self,
448
  input_ids: torch.LongTensor = None,
449
+ attention_mask: torch.Tensor | None = None,
450
+ position_ids: torch.LongTensor | None = None,
451
+ inputs_embeds: torch.FloatTensor | None = None,
452
+ pixel_values: torch.FloatTensor | None = None,
453
+ pixel_attention_mask: torch.BoolTensor | None = None,
454
+ image_hidden_states: torch.FloatTensor | None = None,
455
+ labels: torch.LongTensor | None = None,
456
+ **kwargs: Unpack[TransformersKwargs],
457
+ ) -> tuple | ModernVBertMaskedLMOutput:
 
 
 
458
  r"""
459
  pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
460
  Mask to avoid performing attention on padding pixel indices.
 
462
  The hidden states of the image encoder after modality projection.
463
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
464
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
465
+ text_config.]` or `model.image_token_id`. Tokens with indices set to `model.image_token_id` are
466
+ ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., text_config.]`.
467
  """
468
 
469
+ outputs = self.model(
470
+ input_ids=input_ids,
471
+ attention_mask=attention_mask,
472
+ position_ids=position_ids,
473
+ inputs_embeds=inputs_embeds,
474
+ pixel_values=pixel_values,
475
+ pixel_attention_mask=pixel_attention_mask,
476
+ image_hidden_states=image_hidden_states,
477
+ **kwargs,
478
  )
479
+ hidden_states = outputs[0]
480
+
481
+ logits = self.lm_head(self.projection_head(hidden_states))
482
 
483
+ loss = None
484
+ if labels is not None:
485
+ criterion = CrossEntropyLoss()
486
+ loss = criterion(logits.view(-1, self.vocab_size), labels.view(-1))
487
+
488
+ return ModernVBertMaskedLMOutput(
489
+ loss=loss,
490
+ logits=logits,
491
+ hidden_states=outputs.hidden_states,
492
+ attentions=outputs.attentions,
493
+ image_hidden_states=outputs.image_hidden_states,
494
+ )
495
+
496
+
497
+ @auto_docstring(
498
+ custom_intro="""
499
+ The ModernVBert Model with a sequence classification head on top that performs pooling.
500
+ """
501
+ )
502
+ class ModernVBertForSequenceClassification(ModernVBertPreTrainedModel):
503
+ def __init__(self, config: ModernVBertConfig):
504
+ super().__init__(config)
505
+ self.num_labels = config.num_labels
506
+ self.config = config
507
+
508
+ self.model = ModernVBertModel(config)
509
+ self.head = ModernVBertPredictionHead(config.text_config)
510
+ self.drop = nn.Dropout(config.classifier_dropout)
511
+ self.classifier = nn.Linear(config.text_config.hidden_size, config.num_labels)
512
+
513
+ # Initialize weights and apply final processing
514
+ self.post_init()
515
+
516
+ @check_model_inputs
517
+ @auto_docstring(
518
+ custom_intro="""
519
+ Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
520
+ the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
521
+ max_num_images is the maximum number of images among the batch_size samples in the batch.
522
+ Padding images are not needed beyond padding the pixel_values at the entrance of the model.
523
+ For efficiency, we only pass through the vision_model's forward the real images by
524
+ discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
525
+ image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
526
+ """,
527
+ checkpoint="ModernVBERT/modernvbert",
528
+ )
529
+ def forward(
530
+ self,
531
+ input_ids: torch.LongTensor = None,
532
+ attention_mask: torch.Tensor | None = None,
533
+ position_ids: torch.LongTensor | None = None,
534
+ inputs_embeds: torch.FloatTensor | None = None,
535
+ pixel_values: torch.FloatTensor | None = None,
536
+ pixel_attention_mask: torch.BoolTensor | None = None,
537
+ image_hidden_states: torch.FloatTensor | None = None,
538
+ labels: torch.LongTensor | None = None,
539
+ **kwargs: Unpack[TransformersKwargs],
540
+ ) -> tuple | SequenceClassifierOutput:
541
+ r"""
542
+ pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
543
+ Mask to avoid performing attention on padding pixel indices.
544
+ image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
545
+ The hidden states of the image encoder after modality projection.
546
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
547
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
548
+ text_config.]` or `model.image_token_id`. Tokens with indices set to `model.image_token_id` are
549
+ ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., text_config.]`.
550
+ """
551
  outputs = self.model(
552
  input_ids=input_ids,
553
  attention_mask=attention_mask,
 
556
  pixel_values=pixel_values,
557
  pixel_attention_mask=pixel_attention_mask,
558
  image_hidden_states=image_hidden_states,
 
 
 
559
  **kwargs,
560
  )
561
+ last_hidden_state = outputs[0]
562
+
563
+ if self.config.classifier_pooling == "cls":
564
+ last_hidden_state = last_hidden_state[:, 0]
565
+ elif self.config.classifier_pooling == "mean":
566
+ if inputs_embeds is not None:
567
+ batch_size, seq_len = inputs_embeds.shape[:2]
568
+ else:
569
+ batch_size, seq_len = input_ids.shape[:2]
570
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
571
 
572
+ if attention_mask is None:
573
+ attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
574
+ last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
575
+ dim=1, keepdim=True
576
+ )
577
 
578
+ pooled_output = self.head(last_hidden_state)
579
+ pooled_output = self.drop(pooled_output)
580
+ logits = self.classifier(pooled_output)
 
581
 
582
  loss = None
583
  if labels is not None:
584
+ if self.config.problem_type is None:
585
+ if self.num_labels == 1:
586
+ self.config.problem_type = "regression"
587
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
588
+ self.config.problem_type = "single_label_classification"
589
+ else:
590
+ self.config.problem_type = "multi_label_classification"
591
+
592
+ if self.config.problem_type == "regression":
593
+ loss_fct = MSELoss()
594
+ if self.num_labels == 1:
595
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
596
+ else:
597
+ loss = loss_fct(logits, labels)
598
+ elif self.config.problem_type == "single_label_classification":
599
+ loss_fct = CrossEntropyLoss()
600
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
601
+ elif self.config.problem_type == "multi_label_classification":
602
+ loss_fct = BCEWithLogitsLoss()
603
+ loss = loss_fct(logits, labels)
604
+
605
+ return SequenceClassifierOutput(
606
+ loss=loss,
607
+ logits=logits,
608
+ hidden_states=outputs.hidden_states,
609
+ attentions=outputs.attentions,
610
+ )
611
 
 
 
 
612
 
613
+ @auto_docstring(
614
+ custom_intro="""
615
+ The ModernVBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.
616
+ """
617
+ )
618
+ class ModernVBertForTokenClassification(ModernVBertPreTrainedModel):
619
+ def __init__(self, config: ModernVBertConfig):
620
+ super().__init__(config)
621
+ self.num_labels = config.num_labels
622
+
623
+ self.model = ModernVBertModel(config)
624
+ self.head = ModernVBertPredictionHead(config.text_config)
625
+ self.drop = nn.Dropout(config.classifier_dropout)
626
+ self.classifier = nn.Linear(config.text_config.hidden_size, config.num_labels)
627
+
628
+ # Initialize weights and apply final processing
629
+ self.post_init()
630
+
631
+ @check_model_inputs
632
+ @auto_docstring(
633
+ custom_intro="""
634
+ Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
635
+ the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
636
+ max_num_images is the maximum number of images among the batch_size samples in the batch.
637
+ Padding images are not needed beyond padding the pixel_values at the entrance of the model.
638
+ For efficiency, we only pass through the vision_model's forward the real images by
639
+ discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
640
+ image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
641
+ """,
642
+ checkpoint="ModernVBERT/modernvbert",
643
+ )
644
+ def forward(
645
+ self,
646
+ input_ids: torch.LongTensor = None,
647
+ attention_mask: torch.Tensor | None = None,
648
+ position_ids: torch.LongTensor | None = None,
649
+ inputs_embeds: torch.FloatTensor | None = None,
650
+ pixel_values: torch.FloatTensor | None = None,
651
+ pixel_attention_mask: torch.BoolTensor | None = None,
652
+ image_hidden_states: torch.FloatTensor | None = None,
653
+ labels: torch.LongTensor | None = None,
654
+ **kwargs: Unpack[TransformersKwargs],
655
+ ) -> tuple | TokenClassifierOutput:
656
+ r"""
657
+ pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
658
+ Mask to avoid performing attention on padding pixel indices.
659
+ image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
660
+ The hidden states of the image encoder after modality projection.
661
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
662
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
663
+ text_config.]` or `model.image_token_id`. Tokens with indices set to `model.image_token_id` are
664
+ ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., text_config.]`.
665
+ """
666
+
667
+ outputs = self.model(
668
+ input_ids=input_ids,
669
+ attention_mask=attention_mask,
670
+ position_ids=position_ids,
671
+ inputs_embeds=inputs_embeds,
672
+ pixel_values=pixel_values,
673
+ pixel_attention_mask=pixel_attention_mask,
674
+ image_hidden_states=image_hidden_states,
675
+ **kwargs,
676
+ )
677
+ last_hidden_state = outputs[0]
678
+
679
+ last_hidden_state = self.head(last_hidden_state)
680
+ last_hidden_state = self.drop(last_hidden_state)
681
+ logits = self.classifier(last_hidden_state)
682
+
683
+ loss = None
684
+ if labels is not None:
685
+ loss_fct = CrossEntropyLoss()
686
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
687
+
688
+ return TokenClassifierOutput(
689
  loss=loss,
690
+ logits=logits,
691
  hidden_states=outputs.hidden_states,
692
  attentions=outputs.attentions,
 
693
  )
694
 
695
 
696
+ __all__ = [
697
+ "ModernVBertPreTrainedModel",
698
+ "ModernVBertModel",
699
+ "ModernVBertForMaskedLM",
700
+ "ModernVBertForSequenceClassification",
701
+ "ModernVBertForTokenClassification",
702
+ ]