| import copy |
| import os |
| import sys |
|
|
| dir_path = os.path.dirname(os.path.realpath(__file__)) |
| sys.path.insert(0, dir_path) |
|
|
| import contextlib |
|
|
| import torch.utils.checkpoint |
| import torch.nn as nn |
| from torch.nn import LayerNorm |
| from torchvision import transforms |
| from torchvision.transforms.functional import InterpolationMode |
| from PIL import Image |
|
|
| from .modeling_vit import * |
| from .modeling_InternLM import * |
| from .modeling_utils import * |
| from .resampler import create_resampler |
|
|
| from transformers.utils import logging |
| logger = logging.get_logger(__name__) |
|
|
|
|
| class InternLMXComposerForCausalLM(PreTrainedModel): |
| config_class = InternLMXComposerConfig |
| _auto_class = "AutoModelForCausalLM" |
|
|
| gen_config = dict( |
| num_beams=5, |
| do_sample=True, |
| min_length=1, |
| repetition_penalty=1.5, |
| length_penalty=1.0, |
| temperature=1.0, |
| max_new_tokens=500, |
| ) |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
| self.max_length = config.max_length |
| print (f'Set max length to {self.max_length}') |
| print('Init VIT ... ', end='') |
| self.visual_encoder = create_eva_vit_g(img_size=448) |
| self.ln_vision = nn.Identity() |
| self.supports_gradient_checkpointing = True |
| print('Done') |
| print('Init Perceive Sampler ... ', end='') |
| with all_logging_disabled(): |
| self.Qformer = create_resampler(num_query_token=256) |
| print('Done') |
|
|
| print('Init InternLM ... ', end='') |
| self.flag_image_start = nn.Parameter(torch.zeros([1, 1, 4096])) |
| self.flag_image_end = nn.Parameter(torch.zeros([1, 1, 4096])) |
| self.flag_image_start.requires_grad = False |
| self.flag_image_end.requires_grad = False |
|
|
|
|
| if int(torch.__version__[0]) == 1: |
| self.internlm_model = InternLMForCausalLM._from_config(config).to( |
| torch.float16) |
| else: |
| assert int(torch.__version__[0]) == 2 |
| |
| with torch.device('meta'): |
| self.internlm_model = InternLMForCausalLM._from_config(config) |
| |
| |
| |
| |
| self.internlm_proj = nn.Linear(4096, |
| self.internlm_model.config.hidden_size) |
| print('Done') |
|
|
| self.vis_processor = transforms.Compose([ |
| transforms.Resize((448, 448), |
| interpolation=InterpolationMode.BICUBIC), |
| transforms.ToTensor(), |
| transforms.Normalize((0.48145466, 0.4578275, 0.40821073), |
| (0.26862954, 0.26130258, 0.27577711)), |
| ]) |
|
|
| self.tokenizer = None |
|
|
| @property |
| def eoh(self): |
| return '<TOKENS_UNUSED_0>' |
|
|
| @property |
| def eoa(self): |
| return '<TOKENS_UNUSED_1>' |
|
|
| def get_input_embeddings(self): |
| return self.internlm_model.get_input_embeddings() |
| |
| def _set_gradient_checkpointing(self, module, value=False): |
| if value: |
| self.internlm_model.apply( |
| partial(self.internlm_model._set_gradient_checkpointing, value=True) |
| ) |
|
|
|
|
| def encode_img(self, image): |
| if image is None: |
| return None |
| if isinstance(image, str): |
| image = Image.open(image).convert("RGB") |
| image = self.vis_processor(image).unsqueeze(0).to(self.device) |
| else: |
| assert isinstance(image, torch.Tensor) |
| device = image.device |
| image_embeds = self.ln_vision( |
| self.visual_encoder(image)).to(device) |
| image_atts = torch.ones(image_embeds.size()[:-1], |
| dtype=torch.long).to(device) |
| query_output = self.Qformer(image_embeds) |
| inputs_internlm = self.internlm_proj(query_output) |
|
|
| inputs_internlm = torch.cat([ |
| self.flag_image_start.expand(inputs_internlm.shape[0], -1, -1), |
| inputs_internlm, |
| self.flag_image_end.expand(inputs_internlm.shape[0], -1, -1) |
| ], |
| dim=1) |
| return inputs_internlm |
|
|
| def encode_text(self, text, add_special_tokens=False): |
| text_token_ids = self.tokenizer( |
| text, |
| return_tensors='pt', |
| add_special_tokens=add_special_tokens, |
| ).input_ids.to(self.device) |
| text_embeds = self.internlm_model.model.embed_tokens(text_token_ids) |
| return text_embeds |
|
|
| def decode_text(self, out_embeds): |
| out_text = self.tokenizer.batch_decode(out_embeds, |
| skip_special_tokens=True)[0] |
| out_text = out_text.split(self.eoa)[0] |
| return out_text |
|
|
| def wrap_text(self, user_text, bot_text='', add_special=True): |
| if add_special: |
| eoh = self.eoh |
| else: |
| eoh = '' |
| text = f'<|User|>:{user_text}{eoh}\n<|Bot|>:{bot_text}' |
| return text |
|
|
| def get_gen_args(self, **kwargs): |
| new_kargs = copy.deepcopy(self.gen_config) |
| new_kargs.update(kwargs) |
| return new_kargs |
| |
| def generate(self, text, image=None, **kwargs): |
| text_embeds = self.encode_text(text) |
| img_embeds = self.encode_img(image) |
| prompt_embeds = self.wrap_prompt(text_embeds, img_embeds) |
| out_embeds = self.internlm_model.generate(inputs_embeds=prompt_embeds, |
| **self.get_gen_args(**kwargs)) |
| out_text = self.decode_text(out_embeds) |
| return out_text |
|
|
| def chat(self, text, image=None, history=None, **kwargs): |
| text_embeds = self.encode_text(text) |
| img_embeds = self.encode_img(image) |
| prompt_embeds = self.wrap_prompt(text_embeds, |
| img_embeds, |
| history=history) |
| out_embeds = self.internlm_model.generate(inputs_embeds=prompt_embeds, |
| **self.get_gen_args(**kwargs)) |
| out_text = self.decode_text(out_embeds) |
|
|
| |
| clean_out_text_token_ids = self.tokenizer( |
| out_text, return_tensors='pt').input_ids.to(self.device) |
| clean_out_text_embeds = self.internlm_model.model.embed_tokens( |
| clean_out_text_token_ids) |
| clean_prompt_embeds = self.wrap_prompt(text_embeds, |
| img_embeds, |
| add_special=False) |
| cur_history = torch.cat([clean_prompt_embeds, clean_out_text_embeds], |
| dim=1) |
| if history is None: |
| history = [] |
| history.append(cur_history) |
| return out_text, history |
|
|
| def wrap_prompt(self, |
| text_embeds, |
| img_embeds=None, |
| history=None, |
| add_special=True): |
| if add_special: |
| prompt_segs = ['<|User|>:', f'{self.eoh}\n<|Bot|>:'] |
| else: |
| prompt_segs = ['<|User|>:', '<|Bot|>:'] |
| prompt_seg_embeds = [] |
| for i, seg in enumerate(prompt_segs): |
| if history is not None: |
| add_special_tokens = False |
| else: |
| add_special_tokens = i == 0 |
| seg_embeds = self.encode_text( |
| seg, add_special_tokens=add_special_tokens) |
| prompt_seg_embeds.append(seg_embeds) |
| if img_embeds is None: |
| img_embeds = text_embeds.new_empty(text_embeds.size(0), 0, |
| text_embeds.size(-1)) |
| prompt_seg_embeds = [ |
| prompt_seg_embeds[0], img_embeds, text_embeds, prompt_seg_embeds[1] |
| ] |
| prompt_embeds = torch.cat(prompt_seg_embeds, dim=1) |
| if history is not None: |
| prompt_embeds = torch.cat([*history, prompt_embeds], dim=1) |
| return prompt_embeds |
|
|
|
|