| from dataclasses import dataclass |
| from typing import List, Optional, Tuple, Union |
| import torch |
| import torch.utils.checkpoint |
| from torch import nn |
| from transformers import PreTrainedModel |
| from transformers.activations import ACT2FN |
| from transformers.cache_utils import Cache |
| from transformers.modeling_outputs import ModelOutput |
| from transformers.utils import ( |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| logging, |
| replace_return_docstrings, |
| ) |
| from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, AutoConfig |
| from .configuration_wemm import WeMMConfig |
| from .vision_model import Idefics2VisionTransformer |
| from .connector import Idefics2Connector |
| from .image_processor import Idefics2ImageProcessor |
| from .modeling_downsampler import DownsamplerModel |
| from .modeling_projector import ProjectorModel |
| from .modeling_internlm2 import InternLM2ForCausalLM |
| from .tokenization_internlm2 import InternLM2Tokenizer |
| from peft import PeftModel |
| from peft import PeftConfig |
| import os |
| from PIL import Image |
| import numpy as np |
| IMAGE_TOKEN_INDEX = -200 |
| DEFAULT_IMAGE_TOKEN = "<image>" |
| IGNORE_INDEX = -100 |
| from transformers import StoppingCriteria |
| from transformers import PreTrainedTokenizerFast, StoppingCriteriaList |
| import torch.nn.functional as F |
| class StopWordStoppingCriteria(StoppingCriteria): |
| """StopWord stopping criteria.""" |
| def __init__(self, tokenizer, stop_word): |
| self.tokenizer = tokenizer |
| self.stop_word = stop_word |
| self.length = len(self.stop_word) |
| def __call__(self, input_ids, *args, **kwargs) -> bool: |
| cur_text = self.tokenizer.decode(input_ids[0]) |
| cur_text = cur_text.replace('\r', '').replace('\n', '') |
| return cur_text[-self.length:] == self.stop_word |
| def get_stop_criteria( |
| tokenizer, |
| stop_words=[], |
| ): |
| stop_criteria = StoppingCriteriaList() |
| for word in stop_words: |
| stop_criteria.append(StopWordStoppingCriteria(tokenizer, word)) |
| return stop_criteria |
| def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): |
| assert embed_dim % 2 == 0 |
| |
| emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) |
| emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) |
| emb = np.concatenate([emb_h, emb_w], axis=-1) |
| return emb |
| def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
| """ |
| embed_dim: output dimension for each position |
| pos: a list of positions to be encoded: size (M,) |
| out: (M, D) |
| """ |
| assert embed_dim % 2 == 0 |
| omega = np.arange(embed_dim // 2, dtype=np.float) |
| omega /= embed_dim / 2. |
| omega = 1. / 10000**omega |
| pos = np.squeeze(pos) |
| out = np.einsum('hw,d->hwd', pos, omega) |
| emb_sin = np.sin(out) |
| emb_cos = np.cos(out) |
| emb = np.concatenate([emb_sin, emb_cos], axis=-1) |
| return emb |
| |
| |
| |
| |
| |
| def get_2d_sincos_pos_embed(embed_dim, grid_size_h, grid_size_w, cls_token=False): |
| """ |
| grid_size: int of the grid height and width |
| return: |
| pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) |
| """ |
| grid_h = np.arange(grid_size_h, dtype=np.float32) |
| grid_w = np.arange(grid_size_w, dtype=np.float32) |
| grid = np.meshgrid(grid_w, grid_h) |
| grid = np.stack(grid, axis=0) |
| grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) |
| pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
| if cls_token: |
| pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) |
| return pos_embed |
| def recover_navit_subimages_with_pos_emb( |
| sub_image_hidden_states, |
| attention_mask, |
| num_sub_images, |
| visual_embedding_group, |
| pos_hidden_size, |
| thumbnail_only=False): |
| _slice = int(np.sqrt(num_sub_images)) |
| N, L, D = sub_image_hidden_states.shape |
| _, H, W = attention_mask.shape |
| if thumbnail_only is True: |
| num_sub_images += 1 |
| sub_image_hidden_states = sub_image_hidden_states.reshape(-1, num_sub_images, H, W, D) |
| attention_mask = attention_mask.reshape(-1, num_sub_images, H, W) |
| if thumbnail_only is True: |
| sub_image_hidden_states = sub_image_hidden_states[:, -1:, :, :, :] |
| attention_mask = attention_mask[:, -1:, :, :] |
| _slice = 1 |
| def _infer_ori_image_patch_shape(sub_image_attention_mask): |
| ind_h, ind_w = torch.where(sub_image_attention_mask > 0) |
| return torch.max(ind_h) + 1, torch.max(ind_w) + 1 |
| def _pad_to_same(image_hidden): |
| _dtype = image_hidden.dtype |
| visual_downsample_stride = int(np.sqrt(visual_embedding_group)) |
| full_h, full_w, _ = image_hidden.shape |
| target_h, target_w = H * _slice, W * _slice |
| |
| to_pad_h = (target_h - full_h) + ( |
| visual_downsample_stride - target_h % visual_downsample_stride) % visual_downsample_stride |
| to_pad_w = (target_w - full_w) + ( |
| visual_downsample_stride - target_w % visual_downsample_stride) % visual_downsample_stride |
| |
| image_hidden = image_hidden.permute(2, 0, 1).unsqueeze(0) |
| pad_size = (0, to_pad_w, 0, to_pad_h) |
| |
| image_hidden = F.pad(image_hidden.to(torch.float32), pad_size, mode='replicate').squeeze(0).permute(1, 2, 0) |
| return image_hidden.to(_dtype) |
| image_hidden_states = list() |
| valid_image_token = list() |
| image_2d_pos = list() |
| for batch_id in range(len(sub_image_hidden_states)): |
| ori_h, ori_w = _infer_ori_image_patch_shape(attention_mask[batch_id][0]) |
| full_h, full_w = ori_h * _slice, ori_w * _slice |
| |
| this_image_hidden = sub_image_hidden_states[batch_id][:, 0:ori_h, 0:ori_w, :] \ |
| .view(_slice, _slice, ori_h, ori_w, D).permute(0, 2, 1, 3, 4).contiguous().view(full_h, full_w, D) |
| pos_emb = get_2d_sincos_pos_embed(pos_hidden_size, grid_size_h=full_h, |
| grid_size_w=full_w) |
| pos_emb = torch.tensor(pos_emb, dtype=this_image_hidden.dtype, device=this_image_hidden.device) |
| image_hidden_states.append(_pad_to_same(this_image_hidden)) |
| image_2d_pos.append(_pad_to_same(pos_emb)) |
| valid_image_token.append([full_h, full_w]) |
| image_hidden_states = torch.stack(image_hidden_states) |
| image_2d_pos = torch.stack(image_2d_pos) |
| valid_image_token = torch.tensor(valid_image_token, dtype=torch.int64) |
| return image_hidden_states, image_2d_pos, valid_image_token |
| def visiual_token_downsample( |
| visual_downsampler, |
| image_hidden_states, |
| valid_image_token, |
| visual_embedding_group, |
| image_2d_pos): |
| if image_2d_pos is not None: |
| image_hidden_states = image_hidden_states + image_2d_pos |
| image_hidden_states = visual_downsampler(image_hidden_states) |
| valid_image_token = torch.ceil(valid_image_token / np.sqrt(visual_embedding_group)).to(torch.int64) |
| return image_hidden_states, valid_image_token |
| def merge_native_qformer( |
| clip_embeddings_native_patch, |
| valid_image_token_shape, |
| clip_embeddings_qformer, |
| visual_source_spliter, |
| num_sub_images): |
| assert clip_embeddings_native_patch.size(0) == valid_image_token_shape.size(0) == clip_embeddings_qformer.size(0) |
| def add_split_token_for_qformer_token(qformer_emb): |
| |
| len_per_token = int(qformer_emb.size(0) // (num_sub_images + 1)) |
| qformer_emb_with_spliter = list() |
| for i in range(num_sub_images + 1): |
| qformer_emb_with_spliter.append( |
| visual_source_spliter(torch.tensor([2 * i]).to(visual_source_spliter.weight.device)) |
| ) |
| qformer_emb_with_spliter.append(qformer_emb[i * len_per_token:(i + 1) * len_per_token]) |
| qformer_emb_with_spliter.append( |
| visual_source_spliter(torch.tensor([2 * i + 1]).to(visual_source_spliter.weight.device)) |
| ) |
| return torch.cat(qformer_emb_with_spliter, dim=0) |
| merged_visual_embeddings = list() |
| for batch_id in range(clip_embeddings_native_patch.size(0)): |
| h, w = valid_image_token_shape[batch_id] |
| native_patch_emb = clip_embeddings_native_patch[batch_id][:h, :w, :].reshape(h*w, -1) |
| qformer_emb = clip_embeddings_qformer[batch_id] |
| qformer_emb = add_split_token_for_qformer_token(qformer_emb) |
| merged_visual_embeddings.append( |
| torch.cat( |
| [visual_source_spliter(torch.tensor([10]).to(visual_source_spliter.weight.device)), |
| native_patch_emb, |
| visual_source_spliter(torch.tensor([11]).to(visual_source_spliter.weight.device)), |
| qformer_emb], |
| dim=0)) |
| return merged_visual_embeddings |
| class WemmForConditionalGeneration(PreTrainedModel): |
| config_class = WeMMConfig |
| def __init__(self, config: WeMMConfig): |
| super().__init__(config) |
| self.vision_tower = Idefics2VisionTransformer(config.vision_config) |
| self.image_processor = Idefics2ImageProcessor(config.image_processor) |
| self.connector = Idefics2Connector(config.connector_config) |
| self.projector = ProjectorModel(config.projector_config) |
| self.language_model = InternLM2ForCausalLM(config.text_config) |
| self.tokenizer = AutoTokenizer.from_pretrained("internlm/internlm2-chat-7b", trust_remote_code=True, encode_special_tokens=True) |
| self.downsampler = DownsamplerModel(config.downsampler_config) |
| self.visual_source_spliter_emb = torch.nn.Embedding(**config.spliter_emb_config) |
| self.gen_config = GenerationConfig( |
| max_new_tokens=512, |
| do_sample=False, |
| eos_token_id=self.tokenizer.eos_token_id, |
| pad_token_id=self.tokenizer.pad_token_id |
| if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id, |
| ) |
| self.do_image_splitting = config.do_image_splitting |
| self.stop_criteria = get_stop_criteria( |
| tokenizer=self.tokenizer, stop_words=['<|im_end|>']) |
| self.config = config |
| def mm_generate(self, image_path, prompt, gen_config=None): |
| prompt = "<image>" + '\n' + prompt |
| prompt = f"<|im_start|>user\n{prompt}<|im_end|><|im_start|>assistant\n" |
| if isinstance(image_path, str): |
| image = Image.open(image_path).convert('RGB') |
| else: |
| image = image_path |
| navit980_images = self.image_processor([[image]], return_tensors="pt", do_image_splitting=self.do_image_splitting) |
| batch_size_navit = navit980_images['pixel_values'].shape[0] |
| navit_pixel_values = navit980_images['navit_pixel_values'].cuda() |
| navit_patch_attention_mask = navit980_images["pixel_attention_mask"].cuda() |
| clip_visual_outputs = self.vision_tower(pixel_values=navit_pixel_values,patch_attention_mask=navit_patch_attention_mask,).last_hidden_state |
| super_image_hidden_states, image_2d_pos, valid_image_token_shape = \ |
| recover_navit_subimages_with_pos_emb( |
| clip_visual_outputs, navit_patch_attention_mask, num_sub_images=4, |
| visual_embedding_group=1, |
| pos_hidden_size=4096, |
| thumbnail_only=True |
| ) |
| clip_embeddings_native_patch, valid_image_token_shape = visiual_token_downsample( |
| self.downsampler, |
| super_image_hidden_states, valid_image_token_shape, |
| visual_embedding_group=1, image_2d_pos=None |
| ) |
| clip_embeddings_qformer = self.connector(clip_visual_outputs, attention_mask=navit_patch_attention_mask.view(navit_pixel_values.size(0), -1)) |
| hidden_size = clip_embeddings_qformer.shape[-1] |
| clip_embeddings_qformer = clip_embeddings_qformer.view(batch_size_navit, -1, hidden_size) |
| clip_embeddings_qformer = self.projector(clip_embeddings_qformer) |
| merged_visual_embeddings = \ |
| merge_native_qformer( |
| clip_embeddings_native_patch, |
| valid_image_token_shape, |
| clip_embeddings_qformer, |
| visual_source_spliter=self.visual_source_spliter_emb, |
| num_sub_images=4 |
| ) |
| chunk_encode = [] |
| for idx, chunk in enumerate(prompt.split(DEFAULT_IMAGE_TOKEN)): |
| if idx == 0: |
| cur_encode = self.tokenizer.encode(chunk) |
| else: |
| cur_encode = self.tokenizer.encode(chunk, add_special_tokens=False) |
| chunk_encode.append(cur_encode) |
| assert len(chunk_encode) == 2 |
| ids = [] |
| for idx, cur_chunk_encode in enumerate(chunk_encode): |
| ids.extend(cur_chunk_encode) |
| if idx != len(chunk_encode) - 1: |
| ids.append(IMAGE_TOKEN_INDEX) |
| ids = torch.tensor(ids).cuda().unsqueeze(0) |
| pixel_values = None |
| mm_inputs = self.prepare_inputs_labels_for_multimodal( |
| llm=self.language_model, input_ids=ids, pixel_values=pixel_values, clip_embeddings=merged_visual_embeddings) |
| generate_output = self.language_model.generate( |
| **mm_inputs, |
| generation_config=gen_config if gen_config is not None else self.gen_config, |
| streamer=None, |
| bos_token_id=self.tokenizer.bos_token_id, |
| stopping_criteria=self.stop_criteria |
| ) |
| predict = self.tokenizer.decode( |
| generate_output[0], skip_special_tokens=True).strip() |
| return predict |
| def get_valid_visual_embedding(self, embedding, valid_token_shape): |
| if valid_token_shape is None: |
| return embedding |
| h, w = valid_token_shape |
| return embedding[:h, :w, :].reshape(h*w, -1) |
| |
| def prepare_inputs_labels_for_multimodal( |
| self, |
| llm: PreTrainedModel, |
| input_ids: torch.LongTensor = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| labels: Optional[torch.LongTensor] = None, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| clip_embeddings: Optional[torch.FloatTensor] = None, |
| hard_coded_max_len: Optional[int] = None, |
| **kwargs): |
| if pixel_values is None and clip_embeddings is None: |
| return { |
| 'input_ids': input_ids, |
| 'position_ids': position_ids, |
| 'attention_mask': attention_mask, |
| 'past_key_values': past_key_values, |
| 'inputs_embeds': None, |
| 'labels': labels |
| } |
| valid_image_token_shape = kwargs.get('valid_image_token_shape', None) |
| _labels = labels |
| _position_ids = position_ids |
| _attention_mask = attention_mask |
| if attention_mask is None: |
| attention_mask = torch.ones_like(input_ids, dtype=torch.bool) |
| else: |
| attention_mask = attention_mask.bool() |
| if position_ids is None: |
| position_ids = torch.arange( |
| 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) |
| if labels is None: |
| labels = torch.full_like(input_ids, IGNORE_INDEX) |
| |
| input_ids = [ |
| cur_input_ids[cur_attention_mask] |
| for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) |
| ] |
| labels = [ |
| cur_labels[cur_attention_mask] |
| for cur_labels, cur_attention_mask in zip(labels, attention_mask) |
| ] |
| new_inputs_embeds = [] |
| new_labels = [] |
| new_img_masks = [] |
| cur_image_idx = 0 |
| for batch_idx, cur_input_ids in enumerate(input_ids): |
| num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() |
| if num_images == 0: |
| cur_pixel_values = pixel_values[cur_image_idx] if pixel_values is not None else None |
| cur_clip_emb = self.get_valid_visual_embedding(clip_embeddings[cur_image_idx], valid_image_token_shape[cur_image_idx]) if clip_embeddings is not None else None |
| cur_inputs_embeds_1 = llm.get_input_embeddings()(cur_input_ids) |
| if cur_clip_emb is not None and cur_pixel_values is not None: |
| cur_inputs_embeds = torch.cat( |
| [cur_inputs_embeds_1, cur_pixel_values[0:0], cur_clip_emb[0:0]], dim=0) |
| elif cur_pixel_values is not None: |
| cur_inputs_embeds = torch.cat( |
| [cur_inputs_embeds_1, cur_pixel_values[0:0]], dim=0) |
| elif cur_clip_emb is not None: |
| cur_inputs_embeds = torch.cat( |
| [cur_inputs_embeds_1, cur_clip_emb[0:0]], dim=0) |
| else: |
| raise ValueError |
| new_inputs_embeds.append(cur_inputs_embeds) |
| new_labels.append(labels[batch_idx]) |
| new_img_masks.append(torch.zeros( |
| cur_inputs_embeds.shape[0], device=cur_inputs_embeds.device).bool()) |
| cur_image_idx += 1 |
| continue |
| image_token_indices = [-1] + torch.where( |
| cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [ |
| cur_input_ids.shape[0] |
| ] |
| cur_input_ids_noim = [] |
| cur_labels = labels[batch_idx] |
| cur_labels_noim = [] |
| for i in range(len(image_token_indices) - 1): |
| cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + |
| 1:image_token_indices[i + |
| 1]]) |
| cur_labels_noim.append(cur_labels[image_token_indices[i] + |
| 1:image_token_indices[i + 1]]) |
| split_sizes = [x.shape[0] for x in cur_labels_noim] |
| cur_inputs_embeds = llm.get_input_embeddings()( |
| torch.cat(cur_input_ids_noim)) |
| cur_inputs_embeds_no_im = torch.split( |
| cur_inputs_embeds, split_sizes, dim=0) |
| cur_new_inputs_embeds = [] |
| cur_new_labels = [] |
| cur_img_masks = [] |
| for i in range(num_images + 1): |
| cur_new_inputs_embeds.append(cur_inputs_embeds_no_im[i]) |
| cur_new_labels.append(cur_labels_noim[i]) |
| cur_img_masks.append(torch.zeros( |
| cur_inputs_embeds_no_im[i].shape[0], device=cur_inputs_embeds_no_im[i].device).bool()) |
| if i < num_images: |
| cur_pixel_values = pixel_values[cur_image_idx] if pixel_values is not None else None |
| if(valid_image_token_shape is not None): |
| cur_clip_emb = \ |
| self.get_valid_visual_embedding(clip_embeddings[cur_image_idx], valid_image_token_shape[cur_image_idx]) \ |
| if clip_embeddings is not None else None |
| else: |
| cur_clip_emb = clip_embeddings[cur_image_idx] if clip_embeddings is not None else None |
| cur_image_idx += 1 |
| |
| if cur_pixel_values is not None: |
| cur_new_inputs_embeds.append(cur_pixel_values) |
| cur_img_masks.append(torch.ones( |
| cur_pixel_values.shape[0], device=cur_pixel_values.device).bool()) |
| cur_new_labels.append( |
| torch.full((cur_pixel_values.shape[0], ), |
| IGNORE_INDEX, |
| device=cur_labels.device, |
| dtype=cur_labels.dtype)) |
| |
| if cur_clip_emb is not None: |
| cur_new_inputs_embeds.append(cur_clip_emb) |
| cur_img_masks.append(torch.zeros( |
| cur_clip_emb.shape[0], device=cur_clip_emb.device).bool()) |
| cur_new_labels.append( |
| torch.full((cur_clip_emb.shape[0],), |
| IGNORE_INDEX, |
| device=cur_labels.device, |
| dtype=cur_labels.dtype)) |
| cur_new_inputs_embeds = torch.cat(cur_new_inputs_embeds) |
| cur_new_labels = torch.cat(cur_new_labels) |
| cur_img_masks = torch.cat(cur_img_masks) |
| new_inputs_embeds.append(cur_new_inputs_embeds) |
| new_labels.append(cur_new_labels) |
| new_img_masks.append(cur_img_masks) |
| |
| max_len = max(x.shape[0] for x in new_inputs_embeds) |
| if hard_coded_max_len is not None: |
| max_len = min(max_len, hard_coded_max_len) |
| batch_size = len(new_inputs_embeds) |
| new_inputs_embeds_padded = [] |
| new_labels_padded = torch.full((batch_size, max_len), |
| IGNORE_INDEX, |
| dtype=new_labels[0].dtype, |
| device=new_labels[0].device) |
| attention_mask = torch.zeros((batch_size, max_len), |
| dtype=attention_mask.dtype, |
| device=attention_mask.device) |
| position_ids = torch.zeros((batch_size, max_len), |
| dtype=position_ids.dtype, |
| device=position_ids.device) |
| new_img_masks_padded = torch.zeros((batch_size, max_len), device=new_img_masks[0].device).bool() |
| for i, (cur_new_embed, |
| cur_new_labels, cur_new_img_masks) in enumerate(zip(new_inputs_embeds, new_labels, new_img_masks)): |
| cur_new_embed = cur_new_embed[:max_len] |
| cur_new_labels = cur_new_labels[:max_len] |
| cur_new_img_masks = cur_new_img_masks[:max_len] |
| cur_len = cur_new_embed.shape[0] |
| new_inputs_embeds_padded.append( |
| torch.cat((cur_new_embed, |
| torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), |
| dtype=cur_new_embed.dtype, |
| device=cur_new_embed.device)), |
| dim=0)) |
| if cur_len > 0: |
| new_labels_padded[i, :cur_len] = cur_new_labels |
| attention_mask[i, :cur_len] = True |
| position_ids[i, :cur_len] = torch.arange( |
| 0, |
| cur_len, |
| dtype=position_ids.dtype, |
| device=position_ids.device) |
| new_img_masks_padded[i, :cur_len] = cur_new_img_masks |
| new_inputs_embeds = torch.stack(new_inputs_embeds_padded, dim=0) |
| if _labels is None: |
| new_labels = None |
| else: |
| new_labels = new_labels_padded |
| if _attention_mask is None: |
| attention_mask = None |
| else: |
| attention_mask = attention_mask.to(dtype=_attention_mask.dtype) |
| if _position_ids is None: |
| position_ids = None |
| prepared_data = { |
| 'input_ids': None, |
| 'position_ids': position_ids, |
| 'attention_mask': attention_mask, |
| 'past_key_values': past_key_values, |
| 'inputs_embeds': new_inputs_embeds, |
| 'labels': new_labels, |
| } |
| if pixel_values is not None: |
| prepared_data.update({'im_mask': new_img_masks_padded}) |
| return prepared_data |
| AutoConfig.register("wemm_hf", WeMMConfig) |
| AutoModel.register(WeMMConfig, WemmForConditionalGeneration) |
|
|