| from typing import Optional, List |
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
| from transformers import PreTrainedModel, AutoModel, AutoModelForCausalLM, Qwen2_5_VLForConditionalGeneration |
| from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast |
| from typing import List, Optional, Tuple, Union, Dict |
| import torch |
| from torch.nn import CrossEntropyLoss |
|
|
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from transformers.generation.utils import GenerateOutput |
| from .configuration_qqmm import QQMMConfig |
|
|
| def _prepare_4d_causal_attention_mask_with_cache_position( |
| attention_mask: torch.Tensor, |
| sequence_length: int, |
| target_length: int, |
| dtype: torch.dtype, |
| device: torch.device, |
| min_dtype: float, |
| cache_position: torch.Tensor, |
| batch_size: int, |
| ): |
| """ |
| Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
| `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
| |
| Args: |
| attention_mask (`torch.Tensor`): |
| A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. |
| sequence_length (`int`): |
| The sequence length being processed. |
| target_length (`int`): |
| The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. |
| dtype (`torch.dtype`): |
| The dtype to use for the 4D attention mask. |
| device (`torch.device`): |
| The device to plcae the 4D attention mask on. |
| min_dtype (`float`): |
| The minimum value representable with the dtype `dtype`. |
| cache_position (`torch.Tensor`): |
| Indices depicting the position of the input sequence tokens in the sequence. |
| batch_size (`torch.Tensor`): |
| Batch size. |
| """ |
| if attention_mask is not None and attention_mask.dim() == 4: |
| |
| causal_mask = attention_mask |
| else: |
| causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) |
| if sequence_length != 1: |
| causal_mask = torch.triu(causal_mask, diagonal=1) |
| causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
| causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
| if attention_mask is not None: |
| causal_mask = causal_mask.clone() |
| mask_length = attention_mask.shape[-1] |
| padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] |
| padding_mask = padding_mask == 0 |
| causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
| padding_mask, min_dtype |
| ) |
|
|
| return causal_mask |
|
|
| def padcat_sequences(sequences, value=0, pad_side='right'): |
| if all(s is None for s in sequences): |
| return None |
| max_l = max(s.size(1) for s in sequences) |
| sequences_ = [] |
| for seq in sequences: |
| if seq.size(1) != max_l: |
| pad_len = max_l - seq.size(1) |
| pad_len = (0, pad_len) if pad_side == 'right' else (pad_len, 0) |
| seq = F.pad(seq, pad_len, value=value) |
| sequences_.append(seq) |
|
|
| sequences = torch.cat(sequences_) |
|
|
| return sequences |
|
|
| class QQMMPreTrainedModel(PreTrainedModel): |
| config_class = QQMMConfig |
| supports_gradient_checkpointing = True |
| _skip_keys_device_placement = "past_key_values" |
| _supports_cache_class = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
|
|
| class QQMMForCausalLM(QQMMPreTrainedModel): |
| |
| def __init__(self, |
| config, |
| qwen2_5_vl_model=None): |
| |
| super().__init__(config) |
| if qwen2_5_vl_model is None: |
| kwargs_ = {} |
| if config._attn_implementation_internal is not None: |
| kwargs_['attn_implementation'] = config._attn_implementation_internal |
| model = Qwen2_5_VLForConditionalGeneration(config.model_config) |
| |
|
|
| else: |
| model = qwen2_5_vl_model |
| self.qwen2_5_vl_model = model |
| self.post_init() |
| |
| def make_diy_mask(self, input_ids, attention_mask, embed_token_id, im_start_id, im_end_id): |
| if len(attention_mask.shape) == 2: |
| sequence_length = attention_mask.shape[1] |
| target_length = attention_mask.shape[1] |
| dtype = torch.bfloat16 |
| device = input_ids.device |
| min_dtype = torch.finfo(dtype).min |
| cache_position = torch.arange(0, sequence_length, device=attention_mask.device) |
| attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( |
| attention_mask, |
| sequence_length=sequence_length, |
| target_length=target_length, |
| dtype=dtype, |
| device=device, |
| min_dtype=min_dtype, |
| cache_position=cache_position, |
| batch_size=attention_mask.shape[0], |
| ) |
| else: |
| dtype = torch.bfloat16 |
| min_dtype = torch.finfo(dtype).min |
| mask = input_ids == embed_token_id |
| embed_index = torch.argmax(mask.float(), dim=1) |
| embed_index[embed_index==0] = input_ids.shape[1] |
| embed_index = embed_index.view(-1, ) |
| mask = input_ids == im_start_id |
| im_start_index_tmp = torch.argmax(mask.float(), dim=1).view(-1, 1) |
| mask = torch.scatter(mask, dim=1, index=im_start_index_tmp, value=False) |
| im_start_index = torch.argmax(mask.float(), dim=1).view(-1, ) |
| mask = input_ids == im_end_id |
| im_end_index_tmp = torch.argmax(mask.float(), dim=1).view(-1, 1) |
| mask = torch.scatter(mask, dim=1, index=im_end_index_tmp, value=False) |
| im_end_index = torch.argmax(mask.float(), dim=1).view(-1, ) |
| for b in range(attention_mask.shape[0]): |
| attention_mask[b, 0, embed_index[b]+1:, im_start_index[b]:im_end_index[b]+2] = min_dtype |
| return attention_mask |
| |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| pixel_values: Optional[torch.Tensor] = None, |
| pixel_values_videos: Optional[torch.FloatTensor] = None, |
| image_grid_thw: Optional[torch.LongTensor] = None, |
| video_grid_thw: Optional[torch.LongTensor] = None, |
| rope_deltas: Optional[torch.LongTensor] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| second_per_grid_ts: Optional[torch.Tensor] = None, |
| embed_token_id: Optional[int] = None, |
| return_emb: Optional[bool] = False, |
| cal_loss: Optional[bool] = False |
| ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: |
|
|
| if pixel_values is not None and pixel_values.shape[0] == 0: |
| pixel_values = None |
| image_grid_thw = None |
| output_attentions = output_attentions if output_attentions is not None else self.qwen2_5_vl_model.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.qwen2_5_vl_model.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.qwen2_5_vl_model.config.use_return_dict |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.qwen2_5_vl_model.model.embed_tokens(input_ids) |
| if pixel_values is not None: |
| pixel_values = pixel_values.type(self.qwen2_5_vl_model.visual.dtype) |
| image_embeds = self.qwen2_5_vl_model.visual(pixel_values, grid_thw=image_grid_thw) |
| n_image_tokens = (input_ids == self.qwen2_5_vl_model.config.image_token_id).sum().item() |
| n_image_features = image_embeds.shape[0] |
| if n_image_tokens != n_image_features: |
| raise ValueError( |
| f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
| ) |
|
|
| mask = input_ids == self.qwen2_5_vl_model.config.image_token_id |
| mask_unsqueezed = mask.unsqueeze(-1) |
| mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) |
| image_mask = mask_expanded.to(inputs_embeds.device) |
|
|
| image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
|
|
| if pixel_values_videos is not None: |
| pixel_values_videos = pixel_values_videos.type(self.qwen2_5_vl_model.visual.dtype) |
| video_embeds = self.qwen2_5_vl_model.visual(pixel_values_videos, grid_thw=video_grid_thw) |
| n_video_tokens = (input_ids == self.qwen2_5_vl_model.config.video_token_id).sum().item() |
| n_video_features = video_embeds.shape[0] |
| if n_video_tokens != n_video_features: |
| raise ValueError( |
| f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" |
| ) |
|
|
| mask = input_ids == self.qwen2_5_vl_model.config.video_token_id |
| mask_unsqueezed = mask.unsqueeze(-1) |
| mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) |
| video_mask = mask_expanded.to(inputs_embeds.device) |
|
|
| video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) |
|
|
| if attention_mask is not None: |
| attention_mask = attention_mask.to(inputs_embeds.device) |
|
|
| |
| if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): |
| |
| if ( |
| (cache_position is not None and cache_position[0] == 0) |
| or self.qwen2_5_vl_model.rope_deltas is None |
| or (past_key_values is None or past_key_values.get_seq_length() == 0) |
| ): |
| position_ids, rope_deltas = self.qwen2_5_vl_model.get_rope_index( |
| input_ids, |
| image_grid_thw, |
| video_grid_thw, |
| second_per_grid_ts, |
| attention_mask, |
| ) |
| self.rope_deltas = rope_deltas |
| |
| else: |
| batch_size, seq_length, _ = inputs_embeds.shape |
| delta = ( |
| (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) |
| if cache_position is not None |
| else 0 |
| ) |
| position_ids = torch.arange(seq_length, device=inputs_embeds.device) |
| position_ids = position_ids.view(1, -1).expand(batch_size, -1) |
| if cache_position is not None: |
| delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) |
| position_ids = position_ids.add(delta) |
| position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) |
|
|
| outputs = self.qwen2_5_vl_model.model( |
| input_ids=None, |
| position_ids=position_ids, |
| attention_mask=attention_mask, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| ) |
|
|
| hidden_states = outputs[0] |
| |
| if labels is not None: |
| mask = labels == embed_token_id |
| labels[mask] = -100 |
| |
| logits = self.qwen2_5_vl_model.lm_head(hidden_states) |
|
|
| if return_emb: |
| assert labels is not None, 'labels must be provided to obtain embed' |
| hidden_index = torch.argmax(mask.float(), dim=1) |
| hidden_index[hidden_index==0] = labels.shape[1] |
| hidden_states = torch.gather(hidden_states, dim=1, index=(hidden_index-1).view(hidden_index.shape[0], 1, 1).repeat(1, 1, hidden_states.shape[-1])) |
| emb = hidden_states[:, 0, :].contiguous() |
| else: |
| emb = None |
|
|
| loss = None |
| if labels is not None and cal_loss: |
| |
| logits = logits.float() |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| if (shift_labels < 0).all().item(): |
| loss = 0.0 |
| else: |
| |
| loss_fct = CrossEntropyLoss() |
| shift_logits = shift_logits.view(-1, self.qwen2_5_vl_model.config.vocab_size) |
| shift_labels = shift_labels.view(-1) |
| |
| shift_labels = shift_labels.to(shift_logits.device) |
| loss = loss_fct(shift_logits, shift_labels) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
| else: |
| outputs = Qwen2_5_VLCausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| rope_deltas=self.qwen2_5_vl_model.rope_deltas, |
| ) |
| if emb is not None: |
| outputs['emb'] = emb |
| |
| return outputs |
| |
|
|
| @torch.no_grad() |
| def generate(self, input_ids, *args, **kwargs) -> Union[GenerateOutput, torch.LongTensor]: |
| return self.qwen2_5_vl_model.generate(input_ids, *args, **kwargs) |
| |
| def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): |
| super().gradient_checkpointing_enable(gradient_checkpointing_kwargs) |
| self.qwen2_5_vl_model.model.enable_input_require_grads() |
|
|
| def get_input_embeddings(self): |
| return self.qwen2_5_vl_model.model.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value): |
| self.qwen2_5_vl_model.model.set_input_embeddings(value) |
|
|
| def get_output_embeddings(self): |
| return self.qwen2_5_vl_model.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.qwen2_5_vl_model.lm_head = new_embeddings |
|
|
| def set_decoder(self, decoder): |
| self.qwen2_5_vl_model.model = decoder |
|
|
| def get_decoder(self): |
| return self.qwen2_5_vl_model.model |