anananan116 commited on
Commit
f9ab04a
·
verified ·
1 Parent(s): d203298

upload model files

Browse files
Files changed (5) hide show
  1. configuration_clip.py +453 -0
  2. configuration_llama.py +246 -0
  3. modeling_VLM.py +186 -0
  4. modeling_llama.py +1259 -0
  5. visual_modeling.py +1128 -0
configuration_clip.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """CLIP model configuration"""
16
+
17
+ import os
18
+ from collections import OrderedDict
19
+ from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
20
+
21
+
22
+ if TYPE_CHECKING:
23
+ from transformers.processing_utils import ProcessorMixin
24
+ from transformers.utils import TensorType
25
+
26
+ from transformers.configuration_utils import PretrainedConfig
27
+ from transformers.onnx import OnnxConfig
28
+ from transformers.utils import logging
29
+
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ class CLIPTextConfig(PretrainedConfig):
35
+ r"""
36
+ This is the configuration class to store the configuration of a [`CLIPTextModel`]. It is used to instantiate a CLIP
37
+ text encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
38
+ with the defaults will yield a similar configuration to that of the text encoder of the CLIP
39
+ [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
40
+
41
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
42
+ documentation from [`PretrainedConfig`] for more information.
43
+
44
+ Args:
45
+ vocab_size (`int`, *optional*, defaults to 49408):
46
+ Vocabulary size of the CLIP text model. Defines the number of different tokens that can be represented by
47
+ the `inputs_ids` passed when calling [`CLIPModel`].
48
+ hidden_size (`int`, *optional*, defaults to 512):
49
+ Dimensionality of the encoder layers and the pooler layer.
50
+ intermediate_size (`int`, *optional*, defaults to 2048):
51
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
52
+ projection_dim (`int`, *optional*, defaults to 512):
53
+ Dimensionality of text and vision projection layers.
54
+ num_hidden_layers (`int`, *optional*, defaults to 12):
55
+ Number of hidden layers in the Transformer encoder.
56
+ num_attention_heads (`int`, *optional*, defaults to 8):
57
+ Number of attention heads for each attention layer in the Transformer encoder.
58
+ max_position_embeddings (`int`, *optional*, defaults to 77):
59
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
60
+ just in case (e.g., 512 or 1024 or 2048).
61
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
62
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
63
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
64
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
65
+ The epsilon used by the layer normalization layers.
66
+ attention_dropout (`float`, *optional*, defaults to 0.0):
67
+ The dropout ratio for the attention probabilities.
68
+ initializer_range (`float`, *optional*, defaults to 0.02):
69
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
70
+ initializer_factor (`float`, *optional*, defaults to 1.0):
71
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
72
+ testing).
73
+ pad_token_id (`int`, *optional*, defaults to 1):
74
+ Padding token id.
75
+ bos_token_id (`int`, *optional*, defaults to 49406):
76
+ Beginning of stream token id.
77
+ eos_token_id (`int`, *optional*, defaults to 49407):
78
+ End of stream token id.
79
+
80
+ Example:
81
+
82
+ ```python
83
+ >>> from transformers import CLIPTextConfig, CLIPTextModel
84
+
85
+ >>> # Initializing a CLIPTextConfig with openai/clip-vit-base-patch32 style configuration
86
+ >>> configuration = CLIPTextConfig()
87
+
88
+ >>> # Initializing a CLIPTextModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
89
+ >>> model = CLIPTextModel(configuration)
90
+
91
+ >>> # Accessing the model configuration
92
+ >>> configuration = model.config
93
+ ```"""
94
+
95
+ model_type = "clip_text_model"
96
+
97
+ def __init__(
98
+ self,
99
+ vocab_size=49408,
100
+ hidden_size=512,
101
+ intermediate_size=2048,
102
+ projection_dim=512,
103
+ num_hidden_layers=12,
104
+ num_attention_heads=8,
105
+ max_position_embeddings=77,
106
+ hidden_act="quick_gelu",
107
+ layer_norm_eps=1e-5,
108
+ attention_dropout=0.0,
109
+ initializer_range=0.02,
110
+ initializer_factor=1.0,
111
+ # This differs from `CLIPTokenizer`'s default and from openai/clip
112
+ # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
113
+ pad_token_id=1,
114
+ bos_token_id=49406,
115
+ eos_token_id=49407,
116
+ **kwargs,
117
+ ):
118
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
119
+
120
+ self.vocab_size = vocab_size
121
+ self.hidden_size = hidden_size
122
+ self.intermediate_size = intermediate_size
123
+ self.projection_dim = projection_dim
124
+ self.num_hidden_layers = num_hidden_layers
125
+ self.num_attention_heads = num_attention_heads
126
+ self.max_position_embeddings = max_position_embeddings
127
+ self.layer_norm_eps = layer_norm_eps
128
+ self.hidden_act = hidden_act
129
+ self.initializer_range = initializer_range
130
+ self.initializer_factor = initializer_factor
131
+ self.attention_dropout = attention_dropout
132
+
133
+ @classmethod
134
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
135
+ cls._set_token_in_kwargs(kwargs)
136
+
137
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
138
+
139
+ # get the text config dict if we are loading from CLIPConfig
140
+ if config_dict.get("model_type") == "clip":
141
+ config_dict = config_dict["text_config"]
142
+
143
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
144
+ logger.warning(
145
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
146
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
147
+ )
148
+
149
+ return cls.from_dict(config_dict, **kwargs)
150
+
151
+
152
+ class CLIPVisionConfig(PretrainedConfig):
153
+ r"""
154
+ This is the configuration class to store the configuration of a [`CLIPVisionModel`]. It is used to instantiate a
155
+ CLIP vision encoder according to the specified arguments, defining the model architecture. Instantiating a
156
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the CLIP
157
+ [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
158
+
159
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
160
+ documentation from [`PretrainedConfig`] for more information.
161
+
162
+ Args:
163
+ hidden_size (`int`, *optional*, defaults to 768):
164
+ Dimensionality of the encoder layers and the pooler layer.
165
+ intermediate_size (`int`, *optional*, defaults to 3072):
166
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
167
+ projection_dim (`int`, *optional*, defaults to 512):
168
+ Dimensionality of text and vision projection layers.
169
+ num_hidden_layers (`int`, *optional*, defaults to 12):
170
+ Number of hidden layers in the Transformer encoder.
171
+ num_attention_heads (`int`, *optional*, defaults to 12):
172
+ Number of attention heads for each attention layer in the Transformer encoder.
173
+ num_channels (`int`, *optional*, defaults to 3):
174
+ The number of input channels.
175
+ image_size (`int`, *optional*, defaults to 224):
176
+ The size (resolution) of each image.
177
+ patch_size (`int`, *optional*, defaults to 32):
178
+ The size (resolution) of each patch.
179
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
180
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
181
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
182
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
183
+ The epsilon used by the layer normalization layers.
184
+ attention_dropout (`float`, *optional*, defaults to 0.0):
185
+ The dropout ratio for the attention probabilities.
186
+ initializer_range (`float`, *optional*, defaults to 0.02):
187
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
188
+ initializer_factor (`float`, *optional*, defaults to 1.0):
189
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
190
+ testing).
191
+
192
+ Example:
193
+
194
+ ```python
195
+ >>> from transformers import CLIPVisionConfig, CLIPVisionModel
196
+
197
+ >>> # Initializing a CLIPVisionConfig with openai/clip-vit-base-patch32 style configuration
198
+ >>> configuration = CLIPVisionConfig()
199
+
200
+ >>> # Initializing a CLIPVisionModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
201
+ >>> model = CLIPVisionModel(configuration)
202
+
203
+ >>> # Accessing the model configuration
204
+ >>> configuration = model.config
205
+ ```"""
206
+
207
+ model_type = "clip_vision_model"
208
+
209
+ def __init__(
210
+ self,
211
+ hidden_size=768,
212
+ intermediate_size=3072,
213
+ projection_dim=512,
214
+ num_hidden_layers=12,
215
+ num_attention_heads=12,
216
+ num_channels=3,
217
+ image_size=224,
218
+ patch_size=32,
219
+ hidden_act="quick_gelu",
220
+ layer_norm_eps=1e-5,
221
+ attention_dropout=0.0,
222
+ initializer_range=0.02,
223
+ initializer_factor=1.0,
224
+ **kwargs,
225
+ ):
226
+ super().__init__(**kwargs)
227
+
228
+ self.hidden_size = hidden_size
229
+ self.intermediate_size = intermediate_size
230
+ self.projection_dim = projection_dim
231
+ self.num_hidden_layers = num_hidden_layers
232
+ self.num_attention_heads = num_attention_heads
233
+ self.num_channels = num_channels
234
+ self.patch_size = patch_size
235
+ self.image_size = image_size
236
+ self.initializer_range = initializer_range
237
+ self.initializer_factor = initializer_factor
238
+ self.attention_dropout = attention_dropout
239
+ self.layer_norm_eps = layer_norm_eps
240
+ self.hidden_act = hidden_act
241
+
242
+ @classmethod
243
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
244
+ cls._set_token_in_kwargs(kwargs)
245
+
246
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
247
+
248
+ # get the vision config dict if we are loading from CLIPConfig
249
+ if config_dict.get("model_type") == "clip":
250
+ config_dict = config_dict["vision_config"]
251
+
252
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
253
+ logger.warning(
254
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
255
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
256
+ )
257
+
258
+ return cls.from_dict(config_dict, **kwargs)
259
+
260
+
261
+ class CLIPConfig(PretrainedConfig):
262
+ r"""
263
+ [`CLIPConfig`] is the configuration class to store the configuration of a [`CLIPModel`]. It is used to instantiate
264
+ a CLIP model according to the specified arguments, defining the text model and vision model configs. Instantiating
265
+ a configuration with the defaults will yield a similar configuration to that of the CLIP
266
+ [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
267
+
268
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
269
+ documentation from [`PretrainedConfig`] for more information.
270
+
271
+ Args:
272
+ text_config (`dict`, *optional*):
273
+ Dictionary of configuration options used to initialize [`CLIPTextConfig`].
274
+ vision_config (`dict`, *optional*):
275
+ Dictionary of configuration options used to initialize [`CLIPVisionConfig`].
276
+ projection_dim (`int`, *optional*, defaults to 512):
277
+ Dimensionality of text and vision projection layers.
278
+ logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
279
+ The initial value of the *logit_scale* parameter. Default is used as per the original CLIP implementation.
280
+ kwargs (*optional*):
281
+ Dictionary of keyword arguments.
282
+
283
+ Example:
284
+
285
+ ```python
286
+ >>> from transformers import CLIPConfig, CLIPModel
287
+
288
+ >>> # Initializing a CLIPConfig with openai/clip-vit-base-patch32 style configuration
289
+ >>> configuration = CLIPConfig()
290
+
291
+ >>> # Initializing a CLIPModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
292
+ >>> model = CLIPModel(configuration)
293
+
294
+ >>> # Accessing the model configuration
295
+ >>> configuration = model.config
296
+
297
+ >>> # We can also initialize a CLIPConfig from a CLIPTextConfig and a CLIPVisionConfig
298
+ >>> from transformers import CLIPTextConfig, CLIPVisionConfig
299
+
300
+ >>> # Initializing a CLIPText and CLIPVision configuration
301
+ >>> config_text = CLIPTextConfig()
302
+ >>> config_vision = CLIPVisionConfig()
303
+
304
+ >>> config = CLIPConfig.from_text_vision_configs(config_text, config_vision)
305
+ ```"""
306
+
307
+ model_type = "clip"
308
+
309
+ def __init__(
310
+ self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
311
+ ):
312
+ # If `_config_dict` exist, we use them for the backward compatibility.
313
+ # We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot
314
+ # of confusion!).
315
+ text_config_dict = kwargs.pop("text_config_dict", None)
316
+ vision_config_dict = kwargs.pop("vision_config_dict", None)
317
+
318
+ super().__init__(**kwargs)
319
+
320
+ # Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in
321
+ # `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most
322
+ # cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`.
323
+ if text_config_dict is not None:
324
+ if text_config is None:
325
+ text_config = {}
326
+
327
+ # This is the complete result when using `text_config_dict`.
328
+ _text_config_dict = CLIPTextConfig(**text_config_dict).to_dict()
329
+
330
+ # Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different.
331
+ for key, value in _text_config_dict.items():
332
+ if key in text_config and value != text_config[key] and key not in ["transformers_version"]:
333
+ # If specified in `text_config_dict`
334
+ if key in text_config_dict:
335
+ message = (
336
+ f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. "
337
+ f'The value `text_config_dict["{key}"]` will be used instead.'
338
+ )
339
+ # If inferred from default argument values (just to be super careful)
340
+ else:
341
+ message = (
342
+ f"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The "
343
+ f'value `text_config["{key}"]` will be overridden.'
344
+ )
345
+ logger.info(message)
346
+
347
+ # Update all values in `text_config` with the ones in `_text_config_dict`.
348
+ text_config.update(_text_config_dict)
349
+
350
+ if vision_config_dict is not None:
351
+ if vision_config is None:
352
+ vision_config = {}
353
+
354
+ # This is the complete result when using `vision_config_dict`.
355
+ _vision_config_dict = CLIPVisionConfig(**vision_config_dict).to_dict()
356
+ # convert keys to string instead of integer
357
+ if "id2label" in _vision_config_dict:
358
+ _vision_config_dict["id2label"] = {
359
+ str(key): value for key, value in _vision_config_dict["id2label"].items()
360
+ }
361
+
362
+ # Give a warning if the values exist in both `_vision_config_dict` and `vision_config` but being different.
363
+ for key, value in _vision_config_dict.items():
364
+ if key in vision_config and value != vision_config[key] and key not in ["transformers_version"]:
365
+ # If specified in `vision_config_dict`
366
+ if key in vision_config_dict:
367
+ message = (
368
+ f"`{key}` is found in both `vision_config_dict` and `vision_config` but with different "
369
+ f'values. The value `vision_config_dict["{key}"]` will be used instead.'
370
+ )
371
+ # If inferred from default argument values (just to be super careful)
372
+ else:
373
+ message = (
374
+ f"`vision_config_dict` is provided which will be used to initialize `CLIPVisionConfig`. "
375
+ f'The value `vision_config["{key}"]` will be overridden.'
376
+ )
377
+ logger.info(message)
378
+
379
+ # Update all values in `vision_config` with the ones in `_vision_config_dict`.
380
+ vision_config.update(_vision_config_dict)
381
+
382
+ if text_config is None:
383
+ text_config = {}
384
+ logger.info("`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.")
385
+
386
+ if vision_config is None:
387
+ vision_config = {}
388
+ logger.info("`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.")
389
+
390
+ self.text_config = CLIPTextConfig(**text_config)
391
+ self.vision_config = CLIPVisionConfig(**vision_config)
392
+
393
+ self.projection_dim = projection_dim
394
+ self.logit_scale_init_value = logit_scale_init_value
395
+ self.initializer_factor = 1.0
396
+
397
+ @classmethod
398
+ def from_text_vision_configs(cls, text_config: CLIPTextConfig, vision_config: CLIPVisionConfig, **kwargs):
399
+ r"""
400
+ Instantiate a [`CLIPConfig`] (or a derived class) from clip text model configuration and clip vision model
401
+ configuration.
402
+
403
+ Returns:
404
+ [`CLIPConfig`]: An instance of a configuration object
405
+ """
406
+
407
+ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
408
+
409
+
410
+ class CLIPOnnxConfig(OnnxConfig):
411
+ @property
412
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
413
+ return OrderedDict(
414
+ [
415
+ ("input_ids", {0: "batch", 1: "sequence"}),
416
+ ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
417
+ ("attention_mask", {0: "batch", 1: "sequence"}),
418
+ ]
419
+ )
420
+
421
+ @property
422
+ def outputs(self) -> Mapping[str, Mapping[int, str]]:
423
+ return OrderedDict(
424
+ [
425
+ ("logits_per_image", {0: "batch"}),
426
+ ("logits_per_text", {0: "batch"}),
427
+ ("text_embeds", {0: "batch"}),
428
+ ("image_embeds", {0: "batch"}),
429
+ ]
430
+ )
431
+
432
+ @property
433
+ def atol_for_validation(self) -> float:
434
+ return 1e-4
435
+
436
+ def generate_dummy_inputs(
437
+ self,
438
+ processor: "ProcessorMixin",
439
+ batch_size: int = -1,
440
+ seq_length: int = -1,
441
+ framework: Optional["TensorType"] = None,
442
+ ) -> Mapping[str, Any]:
443
+ text_input_dict = super().generate_dummy_inputs(
444
+ processor.tokenizer, batch_size=batch_size, seq_length=seq_length, framework=framework
445
+ )
446
+ image_input_dict = super().generate_dummy_inputs(
447
+ processor.image_processor, batch_size=batch_size, framework=framework
448
+ )
449
+ return {**text_input_dict, **image_input_dict}
450
+
451
+ @property
452
+ def default_onnx_opset(self) -> int:
453
+ return 14
configuration_llama.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """LLaMA model configuration"""
21
+
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.modeling_rope_utils import rope_config_validation
24
+
25
+
26
+ class BaseConfig(PretrainedConfig):
27
+ r"""
28
+ This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
29
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
30
+ defaults will yield a similar configuration to that of the LLaMA-7B.
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+
36
+ Args:
37
+ vocab_size (`int`, *optional*, defaults to 32000):
38
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
39
+ `inputs_ids` passed when calling [`LlamaModel`]
40
+ hidden_size (`int`, *optional*, defaults to 4096):
41
+ Dimension of the hidden representations.
42
+ intermediate_size (`int`, *optional*, defaults to 11008):
43
+ Dimension of the MLP representations.
44
+ num_hidden_layers (`int`, *optional*, defaults to 32):
45
+ Number of hidden layers in the Transformer decoder.
46
+ num_attention_heads (`int`, *optional*, defaults to 32):
47
+ Number of attention heads for each attention layer in the Transformer decoder.
48
+ num_key_value_heads (`int`, *optional*):
49
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
50
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
51
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
52
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
53
+ by meanpooling all the original heads within that group. For more details checkout [this
54
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
55
+ `num_attention_heads`.
56
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
57
+ The non-linear activation function (function or string) in the decoder.
58
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
59
+ The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
60
+ Llama 2 up to 4096, CodeLlama up to 16384.
61
+ initializer_range (`float`, *optional*, defaults to 0.02):
62
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
63
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
64
+ The epsilon used by the rms normalization layers.
65
+ use_cache (`bool`, *optional*, defaults to `True`):
66
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
67
+ relevant if `config.is_decoder=True`.
68
+ pad_token_id (`int`, *optional*):
69
+ Padding token id.
70
+ bos_token_id (`int`, *optional*, defaults to 1):
71
+ Beginning of stream token id.
72
+ eos_token_id (`int`, *optional*, defaults to 2):
73
+ End of stream token id.
74
+ pretraining_tp (`int`, *optional*, defaults to 1):
75
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
76
+ document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
77
+ understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
78
+ results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
79
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
80
+ Whether to tie weight embeddings
81
+ rope_theta (`float`, *optional*, defaults to 10000.0):
82
+ The base period of the RoPE embeddings.
83
+ rope_scaling (`Dict`, *optional*):
84
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
85
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
86
+ accordingly.
87
+ Expected contents:
88
+ `rope_type` (`str`):
89
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
90
+ 'llama3'], with 'default' being the original RoPE implementation.
91
+ `factor` (`float`, *optional*):
92
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
93
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
94
+ original maximum pre-trained length.
95
+ `original_max_position_embeddings` (`int`, *optional*):
96
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
97
+ pretraining.
98
+ `attention_factor` (`float`, *optional*):
99
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
100
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
101
+ `factor` field to infer the suggested value.
102
+ `beta_fast` (`float`, *optional*):
103
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
104
+ ramp function. If unspecified, it defaults to 32.
105
+ `beta_slow` (`float`, *optional*):
106
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
107
+ ramp function. If unspecified, it defaults to 1.
108
+ `short_factor` (`List[float]`, *optional*):
109
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
110
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
111
+ size divided by the number of attention heads divided by 2
112
+ `long_factor` (`List[float]`, *optional*):
113
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
114
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
115
+ size divided by the number of attention heads divided by 2
116
+ `low_freq_factor` (`float`, *optional*):
117
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
118
+ `high_freq_factor` (`float`, *optional*):
119
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
120
+ attention_bias (`bool`, *optional*, defaults to `False`):
121
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
122
+ attention_dropout (`float`, *optional*, defaults to 0.0):
123
+ The dropout ratio for the attention probabilities.
124
+ mlp_bias (`bool`, *optional*, defaults to `False`):
125
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
126
+ head_dim (`int`, *optional*):
127
+ The attention head dimension. If None, it will default to hidden_size // num_heads
128
+
129
+ ```python
130
+ >>> from transformers import LlamaModel, LlamaConfig
131
+
132
+ >>> # Initializing a LLaMA llama-7b style configuration
133
+ >>> configuration = LlamaConfig()
134
+
135
+ >>> # Initializing a model from the llama-7b style configuration
136
+ >>> model = LlamaModel(configuration)
137
+
138
+ >>> # Accessing the model configuration
139
+ >>> configuration = model.config
140
+ ```"""
141
+
142
+ model_type = "llama"
143
+ keys_to_ignore_at_inference = ["past_key_values"]
144
+
145
+ def __init__(
146
+ self,
147
+ vocab_size=32000,
148
+ hidden_size=4096,
149
+ intermediate_size=11008,
150
+ num_hidden_layers=32,
151
+ num_attention_heads=32,
152
+ num_key_value_heads=None,
153
+ hidden_act="silu",
154
+ max_position_embeddings=2048,
155
+ initializer_range=0.02,
156
+ rms_norm_eps=1e-6,
157
+ use_cache=True,
158
+ pad_token_id=None,
159
+ bos_token_id=1,
160
+ eos_token_id=2,
161
+ pretraining_tp=1,
162
+ tie_word_embeddings=False,
163
+ rope_theta=10000.0,
164
+ rope_scaling=None,
165
+ attention_bias=False,
166
+ attention_dropout=0.0,
167
+ mlp_bias=False,
168
+ head_dim=None,
169
+ **kwargs,
170
+ ):
171
+ self.vocab_size = vocab_size
172
+ self.max_position_embeddings = max_position_embeddings
173
+ self.hidden_size = hidden_size
174
+ self.intermediate_size = intermediate_size
175
+ self.num_hidden_layers = num_hidden_layers
176
+ self.num_attention_heads = num_attention_heads
177
+
178
+ # for backward compatibility
179
+ if num_key_value_heads is None:
180
+ num_key_value_heads = num_attention_heads
181
+
182
+ self.num_key_value_heads = num_key_value_heads
183
+ self.hidden_act = hidden_act
184
+ self.initializer_range = initializer_range
185
+ self.rms_norm_eps = rms_norm_eps
186
+ self.pretraining_tp = pretraining_tp
187
+ self.use_cache = use_cache
188
+ self.rope_theta = rope_theta
189
+ self.rope_scaling = rope_scaling
190
+ self.attention_bias = attention_bias
191
+ self.attention_dropout = attention_dropout
192
+ self.mlp_bias = mlp_bias
193
+ self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
194
+ # Validate the correctness of rotary position embeddings parameters
195
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
196
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
197
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
198
+ rope_config_validation(self)
199
+
200
+ super().__init__(
201
+ pad_token_id=pad_token_id,
202
+ bos_token_id=bos_token_id,
203
+ eos_token_id=eos_token_id,
204
+ tie_word_embeddings=tie_word_embeddings,
205
+ **kwargs,
206
+ )
207
+
208
+ class VLMConfig(BaseConfig):
209
+ def __init__(
210
+ self,
211
+ lora=False,
212
+ task_type="CAUSAL_LM",
213
+ lora_rank=256,
214
+ lora_alpha=None,
215
+ lora_modules=['q', 'k', 'v', "embed_tokens", "lm_head"],
216
+ pretrained_model="meta-llama/Llama-3.2-1B-Instruct",
217
+ hugging_face_token=None,
218
+ adjust_embedding_len=None,
219
+ special_token_map=None,
220
+ flashattention=False,
221
+ encoded_image_dimention=1024,
222
+ num_patches=64,
223
+ visual_config=None,
224
+ load_vision_model=False,
225
+ pretrained_vision_model="openai/clip-vit-large-patch14-336",
226
+ **kwargs,
227
+ ):
228
+ super().__init__(**kwargs)
229
+
230
+ self.lora = lora
231
+ self.task_type = task_type
232
+ self.lora_rank = lora_rank
233
+ self.lora_alpha = lora_rank if lora_alpha is None else lora_alpha
234
+ self.lora_modules = lora_modules
235
+ self.pretrained_model = pretrained_model
236
+ self.hugging_face_token = hugging_face_token
237
+ self.adjust_embedding_len = adjust_embedding_len
238
+ self.special_token_map = special_token_map
239
+ self.flashattention = flashattention
240
+ if self.flashattention:
241
+ self._attn_implementation = "flash_attention_2"
242
+ self.encoded_image_dimention = encoded_image_dimention
243
+ self.num_patches = num_patches
244
+ self.visual_config = visual_config
245
+ self.load_vision_model = load_vision_model
246
+ self.pretrained_vision_model = pretrained_vision_model
modeling_VLM.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .modeling_llama import AdapterMLP, DEFAULT_SYSTEM_PROMPT, LlamaForCausalLM
2
+ from .configuration_llama import VLMConfig
3
+ from .visual_modeling import CLIPModel
4
+ import torch
5
+ from torch import nn
6
+ from transformers import PreTrainedModel, PreTrainedTokenizer, AutoProcessor, GenerationMixin
7
+
8
+ class VLMPretrainedModel(PreTrainedModel):
9
+ config_class = VLMConfig
10
+ base_model_prefix = "model"
11
+ supports_gradient_checkpointing = False
12
+ _no_split_modules = ["LlamaDecoderLayer", "Block"]
13
+ _skip_keys_device_placement = "past_key_values"
14
+
15
+ def _init_weights(self, module):
16
+ std = self.config.initializer_range
17
+ if isinstance(module, nn.Linear):
18
+ module.weight.data.normal_(mean=0.0, std=std)
19
+ if module.bias is not None:
20
+ module.bias.data.zero_()
21
+ elif isinstance(module, nn.Embedding):
22
+ module.weight.data.normal_(mean=0.0, std=std)
23
+ if module.padding_idx is not None:
24
+ module.weight.data[module.padding_idx].zero_()
25
+
26
+ class AtriVLM(VLMPretrainedModel, GenerationMixin):
27
+ def __init__(self, config: VLMConfig):
28
+ super().__init__(config)
29
+ if config.special_token_map:
30
+ self.image_start_token_id = config.special_token_map['Image'][1]
31
+ self.image_end_token_id = config.special_token_map['Image_End'][1]
32
+ self.caption_token_id = config.special_token_map['Caption'][1]
33
+ self.image_token_id = config.special_token_map['Image_Token'][1]
34
+ else:
35
+ raise ValueError("Special token map not found")
36
+ self.image_adapter = AdapterMLP(config)
37
+ self.num_patches = config.num_patches
38
+ self.processor = AutoProcessor.from_pretrained(config.pretrained_vision_model).image_processor
39
+ self.img_place_holder = "<IMGPLH>"
40
+ self.img_start_token = "<IMAGE>"
41
+ self.img_end_token = "<IMAGE_END>"
42
+ self.image_token = "<Image_Token>"
43
+ self.decoder = LlamaForCausalLM(config)
44
+ if config.load_vision_model:
45
+ self.visual = CLIPModel(config.visual_config)
46
+ else:
47
+ self.visual = None
48
+
49
+ def get_input_embeddings(self):
50
+ return self.decoder.get_input_embeddings()
51
+
52
+ def set_input_embeddings(self, value):
53
+ return self.decoder.set_input_embeddings(value)
54
+
55
+ def forward(self, input_ids=None, encoded_image=None, labels=None, past_key_values = None, attention_mask = None, inputs_embeds = None, **kwargs):
56
+ """
57
+ Forward pass for the VLM model that combines image and text embeddings.
58
+
59
+ Args:
60
+ input_ids (torch.LongTensor): Input token ids of shape (batch_size, seq_len)
61
+ encoded_image (torch.FloatTensor): Encoded image features of shape (batch_size, num_patches, hidden_dim)
62
+ labels (torch.LongTensor): Labels for computing the language modeling loss
63
+ """
64
+ if not past_key_values and (encoded_image is not None):
65
+ encoded_image = encoded_image.to(self.decoder.get_input_embeddings().weight.dtype)
66
+ # Process image features through the adapter
67
+ processed_image = self.image_adapter(encoded_image)
68
+
69
+ # Get embeddings for all input tokens
70
+ token_embeddings = self.decoder.get_input_embeddings()(input_ids)
71
+
72
+ # Find positions of image tokens and replace them with processed image embeddings
73
+ image_token_positions = (input_ids == self.image_token_id).nonzero(as_tuple=True)
74
+ token_embeddings = token_embeddings
75
+ token_embeddings[image_token_positions] = processed_image.reshape(-1, processed_image.size(-1))
76
+ else:
77
+ token_embeddings = self.decoder.get_input_embeddings()(input_ids)
78
+ # Call the native forward method with the modified embeddings
79
+ outputs = self.decoder._native_forward(
80
+ inputs_embeds=token_embeddings,
81
+ past_key_values=past_key_values,
82
+ attention_mask=attention_mask,
83
+ labels=labels,
84
+ **kwargs
85
+ )
86
+
87
+ return outputs
88
+
89
+
90
+ def prepare_input_ids_for_generation(self, prompts, images, tokenizer, system_prompt=DEFAULT_SYSTEM_PROMPT):
91
+ """
92
+ Prepare input ids and images for generation.
93
+
94
+ Args:
95
+ prompts (List[str]): List of text prompts
96
+ images (List[Image]): List of images corresponding to prompts
97
+ tokenizer: Tokenizer instance
98
+ system_prompt (str): System prompt to be prepended
99
+
100
+ Returns:
101
+ dict: Contains input_ids, attention_mask, and processed images
102
+ """
103
+ # Process the images first
104
+ processed_images = []
105
+ for image in images:
106
+ # Process image through vision encoder
107
+ pixel_values = self.processor(image, return_tensors="pt")["pixel_values"].to(self.visual.vision_model.embeddings.patch_embedding.weight.device)
108
+ image_features = self.visual.encode_image(pixel_values)
109
+ processed_images.append(image_features)
110
+
111
+ # Stack all processed images
112
+ if processed_images:
113
+ processed_images = torch.cat(processed_images, dim=0)
114
+
115
+ # Process each prompt
116
+ formatted_prompts = []
117
+ for prompt in prompts:
118
+ # Replace image placeholder with tokens
119
+ if self.img_place_holder in prompt:
120
+ image_token_sequence = (
121
+ f"{self.img_start_token}" +
122
+ f"{self.image_token}" * self.num_patches +
123
+ f"{self.img_end_token}"
124
+ )
125
+ formatted_prompt = prompt.replace(self.img_place_holder, image_token_sequence)
126
+ else:
127
+ formatted_prompt = prompt
128
+
129
+ # Create conversation format
130
+ conversation = [
131
+ {"role": "system", "content": system_prompt},
132
+ {"role": "user", "content": formatted_prompt},
133
+ ]
134
+
135
+ # Apply chat template
136
+ formatted_conversation = tokenizer.apply_chat_template(
137
+ conversation,
138
+ tokenize=False,
139
+ add_generation_prompt=True
140
+ )
141
+ formatted_prompts.append(formatted_conversation)
142
+
143
+ # Tokenize all prompts together
144
+ tokenized_output = tokenizer(
145
+ formatted_prompts,
146
+ padding=True,
147
+ return_tensors="pt",
148
+ padding_side="left" # Use left padding since we're generating on the right
149
+ )
150
+
151
+ return {
152
+ "input_ids": tokenized_output["input_ids"],
153
+ "attention_mask": tokenized_output["attention_mask"],
154
+ "encoded_image": processed_images if processed_images.size(0) > 0 else None
155
+ }
156
+
157
+ def prepare_for_generation(self, input_ids, encoded_image, **kwargs):
158
+ """
159
+ Prepare KV cache for generation by processing the image and initial tokens.
160
+
161
+ Args:
162
+ input_ids (torch.LongTensor): Input token ids of shape (batch_size, seq_len)
163
+ encoded_image (torch.FloatTensor): Encoded image features of shape (batch_size, num_patches, hidden_dim)
164
+
165
+ Returns:
166
+ past_key_values: Tuple containing the key and value states to be used for subsequent generation
167
+ """
168
+ encoded_image = encoded_image.to(self.decoder.get_input_embeddings().weight.dtype)
169
+ # Process image features through the adapter
170
+ processed_image = self.image_adapter(encoded_image)
171
+
172
+ # Get embeddings for all input tokens
173
+ token_embeddings = self.decoder.get_input_embeddings()(input_ids)
174
+
175
+ # Find positions of image tokens and replace them with processed image embeddings
176
+ image_token_positions = (input_ids == self.image_token_id).nonzero(as_tuple=True)
177
+ token_embeddings[image_token_positions] = processed_image.reshape(-1, processed_image.size(-1))
178
+
179
+ # Forward pass with cache preparation
180
+ outputs = self.decoder._native_forward(
181
+ inputs_embeds=token_embeddings,
182
+ use_cache=True,
183
+ **kwargs
184
+ )
185
+
186
+ return outputs.past_key_values
modeling_llama.py ADDED
@@ -0,0 +1,1259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ import math
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+
28
+ from transformers.activations import ACT2FN
29
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
30
+ from transformers.generation import GenerationMixin
31
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
32
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
33
+ from transformers.modeling_outputs import (
34
+ BaseModelOutputWithPast,
35
+ CausalLMOutputWithPast,
36
+ QuestionAnsweringModelOutput,
37
+ SequenceClassifierOutputWithPast,
38
+ TokenClassifierOutput,
39
+ )
40
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
41
+ from transformers.modeling_utils import PreTrainedModel
42
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
43
+ from transformers.utils import (
44
+ add_code_sample_docstrings,
45
+ add_start_docstrings,
46
+ add_start_docstrings_to_model_forward,
47
+ is_flash_attn_greater_or_equal_2_10,
48
+ logging,
49
+ replace_return_docstrings,
50
+ )
51
+ from .configuration_llama import BaseConfig, VLMConfig
52
+
53
+
54
+ logger = logging.get_logger(__name__)
55
+
56
+ _CHECKPOINT_FOR_DOC = "meta-llama/Llama-2-7b-hf"
57
+ _CONFIG_FOR_DOC = "VLMConfig"
58
+ DEFAULT_SYSTEM_PROMPT = "You are a powerful visual assistant."
59
+
60
+ class LlamaRMSNorm(nn.Module):
61
+ def __init__(self, hidden_size, eps=1e-6):
62
+ """
63
+ LlamaRMSNorm is equivalent to T5LayerNorm
64
+ """
65
+ super().__init__()
66
+ self.weight = nn.Parameter(torch.ones(hidden_size))
67
+ self.variance_epsilon = eps
68
+
69
+ def forward(self, hidden_states):
70
+ input_dtype = hidden_states.dtype
71
+ hidden_states = hidden_states.to(torch.float32)
72
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
73
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
74
+ return self.weight * hidden_states.to(input_dtype)
75
+
76
+ def extra_repr(self):
77
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
78
+
79
+
80
+ ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
81
+
82
+
83
+ class LlamaRotaryEmbedding(nn.Module):
84
+ def __init__(
85
+ self,
86
+ dim=None,
87
+ max_position_embeddings=2048,
88
+ base=10000,
89
+ device=None,
90
+ scaling_factor=1.0,
91
+ rope_type="default",
92
+ config: Optional[BaseConfig] = None,
93
+ ):
94
+ super().__init__()
95
+ # TODO (joao): remove the `if` below, only used for BC
96
+ self.rope_kwargs = {}
97
+ if config is None:
98
+ logger.warning_once(
99
+ "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
100
+ "`config` argument. All other arguments will be removed in v4.46"
101
+ )
102
+ self.rope_kwargs = {
103
+ "rope_type": rope_type,
104
+ "factor": scaling_factor,
105
+ "dim": dim,
106
+ "base": base,
107
+ "max_position_embeddings": max_position_embeddings,
108
+ }
109
+ self.rope_type = rope_type
110
+ self.max_seq_len_cached = max_position_embeddings
111
+ self.original_max_seq_len = max_position_embeddings
112
+ else:
113
+ # BC: "rope_type" was originally "type"
114
+ if config.rope_scaling is not None:
115
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
116
+ else:
117
+ self.rope_type = "default"
118
+ self.max_seq_len_cached = config.max_position_embeddings
119
+ self.original_max_seq_len = config.max_position_embeddings
120
+
121
+ self.config = config
122
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
123
+
124
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
125
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
126
+ self.original_inv_freq = self.inv_freq
127
+
128
+ def _dynamic_frequency_update(self, position_ids, device):
129
+ """
130
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
131
+ 1 - growing beyond the cached sequence length (allow scaling)
132
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
133
+ """
134
+ seq_len = torch.max(position_ids) + 1
135
+ if seq_len > self.max_seq_len_cached: # growth
136
+ inv_freq, self.attention_scaling = self.rope_init_fn(
137
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
138
+ )
139
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
140
+ self.max_seq_len_cached = seq_len
141
+
142
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
143
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
144
+ self.max_seq_len_cached = self.original_max_seq_len
145
+
146
+ @torch.no_grad()
147
+ def forward(self, x, position_ids):
148
+ if "dynamic" in self.rope_type:
149
+ self._dynamic_frequency_update(position_ids, device=x.device)
150
+
151
+ # Core RoPE block
152
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
153
+ position_ids_expanded = position_ids[:, None, :].float()
154
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
155
+ device_type = x.device.type
156
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
157
+ with torch.autocast(device_type=device_type, enabled=False):
158
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
159
+ emb = torch.cat((freqs, freqs), dim=-1)
160
+ cos = emb.cos()
161
+ sin = emb.sin()
162
+
163
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
164
+ cos = cos * self.attention_scaling
165
+ sin = sin * self.attention_scaling
166
+
167
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
168
+
169
+
170
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
171
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
172
+
173
+ def __init__(self, *args, **kwargs):
174
+ logger.warning_once(
175
+ "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
176
+ "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
177
+ )
178
+ kwargs["rope_type"] = "linear"
179
+ super().__init__(*args, **kwargs)
180
+
181
+
182
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
183
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
184
+
185
+ def __init__(self, *args, **kwargs):
186
+ logger.warning_once(
187
+ "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
188
+ "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
189
+ "__init__)."
190
+ )
191
+ kwargs["rope_type"] = "dynamic"
192
+ super().__init__(*args, **kwargs)
193
+
194
+
195
+ def rotate_half(x):
196
+ """Rotates half the hidden dims of the input."""
197
+ x1 = x[..., : x.shape[-1] // 2]
198
+ x2 = x[..., x.shape[-1] // 2 :]
199
+ return torch.cat((-x2, x1), dim=-1)
200
+
201
+
202
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
203
+ """Applies Rotary Position Embedding to the query and key tensors.
204
+
205
+ Args:
206
+ q (`torch.Tensor`): The query tensor.
207
+ k (`torch.Tensor`): The key tensor.
208
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
209
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
210
+ position_ids (`torch.Tensor`, *optional*):
211
+ Deprecated and unused.
212
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
213
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
214
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
215
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
216
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
217
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
218
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
219
+ Returns:
220
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
221
+ """
222
+ cos = cos.unsqueeze(unsqueeze_dim)
223
+ sin = sin.unsqueeze(unsqueeze_dim)
224
+ q_embed = (q * cos) + (rotate_half(q) * sin)
225
+ k_embed = (k * cos) + (rotate_half(k) * sin)
226
+ return q_embed, k_embed
227
+
228
+
229
+ class LlamaMLP(nn.Module):
230
+ def __init__(self, config):
231
+ super().__init__()
232
+ self.config = config
233
+ self.hidden_size = config.hidden_size
234
+ self.intermediate_size = config.intermediate_size
235
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
236
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
237
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
238
+ self.act_fn = ACT2FN[config.hidden_act]
239
+
240
+ def forward(self, x):
241
+ if self.config.pretraining_tp > 1:
242
+ slice = self.intermediate_size // self.config.pretraining_tp
243
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
244
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
245
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
246
+
247
+ gate_proj = torch.cat(
248
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
249
+ )
250
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
251
+
252
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
253
+ down_proj = [
254
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
255
+ ]
256
+ down_proj = sum(down_proj)
257
+ else:
258
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
259
+
260
+ return down_proj
261
+
262
+
263
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
264
+ """
265
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
266
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
267
+ """
268
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
269
+ if n_rep == 1:
270
+ return hidden_states
271
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
272
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
273
+
274
+
275
+ class LlamaAttention(nn.Module):
276
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
277
+
278
+ def __init__(self, config: BaseConfig, layer_idx: Optional[int] = None):
279
+ super().__init__()
280
+ self.config = config
281
+ self.layer_idx = layer_idx
282
+ if layer_idx is None:
283
+ logger.warning_once(
284
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
285
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
286
+ "when creating this class."
287
+ )
288
+
289
+ self.attention_dropout = config.attention_dropout
290
+ self.hidden_size = config.hidden_size
291
+ self.num_heads = config.num_attention_heads
292
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
293
+ self.num_key_value_heads = config.num_key_value_heads
294
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
295
+ self.max_position_embeddings = config.max_position_embeddings
296
+ self.rope_theta = config.rope_theta
297
+ self.is_causal = True
298
+
299
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
300
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
301
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
302
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
303
+
304
+ # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
305
+ self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
306
+
307
+ def forward(
308
+ self,
309
+ hidden_states: torch.Tensor,
310
+ attention_mask: Optional[torch.Tensor] = None,
311
+ position_ids: Optional[torch.LongTensor] = None,
312
+ past_key_value: Optional[Cache] = None,
313
+ output_attentions: bool = False,
314
+ use_cache: bool = False,
315
+ cache_position: Optional[torch.LongTensor] = None,
316
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
317
+ **kwargs,
318
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
319
+ bsz, q_len, _ = hidden_states.size()
320
+
321
+ if self.config.pretraining_tp > 1:
322
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
323
+ query_slices = self.q_proj.weight.split(
324
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
325
+ )
326
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
327
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
328
+
329
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
330
+ query_states = torch.cat(query_states, dim=-1)
331
+
332
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
333
+ key_states = torch.cat(key_states, dim=-1)
334
+
335
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
336
+ value_states = torch.cat(value_states, dim=-1)
337
+
338
+ else:
339
+ query_states = self.q_proj(hidden_states)
340
+ key_states = self.k_proj(hidden_states)
341
+ value_states = self.v_proj(hidden_states)
342
+
343
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
344
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
345
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
346
+
347
+ if position_embeddings is None:
348
+ logger.warning_once(
349
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
350
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
351
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
352
+ "removed and `position_embeddings` will be mandatory."
353
+ )
354
+ cos, sin = self.rotary_emb(value_states, position_ids)
355
+ else:
356
+ cos, sin = position_embeddings
357
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
358
+
359
+ if past_key_value is not None:
360
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
361
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
362
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
363
+
364
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
365
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
366
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
367
+
368
+ if attention_mask is not None: # no matter the length, we just slice it
369
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
370
+ attn_weights = attn_weights + causal_mask
371
+
372
+ # upcast attention to fp32
373
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
374
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
375
+ attn_output = torch.matmul(attn_weights, value_states)
376
+
377
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
378
+ raise ValueError(
379
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
380
+ f" {attn_output.size()}"
381
+ )
382
+
383
+ attn_output = attn_output.transpose(1, 2).contiguous()
384
+
385
+ attn_output = attn_output.reshape(bsz, q_len, -1)
386
+
387
+ if self.config.pretraining_tp > 1:
388
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
389
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
390
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
391
+ else:
392
+ attn_output = self.o_proj(attn_output)
393
+
394
+ if not output_attentions:
395
+ attn_weights = None
396
+
397
+ return attn_output, attn_weights, past_key_value
398
+
399
+
400
+ class LlamaFlashAttention2(LlamaAttention):
401
+ """
402
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
403
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
404
+ flash attention and deal with padding tokens in case the input contains any of them.
405
+ """
406
+
407
+ def __init__(self, *args, **kwargs):
408
+ super().__init__(*args, **kwargs)
409
+
410
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
411
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
412
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
413
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
414
+
415
+ def forward(
416
+ self,
417
+ hidden_states: torch.Tensor,
418
+ attention_mask: Optional[torch.LongTensor] = None,
419
+ position_ids: Optional[torch.LongTensor] = None,
420
+ past_key_value: Optional[Cache] = None,
421
+ output_attentions: bool = False,
422
+ use_cache: bool = False,
423
+ cache_position: Optional[torch.LongTensor] = None,
424
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
425
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
426
+ if isinstance(past_key_value, StaticCache):
427
+ raise ValueError(
428
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
429
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
430
+ )
431
+
432
+ output_attentions = False
433
+
434
+ bsz, q_len, _ = hidden_states.size()
435
+
436
+ query_states = self.q_proj(hidden_states)
437
+ key_states = self.k_proj(hidden_states)
438
+ value_states = self.v_proj(hidden_states)
439
+
440
+ # Flash attention requires the input to have the shape
441
+ # batch_size x seq_length x head_dim x hidden_dim
442
+ # therefore we just need to keep the original shape
443
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
444
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
445
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
446
+
447
+ if position_embeddings is None:
448
+ logger.warning_once(
449
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
450
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
451
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
452
+ "removed and `position_embeddings` will be mandatory."
453
+ )
454
+ cos, sin = self.rotary_emb(value_states, position_ids)
455
+ else:
456
+ cos, sin = position_embeddings
457
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
458
+
459
+ if past_key_value is not None:
460
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
461
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
462
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
463
+
464
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
465
+ # to be able to avoid many of these transpose/reshape/view.
466
+ query_states = query_states.transpose(1, 2)
467
+ key_states = key_states.transpose(1, 2)
468
+ value_states = value_states.transpose(1, 2)
469
+
470
+ dropout_rate = self.attention_dropout if self.training else 0.0
471
+
472
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
473
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
474
+ # cast them back in the correct dtype just to be sure everything works as expected.
475
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
476
+ # in fp32. (LlamaRMSNorm handles it correctly)
477
+
478
+ input_dtype = query_states.dtype
479
+ if input_dtype == torch.float32:
480
+ if torch.is_autocast_enabled():
481
+ target_dtype = torch.get_autocast_gpu_dtype()
482
+ # Handle the case where the model is quantized
483
+ elif hasattr(self.config, "_pre_quantization_dtype"):
484
+ target_dtype = self.config._pre_quantization_dtype
485
+ else:
486
+ target_dtype = self.q_proj.weight.dtype
487
+
488
+ logger.warning_once(
489
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
490
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
491
+ f" {target_dtype}."
492
+ )
493
+
494
+ query_states = query_states.to(target_dtype)
495
+ key_states = key_states.to(target_dtype)
496
+ value_states = value_states.to(target_dtype)
497
+
498
+ attn_output = _flash_attention_forward(
499
+ query_states,
500
+ key_states,
501
+ value_states,
502
+ attention_mask,
503
+ q_len,
504
+ position_ids=position_ids,
505
+ dropout=dropout_rate,
506
+ sliding_window=getattr(self, "sliding_window", None),
507
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
508
+ is_causal=self.is_causal,
509
+ )
510
+
511
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
512
+ attn_output = self.o_proj(attn_output)
513
+
514
+ if not output_attentions:
515
+ attn_weights = None
516
+
517
+ return attn_output, attn_weights, past_key_value
518
+
519
+
520
+ class LlamaSdpaAttention(LlamaAttention):
521
+ """
522
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
523
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
524
+ SDPA API.
525
+ """
526
+
527
+ # Adapted from LlamaAttention.forward
528
+ def forward(
529
+ self,
530
+ hidden_states: torch.Tensor,
531
+ attention_mask: Optional[torch.Tensor] = None,
532
+ position_ids: Optional[torch.LongTensor] = None,
533
+ past_key_value: Optional[Cache] = None,
534
+ output_attentions: bool = False,
535
+ use_cache: bool = False,
536
+ cache_position: Optional[torch.LongTensor] = None,
537
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
538
+ **kwargs,
539
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
540
+ if output_attentions:
541
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
542
+ logger.warning_once(
543
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
544
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
545
+ )
546
+ return super().forward(
547
+ hidden_states=hidden_states,
548
+ attention_mask=attention_mask,
549
+ position_ids=position_ids,
550
+ past_key_value=past_key_value,
551
+ output_attentions=output_attentions,
552
+ use_cache=use_cache,
553
+ cache_position=cache_position,
554
+ position_embeddings=position_embeddings,
555
+ )
556
+
557
+ bsz, q_len, _ = hidden_states.size()
558
+
559
+ query_states = self.q_proj(hidden_states)
560
+ key_states = self.k_proj(hidden_states)
561
+ value_states = self.v_proj(hidden_states)
562
+
563
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
564
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
565
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
566
+
567
+ if position_embeddings is None:
568
+ logger.warning_once(
569
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
570
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
571
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
572
+ "removed and `position_embeddings` will be mandatory."
573
+ )
574
+ cos, sin = self.rotary_emb(value_states, position_ids)
575
+ else:
576
+ cos, sin = position_embeddings
577
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
578
+
579
+ if past_key_value is not None:
580
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
581
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
582
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
583
+
584
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
585
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
586
+
587
+ causal_mask = attention_mask
588
+ if attention_mask is not None:
589
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
590
+
591
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
592
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
593
+ if query_states.device.type == "cuda" and causal_mask is not None:
594
+ query_states = query_states.contiguous()
595
+ key_states = key_states.contiguous()
596
+ value_states = value_states.contiguous()
597
+
598
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
599
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
600
+ is_causal = True if causal_mask is None and q_len > 1 else False
601
+
602
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
603
+ query_states,
604
+ key_states,
605
+ value_states,
606
+ attn_mask=causal_mask,
607
+ dropout_p=self.attention_dropout if self.training else 0.0,
608
+ is_causal=is_causal,
609
+ )
610
+
611
+ attn_output = attn_output.transpose(1, 2).contiguous()
612
+ attn_output = attn_output.view(bsz, q_len, -1)
613
+
614
+ attn_output = self.o_proj(attn_output)
615
+
616
+ return attn_output, None, past_key_value
617
+
618
+
619
+ LLAMA_ATTENTION_CLASSES = {
620
+ "eager": LlamaAttention,
621
+ "flash_attention_2": LlamaFlashAttention2,
622
+ "sdpa": LlamaSdpaAttention,
623
+ }
624
+
625
+
626
+ class LlamaDecoderLayer(nn.Module):
627
+ def __init__(self, config: BaseConfig, layer_idx: int):
628
+ super().__init__()
629
+ self.hidden_size = config.hidden_size
630
+
631
+ self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
632
+
633
+ self.mlp = LlamaMLP(config)
634
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
635
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
636
+
637
+ def forward(
638
+ self,
639
+ hidden_states: torch.Tensor,
640
+ attention_mask: Optional[torch.Tensor] = None,
641
+ position_ids: Optional[torch.LongTensor] = None,
642
+ past_key_value: Optional[Cache] = None,
643
+ output_attentions: Optional[bool] = False,
644
+ use_cache: Optional[bool] = False,
645
+ cache_position: Optional[torch.LongTensor] = None,
646
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
647
+ **kwargs,
648
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
649
+ """
650
+ Args:
651
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
652
+ attention_mask (`torch.FloatTensor`, *optional*):
653
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
654
+ query_sequence_length, key_sequence_length)` if default attention is used.
655
+ output_attentions (`bool`, *optional*):
656
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
657
+ returned tensors for more detail.
658
+ use_cache (`bool`, *optional*):
659
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
660
+ (see `past_key_values`).
661
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
662
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
663
+ Indices depicting the position of the input sequence tokens in the sequence
664
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
665
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
666
+ with `head_dim` being the embedding dimension of each attention head.
667
+ kwargs (`dict`, *optional*):
668
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
669
+ into the model
670
+ """
671
+ residual = hidden_states
672
+
673
+ hidden_states = self.input_layernorm(hidden_states)
674
+
675
+ # Self Attention
676
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
677
+ hidden_states=hidden_states,
678
+ attention_mask=attention_mask,
679
+ position_ids=position_ids,
680
+ past_key_value=past_key_value,
681
+ output_attentions=output_attentions,
682
+ use_cache=use_cache,
683
+ cache_position=cache_position,
684
+ position_embeddings=position_embeddings,
685
+ **kwargs,
686
+ )
687
+ hidden_states = residual + hidden_states
688
+
689
+ # Fully Connected
690
+ residual = hidden_states
691
+ hidden_states = self.post_attention_layernorm(hidden_states)
692
+ hidden_states = self.mlp(hidden_states)
693
+ hidden_states = residual + hidden_states
694
+
695
+ outputs = (hidden_states,)
696
+
697
+ if output_attentions:
698
+ outputs += (self_attn_weights,)
699
+
700
+ if use_cache:
701
+ outputs += (present_key_value,)
702
+
703
+ return outputs
704
+
705
+
706
+ LLAMA_START_DOCSTRING = r"""
707
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
708
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
709
+ etc.)
710
+
711
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
712
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
713
+ and behavior.
714
+
715
+ Parameters:
716
+ config ([`BaseConfig`]):
717
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
718
+ load the weights associated with the model, only the configuration. Check out the
719
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
720
+ """
721
+
722
+
723
+ @add_start_docstrings(
724
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
725
+ LLAMA_START_DOCSTRING,
726
+ )
727
+ class LlamaPreTrainedModel(PreTrainedModel):
728
+ config_class = VLMConfig
729
+ base_model_prefix = "model"
730
+ supports_gradient_checkpointing = True
731
+ _no_split_modules = ["LlamaDecoderLayer"]
732
+ _skip_keys_device_placement = ["past_key_values"]
733
+ _supports_flash_attn_2 = True
734
+ _supports_sdpa = True
735
+ _supports_cache_class = True
736
+ _supports_quantized_cache = True
737
+ _supports_static_cache = True
738
+
739
+ def _init_weights(self, module):
740
+ std = self.config.initializer_range
741
+ if isinstance(module, nn.Linear):
742
+ module.weight.data.normal_(mean=0.0, std=std)
743
+ if module.bias is not None:
744
+ module.bias.data.zero_()
745
+ elif isinstance(module, nn.Embedding):
746
+ module.weight.data.normal_(mean=0.0, std=std)
747
+ if module.padding_idx is not None:
748
+ module.weight.data[module.padding_idx].zero_()
749
+
750
+
751
+ LLAMA_INPUTS_DOCSTRING = r"""
752
+ Args:
753
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
754
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
755
+ it.
756
+
757
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
758
+ [`PreTrainedTokenizer.__call__`] for details.
759
+
760
+ [What are input IDs?](../glossary#input-ids)
761
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
762
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
763
+
764
+ - 1 for tokens that are **not masked**,
765
+ - 0 for tokens that are **masked**.
766
+
767
+ [What are attention masks?](../glossary#attention-mask)
768
+
769
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
770
+ [`PreTrainedTokenizer.__call__`] for details.
771
+
772
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
773
+ `past_key_values`).
774
+
775
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
776
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
777
+ information on the default strategy.
778
+
779
+ - 1 indicates the head is **not masked**,
780
+ - 0 indicates the head is **masked**.
781
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
782
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
783
+ config.n_positions - 1]`.
784
+
785
+ [What are position IDs?](../glossary#position-ids)
786
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
787
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
788
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
789
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
790
+
791
+ Two formats are allowed:
792
+ - a [`~cache_utils.Cache`] instance, see our
793
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
794
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
795
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
796
+ cache format.
797
+
798
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
799
+ legacy cache format will be returned.
800
+
801
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
802
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
803
+ of shape `(batch_size, sequence_length)`.
804
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
805
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
806
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
807
+ model's internal embedding lookup matrix.
808
+ use_cache (`bool`, *optional*):
809
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
810
+ `past_key_values`).
811
+ output_attentions (`bool`, *optional*):
812
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
813
+ tensors for more detail.
814
+ output_hidden_states (`bool`, *optional*):
815
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
816
+ more detail.
817
+ return_dict (`bool`, *optional*):
818
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
819
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
820
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
821
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
822
+ the complete sequence length.
823
+ """
824
+
825
+
826
+ @add_start_docstrings(
827
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
828
+ LLAMA_START_DOCSTRING,
829
+ )
830
+ class LlamaModel(LlamaPreTrainedModel):
831
+ """
832
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
833
+
834
+ Args:
835
+ config: BaseConfig
836
+ """
837
+
838
+ def __init__(self, config: BaseConfig):
839
+ super().__init__(config)
840
+ self.padding_idx = config.pad_token_id
841
+ self.vocab_size = config.vocab_size
842
+
843
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
844
+ self.layers = nn.ModuleList(
845
+ [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
846
+ )
847
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
848
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
849
+ self.gradient_checkpointing = False
850
+
851
+ # Initialize weights and apply final processing
852
+ self.post_init()
853
+
854
+ def get_input_embeddings(self):
855
+ return self.embed_tokens
856
+
857
+ def set_input_embeddings(self, value):
858
+ self.embed_tokens = value
859
+
860
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
861
+ def forward(
862
+ self,
863
+ input_ids: torch.LongTensor = None,
864
+ attention_mask: Optional[torch.Tensor] = None,
865
+ position_ids: Optional[torch.LongTensor] = None,
866
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
867
+ inputs_embeds: Optional[torch.FloatTensor] = None,
868
+ use_cache: Optional[bool] = None,
869
+ output_attentions: Optional[bool] = None,
870
+ output_hidden_states: Optional[bool] = None,
871
+ return_dict: Optional[bool] = None,
872
+ cache_position: Optional[torch.LongTensor] = None,
873
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
874
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
875
+ output_hidden_states = (
876
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
877
+ )
878
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
879
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
880
+
881
+ if (input_ids is None) ^ (inputs_embeds is not None):
882
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
883
+
884
+ if self.gradient_checkpointing and self.training and use_cache:
885
+ logger.warning_once(
886
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
887
+ )
888
+ use_cache = False
889
+
890
+ if inputs_embeds is None:
891
+ inputs_embeds = self.embed_tokens(input_ids)
892
+
893
+ # kept for BC (non `Cache` `past_key_values` inputs)
894
+ return_legacy_cache = False
895
+ if use_cache and not isinstance(past_key_values, Cache):
896
+ return_legacy_cache = True
897
+ if past_key_values is None:
898
+ past_key_values = DynamicCache()
899
+ else:
900
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
901
+ logger.warning_once(
902
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
903
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
904
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
905
+ )
906
+
907
+ if cache_position is None:
908
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
909
+ cache_position = torch.arange(
910
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
911
+ )
912
+ if position_ids is None:
913
+ position_ids = cache_position.unsqueeze(0)
914
+
915
+ causal_mask = self._update_causal_mask(
916
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
917
+ )
918
+ hidden_states = inputs_embeds
919
+
920
+ # create position embeddings to be shared across the decoder layers
921
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
922
+
923
+ # decoder layers
924
+ all_hidden_states = () if output_hidden_states else None
925
+ all_self_attns = () if output_attentions else None
926
+ next_decoder_cache = None
927
+
928
+ for decoder_layer in self.layers:
929
+ if output_hidden_states:
930
+ all_hidden_states += (hidden_states,)
931
+
932
+ if self.gradient_checkpointing and self.training:
933
+ layer_outputs = self._gradient_checkpointing_func(
934
+ decoder_layer.__call__,
935
+ hidden_states,
936
+ causal_mask,
937
+ position_ids,
938
+ past_key_values,
939
+ output_attentions,
940
+ use_cache,
941
+ cache_position,
942
+ position_embeddings,
943
+ )
944
+ else:
945
+ layer_outputs = decoder_layer(
946
+ hidden_states,
947
+ attention_mask=causal_mask,
948
+ position_ids=position_ids,
949
+ past_key_value=past_key_values,
950
+ output_attentions=output_attentions,
951
+ use_cache=use_cache,
952
+ cache_position=cache_position,
953
+ position_embeddings=position_embeddings,
954
+ )
955
+
956
+ hidden_states = layer_outputs[0]
957
+
958
+ if use_cache:
959
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
960
+
961
+ if output_attentions:
962
+ all_self_attns += (layer_outputs[1],)
963
+
964
+ hidden_states = self.norm(hidden_states)
965
+
966
+ # add hidden states from the last decoder layer
967
+ if output_hidden_states:
968
+ all_hidden_states += (hidden_states,)
969
+
970
+ next_cache = next_decoder_cache if use_cache else None
971
+ if return_legacy_cache:
972
+ next_cache = next_cache.to_legacy_cache()
973
+
974
+ if not return_dict:
975
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
976
+ return BaseModelOutputWithPast(
977
+ last_hidden_state=hidden_states,
978
+ past_key_values=next_cache,
979
+ hidden_states=all_hidden_states,
980
+ attentions=all_self_attns,
981
+ )
982
+
983
+ def _update_causal_mask(
984
+ self,
985
+ attention_mask: torch.Tensor,
986
+ input_tensor: torch.Tensor,
987
+ cache_position: torch.Tensor,
988
+ past_key_values: Cache,
989
+ output_attentions: bool,
990
+ ):
991
+ if self.config._attn_implementation == "flash_attention_2":
992
+ if attention_mask is not None and 0.0 in attention_mask:
993
+ return attention_mask
994
+ return None
995
+
996
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
997
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
998
+ # to infer the attention mask.
999
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1000
+ using_static_cache = isinstance(past_key_values, StaticCache)
1001
+
1002
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1003
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
1004
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1005
+ attention_mask,
1006
+ inputs_embeds=input_tensor,
1007
+ past_key_values_length=past_seen_tokens,
1008
+ is_training=self.training,
1009
+ ):
1010
+ return None
1011
+
1012
+ dtype, device = input_tensor.dtype, input_tensor.device
1013
+ sequence_length = input_tensor.shape[1]
1014
+ if using_static_cache:
1015
+ target_length = past_key_values.get_max_cache_shape()
1016
+ else:
1017
+ target_length = (
1018
+ attention_mask.shape[-1]
1019
+ if isinstance(attention_mask, torch.Tensor)
1020
+ else past_seen_tokens + sequence_length + 1
1021
+ )
1022
+
1023
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1024
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
1025
+ attention_mask,
1026
+ sequence_length=sequence_length,
1027
+ target_length=target_length,
1028
+ dtype=dtype,
1029
+ device=device,
1030
+ cache_position=cache_position,
1031
+ batch_size=input_tensor.shape[0],
1032
+ )
1033
+
1034
+ if (
1035
+ self.config._attn_implementation == "sdpa"
1036
+ and attention_mask is not None
1037
+ and attention_mask.device.type == "cuda"
1038
+ and not output_attentions
1039
+ ):
1040
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1041
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1042
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1043
+ min_dtype = torch.finfo(dtype).min
1044
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1045
+
1046
+ return causal_mask
1047
+
1048
+ @staticmethod
1049
+ def _prepare_4d_causal_attention_mask_with_cache_position(
1050
+ attention_mask: torch.Tensor,
1051
+ sequence_length: int,
1052
+ target_length: int,
1053
+ dtype: torch.dtype,
1054
+ device: torch.device,
1055
+ cache_position: torch.Tensor,
1056
+ batch_size: int,
1057
+ **kwargs,
1058
+ ):
1059
+ """
1060
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
1061
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
1062
+
1063
+ Args:
1064
+ attention_mask (`torch.Tensor`):
1065
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
1066
+ `(batch_size, 1, query_length, key_value_length)`.
1067
+ sequence_length (`int`):
1068
+ The sequence length being processed.
1069
+ target_length (`int`):
1070
+ The target length: when generating with static cache, the mask should be as long as the static cache,
1071
+ to account for the 0 padding, the part of the cache that is not filled yet.
1072
+ dtype (`torch.dtype`):
1073
+ The dtype to use for the 4D attention mask.
1074
+ device (`torch.device`):
1075
+ The device to plcae the 4D attention mask on.
1076
+ cache_position (`torch.Tensor`):
1077
+ Indices depicting the position of the input sequence tokens in the sequence.
1078
+ batch_size (`torch.Tensor`):
1079
+ Batch size.
1080
+ """
1081
+ if attention_mask is not None and attention_mask.dim() == 4:
1082
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
1083
+ causal_mask = attention_mask
1084
+ else:
1085
+ min_dtype = torch.finfo(dtype).min
1086
+ causal_mask = torch.full(
1087
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
1088
+ )
1089
+ if sequence_length != 1:
1090
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1091
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1092
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
1093
+ if attention_mask is not None:
1094
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1095
+ mask_length = attention_mask.shape[-1]
1096
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
1097
+ padding_mask = padding_mask == 0
1098
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1099
+ padding_mask, min_dtype
1100
+ )
1101
+
1102
+ return causal_mask
1103
+
1104
+
1105
+ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
1106
+ _tied_weights_keys = ["lm_head.weight"]
1107
+
1108
+ def __init__(self, config):
1109
+ super().__init__(config)
1110
+ self.model = LlamaModel(config)
1111
+ self.vocab_size = config.vocab_size
1112
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1113
+
1114
+ # Initialize weights and apply final processing
1115
+ self.post_init()
1116
+
1117
+ def get_input_embeddings(self):
1118
+ return self.model.embed_tokens
1119
+
1120
+ def set_input_embeddings(self, value):
1121
+ self.model.embed_tokens = value
1122
+
1123
+ def get_output_embeddings(self):
1124
+ return self.lm_head
1125
+
1126
+ def set_output_embeddings(self, new_embeddings):
1127
+ self.lm_head = new_embeddings
1128
+
1129
+ def set_decoder(self, decoder):
1130
+ self.model = decoder
1131
+
1132
+ def get_decoder(self):
1133
+ return self.model
1134
+
1135
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1136
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1137
+ def _native_forward(
1138
+ self,
1139
+ input_ids: torch.LongTensor = None,
1140
+ attention_mask: Optional[torch.Tensor] = None,
1141
+ position_ids: Optional[torch.LongTensor] = None,
1142
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1143
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1144
+ labels: Optional[torch.LongTensor] = None,
1145
+ use_cache: Optional[bool] = None,
1146
+ output_attentions: Optional[bool] = None,
1147
+ output_hidden_states: Optional[bool] = None,
1148
+ return_dict: Optional[bool] = None,
1149
+ cache_position: Optional[torch.LongTensor] = None,
1150
+ num_logits_to_keep: int = 0,
1151
+ **loss_kwargs,
1152
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1153
+ r"""
1154
+ Args:
1155
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1156
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1157
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1158
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1159
+
1160
+ num_logits_to_keep (`int`, *optional*):
1161
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
1162
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1163
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1164
+
1165
+ Returns:
1166
+
1167
+ Example:
1168
+
1169
+ ```python
1170
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
1171
+
1172
+ >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
1173
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
1174
+
1175
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1176
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1177
+
1178
+ >>> # Generate
1179
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1180
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1181
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1182
+ ```"""
1183
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1184
+ output_hidden_states = (
1185
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1186
+ )
1187
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1188
+
1189
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1190
+ outputs = self.model(
1191
+ input_ids=input_ids,
1192
+ attention_mask=attention_mask,
1193
+ position_ids=position_ids,
1194
+ past_key_values=past_key_values,
1195
+ inputs_embeds=inputs_embeds,
1196
+ use_cache=use_cache,
1197
+ output_attentions=output_attentions,
1198
+ output_hidden_states=output_hidden_states,
1199
+ return_dict=return_dict,
1200
+ cache_position=cache_position,
1201
+ )
1202
+
1203
+ hidden_states = outputs[0]
1204
+ if self.config.pretraining_tp > 1:
1205
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1206
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1207
+ logits = torch.cat(logits, dim=-1)
1208
+ else:
1209
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1210
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
1211
+
1212
+ loss = None
1213
+ if labels is not None:
1214
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs)
1215
+
1216
+ if not return_dict:
1217
+ output = (logits,) + outputs[1:]
1218
+ return (loss,) + output if loss is not None else output
1219
+
1220
+ return CausalLMOutputWithPast(
1221
+ loss=loss,
1222
+ logits=logits,
1223
+ past_key_values=outputs.past_key_values,
1224
+ hidden_states=outputs.hidden_states,
1225
+ attentions=outputs.attentions,
1226
+ )
1227
+
1228
+ class AdapterMLP(nn.Module):
1229
+ def __init__(self, config):
1230
+ super().__init__()
1231
+ self.config = config
1232
+ self.hidden_size = config.hidden_size
1233
+ self.intermediate_size = config.intermediate_size
1234
+ self.gate_proj = nn.Linear(config.encoded_image_dimention, self.intermediate_size, bias=config.mlp_bias)
1235
+ self.up_proj = nn.Linear(config.encoded_image_dimention, self.intermediate_size, bias=config.mlp_bias)
1236
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
1237
+ self.act_fn = ACT2FN[config.hidden_act]
1238
+
1239
+ def forward(self, x):
1240
+ if self.config.pretraining_tp > 1:
1241
+ slice = self.intermediate_size // self.config.pretraining_tp
1242
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
1243
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
1244
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
1245
+
1246
+ gate_proj = torch.cat(
1247
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
1248
+ )
1249
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
1250
+
1251
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
1252
+ down_proj = [
1253
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
1254
+ ]
1255
+ down_proj = sum(down_proj)
1256
+ else:
1257
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
1258
+
1259
+ return down_proj
visual_modeling.py ADDED
@@ -0,0 +1,1128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch CLIP model."""
16
+
17
+ from dataclasses import dataclass
18
+ from typing import Any, Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.utils.checkpoint
22
+ from torch import nn
23
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
+ import torch.nn.functional as F
25
+
26
+ from transformers.activations import ACT2FN
27
+ from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
28
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
29
+ from transformers.modeling_utils import PreTrainedModel
30
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2
31
+ from transformers.utils import (
32
+ ModelOutput,
33
+ add_code_sample_docstrings,
34
+ add_start_docstrings,
35
+ add_start_docstrings_to_model_forward,
36
+ is_flash_attn_2_available,
37
+ is_flash_attn_greater_or_equal_2_10,
38
+ logging,
39
+ replace_return_docstrings,
40
+ torch_int,
41
+ )
42
+ try:
43
+ from configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
44
+ except ImportError:
45
+ from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
46
+
47
+
48
+ if is_flash_attn_2_available():
49
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
50
+
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+ # General docstring
55
+ _CONFIG_FOR_DOC = "CLIPConfig"
56
+ _CHECKPOINT_FOR_DOC = "openai/clip-vit-base-patch32"
57
+
58
+ # Image classification docstring
59
+ _IMAGE_CLASS_CHECKPOINT = "openai/clip-vit-base-patch32"
60
+ _IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_0"
61
+
62
+
63
+ # contrastive loss function, adapted from
64
+ # https://sachinruk.github.io/blog/2021-03-07-clip.html
65
+ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
66
+ return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
67
+
68
+
69
+ def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
70
+ caption_loss = contrastive_loss(similarity)
71
+ image_loss = contrastive_loss(similarity.t())
72
+ return (caption_loss + image_loss) / 2.0
73
+
74
+
75
+ def _get_vector_norm(tensor: torch.Tensor) -> torch.Tensor:
76
+ """
77
+ This method is equivalent to tensor.norm(p=2, dim=-1, keepdim=True) and used to make
78
+ model `executorch` exportable. See issue https://github.com/pytorch/executorch/issues/3566
79
+ """
80
+ square_tensor = torch.pow(tensor, 2)
81
+ sum_tensor = torch.sum(square_tensor, dim=-1, keepdim=True)
82
+ normed_tensor = torch.pow(sum_tensor, 0.5)
83
+ return normed_tensor
84
+
85
+
86
+ @dataclass
87
+ class CLIPVisionModelOutput(ModelOutput):
88
+ """
89
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
90
+
91
+ Args:
92
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
93
+ The image embeddings obtained by applying the projection layer to the pooler_output.
94
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
95
+ Sequence of hidden-states at the output of the last layer of the model.
96
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
97
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
98
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
99
+
100
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
101
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
102
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
103
+ sequence_length)`.
104
+
105
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
106
+ heads.
107
+ """
108
+
109
+ image_embeds: Optional[torch.FloatTensor] = None
110
+ last_hidden_state: torch.FloatTensor = None
111
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
112
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
113
+
114
+
115
+ @dataclass
116
+ class CLIPTextModelOutput(ModelOutput):
117
+ """
118
+ Base class for text model's outputs that also contains a pooling of the last hidden states.
119
+
120
+ Args:
121
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
122
+ The text embeddings obtained by applying the projection layer to the pooler_output.
123
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
124
+ Sequence of hidden-states at the output of the last layer of the model.
125
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
126
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
127
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
128
+
129
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
130
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
131
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
132
+ sequence_length)`.
133
+
134
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
135
+ heads.
136
+ """
137
+
138
+ text_embeds: Optional[torch.FloatTensor] = None
139
+ last_hidden_state: torch.FloatTensor = None
140
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
141
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
142
+
143
+
144
+ @dataclass
145
+ class CLIPOutput(ModelOutput):
146
+ """
147
+ Args:
148
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
149
+ Contrastive loss for image-text similarity.
150
+ logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
151
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
152
+ similarity scores.
153
+ logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
154
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
155
+ similarity scores.
156
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
157
+ The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`].
158
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
159
+ The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`].
160
+ text_model_output (`BaseModelOutputWithPooling`):
161
+ The output of the [`CLIPTextModel`].
162
+ vision_model_output (`BaseModelOutputWithPooling`):
163
+ The output of the [`CLIPVisionModel`].
164
+ """
165
+
166
+ loss: Optional[torch.FloatTensor] = None
167
+ logits_per_image: torch.FloatTensor = None
168
+ logits_per_text: torch.FloatTensor = None
169
+ text_embeds: torch.FloatTensor = None
170
+ image_embeds: torch.FloatTensor = None
171
+ text_model_output: BaseModelOutputWithPooling = None
172
+ vision_model_output: BaseModelOutputWithPooling = None
173
+
174
+ def to_tuple(self) -> Tuple[Any]:
175
+ return tuple(
176
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
177
+ for k in self.keys()
178
+ )
179
+
180
+
181
+ class CLIPVisionEmbeddings(nn.Module):
182
+ def __init__(self, config: CLIPVisionConfig):
183
+ super().__init__()
184
+ self.config = config
185
+ self.embed_dim = config.hidden_size
186
+ self.image_size = config.image_size
187
+ self.patch_size = config.patch_size
188
+
189
+ self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
190
+
191
+ self.patch_embedding = nn.Conv2d(
192
+ in_channels=config.num_channels,
193
+ out_channels=self.embed_dim,
194
+ kernel_size=self.patch_size,
195
+ stride=self.patch_size,
196
+ bias=False,
197
+ )
198
+
199
+ self.num_patches = (self.image_size // self.patch_size) ** 2
200
+ self.num_positions = self.num_patches + 1
201
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
202
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
203
+
204
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
205
+ """
206
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
207
+ images. This method is also adapted to support torch.jit tracing.
208
+
209
+ Adapted from:
210
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
211
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
212
+ """
213
+
214
+ num_patches = embeddings.shape[1] - 1
215
+ position_embedding = self.position_embedding.weight.unsqueeze(0)
216
+ num_positions = position_embedding.shape[1] - 1
217
+
218
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
219
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
220
+ return self.position_embedding(self.position_ids)
221
+
222
+ class_pos_embed = position_embedding[:, :1]
223
+ patch_pos_embed = position_embedding[:, 1:]
224
+
225
+ dim = embeddings.shape[-1]
226
+
227
+ new_height = height // self.patch_size
228
+ new_width = width // self.patch_size
229
+
230
+ sqrt_num_positions = torch_int(num_positions**0.5)
231
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
232
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
233
+
234
+ patch_pos_embed = nn.functional.interpolate(
235
+ patch_pos_embed,
236
+ size=(new_height, new_width),
237
+ mode="bicubic",
238
+ align_corners=False,
239
+ )
240
+
241
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
242
+
243
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
244
+
245
+ def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
246
+ batch_size, _, height, width = pixel_values.shape
247
+ if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
248
+ raise ValueError(
249
+ f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size}*{self.image_size})."
250
+ )
251
+ target_dtype = self.patch_embedding.weight.dtype
252
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
253
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
254
+
255
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
256
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
257
+ if interpolate_pos_encoding:
258
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
259
+ else:
260
+ embeddings = embeddings + self.position_embedding(self.position_ids)
261
+ return embeddings
262
+
263
+
264
+ class CLIPTextEmbeddings(nn.Module):
265
+ def __init__(self, config: CLIPTextConfig):
266
+ super().__init__()
267
+ embed_dim = config.hidden_size
268
+
269
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
270
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
271
+
272
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
273
+ self.register_buffer(
274
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
275
+ )
276
+
277
+ def forward(
278
+ self,
279
+ input_ids: Optional[torch.LongTensor] = None,
280
+ position_ids: Optional[torch.LongTensor] = None,
281
+ inputs_embeds: Optional[torch.FloatTensor] = None,
282
+ ) -> torch.Tensor:
283
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
284
+
285
+ if position_ids is None:
286
+ position_ids = self.position_ids[:, :seq_length]
287
+
288
+ if inputs_embeds is None:
289
+ inputs_embeds = self.token_embedding(input_ids)
290
+
291
+ position_embeddings = self.position_embedding(position_ids)
292
+ embeddings = inputs_embeds + position_embeddings
293
+
294
+ return embeddings
295
+
296
+
297
+ class CLIPAttention(nn.Module):
298
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
299
+
300
+ def __init__(self, config):
301
+ super().__init__()
302
+ self.config = config
303
+ self.embed_dim = config.hidden_size
304
+ self.num_heads = config.num_attention_heads
305
+ self.head_dim = self.embed_dim // self.num_heads
306
+ if self.head_dim * self.num_heads != self.embed_dim:
307
+ raise ValueError(
308
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
309
+ f" {self.num_heads})."
310
+ )
311
+ self.scale = self.head_dim**-0.5
312
+ self.dropout = config.attention_dropout
313
+
314
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
315
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
316
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
317
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
318
+
319
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
320
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
321
+
322
+ def forward(
323
+ self,
324
+ hidden_states: torch.Tensor,
325
+ attention_mask: Optional[torch.Tensor] = None,
326
+ causal_attention_mask: Optional[torch.Tensor] = None,
327
+ output_attentions: Optional[bool] = False,
328
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
329
+ """Input shape: Batch x Time x Channel"""
330
+
331
+ bsz, tgt_len, embed_dim = hidden_states.size()
332
+
333
+ # get query proj
334
+ query_states = self.q_proj(hidden_states) * self.scale
335
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
336
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
337
+
338
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
339
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
340
+ key_states = key_states.view(*proj_shape)
341
+ value_states = value_states.view(*proj_shape)
342
+
343
+ src_len = key_states.size(1)
344
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
345
+
346
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
347
+ raise ValueError(
348
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
349
+ f" {attn_weights.size()}"
350
+ )
351
+
352
+ # apply the causal_attention_mask first
353
+ if causal_attention_mask is not None:
354
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
355
+ raise ValueError(
356
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
357
+ f" {causal_attention_mask.size()}"
358
+ )
359
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
360
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
361
+
362
+ if attention_mask is not None:
363
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
364
+ raise ValueError(
365
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
366
+ )
367
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
368
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
369
+
370
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
371
+
372
+ if output_attentions:
373
+ # this operation is a bit akward, but it's required to
374
+ # make sure that attn_weights keeps its gradient.
375
+ # In order to do so, attn_weights have to reshaped
376
+ # twice and have to be reused in the following
377
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
378
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
379
+ else:
380
+ attn_weights_reshaped = None
381
+
382
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
383
+
384
+ attn_output = torch.bmm(attn_probs, value_states)
385
+
386
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
387
+ raise ValueError(
388
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
389
+ f" {attn_output.size()}"
390
+ )
391
+
392
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
393
+ attn_output = attn_output.transpose(1, 2)
394
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
395
+
396
+ attn_output = self.out_proj(attn_output)
397
+
398
+ return attn_output, attn_weights_reshaped
399
+
400
+
401
+ class CLIPFlashAttention2(CLIPAttention):
402
+ """
403
+ CLIPAttention flash attention module. This module inherits from `CLIPAttention` as the weights of the module stays
404
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
405
+ flash attention and deal with padding tokens in case the input contains any of them.
406
+ """
407
+
408
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
409
+ def __init__(self, *args, **kwargs):
410
+ super().__init__(*args, **kwargs)
411
+
412
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
413
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
414
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
415
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
416
+
417
+ # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
418
+ def forward(
419
+ self,
420
+ hidden_states: torch.Tensor,
421
+ attention_mask: Optional[torch.Tensor] = None,
422
+ causal_attention_mask: Optional[torch.Tensor] = None,
423
+ output_attentions: Optional[bool] = False,
424
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
425
+ output_attentions = False
426
+
427
+ batch_size, q_len, _ = hidden_states.size()
428
+
429
+ query_states = self.q_proj(hidden_states)
430
+ key_states = self.k_proj(hidden_states)
431
+ value_states = self.v_proj(hidden_states)
432
+
433
+ # Flash attention requires the input to have the shape
434
+ # batch_size x seq_length x head_dim x hidden_dim
435
+ # therefore we just need to keep the original shape
436
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim)
437
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim)
438
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim)
439
+
440
+ dropout_rate = self.dropout if self.training else 0.0
441
+
442
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
443
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
444
+ # cast them back in the correct dtype just to be sure everything works as expected.
445
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
446
+ # in fp32.
447
+
448
+ input_dtype = query_states.dtype
449
+ if input_dtype == torch.float32:
450
+ if torch.is_autocast_enabled():
451
+ target_dtype = torch.get_autocast_gpu_dtype()
452
+ # Handle the case where the model is quantized
453
+ elif hasattr(self.config, "_pre_quantization_dtype"):
454
+ target_dtype = self.config._pre_quantization_dtype
455
+ else:
456
+ target_dtype = self.q_proj.weight.dtype
457
+
458
+ logger.warning_once(
459
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
460
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
461
+ f" {target_dtype}."
462
+ )
463
+
464
+ query_states = query_states.to(target_dtype)
465
+ key_states = key_states.to(target_dtype)
466
+ value_states = value_states.to(target_dtype)
467
+
468
+ attn_output = _flash_attention_forward(
469
+ query_states,
470
+ key_states,
471
+ value_states,
472
+ attention_mask,
473
+ q_len,
474
+ dropout=dropout_rate,
475
+ is_causal=causal_attention_mask is not None,
476
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
477
+ )
478
+
479
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
480
+ attn_output = self.out_proj(attn_output)
481
+
482
+ if not output_attentions:
483
+ attn_weights = None
484
+
485
+ return attn_output, attn_weights
486
+
487
+
488
+ class CLIPSdpaAttention(CLIPAttention):
489
+ """
490
+ SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
491
+ `CLIPAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
492
+ SDPA API.
493
+ """
494
+
495
+ # Adapted from CLIPAttention.forward
496
+ def forward(
497
+ self,
498
+ hidden_states: torch.Tensor,
499
+ attention_mask: Optional[torch.Tensor] = None,
500
+ causal_attention_mask: Optional[torch.Tensor] = None,
501
+ output_attentions: Optional[bool] = False,
502
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
503
+ if output_attentions:
504
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
505
+ logger.warning_once(
506
+ "CLIPModel is using CLIPSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
507
+ "support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
508
+ "the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
509
+ 'be removed using the argument `attn_implementation="eager"` when loading the model.'
510
+ )
511
+ return super().forward(
512
+ hidden_states=hidden_states,
513
+ attention_mask=attention_mask,
514
+ causal_attention_mask=causal_attention_mask,
515
+ output_attentions=output_attentions,
516
+ )
517
+
518
+ # CLIP text model uses both `causal_attention_mask` and `attention_mask`
519
+ if attention_mask is not None and causal_attention_mask is not None:
520
+ attn_mask = attention_mask + causal_attention_mask
521
+ elif causal_attention_mask is not None:
522
+ attn_mask = causal_attention_mask
523
+ else:
524
+ attn_mask = attention_mask
525
+
526
+ bsz, tgt_len, embed_dim = hidden_states.size()
527
+
528
+ query_states = self.q_proj(hidden_states)
529
+ key_states = self.k_proj(hidden_states)
530
+ value_states = self.v_proj(hidden_states)
531
+
532
+ query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
533
+ key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
534
+ value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
535
+
536
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
537
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
538
+ if not is_torch_greater_or_equal_than_2_2 and query_states.device.type == "cuda" and attn_mask is not None:
539
+ query_states = query_states.contiguous()
540
+ key_states = key_states.contiguous()
541
+ value_states = value_states.contiguous()
542
+
543
+ # CLIP text model uses both `causal_attention_mask` and `attention_mask` sequentially.
544
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
545
+ query_states,
546
+ key_states,
547
+ value_states,
548
+ attn_mask=attn_mask,
549
+ dropout_p=self.dropout if self.training else 0.0,
550
+ scale=self.scale,
551
+ )
552
+
553
+ attn_output = attn_output.transpose(1, 2)
554
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
555
+
556
+ attn_output = self.out_proj(attn_output)
557
+
558
+ return attn_output, None
559
+
560
+
561
+ CLIP_ATTENTION_CLASSES = {
562
+ "eager": CLIPAttention,
563
+ "sdpa": CLIPSdpaAttention,
564
+ "flash_attention_2": CLIPFlashAttention2,
565
+ }
566
+
567
+
568
+ class CLIPMLP(nn.Module):
569
+ def __init__(self, config):
570
+ super().__init__()
571
+ self.config = config
572
+ self.activation_fn = ACT2FN[config.hidden_act]
573
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
574
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
575
+
576
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
577
+ hidden_states = self.fc1(hidden_states)
578
+ hidden_states = self.activation_fn(hidden_states)
579
+ hidden_states = self.fc2(hidden_states)
580
+ return hidden_states
581
+
582
+
583
+ class CLIPEncoderLayer(nn.Module):
584
+ def __init__(self, config: CLIPConfig):
585
+ super().__init__()
586
+ self.embed_dim = config.hidden_size
587
+ self.self_attn = CLIP_ATTENTION_CLASSES[config._attn_implementation](config)
588
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
589
+ self.mlp = CLIPMLP(config)
590
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
591
+
592
+ def forward(
593
+ self,
594
+ hidden_states: torch.Tensor,
595
+ attention_mask: torch.Tensor,
596
+ causal_attention_mask: torch.Tensor,
597
+ output_attentions: Optional[bool] = False,
598
+ ) -> Tuple[torch.FloatTensor]:
599
+ """
600
+ Args:
601
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
602
+ attention_mask (`torch.FloatTensor`): attention mask of size
603
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
604
+ `(config.encoder_attention_heads,)`.
605
+ output_attentions (`bool`, *optional*):
606
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
607
+ returned tensors for more detail.
608
+ """
609
+ residual = hidden_states
610
+
611
+ hidden_states = self.layer_norm1(hidden_states)
612
+ hidden_states, attn_weights = self.self_attn(
613
+ hidden_states=hidden_states,
614
+ attention_mask=attention_mask,
615
+ causal_attention_mask=causal_attention_mask,
616
+ output_attentions=output_attentions,
617
+ )
618
+ hidden_states = residual + hidden_states
619
+
620
+ residual = hidden_states
621
+ hidden_states = self.layer_norm2(hidden_states)
622
+ hidden_states = self.mlp(hidden_states)
623
+ hidden_states = residual + hidden_states
624
+
625
+ outputs = (hidden_states,)
626
+
627
+ if output_attentions:
628
+ outputs += (attn_weights,)
629
+
630
+ return outputs
631
+
632
+
633
+ class CLIPPreTrainedModel(PreTrainedModel):
634
+ """
635
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
636
+ models.
637
+ """
638
+
639
+ config_class = CLIPConfig
640
+ base_model_prefix = "clip"
641
+ supports_gradient_checkpointing = True
642
+ _supports_sdpa = True
643
+ _supports_flash_attn_2 = True
644
+
645
+ def _init_weights(self, module):
646
+ """Initialize the weights"""
647
+ factor = self.config.initializer_factor
648
+ if isinstance(module, CLIPTextEmbeddings):
649
+ module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
650
+ module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
651
+ elif isinstance(module, CLIPVisionEmbeddings):
652
+ factor = self.config.initializer_factor
653
+ nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
654
+ nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
655
+ nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
656
+ elif isinstance(module, CLIPAttention):
657
+ factor = self.config.initializer_factor
658
+ in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
659
+ out_proj_std = (module.embed_dim**-0.5) * factor
660
+ nn.init.normal_(module.q_proj.weight, std=in_proj_std)
661
+ nn.init.normal_(module.k_proj.weight, std=in_proj_std)
662
+ nn.init.normal_(module.v_proj.weight, std=in_proj_std)
663
+ nn.init.normal_(module.out_proj.weight, std=out_proj_std)
664
+ elif isinstance(module, CLIPMLP):
665
+ factor = self.config.initializer_factor
666
+ in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
667
+ fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
668
+ nn.init.normal_(module.fc1.weight, std=fc_std)
669
+ nn.init.normal_(module.fc2.weight, std=in_proj_std)
670
+ elif isinstance(module, CLIPModel):
671
+ nn.init.normal_(
672
+ module.text_projection.weight,
673
+ std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
674
+ )
675
+ nn.init.normal_(
676
+ module.visual_projection.weight,
677
+ std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
678
+ )
679
+
680
+ if isinstance(module, nn.LayerNorm):
681
+ module.bias.data.zero_()
682
+ module.weight.data.fill_(1.0)
683
+ if isinstance(module, nn.Linear) and module.bias is not None:
684
+ module.bias.data.zero_()
685
+
686
+
687
+ CLIP_START_DOCSTRING = r"""
688
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
689
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
690
+ etc.)
691
+
692
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
693
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
694
+ and behavior.
695
+
696
+ Parameters:
697
+ config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
698
+ Initializing with a config file does not load the weights associated with the model, only the
699
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
700
+ """
701
+
702
+ CLIP_TEXT_INPUTS_DOCSTRING = r"""
703
+ Args:
704
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
705
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
706
+ it.
707
+
708
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
709
+ [`PreTrainedTokenizer.__call__`] for details.
710
+
711
+ [What are input IDs?](../glossary#input-ids)
712
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
713
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
714
+
715
+ - 1 for tokens that are **not masked**,
716
+ - 0 for tokens that are **masked**.
717
+
718
+ [What are attention masks?](../glossary#attention-mask)
719
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
720
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
721
+ config.max_position_embeddings - 1]`.
722
+
723
+ [What are position IDs?](../glossary#position-ids)
724
+ output_attentions (`bool`, *optional*):
725
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
726
+ tensors for more detail.
727
+ output_hidden_states (`bool`, *optional*):
728
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
729
+ more detail.
730
+ return_dict (`bool`, *optional*):
731
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
732
+ """
733
+
734
+ CLIP_VISION_INPUTS_DOCSTRING = r"""
735
+ Args:
736
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
737
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
738
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
739
+ output_attentions (`bool`, *optional*):
740
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
741
+ tensors for more detail.
742
+ output_hidden_states (`bool`, *optional*):
743
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
744
+ more detail.
745
+ interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
746
+ Whether to interpolate the pre-trained position encodings.
747
+ return_dict (`bool`, *optional*):
748
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
749
+ """
750
+
751
+ CLIP_INPUTS_DOCSTRING = r"""
752
+ Args:
753
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
754
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
755
+ it.
756
+
757
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
758
+ [`PreTrainedTokenizer.__call__`] for details.
759
+
760
+ [What are input IDs?](../glossary#input-ids)
761
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
762
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
763
+
764
+ - 1 for tokens that are **not masked**,
765
+ - 0 for tokens that are **masked**.
766
+
767
+ [What are attention masks?](../glossary#attention-mask)
768
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
769
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
770
+ config.max_position_embeddings - 1]`.
771
+
772
+ [What are position IDs?](../glossary#position-ids)
773
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
774
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
775
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
776
+ return_loss (`bool`, *optional*):
777
+ Whether or not to return the contrastive loss.
778
+ output_attentions (`bool`, *optional*):
779
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
780
+ tensors for more detail.
781
+ output_hidden_states (`bool`, *optional*):
782
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
783
+ more detail.
784
+ interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
785
+ Whether to interpolate the pre-trained position encodings.
786
+ return_dict (`bool`, *optional*):
787
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
788
+ """
789
+
790
+
791
+ class CLIPEncoder(nn.Module):
792
+ """
793
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
794
+ [`CLIPEncoderLayer`].
795
+
796
+ Args:
797
+ config: CLIPConfig
798
+ """
799
+
800
+ def __init__(self, config: CLIPConfig):
801
+ super().__init__()
802
+ self.config = config
803
+ self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
804
+ self.gradient_checkpointing = False
805
+
806
+ def forward(
807
+ self,
808
+ inputs_embeds,
809
+ attention_mask: Optional[torch.Tensor] = None,
810
+ causal_attention_mask: Optional[torch.Tensor] = None,
811
+ output_attentions: Optional[bool] = None,
812
+ output_hidden_states: Optional[bool] = None,
813
+ return_dict: Optional[bool] = None,
814
+ ) -> Union[Tuple, BaseModelOutput]:
815
+ r"""
816
+ Args:
817
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
818
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
819
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
820
+ than the model's internal embedding lookup matrix.
821
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
822
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
823
+
824
+ - 1 for tokens that are **not masked**,
825
+ - 0 for tokens that are **masked**.
826
+
827
+ [What are attention masks?](../glossary#attention-mask)
828
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
829
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
830
+
831
+ - 1 for tokens that are **not masked**,
832
+ - 0 for tokens that are **masked**.
833
+
834
+ [What are attention masks?](../glossary#attention-mask)
835
+ output_attentions (`bool`, *optional*):
836
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
837
+ returned tensors for more detail.
838
+ output_hidden_states (`bool`, *optional*):
839
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
840
+ for more detail.
841
+ return_dict (`bool`, *optional*):
842
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
843
+ """
844
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
845
+ output_hidden_states = (
846
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
847
+ )
848
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
849
+
850
+ encoder_states = () if output_hidden_states else None
851
+ all_attentions = () if output_attentions else None
852
+
853
+ hidden_states = inputs_embeds
854
+ for idx, encoder_layer in enumerate(self.layers):
855
+ if output_hidden_states:
856
+ encoder_states = encoder_states + (hidden_states,)
857
+ if self.gradient_checkpointing and self.training:
858
+ layer_outputs = self._gradient_checkpointing_func(
859
+ encoder_layer.__call__,
860
+ hidden_states,
861
+ attention_mask,
862
+ causal_attention_mask,
863
+ output_attentions,
864
+ )
865
+ else:
866
+ layer_outputs = encoder_layer(
867
+ hidden_states,
868
+ attention_mask,
869
+ causal_attention_mask,
870
+ output_attentions=output_attentions,
871
+ )
872
+
873
+ hidden_states = layer_outputs[0]
874
+
875
+ if output_attentions:
876
+ all_attentions = all_attentions + (layer_outputs[1],)
877
+
878
+ if output_hidden_states:
879
+ encoder_states = encoder_states + (hidden_states,)
880
+
881
+ if not return_dict:
882
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
883
+ return BaseModelOutput(
884
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
885
+ )
886
+
887
+ class CLIPVisionTransformer(nn.Module):
888
+ def __init__(self, config: CLIPVisionConfig):
889
+ super().__init__()
890
+ self.config = config
891
+ embed_dim = config.hidden_size
892
+
893
+ self.embeddings = CLIPVisionEmbeddings(config)
894
+ self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
895
+ self.encoder = CLIPEncoder(config)
896
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
897
+
898
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
899
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
900
+ def forward(
901
+ self,
902
+ pixel_values: Optional[torch.FloatTensor] = None,
903
+ output_attentions: Optional[bool] = None,
904
+ output_hidden_states: Optional[bool] = None,
905
+ return_dict: Optional[bool] = None,
906
+ interpolate_pos_encoding: Optional[bool] = False,
907
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
908
+ r"""
909
+ Returns:
910
+
911
+ """
912
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
913
+ output_hidden_states = (
914
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
915
+ )
916
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
917
+
918
+ if pixel_values is None:
919
+ raise ValueError("You have to specify pixel_values")
920
+
921
+ hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
922
+ hidden_states = self.pre_layrnorm(hidden_states)
923
+
924
+ encoder_outputs = self.encoder(
925
+ inputs_embeds=hidden_states,
926
+ output_attentions=output_attentions,
927
+ output_hidden_states=output_hidden_states,
928
+ return_dict=return_dict,
929
+ )
930
+
931
+ last_hidden_state = encoder_outputs[0]
932
+ pooled_output = last_hidden_state[:, 0, :]
933
+ pooled_output = self.post_layernorm(pooled_output)
934
+
935
+ if not return_dict:
936
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
937
+
938
+ return BaseModelOutputWithPooling(
939
+ last_hidden_state=last_hidden_state,
940
+ pooler_output=pooled_output,
941
+ hidden_states=encoder_outputs.hidden_states,
942
+ attentions=encoder_outputs.attentions,
943
+ )
944
+
945
+
946
+ @add_start_docstrings(
947
+ """The vision model from CLIP without any head or projection on top.""",
948
+ CLIP_START_DOCSTRING,
949
+ )
950
+ class CLIPVisionModel(CLIPPreTrainedModel):
951
+ config_class = CLIPVisionConfig
952
+ main_input_name = "pixel_values"
953
+ _no_split_modules = ["CLIPEncoderLayer"]
954
+
955
+ def __init__(self, config: CLIPVisionConfig):
956
+ super().__init__(config)
957
+ self.vision_model = CLIPVisionTransformer(config)
958
+ # Initialize weights and apply final processing
959
+ self.post_init()
960
+
961
+ def get_input_embeddings(self) -> nn.Module:
962
+ return self.vision_model.embeddings.patch_embedding
963
+
964
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
965
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
966
+ def forward(
967
+ self,
968
+ pixel_values: Optional[torch.FloatTensor] = None,
969
+ output_attentions: Optional[bool] = None,
970
+ output_hidden_states: Optional[bool] = None,
971
+ interpolate_pos_encoding: bool = False,
972
+ return_dict: Optional[bool] = None,
973
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
974
+ r"""
975
+ Returns:
976
+
977
+ Examples:
978
+
979
+ ```python
980
+ >>> from PIL import Image
981
+ >>> import requests
982
+ >>> from transformers import AutoProcessor, CLIPVisionModel
983
+
984
+ >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
985
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
986
+
987
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
988
+ >>> image = Image.open(requests.get(url, stream=True).raw)
989
+
990
+ >>> inputs = processor(images=image, return_tensors="pt")
991
+
992
+ >>> outputs = model(**inputs)
993
+ >>> last_hidden_state = outputs.last_hidden_state
994
+ >>> pooled_output = outputs.pooler_output # pooled CLS states
995
+ ```"""
996
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
997
+
998
+ return self.vision_model(
999
+ pixel_values=pixel_values,
1000
+ output_attentions=output_attentions,
1001
+ output_hidden_states=output_hidden_states,
1002
+ return_dict=return_dict,
1003
+ interpolate_pos_encoding=interpolate_pos_encoding,
1004
+ )
1005
+
1006
+
1007
+ @add_start_docstrings(CLIP_START_DOCSTRING)
1008
+ class CLIPModel(CLIPPreTrainedModel):
1009
+ config_class = CLIPConfig
1010
+ _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer", "CLIPVisionEmbeddings"]
1011
+
1012
+ def __init__(self, config: CLIPConfig):
1013
+ super().__init__(config)
1014
+
1015
+ if not isinstance(config.text_config, CLIPTextConfig):
1016
+ raise TypeError(
1017
+ "config.text_config is expected to be of type CLIPTextConfig but is of type"
1018
+ f" {type(config.text_config)}."
1019
+ )
1020
+
1021
+ if not isinstance(config.vision_config, CLIPVisionConfig):
1022
+ raise TypeError(
1023
+ "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
1024
+ f" {type(config.vision_config)}."
1025
+ )
1026
+
1027
+ text_config = config.text_config
1028
+ vision_config = config.vision_config
1029
+
1030
+ self.projection_dim = config.projection_dim
1031
+ self.text_embed_dim = text_config.hidden_size
1032
+ self.vision_embed_dim = vision_config.hidden_size
1033
+
1034
+ vision_model = CLIPVisionModel._from_config(vision_config)
1035
+ self.vision_model = vision_model.vision_model
1036
+
1037
+ self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
1038
+ self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
1039
+ self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
1040
+
1041
+ # Initialize weights and apply final processing
1042
+ self.post_init()
1043
+ self.reference_embedding = None
1044
+ self.cossim = nn.CosineSimilarity(dim=-1)
1045
+
1046
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
1047
+ def get_image_features(
1048
+ self,
1049
+ pixel_values: Optional[torch.FloatTensor] = None,
1050
+ output_attentions: Optional[bool] = None,
1051
+ output_hidden_states: Optional[bool] = None,
1052
+ interpolate_pos_encoding: bool = False,
1053
+ return_dict: Optional[bool] = None,
1054
+ ) -> torch.FloatTensor:
1055
+ r"""
1056
+ Returns:
1057
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
1058
+ applying the projection layer to the pooled output of [`CLIPVisionModel`].
1059
+
1060
+ Examples:
1061
+
1062
+ ```python
1063
+ >>> from PIL import Image
1064
+ >>> import requests
1065
+ >>> from transformers import AutoProcessor, CLIPModel
1066
+
1067
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
1068
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1069
+
1070
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1071
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1072
+
1073
+ >>> inputs = processor(images=image, return_tensors="pt")
1074
+
1075
+ >>> image_features = model.get_image_features(**inputs)
1076
+ ```"""
1077
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
1078
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1079
+ output_hidden_states = (
1080
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1081
+ )
1082
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1083
+
1084
+ vision_outputs = self.vision_model(
1085
+ pixel_values=pixel_values,
1086
+ output_attentions=output_attentions,
1087
+ output_hidden_states=output_hidden_states,
1088
+ interpolate_pos_encoding=interpolate_pos_encoding,
1089
+ return_dict=return_dict,
1090
+ )
1091
+
1092
+ output = vision_outputs[0]
1093
+ return output
1094
+
1095
+ @torch.no_grad()
1096
+ def set_reference_embedding(self, x):
1097
+ self.reference_embedding = self.get_image_features(x)[:, 0, :]
1098
+
1099
+ def encode_image(self, x, n_patches=64):
1100
+ image_embeds = self.get_image_features(x)
1101
+
1102
+ image_embeds = image_embeds[:, 1:, :]
1103
+ b, n, c = image_embeds.shape
1104
+ sqrt_n = int(n**0.5)
1105
+ image_embeds = image_embeds.permute(0, 2, 1).view(b, c, sqrt_n, sqrt_n)
1106
+ stride = int(sqrt_n // (n_patches ** 0.5))
1107
+ image_embeds = F.avg_pool2d(image_embeds, kernel_size=(stride, stride), stride=stride)
1108
+ image_embeds = image_embeds.view(b, c, -1).permute(0, 2, 1).contiguous()
1109
+
1110
+ return image_embeds
1111
+
1112
+ def encode_image_w_similarity(self, x, n_patches=64):
1113
+ image_embeds = self.get_image_features(x)
1114
+
1115
+ # Calculate cosine similarity with reference embedding before processing
1116
+ original_embeds = image_embeds[:, 0, :]
1117
+ cos = nn.CosineSimilarity(dim=-1)
1118
+ similarity = cos(original_embeds, self.reference_embedding)
1119
+
1120
+ image_embeds = image_embeds[:, 1:, :]
1121
+ b, n, c = image_embeds.shape
1122
+ sqrt_n = int(n**0.5)
1123
+ image_embeds = image_embeds.permute(0, 2, 1).view(b, c, sqrt_n, sqrt_n)
1124
+ stride = int(sqrt_n // (n_patches ** 0.5))
1125
+ image_embeds = F.avg_pool2d(image_embeds, kernel_size=(stride, stride), stride=stride)
1126
+ image_embeds = image_embeds.view(b, c, -1).permute(0, 2, 1).contiguous()
1127
+
1128
+ return image_embeds, similarity