| |
| |
| |
| |
| |
|
|
| import torch.nn as nn |
|
|
| from transformers.modeling_outputs import BaseModelOutputWithPooling |
| from typing import Optional, Tuple, Union |
|
|
| from .multi_backbone_channel_concatenation_encoder import MultiBackboneChannelConcatenationVisionTower |
| from .configuration_multi_backbone_channel_concatentation_model import MultiBackboneChannelConcatenationVisionModelConfig |
|
|
|
|
| class MultiBackboneChannelConcatenationVisionModel(nn.Module): |
|
|
| """ |
| A vision model wrapper that concatenates channels from multiple backbones. |
| |
| Args: |
| config (MultiBackboneChannelConcatenationVisionModelConfig): The configuration for the model. |
| |
| Attributes: |
| vision_model (MultiBackboneChannelConcatenationVisionTower): The vision tower that performs the channel concatenation. |
| |
| Notes: |
| **The class is not inherited from the PreTrainedModel in transformers** |
| |
| """ |
|
|
| config_class = MultiBackboneChannelConcatenationVisionModelConfig |
| main_input_name = "pixel_values" |
|
|
| def __init__(self, config: MultiBackboneChannelConcatenationVisionModelConfig, raw_config): |
| super().__init__() |
|
|
| self.vision_model = MultiBackboneChannelConcatenationVisionTower( |
| vision_tower=config.vision_tower, |
| args=config, |
| grid_size=config.grid_size, |
| convnext_img_size=config.convnext_img_size, |
| normalize_type=config.normalize_type, |
| raw_config=raw_config |
| ) |
|
|
|
|
| def get_input_embeddings(self): |
| |
| return self.vision_model.vision_towers[0].get_input_embeddings() |
|
|
| def forward( |
| self, |
| pixel_values, |
| return_dict: Optional[bool] = True, |
| output_hidden_states: Optional[bool] = False, |
| ) -> Union[Tuple, BaseModelOutputWithPooling]: |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| assert return_dict is True, "We only support return_dict" |
| assert output_hidden_states is False, "We do not support output_hidden_states" |
|
|
| features = self.vision_model(pixel_values) |
|
|
| |
| return BaseModelOutputWithPooling( |
| last_hidden_state=features, |
| pooler_output=None, |
| hidden_states=None, |
| attentions=None, |
| ) |
|
|
| @property |
| def dummy_feature(self): |
| return self.vision_model.dummy_feature |
|
|
| @property |
| def dtype(self): |
| return self.vision_model.dtype |
|
|
| @property |
| def device(self): |
| return self.vision_model.device |
|
|
| @property |
| def config(self): |
| return self.vision_model.config |
|
|
| @property |
| def hidden_size(self): |
| return self.vision_model.hidden_size |
|
|
| @property |
| def num_patches(self): |
| return self.vision_model.num_patches |
|
|