| |
| import os |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from timm.models.layers import trunc_normal_ |
| from contextlib import suppress |
| import logging |
| from einops import rearrange |
| from peft import LoraConfig, get_peft_model |
| from bigmodelvis import Visualization |
|
|
| from .clip_encoder_hd import CLIPVisionTowerHD |
| from .conversation import get_conv_template |
| from .processors_conv import preprocess_qwen |
| from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel |
| from transformers.generation import GenerationConfig |
| from transformers import Qwen2Config, Qwen2ForCausalLM |
|
|
|
|
| def get_autocast(precision, cache_enabled=True): |
| if precision == "amp_bfloat16" or precision == "amp_bf16" or precision == 'bf16': |
| |
| return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16, cache_enabled=cache_enabled) |
| elif precision == 'fp16': |
| return lambda: torch.cuda.amp.autocast(dtype=torch.float16, cache_enabled=cache_enabled) |
| elif precision == 'fp32': |
| return suppress |
| else: |
| raise ValueError('not supported precision: {}'.format(precision)) |
|
|
|
|
| class LayerNorm(nn.LayerNorm): |
| """Subclass torch's LayerNorm to handle fp16.""" |
| def forward(self, x: torch.Tensor): |
| orig_type = x.dtype |
| ret = super().forward(x.type(torch.float32)) |
| return ret.type(orig_type) |
|
|
|
|
| class MLP(nn.Module): |
| """ Very simple multi-layer perceptron (also called FFN)""" |
|
|
| def __init__(self, input_dim, hidden_dim, output_dim, num_layers): |
| super().__init__() |
| self.num_layers = num_layers |
| h = [hidden_dim] * (num_layers - 1) |
| self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) |
|
|
| def forward(self, x): |
| for i, layer in enumerate(self.layers): |
| x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) |
| return x |
|
|
|
|
| class InfMLLM_Unified_HD_Chat(PreTrainedModel): |
| |
| def __init__(self, config, debug=False): |
| super().__init__(config) |
|
|
| |
| self.lm_tokenizer = AutoTokenizer.from_pretrained(config._name_or_path, use_fast=False, trust_remote_code=True) |
| self.media_token_img = "<|image|>" |
| self.media_token_id_img = self.lm_tokenizer(self.media_token_img, return_tensors="pt",add_special_tokens=False).input_ids.item() |
| self.lm_model = Qwen2ForCausalLM(config.lm_config) |
|
|
| self.lm_tokenizer.model_max_length = config.max_txt_len |
| |
| self.template_name = config.conv_style |
| self.preprocess_function = preprocess_qwen |
|
|
| self.separate = nn.Parameter(torch.zeros([1, 1, 4096])) |
| self.newline = nn.Parameter(torch.zeros([1, 1, 1, 4096])) |
|
|
| |
| self.encoder_img = CLIPVisionTowerHD(config.vision_config, vision_select_layer=-2) |
| self.encoder_img_ln = lambda x: x |
|
|
| self.adapter_img = nn.Sequential( |
| nn.Linear(self.encoder_img.num_features*4, self.lm_model.config.hidden_size), |
| nn.GELU(), |
| nn.Linear(self.lm_model.config.hidden_size, self.lm_model.config.hidden_size) |
| ) |
|
|
| |
| self.config = config |
| self.precision = config.precision |
| self._apply_lemmatizer = getattr(config, 'apply_lemmatizer', False) |
| self._lemmatizer = None |
| |
|
|
| def forward_encoder_img(self, image): |
| autocast = get_autocast(self.precision, cache_enabled=True) |
| with autocast(): |
| assert isinstance(image, list) |
| image_embeds, image_split = self.encoder_img(image, self.separate, self.newline) |
|
|
| image_embeds = self.encoder_img_ln(image_embeds) |
| image_embeds = self.adapter_img(image_embeds) |
| return image_embeds, image_split |
|
|
| def _concat_embeds(self, |
| prompt_embeds, prompt_ids, prompt_masks, |
| labels=None, padding='left'): |
| emb_lens = [len(emb) for emb in prompt_embeds] |
| if len(set(emb_lens)) == 1: |
| if labels is not None: |
| return torch.stack(prompt_embeds, dim=0), torch.stack(prompt_ids, dim=0), torch.stack(prompt_masks, dim=0), torch.stack(labels, dim=0) |
| return torch.stack(prompt_embeds, dim=0), torch.stack(prompt_ids, dim=0), torch.stack(prompt_masks, dim=0) |
|
|
|
|
| pad_emb = self.lm_model.get_input_embeddings()(torch.tensor(self.lm_tokenizer.pad_token_id, device=prompt_embeds[0].device)) |
|
|
| prompt_embeds_new = pad_emb.expand(len(emb_lens), max(emb_lens), -1).clone() |
| prompt_ids_new = torch.ones([len(emb_lens), max(emb_lens)]).to(prompt_ids[0]) * self.lm_tokenizer.pad_token_id |
| prompt_masks_new = torch.zeros([len(emb_lens), max(emb_lens)]).to(prompt_masks[0]) |
| if labels is not None: |
| labels_new = -100 * torch.ones([len(emb_lens), max(emb_lens)]).to(prompt_ids[0]) |
|
|
| for i, L in enumerate(emb_lens): |
| if padding == 'left': |
| prompt_embeds_new[i, -L:] = prompt_embeds[i] |
| prompt_ids_new[i, -L:] = prompt_ids[i] |
| prompt_masks_new[i, -L:] = prompt_masks[i] |
| if labels is not None: |
| labels_new[i, -L:] = labels[i] |
|
|
| elif padding == 'right': |
| prompt_embeds_new[i, :L] = prompt_embeds[i] |
| prompt_ids_new[i, :L] = prompt_ids[i] |
| prompt_masks_new[i, :L] = prompt_masks[i] |
| if labels is not None: |
| labels_new[i, :L] = labels[i] |
| else: |
| raise ValueError() |
|
|
| if labels is not None: |
| return prompt_embeds_new, prompt_ids_new, prompt_masks_new, labels_new |
| return prompt_embeds_new, prompt_ids_new, prompt_masks_new |
|
|
| def _insert_media_feat(self, |
| prompt_embeds, prompt_ids, prompt_masks, |
| is_languages, |
| embeds_media, media_token_id, |
| index_list=None, |
| labels=None, len_media=None): |
| |
| prompt_embeds_new = [] |
| prompt_masks_new = [] |
| prompt_ids_new = [] |
| labels_new = [] |
| device = embeds_media[0].device |
|
|
| if index_list is not None: |
| assert len(index_list) == len(embeds_media) |
| assert len(embeds_media) <= len(prompt_embeds) |
|
|
| for b in range(len(prompt_embeds)): |
| if (index_list is not None) and (b not in index_list): |
| prompt_embeds_new.append(prompt_embeds[b]) |
| prompt_ids_new.append(prompt_ids[b]) |
| prompt_masks_new.append(prompt_masks[b]) |
| if labels is not None: |
| labels_new.append(labels[b]) |
| else: |
| _idx = prompt_ids[b].tolist().index(media_token_id) |
| if index_list is not None: |
| b_media = index_list.index(b) |
| else: |
| b_media = b |
|
|
| if len_media is not None: |
| cur_embeds_media = embeds_media[b_media, :len_media[b_media]] |
| else: |
| cur_embeds_media = embeds_media[b_media] |
|
|
| prompt_embeds_new.append(torch.cat([prompt_embeds[b][:_idx+1], |
| cur_embeds_media, |
| prompt_embeds[b][_idx+1:] |
| ], dim=0)) |
| prompt_ids_new.append(torch.cat([prompt_ids[b][:_idx+1], |
| torch.ones(len(cur_embeds_media), dtype=torch.long).to(device).fill_(-100), |
| prompt_ids[b][_idx+1:] |
| ], dim=0)) |
| if labels is not None: |
| labels_new.append(torch.cat([labels[b][:_idx+1], |
| torch.ones(len(cur_embeds_media), dtype=torch.long).to(device).fill_(-100), |
| labels[b][_idx+1:] |
| ], dim=0)) |
|
|
| |
| prompt_masks_new.append(torch.cat([prompt_masks[b][:_idx+1], |
| torch.zeros(len(cur_embeds_media), dtype=torch.long).to(device) if is_languages[b] else |
| torch.ones(len(cur_embeds_media), dtype=torch.long).to(device), |
| prompt_masks[b][_idx+1:]], dim=0)) |
|
|
| if labels is not None: |
| return prompt_embeds_new, prompt_ids_new, prompt_masks_new, labels_new |
| return prompt_embeds_new, prompt_ids_new, prompt_masks_new |
|
|
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| samples, |
| num_beams=5, |
| max_length=128, |
| min_length=1, |
| top_p=0.9, |
| temperature=0., |
| return_prompts=False |
| ): |
| autocast = get_autocast(self.precision, cache_enabled=True) |
| with autocast(): |
| conversations = samples['conversations'] |
| is_languages = [False] * len(conversations) |
|
|
| image_img = samples.get('images', None) |
| |
| index_img = list(range(len(image_img))) |
|
|
| device = None |
| special_prefix = ["" for _ in range(len(conversations))] |
|
|
| if (self.config.encoder_img is not None) and (image_img is not None) and len(index_img) > 0: |
| for i in index_img: |
| special_prefix[i] = self.media_token_img + special_prefix[i] |
|
|
| new_image_img = [] |
| for index in index_img: |
| new_image_img.append(image_img[index]) |
| embeds_img, len_img = self.forward_encoder_img(new_image_img) |
| device = embeds_img.device |
|
|
| conv = get_conv_template(self.template_name) |
| roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} |
|
|
| prompts = [] |
| for i, source in enumerate(conversations): |
| if roles[source[0]['from']] != conv.roles[0]: |
| |
| source = source[1:] |
|
|
| per_prefix = special_prefix[i] |
| conv.messages = [] |
| for j, sentence in enumerate(source): |
| role = roles[sentence['from']] |
| assert role == conv.roles[j % 2], f'{i}' |
| sentence['value'] = sentence['value'].replace("<image>", "").strip() |
|
|
| if j == 0: |
| sentence['value'] = per_prefix + sentence['value'] |
|
|
| conv.append_message(role, sentence['value']) |
| prompts.append(conv.get_prompt()) |
|
|
| self.lm_tokenizer.padding_side = "left" |
| if self.lm_tokenizer.bos_token is not None: |
| prompt_text = [self.lm_tokenizer.bos_token + t for t in prompts] |
| else: |
| prompt_text = prompts |
|
|
| prompt_tokens = self.lm_tokenizer( |
| prompt_text, |
| return_tensors="pt", |
| padding="longest", |
| truncation=False, |
| add_special_tokens=False |
| ).to(device) |
|
|
|
|
| prompt_embeds = self.lm_model.get_input_embeddings()(prompt_tokens.input_ids) |
| |
| prompt_masks = prompt_tokens.attention_mask |
| prompt_ids = prompt_tokens.input_ids |
| assert torch.all(prompt_ids[:, -1] != self.lm_tokenizer.pad_token_id), "make sure padding left" |
|
|
| if embeds_img is not None: |
| prompt_embeds, prompt_ids, prompt_masks = self._insert_media_feat(prompt_embeds=prompt_embeds, |
| prompt_ids=prompt_ids, |
| prompt_masks=prompt_masks, |
| is_languages=is_languages, |
| embeds_media=embeds_img, |
| media_token_id=self.media_token_id_img, |
| index_list=index_img, |
| len_media=len_img) |
|
|
|
|
| |
| prompt_embeds, prompt_ids, prompt_masks = self._concat_embeds(prompt_embeds, prompt_ids, prompt_masks, padding="left") |
| assert torch.all(prompt_ids[:, -1] != self.lm_tokenizer.pad_token_id), "make sure padding left" |
|
|
| kwargs = {} |
| kwargs['max_new_tokens'] = max_length |
|
|
| outputs = self.lm_model.generate( |
| |
| inputs_embeds=prompt_embeds, |
| attention_mask=prompt_masks, |
| do_sample=True if temperature > 0 else False, |
| temperature=temperature, |
| top_p=top_p, |
| num_beams=num_beams, |
| eos_token_id=self.lm_tokenizer.eos_token_id, |
| |
| min_length=min_length, |
| **kwargs |
| ) |
| output_text = self.lm_tokenizer.batch_decode( |
| outputs, skip_special_tokens=True |
| ) |
| output_text = [text.strip() for text in output_text] |
|
|
| if self._apply_lemmatizer or ("apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]): |
| output_text = self._lemmatize(output_text) |
|
|
| if return_prompts: |
| return output_text, prompts |
| return output_text |
|
|
| def _lemmatize(self, answers): |
| def apply(answer): |
| doc = self.lemmatizer(answer) |
|
|
| words = [] |
| for token in doc: |
| if token.pos_ in ["NOUN", "VERB"]: |
| words.append(token.lemma_) |
| else: |
| words.append(token.text) |
| answer = " ".join(words) |
|
|
| return answer |
|
|
| return [apply(answer) for answer in answers] |
|
|
| @property |
| def lemmatizer(self): |
| if self._lemmatizer is None: |
| try: |
| import spacy |
| self._lemmatizer = spacy.load("en_core_web_sm") |
| except ImportError: |
| logging.error( |
| """ |
| Please install spacy and en_core_web_sm model to apply lemmatization. |
| python -m spacy download en_core_web_sm |
| OR |
| import spacy.cli |
| spacy.cli.download("en_core_web_sm") |
| """ |
| ) |
| exit(1) |
|
|
| return self._lemmatizer |
|
|