| """PyTorch TraVisionLM"""
|
| import torch
|
| from transformers import PreTrainedModel, AutoModel, AutoModelForCausalLM
|
| from transformers.utils import logging, add_start_docstrings, ModelOutput
|
| from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa
|
| from dataclasses import dataclass
|
| from typing import List, Optional, Tuple, Union
|
| from torch import nn
|
| from transformers.cache_utils import Cache
|
|
|
| logger = logging.get_logger(__name__)
|
|
|
| from .configuration_travisionlm import TraVisionLMConfig
|
|
|
| _CONFIG_FOR_DOC = "TraVisionLMConfig"
|
|
|
| @dataclass
|
| class TraVisionCausalLMOutputWithPast(ModelOutput):
|
| """
|
| Base class for TraVision language model (or autoregressive) outputs.
|
|
|
| Args:
|
| loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| Language modeling loss (for next-token prediction).
|
| logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
| `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
|
|
| Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
| `past_key_values` input) to speed up sequential decoding.
|
| hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
|
|
| Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
| attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| sequence_length)`.
|
|
|
| Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| heads.
|
| image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
|
| Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
|
| sequence_length, hidden_size)`.
|
|
|
| image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
|
| """
|
| loss: Optional[torch.FloatTensor] = None
|
| logits: torch.FloatTensor = None
|
| past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None
|
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
|
|
|
|
| class TraVisionMultiModalProjector(nn.Module):
|
| """
|
| Multimodal projector that cast the image features into the same dimension space as the language model
|
| """
|
| def __init__(self, config: TraVisionLMConfig, dropout=0.1):
|
| super().__init__()
|
| self.net = nn.Sequential(
|
| nn.Linear(config.vision_config.projection_dim, 4*config.vision_config.projection_dim, bias=True),
|
| nn.GELU(),
|
| nn.Linear(4*config.vision_config.projection_dim, config.hidden_size, bias=True),
|
| nn.Dropout(dropout)
|
| )
|
|
|
| def forward(self, image_features):
|
| hidden_states = self.net(image_features).to(image_features.dtype)
|
| return hidden_states
|
|
|
|
|
| TRAVISIONLM_START_DOCSTRING = r"""
|
| This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| etc.)
|
|
|
| This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| and behavior.
|
|
|
| Parameters:
|
| config ([`TraVisionLMConfig`]):
|
| Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| load the weights associated with the model, only the configuration. Check out the
|
| [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| """
|
|
|
| @add_start_docstrings(
|
| "The bare TraVision Model outputting raw hidden-states without any specific head on top.",
|
| TRAVISIONLM_START_DOCSTRING,
|
| )
|
|
|
| class TraVisionPreTrainedModel(PreTrainedModel):
|
| config_class = TraVisionLMConfig
|
| base_model_prefix = "model"
|
| supports_gradient_checkpointing = True
|
| _no_split_modules = ["TraVisionMultiModalProjector"]
|
| _skip_keys_device_placement = "past_key_values"
|
| _supports_flash_attn_2 = True
|
| _supports_sdpa = True
|
|
|
| def _init_weights(self, module):
|
|
|
|
|
| std = (
|
| self.config.initializer_range
|
| if hasattr(self.config, "initializer_range")
|
| else self.config.text_config.initializer_range
|
| )
|
|
|
| if hasattr(module, "class_embedding"):
|
| module.class_embedding.data.normal_(mean=0.0, std=std)
|
|
|
| if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| module.weight.data.normal_(mean=0.0, std=std)
|
| if module.bias is not None:
|
| module.bias.data.zero_()
|
| elif isinstance(module, nn.Embedding):
|
| module.weight.data.normal_(mean=0.0, std=std)
|
| if module.padding_idx is not None:
|
| module.weight.data[module.padding_idx].zero_()
|
|
|
| @property
|
| def _supports_sdpa(self):
|
| """
|
| Retrieve language_model's attribute to check whether the model supports
|
| SDPA or not.
|
| """
|
| return self.language_model._supports_sdpa
|
|
|
|
|
| @add_start_docstrings(
|
| """The TraVisionLM model which consists of a vision backbone and a language model.""",
|
| TRAVISIONLM_START_DOCSTRING,
|
| )
|
| class TraVisionForCausalLM(TraVisionPreTrainedModel):
|
| def __init__(self, config: TraVisionLMConfig):
|
| super(TraVisionForCausalLM, self).__init__(config)
|
| self.vocab_size = config.vocab_size
|
| self.pad_token_id = -1 if config.pad_token_id == None else config.pad_token_id
|
| self._attn_implementation = config._attn_implementation
|
| self.gradient_checkpointing = False
|
|
|
| self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
| self.vision_projector = TraVisionMultiModalProjector(config)
|
|
|
| language_model = AutoModelForCausalLM.from_config(
|
| config=config.text_config, attn_implementation=self._attn_implementation
|
| )
|
|
|
| if language_model._tied_weights_keys is not None:
|
| self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
|
|
|
| self.language_model = language_model
|
|
|
| self.post_init()
|
|
|
|
|
| def get_input_embeddings(self):
|
| return self.language_model.get_input_embeddings()
|
|
|
|
|
| def set_input_embeddings(self, value):
|
| self.language_model.set_input_embeddings(value)
|
|
|
|
|
| def get_output_embeddings(self):
|
| return self.language_model.get_output_embeddings()
|
|
|
|
|
| def set_output_embeddings(self, new_embeddings):
|
| self.language_model.set_output_embeddings(new_embeddings)
|
|
|
|
|
| def set_decoder(self, decoder):
|
| self.language_model.set_decoder(decoder)
|
|
|
|
|
| def get_decoder(self):
|
| return self.language_model.get_decoder()
|
|
|
|
|
| def tie_weights(self):
|
| return self.language_model.tie_weights()
|
|
|
| def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
|
|
|
|
|
| model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
| self.config.text_config.vocab_size = model_embeds.num_embeddings
|
| self.config.vocab_size = model_embeds.num_embeddings
|
| self.vocab_size = model_embeds.num_embeddings
|
| return model_embeds
|
|
|
|
|
| """ !!! Two significant modifications are made to the original code:
|
| ------> 1) The pad and eos tokens are set to be the same in TraVisionProcessor. Hence, only the features corresponding to the padding mask are filtered out
|
| using the attention mask.
|
| ------> 2) The features corresponding to both the prompts (called prefixes in PaliGemma) and labels (called suffixes in PaliGemma) are added the final embedding tensor
|
| and the tokens of both the prompts and labels are applied causal attention mask. All the image tokens are attended using full-attention mask.
|
| NOTE: In the original PaliGemma implementation, only the suffix tokens are applied causal masking. Check out [PaliGemma arXiv Paper](https://arxiv.org/pdf/2407.07726)
|
| for the details.
|
| """
|
| def _merge_input_ids_with_image_features(
|
| self, image_features, inputs_embeds, input_ids, attention_mask, labels, token_type_ids, cache_position
|
| ):
|
| _, _, embed_dim = image_features.shape
|
| batch_size, sequence_length = input_ids.shape
|
| dtype, device = inputs_embeds.dtype, inputs_embeds.device
|
| min_dtype = torch.finfo(dtype).min
|
|
|
| scaled_image_features = image_features / (self.config.hidden_size**0.5)
|
| final_embedding = torch.zeros(
|
| batch_size, sequence_length, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
| )
|
|
|
| text_mask = (input_ids != self.config.image_token_index) & (attention_mask | input_ids != self.config.text_config.pad_token_id)
|
| image_mask = input_ids == self.config.image_token_index
|
| pad_mask = (attention_mask == 0) & (input_ids == self.config.text_config.pad_token_id)
|
|
|
|
|
| text_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim).to(inputs_embeds.device)
|
| pad_mask_expanded = pad_mask.unsqueeze(-1).expand(-1, -1, embed_dim).to(inputs_embeds.device)
|
|
|
| final_embedding = torch.where(text_mask_expanded, inputs_embeds, final_embedding)
|
| final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding)
|
|
|
| final_embedding = final_embedding.masked_scatter(
|
| image_mask.unsqueeze(-1).expand_as(final_embedding).to(device=final_embedding.device),
|
| scaled_image_features.to(device=final_embedding.device, dtype=final_embedding.dtype),
|
| )
|
| final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding)
|
| if attention_mask is not None:
|
| position_ids = (attention_mask.cumsum(-1)).masked_fill_((attention_mask == 0), 1)
|
| else:
|
| position_ids = None
|
|
|
| if token_type_ids is not None:
|
|
|
| target_length = cache_position[-1] + 1
|
| causal_mask = torch.full(
|
| (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
| )
|
| if sequence_length != 1:
|
| causal_mask = torch.triu(causal_mask, diagonal=1)
|
| causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
| causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1)
|
| if attention_mask is not None:
|
| causal_mask = causal_mask.clone()
|
| mask_length = attention_mask.shape[-1]
|
| padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
| causal_mask.device
|
| )
|
|
|
| causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
| token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
|
| )
|
| padding_mask = padding_mask == 0
|
| causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
| padding_mask, min_dtype
|
| )
|
|
|
| final_labels = None
|
| if labels is not None:
|
| final_labels = torch.full(
|
| (batch_size, sequence_length), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
|
| )
|
| final_labels = torch.where((attention_mask | input_ids != self.config.text_config.pad_token_id), labels, final_labels)
|
| else:
|
| causal_mask = attention_mask.unsqueeze(1).unsqueeze(2) * attention_mask.unsqueeze(1).unsqueeze(-1)
|
|
|
| causal_mask = torch.where(causal_mask == 0, min_dtype, 0).to(dtype)
|
| final_labels = None
|
|
|
| return final_embedding, causal_mask, final_labels, position_ids
|
|
|
|
|
| def forward(
|
| self,
|
| input_ids: torch.LongTensor = None,
|
| pixel_values: torch.FloatTensor = None,
|
| attention_mask: Optional[torch.Tensor] = None,
|
| position_ids: Optional[torch.LongTensor] = None,
|
| past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
|
| token_type_ids: Optional[torch.LongTensor] = None,
|
| cache_position: Optional[torch.LongTensor] = None,
|
| inputs_embeds: Optional[torch.FloatTensor] = None,
|
| labels: Optional[torch.LongTensor] = None,
|
| use_cache: Optional[bool] = None,
|
| output_attentions: Optional[bool] = None,
|
| output_hidden_states: Optional[bool] = None,
|
| return_dict: Optional[bool] = None,
|
| ) -> Union[Tuple, TraVisionCausalLMOutputWithPast]:
|
|
|
| if labels is not None:
|
| use_cache = False
|
|
|
| if input_ids is not None and inputs_embeds is not None:
|
| raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| elif input_ids is not None:
|
| self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
| input_shape = input_ids.size()
|
| input_ids = input_ids.view(-1, input_shape[-1])
|
| batch_size = input_ids.shape[0]
|
| elif inputs_embeds is not None:
|
| input_shape = inputs_embeds.size()[:-1]
|
| batch_size = inputs_embeds.shape[0]
|
| else:
|
| raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| output_hidden_states = (
|
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| )
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
| if past_key_values is None:
|
| past_length = 0
|
| past_key_values = tuple([None] * len(self.language_model.transformer.h))
|
| else:
|
| past_length = past_key_values[0][0].size(-2)
|
| if position_ids is None:
|
| position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=input_ids.device if input_ids is not None else inputs_embeds.device)
|
| position_ids = position_ids.unsqueeze(0)
|
|
|
|
|
| input_attention_mask = attention_mask
|
|
|
| if inputs_embeds is None:
|
|
|
| inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
|
|
|
| if pixel_values is not None and inputs_embeds.shape[1] != 1:
|
|
|
| position_ids_mask = torch.where(input_ids != self.config.image_token_index, position_ids, 1)
|
|
|
| position_ids_mask[:, :-1] = torch.where(input_ids[:, :-1] != 0, position_ids_mask[:, :-1], 1)
|
|
|
| first_position_embed_locs = torch.sum(position_ids_mask == 1, dim=1)
|
|
|
| position_ids_mask.sub_(first_position_embed_locs[:, None])
|
|
|
| position_emb_ids = torch.where(position_ids_mask >= 0, position_ids_mask, 1)
|
|
|
| position_embeds = self.language_model.transformer.wpe(position_emb_ids)
|
| else:
|
|
|
| pos_emb_ind = position_ids.view(batch_size, -1)
|
| position_embeds = self.language_model.transformer.wpe(pos_emb_ind)
|
|
|
|
|
| hidden_states = inputs_embeds + position_embeds
|
|
|
|
|
| if pixel_values is not None and input_ids.shape[1] != 1:
|
|
|
| if pixel_values.dim() == 3:
|
| pixel_values = pixel_values.unsqueeze(dim=0)
|
|
|
| image_outputs = self.vision_tower(pixel_values.to(inputs_embeds.dtype))
|
| selected_image_feature = image_outputs.last_hidden_state
|
| image_features = self.vision_projector(selected_image_feature)
|
|
|
| if cache_position is None:
|
| cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
|
| hidden_states, attention_mask, labels, _ = self._merge_input_ids_with_image_features(
|
| image_features, hidden_states, input_ids, attention_mask, labels, token_type_ids, cache_position
|
| )
|
|
|
| else:
|
|
|
|
|
| if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
|
|
|
| _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False
|
| if attention_mask is not None:
|
| attention_mask = attention_mask.view(batch_size, -1)
|
| if self._attn_implementation == "flash_attention_2":
|
| attention_mask = attention_mask if 0 in attention_mask else None
|
| elif _use_sdpa:
|
| attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
| attention_mask=attention_mask,
|
| input_shape=(batch_size, input_shape[-1]),
|
| inputs_embeds=inputs_embeds,
|
| past_key_values_length=past_length,
|
| )
|
| else:
|
|
|
|
|
|
|
|
|
|
|
| attention_mask = attention_mask[:, None, None, :]
|
|
|
|
|
|
|
|
|
|
|
|
|
| attention_mask = attention_mask.to(dtype=self.dtype)
|
| attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
|
|
| if attention_mask is not None:
|
| attention_mask = attention_mask.to(inputs_embeds.dtype)
|
|
|
| hidden_states = self.language_model.transformer.drop(hidden_states)
|
| output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
|
|
| presents = () if use_cache else None
|
| all_self_attentions = () if output_attentions else None
|
| all_hidden_states = () if output_hidden_states else None
|
| for i, (block, layer_past) in enumerate(zip(self.language_model.transformer.h, past_key_values)):
|
| if output_hidden_states:
|
| all_hidden_states = all_hidden_states + (hidden_states,)
|
| outputs = block(
|
| hidden_states,
|
| layer_past=layer_past,
|
| attention_mask=attention_mask,
|
| use_cache=use_cache,
|
| output_attentions=output_attentions,
|
| )
|
| hidden_states = outputs[0]
|
| if use_cache is True:
|
| presents = presents + (outputs[1],)
|
|
|
| if output_attentions:
|
| all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
|
|
| hidden_states = self.language_model.transformer.ln_f(hidden_states)
|
|
|
| hidden_states = hidden_states.view(output_shape)
|
|
|
| if output_hidden_states:
|
| all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
| logits = self.language_model.lm_head(hidden_states)
|
| logits = logits.float()
|
| loss = None
|
| if labels is not None:
|
| shift_logits = logits[..., :-1, :]
|
| shift_labels = labels[..., 1:]
|
| if input_attention_mask is not None:
|
|
|
| shift_attention_mask = input_attention_mask[..., 1:]
|
| shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
|
| shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
|
| else:
|
| shift_logits = shift_logits.contiguous()
|
| shift_labels = shift_labels.contiguous()
|
|
|
| loss_fct = nn.CrossEntropyLoss()
|
|
|
| flat_logits = shift_logits.view(-1, self.config.vocab_size)
|
| flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
| loss = loss_fct(flat_logits, flat_labels)
|
| if not return_dict:
|
| output = (logits, presents, all_hidden_states, all_self_attentions)
|
| return (loss,) + output if loss is not None else output
|
|
|
| return TraVisionCausalLMOutputWithPast(
|
| loss=loss,
|
| logits=logits,
|
| past_key_values=presents,
|
| hidden_states=all_hidden_states,
|
| attentions=all_self_attentions,
|
| )
|
|
|
| def prepare_inputs_for_generation(
|
| self,
|
| input_ids,
|
| past_key_values=None,
|
| inputs_embeds=None,
|
| cache_position=None,
|
| position_ids=None,
|
| pixel_values=None,
|
| attention_mask=None,
|
| token_type_ids=None,
|
| use_cache=True,
|
| **kwargs,
|
| ):
|
|
|
| if attention_mask is not None and position_ids is None:
|
| if past_key_values:
|
| position_ids_mask = (input_ids != self.config.image_token_index)
|
| position_ids_mask[:, :-1] &= (input_ids[:, :-1] != self.config.text_config.pad_token_id)
|
| last_index = position_ids_mask.sum(dim=1) - 1
|
| position_ids = torch.stack([torch.arange(start, start+cache_position.shape[0], device=input_ids.device) for start in last_index])
|
|
|
|
|
|
|
|
|
| if past_key_values is not None:
|
| if inputs_embeds is not None:
|
| input_ids = input_ids[:, -cache_position.shape[0] :]
|
| elif input_ids.shape[1] != cache_position.shape[0]:
|
| input_ids = input_ids[:, cache_position]
|
|
|
|
|
| if inputs_embeds is not None and cache_position[0] == 0:
|
| model_inputs = {"inputs_embeds": inputs_embeds}
|
| else:
|
| model_inputs = {"input_ids": input_ids.contiguous()}
|
|
|
| model_inputs.update(
|
| {
|
| "position_ids": position_ids,
|
| "past_key_values": past_key_values,
|
| "cache_position": cache_position,
|
| "use_cache": use_cache,
|
| "attention_mask": attention_mask,
|
| "pixel_values": pixel_values,
|
| "token_type_ids": token_type_ids,
|
| }
|
| )
|
| return model_inputs |