| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from xtuner.registry import BUILDER |
|
|
| from xtuner.utils import PROMPT_TEMPLATE |
| from xtuner.tools.utils import get_stop_criteria |
| from xtuner.model.utils import guess_load_checkpoint |
|
|
| from mmcv.ops import point_sample |
| from mmdet.models.utils import get_uncertain_point_coords_with_randomness |
|
|
| from mmengine.model import BaseModel |
| from projects.ST.dataset.utils import convert_image_to_patches |
| from projects.ST.dataset.collect_fns import create_single_prefix_mask |
| from einops import rearrange |
| from transformers import DynamicCache, GenerationConfig |
| import copy |
| from mmengine.config import Config, ConfigDict |
| from peft import get_peft_model, prepare_model_for_kbit_training |
|
|
| def find_all_linear_names(model): |
| lora_module_names = set() |
| for name, module in model.named_modules(): |
| if isinstance(module, nn.Linear): |
| names = name.split('.') |
| lora_module_names.add(names[0] if len(names) == 1 else names[-1]) |
|
|
| if 'lm_head' in lora_module_names: |
| lora_module_names.remove('lm_head') |
| if 'output_layer' in lora_module_names: |
| lora_module_names.remove('output_layer') |
| return list(lora_module_names) |
|
|
| NON_VISION_TOKEN = -1 |
| PROMPT_TMPL = '<|im_start|>user\n{input}<|im_end|>\n' |
|
|
| class Sa2VASTModel(BaseModel): |
| IMG_CONTEXT_TOKEN = "<vpatch>" |
| IMG_START_TOKEN = "<vision>" |
| IMG_END_TOKEN = "</vision>" |
|
|
| IMG_RSEP_TOKEN = "<vrow_sep>" |
| CLS_TOKEN = "<|vis_cls|>" |
| def __init__(self, |
| single_transformer, |
| tokenizer, |
| single_transformer_lora=None, |
| seg_hidden_states=256, |
| patch_size=32, |
| seg_pred_down_ratio=4, |
| loss_mask=None, |
| loss_dice=None, |
| torch_dtype=torch.bfloat16, |
| pretrained_pth=None, |
| special_tokens=None, |
| loss_sample_points=False, |
| num_points=12544, |
| |
| template=None, |
| add_cls=False, |
| bs=1, |
| ): |
| super().__init__() |
| self.add_cls = add_cls |
| self.bs = bs |
| self.patch_size = patch_size |
| self.seg_pred_down_ratio = seg_pred_down_ratio |
| self.seg_hidden_states = seg_hidden_states |
| if special_tokens is None: |
| special_tokens = ['[SEG]'] |
| self.special_tokens = special_tokens |
| self.single_transformer = BUILDER.build(single_transformer) |
| self.llm = self.single_transformer |
|
|
| self.tokenizer = BUILDER.build(tokenizer) |
| self._add_special_tokens() |
|
|
| in_dim = self.single_transformer.config.hidden_size |
| out_dim = seg_hidden_states |
| self.seg_token_projector = nn.Sequential( |
| nn.Linear(in_dim, in_dim), nn.ReLU(inplace=True), |
| nn.Linear(in_dim, out_dim), nn.Dropout(0.0) |
| ) |
|
|
| out_dim = seg_hidden_states * (patch_size // seg_pred_down_ratio) ** 2 |
| self.image_feature_projector = nn.Sequential( |
| nn.Linear(in_dim, in_dim), nn.ReLU(inplace=True), |
| nn.Linear(in_dim, out_dim), nn.Dropout(0.0) |
| ) |
|
|
| if single_transformer_lora is not None: |
| self.single_transformer.requires_grad_(False) |
| self.activation_checkpointing_enable() |
| self.single_transformer.enable_input_require_grads() |
| self._prepare_llm_for_lora(single_transformer_lora) |
| self.single_transformer.model.base_model.get_input_embeddings().requires_grad_(True) |
| self.single_transformer.lm_head.requires_grad_(True) |
|
|
| self.loss_mask = BUILDER.build(loss_mask) |
| self.loss_dice = BUILDER.build(loss_dice) |
|
|
| self.torch_dtype = torch_dtype |
|
|
| if pretrained_pth is not None: |
| pretrained_state_dict = guess_load_checkpoint(pretrained_pth) |
| self.load_state_dict(pretrained_state_dict, strict=False) |
| print(f'Load pretrained weight from {pretrained_pth}') |
|
|
| self.loss_sample_points = loss_sample_points |
| self.num_points = num_points |
| self.oversample_ratio = 3.0 |
| self.importance_sample_ratio = 0.75 |
|
|
| self.template = template |
| self.template['INSTRUCTION'] = PROMPT_TMPL |
|
|
| def _parse_lora_config(self, lora_config): |
| if isinstance(lora_config, dict) or isinstance( |
| lora_config, Config) or isinstance(lora_config, ConfigDict): |
| lora_config = BUILDER.build(lora_config) |
| return lora_config |
|
|
| def _prepare_llm_for_lora(self, |
| lora_config, |
| use_activation_checkpointing=True): |
| lora_config = self._parse_lora_config(lora_config) |
| self.single_transformer.model = prepare_model_for_kbit_training( |
| self.single_transformer.model, use_activation_checkpointing) |
| if lora_config.target_modules is None: |
| modules = find_all_linear_names(self.single_transformer.model) |
| lora_config.target_modules = modules |
| self.single_transformer.model = get_peft_model(self.single_transformer.model, |
| lora_config) |
|
|
| def activation_checkpointing_disable(self): |
| self.single_transformer.gradient_checkpointing_disable() |
|
|
| def activation_checkpointing_enable(self): |
| self.single_transformer.gradient_checkpointing_enable() |
|
|
| def _add_special_tokens(self): |
|
|
| self.tokenizer.vis_beg_tok = "<vision>" |
| self.tokenizer.vis_patch_tok = "<vpatch>" |
| self.tokenizer.vis_rsep_tok = "<vrow_sep>" |
| self.tokenizer.vis_frm_tok = "<vframe_sep>" |
| self.tokenizer.vis_end_tok = "</vision>" |
| self.tokenizer.vis_cls_tok = "<|vis_cls|>" |
|
|
| special_tokens = self.special_tokens |
| _num_new_tokens = self.tokenizer.add_tokens(special_tokens, special_tokens=True) |
| if _num_new_tokens > 0: |
| self.single_transformer.resize_token_embeddings(len(self.tokenizer)) |
| self.seg_token_idx = self.tokenizer("[SEG]", add_special_tokens=False).input_ids[0] |
| self.vision_patch_idx = self.tokenizer("<vpatch>", add_special_tokens=False).input_ids[0] |
|
|
| def state_dict(self, *args, **kwargs): |
| state_dict = super().state_dict(*args, **kwargs) |
| return state_dict |
|
|
| def _get_pesudo_data(self, device): |
| gt_masks = torch.zeros((1, 256, 256), dtype=torch.uint8, device=device) |
| gt_masks = [gt_masks] * self.bs |
| return gt_masks |
|
|
| def get_mask_prediction(self, seg_embeddings_list, image_seg_features): |
| |
| |
| ret = [] |
| for seg_embeddings, image_seg_feature in zip(seg_embeddings_list, image_seg_features): |
| pred_masks = torch.einsum("qc,hwc->qhw", seg_embeddings, image_seg_feature) |
| ret.append(pred_masks) |
| return ret |
|
|
| def forward(self, data, data_samples=None, mode='loss'): |
| gt_masks = data.pop('masks', None) |
| patch_nums_per_images = data.pop('patch_nums_per_images', None) |
| input_ids = data['input_ids'] |
|
|
| if 'vision_patches' in data.keys() and data['vision_patches'] is not None: |
| data['vision_patches'] = data['vision_patches'].flatten(1).to(self.torch_dtype) |
|
|
| if gt_masks is None: |
| |
| seg_valid = False |
| gt_masks = self._get_pesudo_data( |
| device=input_ids.device, |
| ) |
| else: |
| seg_valid = True |
|
|
| output = self.single_transformer(**data, return_dict=True, output_hidden_states=True) |
| hidden_states = output.hidden_states |
| |
| hidden_states = hidden_states[-1] |
|
|
| |
| image_token_mask = input_ids == self.vision_patch_idx |
| vision_features = self.image_feature_projector(hidden_states[image_token_mask]) |
| patch_split_nums = [item[0] * item[1] for item in patch_nums_per_images] |
| vision_features = torch.split(vision_features, patch_split_nums, dim=0) |
| all_image_features = [] |
| for patch_num, image_features in zip(patch_nums_per_images, vision_features): |
| sub_pixels = self.patch_size // self.seg_pred_down_ratio |
| h_patches, w_patches = patch_num |
| if h_patches * w_patches == 0: |
| |
| all_image_features.append(None) |
| else: |
| image_features = image_features.reshape(h_patches, w_patches, self.seg_hidden_states, sub_pixels, sub_pixels) |
| image_features = image_features.permute(0, 3, 1, 4, 2) |
| image_features = image_features.flatten(0, 1).flatten(1, 2) |
| all_image_features.append(image_features) |
|
|
| |
| seg_token_mask = input_ids == self.seg_token_idx |
| if seg_valid: |
| seg_token_features = self.seg_token_projector(hidden_states[seg_token_mask]) |
| else: |
| seg_token_features = self.seg_token_projector(hidden_states[:, :1].flatten(0, 1)) |
| seg_token_counts = seg_token_mask.int().sum(-1) |
| if not seg_valid: |
| seg_token_counts += 1 |
|
|
| seg_embeddings_list_ = torch.split(seg_token_features, seg_token_counts.tolist(), dim=0) |
| seg_embeddings_list = [] |
| image_seg_features = [] |
| gt_masks_ = [] |
| for idx, item in enumerate(seg_embeddings_list_): |
| if len(item) != 0 and all_image_features[idx] is not None: |
| seg_embeddings_list.append(item) |
| image_seg_features.append(all_image_features[idx]) |
| gt_masks_.append(gt_masks[idx]) |
| gt_masks = gt_masks_ |
|
|
| pred_masks = self.get_mask_prediction(seg_embeddings_list, image_seg_features) |
| if not self.loss_sample_points: |
| gt_masks = [F.interpolate(gt_mask.unsqueeze(0), size=pred_mask.shape[-2:], mode='nearest').squeeze(0) for |
| gt_mask, pred_mask in zip(gt_masks, pred_masks)] |
|
|
| loss_mask, loss_dice = 0, 0 |
| n_masks = 0 |
| for pred_mask, gt_mask in zip(pred_masks, gt_masks): |
| |
| if len(pred_mask) != len(gt_mask): |
| |
| print(f"Pred mask shape {pred_mask.shape} is not equal to gt_mask shape {gt_mask.shape} !!!") |
| min_num = min(len(pred_mask), len(gt_mask)) |
| pred_mask = pred_mask[:min_num] |
| gt_mask = gt_mask[:min_num] |
| _seg_valid = False |
| else: |
| _seg_valid = True |
|
|
| if self.loss_sample_points: |
| sampled_pred_mask, sampled_gt_mask = self.sample_points(pred_mask, gt_mask) |
| sam_loss_dice = self.loss_dice( |
| sampled_pred_mask, |
| sampled_gt_mask, avg_factor=(1 + 1e-4)) |
| sam_loss_mask = self.loss_mask( |
| sampled_pred_mask.reshape(-1), |
| sampled_gt_mask.reshape(-1), |
| avg_factor=(sampled_pred_mask.shape[1] + 1e-4)) |
| else: |
| sam_loss_mask = self.loss_mask(pred_mask, gt_mask) * len(pred_mask) |
| sam_loss_dice = self.loss_dice(pred_mask, gt_mask) * len(pred_mask) |
|
|
| if _seg_valid and seg_valid: |
| _scale = 1.0 |
| n_masks += len(pred_mask) |
| else: |
| _scale = 0.0 |
|
|
| loss_mask += sam_loss_mask * _scale |
| loss_dice += sam_loss_dice * _scale |
|
|
| if loss_mask == 0.0: |
| _llm_loss_scale = 1.0 |
| else: |
| _llm_loss_scale = 0.1 |
|
|
| loss_dict = { |
| 'loss_mask': loss_mask / (n_masks + 1e-4) + output.loss * 0.0, |
| 'loss_dice': loss_dice / (n_masks + 1e-4) + output.loss * 0.0, |
| 'llm_loss': output.loss * _llm_loss_scale, |
| } |
| return loss_dict |
|
|
| def sample_points(self, mask_pred, gt_masks): |
| gt_masks = gt_masks.unsqueeze(1) |
| gt_masks = gt_masks.to(mask_pred) |
| mask_pred = mask_pred.unsqueeze(1) |
| |
|
|
| with torch.no_grad(): |
| points_coords = get_uncertain_point_coords_with_randomness( |
| mask_pred.to(torch.float32), None, self.num_points, |
| self.oversample_ratio, self.importance_sample_ratio) |
| |
| mask_point_targets = point_sample( |
| gt_masks.float(), points_coords).squeeze(1) |
| |
| mask_point_preds = point_sample( |
| mask_pred.to(torch.float32), points_coords.to(torch.float32)).squeeze(1) |
| return mask_point_preds.to(mask_pred.dtype), mask_point_targets.to(mask_pred.dtype) |
|
|
| def preparing_for_generation(self, metainfo, **kwargs): |
| |
| assert hasattr(self, 'tokenizer'), "The Model does not have the tokenizer!!!" |
| self.bot_name = 'BOT' |
| if 'template' in metainfo.keys(): |
| template = metainfo['template'] |
| else: |
| template = PROMPT_TEMPLATE['phi3_chat'] |
| if self.template is None: |
| self.template = template |
| stop_words = [] |
| stop_words += self.template.get('STOP_WORDS', []) |
| stop_criteria = get_stop_criteria( |
| tokenizer=self.tokenizer, stop_words=stop_words) |
| self.stop_criteria = stop_criteria |
|
|
| default_generation_kwargs = dict( |
| max_new_tokens=512, |
| do_sample=False, |
| temperature=0, |
| num_beams=1, |
| eos_token_id=self.tokenizer.eos_token_id, |
| pad_token_id=self.tokenizer.eos_token_id, |
| ) |
| default_generation_kwargs.update(metainfo.get('generation_kwargs', {})) |
| self.gen_config = GenerationConfig(**default_generation_kwargs) |
| self.init_prediction_config = True |
|
|
| self.single_transformer.to(self.torch_dtype) |
| self.seg_token_projector.to(self.torch_dtype) |
| self.image_feature_projector.to(self.torch_dtype) |
| return |
|
|
| def prepare_image_textual_seq_norowsep(self, h, w): |
| image_token_patch_indices = [] |
| seq = "" |
| tok_len = 0 |
|
|
| seq += self.IMG_START_TOKEN |
| tok_len += 1 |
| image_token_patch_indices.append(NON_VISION_TOKEN) |
|
|
| seq += self.IMG_CONTEXT_TOKEN * (w * h) |
| tok_len += (w * h) |
| image_token_patch_indices += [idx for idx in range(w * h)] |
|
|
| seq += self.IMG_END_TOKEN |
| tok_len += 1 |
| image_token_patch_indices.append(NON_VISION_TOKEN) |
|
|
| if self.add_cls: |
| seq += self.CLS_TOKEN |
| tok_len += 1 |
| image_token_patch_indices.append(NON_VISION_TOKEN) |
| return seq, tok_len, image_token_patch_indices |
|
|
| def predict_forward( |
| self, |
| image=None, |
| text=None, |
| past_text='', |
| ): |
| assert self.tokenizer |
|
|
| input_dict = {} |
| ori_image_size = image.size |
|
|
| if image is None: |
| input_dict['vision_patches'] = None |
| input_dict['patch_nums_per_images'] = (0, 0) |
| image_token_str = '' |
| image_token_patch_indices = [] |
| else: |
| image_patches = convert_image_to_patches(image, self.patch_size) |
| |
| h_patches, w_patches = image_patches.shape[:2] |
| n_patches = h_patches * w_patches |
| |
| input_dict['vision_patches'] = image_patches.flatten(0, 1).flatten(1) |
| input_dict['patch_nums_per_images'] = (h_patches, w_patches) |
| image_token_str, image_token_len, image_token_patch_indices = \ |
| self.prepare_image_textual_seq_norowsep( |
| image_patches.shape[0], image_patches.shape[1] |
| ) |
|
|
| ret_masks = [] |
| if '<image>' in text: |
| assert past_text is None or len(past_text) == 0 |
| first_conv = True |
| else: |
| first_conv = False |
| text = text.replace('<image>\n', '').replace('\n<image>', '').replace('<image>', '') |
| input_text = '' |
| input_text += self.template['INSTRUCTION'].format( |
| input=text, round=1, bot_name=self.bot_name) |
| if first_conv: |
| input_text = image_token_str + input_text |
| else: |
| input_text = past_text + input_text |
|
|
| ids = self.tokenizer.encode(input_text, add_special_tokens=False) |
| vision_start_end = self.search_vision_tokens(ids) |
|
|
| attention_mask = create_single_prefix_mask(vision_start_end, len(ids)).unsqueeze(0).unsqueeze(0).cuda() |
| |
|
|
| ids = torch.tensor(ids).cuda().unsqueeze(0) |
| position_ids = generate_mm_pos_ids_singleit( |
| ids[0].cpu().numpy().tolist(), self.vision_patch_idx, |
| input_dict['patch_nums_per_images'][0], input_dict['patch_nums_per_images'][1]).unsqueeze(1).cuda() |
|
|
| vision_patch_indices = [] |
| vision_patch_indices += image_token_patch_indices |
| vision_patch_indices += [NON_VISION_TOKEN] * (ids.shape[-1] - len(vision_patch_indices)) |
|
|
| vision_patch_indices = torch.tensor(vision_patch_indices).cuda().unsqueeze(0) |
|
|
| padding_attention_mask = torch.ones_like(ids).cuda() |
|
|
| mm_inputs = { |
| 'vision_patches': input_dict['vision_patches'].flatten(1).cuda().to(self.torch_dtype), |
| |
| 'input_ids': ids, |
| 'attention_mask': padding_attention_mask, |
| 'position_ids': position_ids, |
| 'labels': None, |
| 'vision_patch_indices': vision_patch_indices, |
| } |
|
|
| |
| image_tokens_len = vision_start_end[-1] + 1 |
| cached_inputs = dict( |
| input_ids=ids[:, :image_tokens_len], |
| position_ids=position_ids[:, :, :image_tokens_len], |
| attention_mask=attention_mask[:, :, :image_tokens_len, :image_tokens_len], |
| vision_patches=mm_inputs['vision_patches'], |
| vision_patch_indices=vision_patch_indices[:, :image_tokens_len], |
| use_cache=True |
| ) |
| prefix_cache = DynamicCache() |
| with torch.no_grad(): |
| prefix_cache = self.single_transformer.forward(**cached_inputs, past_key_values=prefix_cache, |
| return_dict=True, output_hidden_states=True) |
| past_hidden_states = prefix_cache.hidden_states |
| prefix_cache = prefix_cache.past_key_values |
| past_key_values = copy.deepcopy(prefix_cache) |
|
|
| generate_output = self.single_transformer.generate( |
| **mm_inputs, |
| generation_config=self.gen_config, |
| streamer=None, |
| bos_token_id=self.tokenizer.bos_token_id, |
| stopping_criteria=self.stop_criteria, |
| output_hidden_states=True, |
| return_dict_in_generate=True, |
| past_key_values=past_key_values, |
| ) |
| predict = self.tokenizer.decode( |
| generate_output.sequences[0], skip_special_tokens=False).strip() |
|
|
| |
| last_past_hidden_states = past_hidden_states[-1][0] |
|
|
| |
| hidden_states = generate_output.hidden_states |
| last_hidden_states = [item[-1][0] for item in hidden_states] |
| last_hidden_states = torch.cat(last_hidden_states, dim=0) |
|
|
| last_hidden_states = torch.cat([last_past_hidden_states, last_hidden_states], dim=0) |
|
|
| |
| image_token_mask = ids[0] == self.vision_patch_idx |
| vision_features = self.image_feature_projector( |
| last_hidden_states[:len(ids[0])][image_token_mask]) |
| patch_split_nums = [item[0] * item[1] for item in [input_dict['patch_nums_per_images']]] |
| vision_features = torch.split(vision_features, patch_split_nums, dim=0) |
| all_image_features = [] |
| for patch_num, image_features in zip([input_dict['patch_nums_per_images']], vision_features): |
| sub_pixels = self.patch_size // self.seg_pred_down_ratio |
| h_patches, w_patches = patch_num |
| if h_patches * w_patches == 0: |
| |
| all_image_features.append(None) |
| else: |
| image_features = image_features.reshape(h_patches, w_patches, self.seg_hidden_states, sub_pixels, |
| sub_pixels) |
| image_features = image_features.permute(0, 3, 1, 4, |
| 2) |
| image_features = image_features.flatten(0, 1).flatten(1, 2) |
| all_image_features.append(image_features) |
| image_features = all_image_features[0] |
|
|
| seg_hidden_states = get_seg_hidden_states( |
| last_hidden_states, generate_output.sequences[0][:-1], |
| seg_id=self.seg_token_idx |
| ) |
| all_seg_hidden_states = self.seg_token_projector(seg_hidden_states) |
| if all_seg_hidden_states.shape[0] == 0: |
| ret_masks = None |
| else: |
| pred_masks = torch.einsum("qc,hwc->qhw", all_seg_hidden_states, image_features) |
| w, h = ori_image_size |
| masks = F.interpolate(pred_masks.unsqueeze(0), size=(h, w), mode='bilinear', align_corners=False)[0] |
| masks = masks.sigmoid() > 0.5 |
| |
| masks = masks.cpu() |
| ret_masks.append(masks) |
|
|
| return {'prediction': predict, 'prediction_masks': ret_masks, 'input_text': ''} |
|
|
| def search_vision_tokens(self, input_ids): |
| image_start_idx = self.tokenizer(self.IMG_START_TOKEN, add_special_tokens=False).input_ids[0] |
| image_end_idx = self.tokenizer(self.IMG_END_TOKEN, add_special_tokens=False).input_ids[0] |
| if image_start_idx not in input_ids: |
| return None |
| else: |
| start_idx = input_ids.index(image_start_idx) |
| end_idx = input_ids.index(image_end_idx) |
| return [start_idx+1, end_idx] |
|
|
| def get_seg_hidden_states(hidden_states, output_ids, seg_id): |
| seg_mask = output_ids == seg_id |
| n_out = len(seg_mask) |
| return hidden_states[-n_out:][seg_mask] |
|
|
|
|
| def generate_mm_pos_ids_singleit(input_ids, vpatch_id, h, w): |
| input_ids_pt = torch.Tensor(input_ids).int() |
| vpatch_pos = torch.argwhere(input_ids_pt == vpatch_id) |
| vpatch_start_pos = vpatch_pos[0].item() |
| nt = len(input_ids) - (h * w) + 1 |
|
|
| |
| t_indices = torch.arange(1) |
| h_indices = torch.arange(h) |
| w_indices = torch.arange(w) |
| v_pos_id = torch.stack(torch.meshgrid(t_indices, h_indices, w_indices, indexing='ij'), dim=0) |
| v_pos_id = rearrange(v_pos_id, "d t h w -> (t h w) d") |
| v_pos_id += vpatch_start_pos |
| position_id = torch.cat( |
| [ |
| torch.arange(vpatch_start_pos).unsqueeze(-1).repeat(1, 3), |
| v_pos_id, |
| torch.arange(nt - vpatch_start_pos - 1).unsqueeze(-1).repeat(1, 3) + v_pos_id.max() + 1, |
| ], |
| dim=0 |
| ) |
| assert len(input_ids) == position_id.size(0) |
| position_id = rearrange(position_id, "slen d -> d slen").long() |
|
|
| return position_id |
|
|