| import math |
| from typing import Optional, Tuple |
| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel, AutoModelForSeq2SeqLM, SiglipVisionModel |
| from transformers.modeling_outputs import Seq2SeqLMOutput |
| from .config import LiteVit5Config |
|
|
|
|
| class LiteVit5ForConditionalGeneration(PreTrainedModel): |
| """ |
| LiteVit5 model for vision-to-text generation tasks. |
| Combines SigLIP vision encoder with T5 seq2seq decoder for image-to-text tasks. |
| """ |
| |
| config_class = LiteVit5Config |
| base_model_prefix = "litevit5" |
| |
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
| |
| |
| self.vision_model = SiglipVisionModel.from_pretrained( |
| "google/siglip2-base-patch16-512", |
| dtype=torch.float16 |
| ) |
| self.vision_model.eval() |
| for param in self.vision_model.parameters(): |
| param.requires_grad = False |
| |
| |
| seq2seq_model = AutoModelForSeq2SeqLM.from_pretrained( |
| "Salesforce/codet5-base", |
| dtype=torch.float16 |
| ) |
| self.seq2seq_decoder = seq2seq_model.decoder |
| self.seq2seq_lm_head = seq2seq_model.lm_head |
| self._shift_right = seq2seq_model._shift_right |
| |
| |
| self.downsampler = nn.Conv2d(768, 768, kernel_size=2, stride=2, bias=False, dtype=torch.float16) |
| self.fuse = nn.Linear(768 * 2, 768).half() |
| self.pos_embedding = nn.Parameter(torch.zeros(1, 1024, 768, dtype=torch.float16), requires_grad=True) |
| self.linear_projection = nn.Linear(768, 768).half() |
| |
| self.post_init() |
| |
| def get_encoder(self): |
| """Return the vision encoder for the model.""" |
| return self.vision_model |
| |
| def get_decoder(self): |
| """Return the seq2seq decoder.""" |
| return self.seq2seq_decoder |
| |
| def _encode_vision(self, pixel_values: torch.Tensor) -> torch.Tensor: |
| """ |
| Encode image inputs into vision features. |
| |
| Args: |
| pixel_values: Input images of shape [B*5, 3, 512, 512] (5 views per sample) |
| |
| Returns: |
| Encoded vision features of shape [B, 1024, 768] |
| """ |
| |
| pixel_values = pixel_values.half() |
| |
| batch_size = pixel_values.size(0) // 5 |
| scale = 5 |
| num_patches = 32 |
| |
| |
| with torch.no_grad(): |
| vision_model_outputs = self.vision_model(pixel_values=pixel_values) |
| vision_hidden_states = vision_model_outputs.last_hidden_state |
| |
| |
| vision_hidden_states = vision_hidden_states.view(batch_size, scale, *vision_hidden_states.shape[1:]) |
| |
| |
| quarters = vision_hidden_states[:, :4] |
| quarters = quarters.view(batch_size, 4, num_patches, num_patches, -1) |
| |
| |
| upper = torch.cat([quarters[:, 0], quarters[:, 1]], dim=2) |
| lower = torch.cat([quarters[:, 2], quarters[:, 3]], dim=2) |
| pooled_image = torch.cat([upper, lower], dim=1) |
| pooled_image = pooled_image.permute(0, 3, 1, 2) |
| |
| |
| pooled32 = self.downsampler(pooled_image) |
| pooled_tok = pooled32.flatten(2).transpose(1, 2) |
| |
| |
| full_image = vision_hidden_states[:, 4] |
| |
| |
| concat = torch.cat([pooled_tok, full_image], dim=-1) |
| fused = self.fuse(concat) |
| |
| |
| fused = fused + self.pos_embedding |
| vision_hidden_states = self.linear_projection(fused) |
| |
| return vision_hidden_states |
| |
| def forward( |
| self, |
| pixel_values: torch.Tensor, |
| input_ids: Optional[torch.LongTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| decoder_input_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Tuple] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| **kwargs |
| ) -> Seq2SeqLMOutput: |
| """ |
| Forward pass for the model. |
| |
| Args: |
| pixel_values: Vision input images |
| input_ids: Decoder input token IDs |
| labels: Target token IDs for training |
| decoder_input_ids: Decoder input IDs (used during generation) |
| past_key_values: Cached key values for efficient generation |
| attention_mask: Attention mask for decoder inputs |
| |
| Returns: |
| Seq2SeqLMOutput with loss, logits, and generation-related outputs |
| """ |
| |
| encoder_hidden_states = self._encode_vision(pixel_values) |
| |
| |
| if decoder_input_ids is None and input_ids is None: |
| decoder_input_ids = self._get_decoder_start_token_id() |
| decoder_input_ids = torch.full( |
| (pixel_values.shape[0] // 5, 1), |
| decoder_input_ids, |
| dtype=torch.long, |
| device=pixel_values.device |
| ) |
| |
| if decoder_input_ids is None and input_ids is not None: |
| decoder_input_ids = self._shift_right(input_ids) |
| |
| |
| decoder_outputs = self.seq2seq_decoder( |
| input_ids=decoder_input_ids, |
| encoder_hidden_states=encoder_hidden_states, |
| past_key_values=past_key_values, |
| attention_mask=attention_mask, |
| ) |
| |
| sequence_output = decoder_outputs[0] |
| lm_logits = self.seq2seq_lm_head(sequence_output) |
| |
| loss = None |
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
| labels = labels.to(lm_logits.device) |
| loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) |
| |
| return Seq2SeqLMOutput( |
| loss=loss, |
| logits=lm_logits, |
| past_key_values=decoder_outputs.past_key_values, |
| decoder_hidden_states=decoder_outputs.hidden_states, |
| decoder_attentions=decoder_outputs.attentions, |
| cross_attentions=decoder_outputs.cross_attentions, |
| ) |
| |
| def prepare_inputs_for_generation( |
| self, |
| decoder_input_ids, |
| past_key_values=None, |
| attention_mask=None, |
| use_cache=None, |
| encoder_outputs=None, |
| **kwargs |
| ): |
| """Prepare inputs for generation.""" |
| |
| if past_key_values is not None: |
| decoder_input_ids = decoder_input_ids[:, -1:] |
|
|
| return { |
| "input_ids": None, |
| "encoder_outputs": encoder_outputs, |
| "past_key_values": past_key_values, |
| "decoder_input_ids": decoder_input_ids, |
| "attention_mask": attention_mask, |
| "use_cache": use_cache, |
| } |
|
|
| def _prepare_encoder_decoder_kwargs_for_generation( |
| self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None |
| ): |
| """Encode pixel values to get encoder outputs.""" |
| |
| if "encoder_outputs" not in model_kwargs: |
| encoder_outputs = self._encode_vision(inputs_tensor) |
| model_kwargs["encoder_outputs"] = (encoder_outputs,) |
| |
| return model_kwargs |
|
|
| def generate( |
| self, |
| pixel_values: torch.Tensor, |
| max_length: int = 1024, |
| num_beams: int = 1, |
| temperature: float = 1.0, |
| do_sample: bool = False, |
| **kwargs |
| ) -> torch.LongTensor: |
| """ |
| Generate text from image inputs. |
| |
| Args: |
| pixel_values: Input images [B*5, 3, 512, 512] |
| max_length: Maximum generation length |
| num_beams: Number of beams for beam search (1 = greedy) TODO: Not implemented |
| temperature: Sampling temperature |
| do_sample: Whether to use sampling |
| |
| Returns: |
| Generated token sequences |
| """ |
| |
| encoder_hidden_states = self._encode_vision(pixel_values) |
| batch_size = pixel_values.shape[0] // 5 |
| |
| |
| decoder_input_ids = torch.full( |
| (batch_size, 1), |
| self._get_decoder_start_token_id(), |
| dtype=torch.long, |
| device=pixel_values.device |
| ) |
| |
| generated_tokens = [] |
| past_key_values = None |
| |
| for step in range(max_length): |
| with torch.no_grad(): |
| |
| decoder_outputs = self.seq2seq_decoder( |
| input_ids=decoder_input_ids if past_key_values is None else decoder_input_ids[:, -1:], |
| encoder_hidden_states=encoder_hidden_states, |
| past_key_values=past_key_values, |
| use_cache=True, |
| ) |
| |
| past_key_values = decoder_outputs.past_key_values |
| |
| |
| hidden_states = decoder_outputs[0][:, -1:, :] |
| lm_logits = self.seq2seq_lm_head(hidden_states) |
| |
| |
| if temperature != 1.0: |
| lm_logits = lm_logits / temperature |
| |
| |
| if do_sample: |
| probs = torch.softmax(lm_logits[:, -1, :], dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| else: |
| next_token = torch.argmax(lm_logits[:, -1, :], dim=-1, keepdim=True) |
| |
| |
| generated_tokens.append(next_token) |
| decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=1) |
| |
| |
| if (next_token == self.config.eos_token_id).all(): |
| break |
| |
| return decoder_input_ids |
| |
| def _get_decoder_start_token_id(self) -> int: |
| """Get decoder start token ID.""" |
| return self.config.decoder_start_token_id or self.config.pad_token_id |
|
|