| import io |
| import os |
|
|
| import requests |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| |
| from PIL import Image, ImageFile |
| from torch.nn.utils import rnn |
| from types import SimpleNamespace |
| from peft import LoraConfig, TaskType, get_peft_model |
| from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig |
|
|
| import numpy as np |
| |
|
|
| from transformers import StoppingCriteria, StoppingCriteriaList |
|
|
| from .CLIP import load as load_clip |
| from .PROCESS import data |
| from .modeling_llama import LlamaForCausalLM |
| from .utils.pcl_utils import MEAN_COLOR_RGB, RandomCuboid, random_sampling |
| from .conversations import conversation_dict, default_conversation |
|
|
| ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
| |
| VISION_TAGS = { |
| 'pos': {'image': '<image>', 'pcl': '<pcl>'}, |
| 'sov': {'image': '<Img>', 'pcl': '<Pcl>'}, |
| 'eov': {'image': '</Img>', 'pcl': '</Pcl>'}, |
| } |
| ModalityType = SimpleNamespace( |
| VISION="vision", |
| TEXT="text", |
| AUDIO="audio", |
| THERMAL="thermal", |
| DEPTH="depth", |
| IMU="imu", |
| ) |
|
|
| class StoppingCriteriaSub(StoppingCriteria): |
|
|
| def __init__(self, stops = [], encounters=1): |
| super().__init__() |
| self.stops = stops |
| self.ENCOUNTERS = encounters |
|
|
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): |
| stop_count = 0 |
| for stop in self.stops: |
| stop_count = (stop == input_ids[0]).sum().item() |
| if stop_count >= self.ENCOUNTERS: |
| return True |
| return False |
|
|
|
|
| class MyStoppingCriteria(StoppingCriteria): |
| def __init__(self, stops, input_ids): |
| super().__init__() |
| self.stops = [torch.tensor(stop).to('cuda:0') for stop in stops] |
| self.stop_flag = [0]*input_ids.shape[0] |
|
|
| def check_stop(self, input_ids): |
| for stop in self.stops: |
| if torch.all((stop == input_ids[-len(stop):])).item(): |
| return True |
| return False |
|
|
| def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
| flag = 1 |
| for id, output_id in enumerate(output_ids): |
| if self.stop_flag[id] == 1: |
| continue |
| if self.check_stop(output_id): |
| self.stop_flag[id] = 1 |
| else: |
| flag = 0 |
| if flag == 1: |
| return True |
| return False |
|
|
|
|
| def build_one_instance(tokenizer, conversation, vision_type='image'): |
| pos = VISION_TAGS['pos'][vision_type] |
| |
| eov = VISION_TAGS['eov'][vision_type] |
|
|
| text_list = [] |
| turn_num = len(conversation) |
| input_ids, target_ids = [], [] |
| for i in range(turn_num): |
| turn = conversation[i] |
| role = turn['from'] |
| if i == 0: |
| assert role == 'human' |
| turn['value'] = turn['value'].replace(f'{pos}\n', '').replace(f'\n{pos}', '') |
| text = f'{eov} ' + turn['value'] + '\n### Assistant:' |
| one_input_id = tokenizer(text, add_special_tokens=False).input_ids |
| input_ids += one_input_id |
| target_ids += [-100]*len(one_input_id) |
| else: |
| if role == 'human': |
| text = 'Human: ' + turn['value'] + '\n### Assistant:' |
| one_input_id = tokenizer(text, add_special_tokens=False).input_ids |
| input_ids += one_input_id |
| target_ids += [-100]*len(one_input_id) |
| elif role == 'gpt': |
| text = turn['value'] + '\n###' |
| one_input_id = tokenizer(text, add_special_tokens=False).input_ids |
| input_ids += one_input_id |
| target_ids += one_input_id |
| else: |
| raise Exception('Wrong Role!!!') |
| text_list.append(text) |
| assert len(input_ids) == len(target_ids) |
| return text_list, input_ids, target_ids |
|
|
|
|
| def process_batch_instance(tokenizer, batch_of_conversations, max_tgt_len, vision_type='image'): |
| batch_input_ids, batch_target_ids = [], [] |
| for conversation in batch_of_conversations: |
| _, one_input_ids, one_target_ids = build_one_instance(tokenizer, conversation, vision_type=vision_type) |
| batch_input_ids.append(torch.LongTensor(one_input_ids)) |
| batch_target_ids.append(torch.LongTensor(one_target_ids)) |
| input_ids = rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) |
| target_ids = rnn.pad_sequence(batch_target_ids, batch_first=True, padding_value=-100) |
| assert input_ids.size() == target_ids.size() |
| input_ids = input_ids[:,:max_tgt_len] |
| target_ids = target_ids[:,:max_tgt_len] |
| attention_mask = input_ids.ne(tokenizer.pad_token_id) |
| assert attention_mask.size() == input_ids.size() |
| return input_ids, target_ids, attention_mask.long() |
|
|
|
|
| def make_prompt_start(system_header=False, vision_type='image', task_type='normal'): |
| |
| PROMPT_START = f'### Human: {VISION_TAGS["sov"][vision_type]}' |
| if system_header: |
| if task_type == 'normal': |
| return f"{default_conversation.system}\n\n" + PROMPT_START |
| else: |
| return [f"{conversation_dict[task]}\n\n" + PROMPT_START for task in task_type] |
| else: |
| return PROMPT_START |
|
|
|
|
| class LAMMPEFTModel(nn.Module): |
|
|
| '''LoRA for LLaMa model''' |
|
|
| def __init__(self, **args): |
| super(LAMMPEFTModel, self).__init__() |
| self.args = args |
| |
| self.client = None |
|
|
| self.vision_type = args['vision_type'] if 'vision_type' in args else 'image' |
| encoder_pretrain = args['encoder_pretrain'] if 'encoder_pretrain' in args else 'clip' |
| self.encoder_pretrain = encoder_pretrain |
| assert encoder_pretrain in ['imagebind', 'clip', 'epcl'], f'Encoder_pretrain: {encoder_pretrain} Not Implemented' |
| if not encoder_pretrain == 'clip' or os.path.isfile(args['encoder_ckpt_path']): |
| encoder_ckpt_path = args['encoder_ckpt_path'] |
| elif not os.path.isfile(args['encoder_ckpt_path']): |
| encoder_ckpt_path = 'ViT-L/14' |
| |
| vicuna_ckpt_path = args['vicuna_ckpt_path'] |
| |
| system_header = args['system_header'] if 'system_header' in args else False |
| stage = args['stage'] |
|
|
| |
| |
| self.vision_feature_type = args['vision_feature_type'] |
| self.num_vision_token = args['num_vision_token'] |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print (f'Initializing [{encoder_pretrain}] visual encoder from {encoder_ckpt_path} [{device}]...') |
|
|
| |
| if args['encoder_pretrain'].lower() == 'clip': |
| clip_encoder, self.visual_preprocess = load_clip(encoder_ckpt_path, device=device) |
| self.visual_encoder = clip_encoder.visual |
| if self.vision_feature_type == 'global': |
| self.vision_hidden_size = 768 |
| self.num_vision_token = 1 |
| assert self.num_vision_token == 1, 'Only 1 global token is available!' |
| elif self.vision_feature_type == 'local': |
| self.vision_hidden_size = 1024 |
| self.num_vision_token = min(self.num_vision_token, 256) |
|
|
| |
| for name, param in self.visual_encoder.named_parameters(): |
| param.requires_grad = False |
| self.visual_encoder.eval() |
| print ('Visual encoder initialized.') |
|
|
| print (f'Initializing language decoder from {vicuna_ckpt_path} ...') |
| |
| peft_config = LoraConfig( |
| task_type=TaskType.CAUSAL_LM, |
| inference_mode=False, |
| r=self.args['lora_r'], |
| lora_alpha=self.args['lora_alpha'], |
| lora_dropout=self.args['lora_dropout'], |
| target_modules=self.args['lora_target_modules'] |
| ) |
|
|
| self.llama_model = LlamaForCausalLM.from_pretrained(vicuna_ckpt_path) |
| self.llama_model = get_peft_model(self.llama_model, peft_config) |
| self.llama_model.print_trainable_parameters() |
|
|
| self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False) |
| self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token |
| self.llama_tokenizer.padding_side = "right" |
| print ('Language decoder initialized.') |
|
|
| self.llama_proj = nn.Linear( |
| self.vision_hidden_size, self.llama_model.config.hidden_size |
| ) |
| print ('LLaMa projection layer initialized.') |
|
|
| self.max_tgt_len = args['max_tgt_len'] |
| self.system_header = system_header |
| self.device = torch.cuda.current_device() |
|
|
| def encode_image(self, image_paths): |
| """encode images to llama inputs |
| |
| :param tupe image_paths: (bsz, ) |
| :return tensor, tensor: input feature to llama, attention mask to llama |
| """ |
| if self.encoder_pretrain == 'imagebind': |
| inputs = {ModalityType.VISION: data.load_and_transform_vision_data(image_paths, self.device)} |
| |
| inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs} |
| with torch.no_grad(): |
| embeddings = self.visual_encoder(inputs) |
| image_embeds = embeddings['vision'] |
| inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) |
| atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) |
| return inputs_llama, atts_llama |
| elif self.encoder_pretrain == 'clip': |
| inputs = self.load_and_transform_vision_data_clip(image_paths, self.device) |
| inputs = inputs.to(self.llama_model.dtype) |
| inputs_llama = self.clip_encode_image(inputs) |
| atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) |
| return inputs_llama, atts_llama |
| |
| def my_encode_image(self, images): |
| """encoder loaded image objects""" |
| if self.encoder_pretrain == 'clip': |
| inputs = data.transform_vision_data(images, self.device) |
| inputs_llama = self.clip_encode_image(inputs) |
| atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) |
| return inputs_llama, atts_llama |
| else: |
| raise NotImplementedError("Encoder pretrain [{}] not implemented".format(self.encoder_pretrain)) |
| |
| def encode_pcl(self, pcl_paths): |
| |
| inputs = self.load_and_transform_pcl_data(pcl_paths, self.device) |
| |
| inputs = inputs.to(self.llama_model.dtype) |
| with torch.no_grad(): |
| if self.vision_feature_type == 'global': |
| raise NotImplementedError("Global feature not implemented for pcl") |
| elif self.vision_feature_type == 'local': |
| embeddings = self.visual_encoder(inputs)[1][:, :self.num_vision_token] |
| image_embeds = embeddings.reshape(-1, self.vision_hidden_size).to(self.llama_model.dtype) |
| inputs_llama = self.llama_proj(image_embeds).reshape(-1, self.num_vision_token, self.llama_model.config.hidden_size) |
| atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) |
| return inputs_llama, atts_llama |
| |
| def clip_encode_image(self, inputs): |
| inputs = inputs.to(self.llama_model.dtype) |
| with torch.no_grad(): |
| if self.vision_feature_type == 'global': |
| embeddings = self.visual_encoder(inputs) |
| image_embeds = embeddings.to(self.llama_model.dtype) |
| inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) |
| elif self.vision_feature_type == 'local': |
| embeddings = self.visual_encoder.forward_patch_features(inputs)[:, :self.num_vision_token] |
| image_embeds = embeddings.reshape(-1, self.vision_hidden_size).to(self.llama_model.dtype) |
| inputs_llama = self.llama_proj(image_embeds).reshape(-1, self.num_vision_token, self.llama_model.config.hidden_size) |
| else: |
| raise NotImplementedError("{} not Implemented".format(self.vision_feature_type)) |
| return inputs_llama |
|
|
| def load_and_transform_vision_data_clip(self, image_paths, device): |
| if image_paths is None: |
| return None |
| image_ouputs = [] |
| for image_path in image_paths: |
| if os.path.exists(image_path): |
| image = Image.open(image_path) |
| elif image_path.startswith('s3://') and self.client is not None: |
| image = Image.open(io.BytesIO(self.client.get(image_path, update_cache=True))).convert("RGB") |
| elif image_path.startswith('http://'): |
| image = Image.open(requests.get(image_path, stream=True).raw) |
| else: |
| print("can not load image: ", image_path) |
| image_outpt = self.visual_preprocess(image).to(device) |
| image_ouputs.append(image_outpt) |
| return torch.stack(image_ouputs, dim=0) |
| |
| def load_and_transform_pcl_data(self, pcl_paths, device): |
| if pcl_paths is None: |
| return None |
| pcl_output = [] |
| for pcl_path in pcl_paths: |
| mesh_vertices = np.load(pcl_path) |
| if not self.use_color: |
| point_cloud = mesh_vertices[:, 0:3] |
| else: |
| point_cloud = mesh_vertices[:, 0:6] |
| point_cloud[:, 3:] = (point_cloud[:, 3:] - MEAN_COLOR_RGB) / 256.0 |
| |
| if self.use_height: |
| floor_height = np.percentile(point_cloud[:, 2], 0.99) |
| height = point_cloud[:, 2] - floor_height |
| point_cloud = np.concatenate([point_cloud, np.expand_dims(height, 1)], 1) |
| |
| point_cloud, _ = random_sampling( |
| point_cloud, self.num_points, return_choices=True |
| ) |
| pcl_output.append(torch.from_numpy(point_cloud)) |
| return torch.stack(pcl_output, dim=0).to(device) |
|
|
| def prompt_wrap(self, img_embeds, input_ids, target_ids, attention_mask, system_header, task_type): |
| ''' |
| input_ids, target_ids, attention_mask: bsz x s2 |
| ''' |
| input_ids = input_ids.to(self.device) |
| target_ids = target_ids.to(self.device) |
| attention_mask = attention_mask.to(self.device) |
|
|
| batch_size = img_embeds.shape[0] |
|
|
| |
| p_before = make_prompt_start(system_header=system_header, vision_type=self.vision_type, task_type=task_type) |
| if isinstance(p_before, list): |
| p_before_tokens = [self.llama_tokenizer(p, |
| return_tensors="pt", add_special_tokens=False).input_ids[0].to(self.device) for p in p_before] |
| |
| p_before_token_ids = rnn.pad_sequence(p_before_tokens, batch_first=True, padding_value=self.llama_tokenizer.pad_token_id) |
| p_before_attn_mask = p_before_token_ids.ne(self.llama_tokenizer.pad_token_id) |
| else: |
| p_before_tokens = self.llama_tokenizer(p_before, |
| return_tensors="pt", add_special_tokens=False).to(self.device) |
| p_before_token_ids = p_before_tokens.input_ids.expand(batch_size, -1) |
| p_before_attn_mask = p_before_tokens.attention_mask.expand(batch_size, -1) |
| |
| p_before_embeds = self.llama_model.model.model.embed_tokens(p_before_token_ids) |
| p_after_embeds = self.llama_model.model.model.embed_tokens(input_ids).expand(batch_size, -1, -1) |
| bos = torch.ones([batch_size, 1], |
| dtype=p_before_token_ids.dtype, |
| device=p_before_token_ids.device) * self.llama_tokenizer.bos_token_id |
| bos_embeds = self.llama_model.model.model.embed_tokens(bos) |
| inputs_embeds = torch.cat([bos_embeds, p_before_embeds, img_embeds, p_after_embeds], dim=1) |
|
|
| |
| empty_targets = ( |
| torch.ones([batch_size, 1 + p_before_embeds.size()[1] + self.num_vision_token], |
| dtype=torch.long).to(self.device).fill_(-100) |
| ) |
| targets = torch.cat([empty_targets, target_ids], dim=1) |
| assert inputs_embeds.size()[1] == targets.size()[1] |
|
|
| |
| atts_bos = torch.ones([batch_size, 1], dtype=torch.long).to(self.device) |
| atts_img = torch.ones([batch_size, self.num_vision_token], dtype=torch.long).to(self.device) |
| attention_mask = torch.cat([atts_bos, p_before_attn_mask, atts_img, attention_mask], dim=1) |
| assert attention_mask.size() == targets.size() |
| return inputs_embeds, targets, attention_mask |
|
|
| def forward(self, inputs): |
| """Model Forward in training |
| |
| :param class inputs: model itself |
| :raises ValueError: valueerror if not image or pcl |
| :return list: loss & token acc |
| """ |
| |
| assert self.vision_type == inputs['vision_type'] |
| task_type = inputs['task_type'] |
| vision_paths = inputs['vision_paths'] |
| if self.vision_type == 'image': |
| vision_embeds, _ = self.encode_image(vision_paths) |
| elif self.vision_type == 'pcl': |
| vision_embeds, _ = self.encode_pcl(vision_paths) |
| else: |
| raise ValueError('vision type [{}] not supported'.format(self.vision_type)) |
|
|
| output_texts = inputs['output_texts'] |
| input_ids, target_ids, attention_mask = process_batch_instance(self.llama_tokenizer, output_texts, self.max_tgt_len, self.vision_type) |
| inputs_embeds, targets, attention_mask = self.prompt_wrap(vision_embeds, input_ids, target_ids, attention_mask, self.system_header, task_type) |
|
|
| outputs = self.llama_model( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| return_dict=True, |
| labels=targets, |
| ) |
| loss = outputs.loss |
| |
| chosen_tokens = torch.max(outputs.logits, dim=-1)[1][:, 1: -1] |
| labels = targets[:, 2:] |
| gen_acc = (chosen_tokens.reshape(-1) == labels.reshape(-1)).to(torch.long) |
| valid_mask = (labels != -100).reshape(-1) |
| valid_tokens = gen_acc & valid_mask |
| gen_acc = valid_tokens.sum().item() / valid_mask.sum().item() |
| return loss, gen_acc |
|
|
| def extract_multimodal_feature(self, inputs): |
| """Extract multimodal features from the input in Generation (Test) |
| |
| :param Dict inputs: input dict; modality: path |
| :return _type_: _description_ |
| """ |
| features = [] |
| if inputs['image_paths']: |
| image_embeds, _ = self.encode_image(inputs['image_paths']) |
| features.append(image_embeds) |
| if 'images' in inputs and inputs['images']: |
| image_embeds, _ = self.my_encode_image(inputs['images']) |
| return image_embeds |
| |
| if 'pcl_paths' in inputs and inputs['pcl_paths']: |
| pcl_embeds, _ = self.encode_pcl(inputs['pcl_paths']) |
| features.append(pcl_embeds) |
| |
| feature_embeds = torch.cat(features).sum(dim=0).unsqueeze(0) |
| return feature_embeds |
|
|
| def prepare_generation_embedding(self, inputs): |
| """prepare for generation |
| |
| :param class inputs: model |
| :return Dict: generation input |
| """ |
| eov = VISION_TAGS['eov'][self.vision_type] |
| |
| prompt_list = inputs['prompt'] |
| if len(inputs['modality_embeds']) == 1: |
| feature_embeds = inputs['modality_embeds'][0] |
| else: |
| feature_embeds = self.extract_multimodal_feature(inputs) |
| inputs['modality_embeds'].append(feature_embeds) |
|
|
| batch_size = feature_embeds.shape[0] |
| p_before = make_prompt_start(vision_type=self.vision_type) |
| p_before_tokens = self.llama_tokenizer(p_before, |
| return_tensors="pt", add_special_tokens=False).to(self.device) |
| p_before_embeds = self.llama_model.model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) |
| p_after_embeds_list = [] |
| p_after_tokens_list = [] |
| for prompt in prompt_list: |
| |
| text = f'{eov} ' + prompt + '\n### Assistant:' |
| p_after_tokens = self.llama_tokenizer(text, add_special_tokens=False, return_tensors='pt').to(self.device) |
|
|
| p_after_tokens_list.append(p_after_tokens.input_ids.squeeze(0)) |
|
|
| p_after_tokens = rnn.pad_sequence(p_after_tokens_list, batch_first=True, padding_value=self.llama_tokenizer.pad_token_id) |
|
|
| p_after_embeds = self.llama_model.model.model.embed_tokens(p_after_tokens) |
| |
| |
| |
| |
| bos = torch.ones([batch_size, 1], |
| dtype=p_before_tokens.input_ids.dtype, |
| device=p_before_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id |
| bos_embeds = self.llama_model.model.model.embed_tokens(bos) |
| |
| inputs_embeds = torch.cat([bos_embeds, p_before_embeds, feature_embeds, p_after_embeds], dim=1) |
| return inputs_embeds |
|
|
| def generate(self, inputs): |
| ''' |
| inputs = { |
| 'image_paths': optional, |
| 'audio_paths': optional |
| 'video_paths': optional |
| 'thermal_paths': optional |
| 'mode': generation mode, |
| 'prompt': human input prompt, |
| 'max_tgt_len': generation length, |
| 'top_p': top_p, |
| 'temperature': temperature |
| 'modality_embeds': None or torch.tensor |
| 'modality_cache': save the image cache |
| } |
| ''' |
| input_embeds = self.prepare_generation_embedding(inputs) |
| |
| stopping_criteria = StoppingCriteriaList([MyStoppingCriteria([[2277]], input_embeds)]) |
| outputs = self.llama_model.generate( |
| inputs_embeds=input_embeds, |
| max_new_tokens=inputs['max_tgt_len'], |
| top_p=inputs['top_p'], |
| temperature=inputs['temperature'], |
| do_sample=True, |
| use_cache=True, |
| stopping_criteria=stopping_criteria, |
| ) |
| |
| output_text = self.llama_tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| return output_text |
|
|