| import copy |
| from collections import OrderedDict |
| import torch |
| import torch.nn as nn |
| from mmengine.config import Config, ConfigDict |
| from mmengine.model import BaseModel |
| from peft import get_peft_model, prepare_model_for_kbit_training |
|
|
| from xtuner.registry import BUILDER |
| from .modules import ProjectorConfig_OMG_LLaVA, ProjectorModel_OMG_LLaVA |
| from xtuner.model.modules import ProjectorModel, ProjectorConfig |
| from xtuner.model.modules import dispatch_modules |
| from .utils import (LoadWoInit, find_all_linear_names, |
| get_peft_model_state_dict, guess_load_checkpoint, |
| make_inputs_require_grad, |
| traverse_dict, |
| prepare_inputs_labels_for_multimodal_with_visual_prompts) |
| from .convnext_clip import OpenCLIPBackbone |
| from .omg_seg import OMGSegVisualEncoder |
|
|
| from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, |
| PROMPT_TEMPLATE) |
| from xtuner.tools.utils import get_stop_criteria, is_cn_string |
| from transformers import GenerationConfig |
| import torch.nn.functional as F |
| import numpy as np |
| from pycocotools import mask as _mask |
|
|
| class OMG_LLaVA(BaseModel): |
| def __init__(self, |
| llm, |
| visual_encoder, |
| visual_select_layer=-2, |
| freeze_llm=False, |
| freeze_visual_encoder=False, |
| require_omg_decoder=False, |
| pretrained_pth=None, |
| llm_lora=None, |
| visual_encoder_lora=None, |
| use_activation_checkpointing=True, |
| projector_depth=2, |
| text2vision_projector=False, |
| tokenizer=None, |
| keep_omg_decoder_frozen=False, |
| add_seg_pretrain=False, |
| additional_cross_attn_layers=False, |
| pixel_shuffle_ratio=None, |
| train_vocabulary=False, |
| freeze_llm_with_lora=False, |
| freeze_visual_projector=False, |
| rm_prior_embedding=False, |
| rm_query=False, |
| clip_feat_channel=1536, |
| ): |
| super().__init__() |
|
|
| self.freeze_llm_with_lora = freeze_llm_with_lora |
| self.freeze_visual_projector = freeze_visual_projector |
|
|
| self.freeze_llm = freeze_llm |
| self.freeze_visual_encoder = freeze_visual_encoder |
| with LoadWoInit(): |
| self.llm = self._build_from_cfg_or_module(llm) |
| if visual_encoder.type == OpenCLIPBackbone or visual_encoder.type == OMGSegVisualEncoder: |
| self.visual_encoder = visual_encoder.type(**visual_encoder) |
| else: |
| self.visual_encoder = self._build_from_cfg_or_module( |
| visual_encoder) |
| self.llm.config.use_cache = False |
| dispatch_modules(self.llm) |
|
|
| projector_config = ProjectorConfig_OMG_LLaVA( |
| query_channels=256, |
| feat_channels=clip_feat_channel, |
| llm_hidden_size=self.llm.config.hidden_size, |
| depth=projector_depth, |
| pixel_shuffle_ratio=pixel_shuffle_ratio, |
| ) |
| self.projector = ProjectorModel_OMG_LLaVA(projector_config).to( |
| self.visual_encoder.dtype) |
|
|
| self.text2vision_projector = text2vision_projector |
| if text2vision_projector: |
| projector_config = ProjectorConfig( |
| visual_hidden_size=self.llm.config.hidden_size, |
| llm_hidden_size=256 * 2, |
| depth=projector_depth) |
| self.projector_text2vision = ProjectorModel(projector_config).to( |
| self.visual_encoder.dtype) |
|
|
| if rm_query: |
| self.projector.model.rm_query = rm_query |
| if rm_prior_embedding: |
| self.projector.model.rm_prior_embedding = rm_prior_embedding |
|
|
| if self.freeze_llm: |
| self.llm.requires_grad_(False) |
| if self.freeze_visual_encoder: |
| self.visual_encoder.requires_grad_(False) |
|
|
| self.use_activation_checkpointing = use_activation_checkpointing |
| if use_activation_checkpointing: |
| |
| if hasattr(self.llm, 'enable_input_require_grads'): |
| self.llm.enable_input_require_grads() |
| else: |
| self.llm.get_input_embeddings().register_forward_hook( |
| make_inputs_require_grad) |
| if hasattr(self.visual_encoder, 'enable_input_require_grads'): |
| self.visual_encoder.enable_input_require_grads() |
| else: |
| self.visual_encoder.get_input_embeddings( |
| ).register_forward_hook(make_inputs_require_grad) |
| self.projector.enable_input_require_grads() |
| if text2vision_projector: |
| self.projector_text2vision.enable_input_require_grads() |
|
|
| |
| self.gradient_checkpointing_enable() |
|
|
| |
| self.added_special_token = False |
| if tokenizer is not None: |
| self.tokenizer = tokenizer |
| tokenizer_type = self.tokenizer['type'] |
| del self.tokenizer['type'] |
| self.tokenizer = tokenizer_type(**self.tokenizer) |
| self._add_special_tokens() |
|
|
| self.use_llm_lora = llm_lora is not None |
| self.use_visual_encoder_lora = visual_encoder_lora is not None |
|
|
| if self.use_llm_lora: |
| self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing) |
| if self.freeze_llm_with_lora: |
| for name, param in self.llm.named_parameters(): |
| param.requires_grad_(False) |
| else: |
| if train_vocabulary: |
| |
| for name, param in self.named_parameters(): |
| if ('tok_' in name or 'embed_tokens' in name) or 'lm_head' in name: |
| print("Unfrozen {} !!!".format(name)) |
| param.requires_grad_(True) |
| if ('output.' in name or 'lm_head' in name) and 'llm' in name and 'lora' not in name: |
| print("Unfrozen {} !!!".format(name)) |
| param.requires_grad_(True) |
|
|
| if self.use_visual_encoder_lora: |
| self._prepare_visual_encoder_for_lora( |
| visual_encoder_lora, use_activation_checkpointing) |
|
|
| if pretrained_pth is not None: |
| pretrained_state_dict = guess_load_checkpoint(pretrained_pth) |
| self.load_state_dict(pretrained_state_dict, strict=False) |
| print(f'Load pretrained weight from {pretrained_pth}') |
|
|
| self.visual_select_layer = visual_select_layer |
|
|
| self._is_init = True |
|
|
| self.require_omg_decoder = require_omg_decoder |
| if require_omg_decoder: |
| self.visual_encoder.init_new_decoder() |
| if keep_omg_decoder_frozen: |
| for name, param in self.visual_encoder.panoptic_head.transformer_decoder_llm.named_parameters(): |
| param.requires_grad_(False) |
| print("Frozen all the omg seg decoder !!!") |
|
|
| self.additional_cross_attn_layers = additional_cross_attn_layers |
| if self.additional_cross_attn_layers: |
| self.visual_encoder.init_cross_attn_layer() |
|
|
| if self.freeze_visual_projector: |
| for name, param in self.projector.named_parameters(): |
| param.requires_grad_(False) |
|
|
| self.add_seg_pretrain = add_seg_pretrain |
| self.init_prediction_config = False |
|
|
|
|
| def _add_special_tokens(self): |
| assert hasattr(self, "tokenizer") |
|
|
| segmentation_tokens = ['[SEG]'] |
| |
| phrase_tokens = ['<p>', '</p>'] |
| |
| region_tokens = ['<region>'] |
| point_tokens = ['<mark>'] |
| special_tokens = segmentation_tokens + phrase_tokens + region_tokens |
| self.tokenizer.add_tokens(special_tokens, special_tokens=True) |
|
|
| self.seg_token_idx = self.tokenizer("[SEG]", add_special_tokens=False).input_ids[0] |
| self.bop_token_idx = self.tokenizer("<p>", add_special_tokens=False).input_ids[0] |
| self.eop_token_idx = self.tokenizer("</p>", add_special_tokens=False).input_ids[0] |
| self.region_token_idx = self.tokenizer("<region>", add_special_tokens=False).input_ids[0] |
|
|
| self.llm.resize_token_embeddings(len(self.tokenizer)) |
|
|
| self.tokenizer.add_tokens(point_tokens, special_tokens=True) |
| self.mark_token_idx = self.tokenizer("<mark>", add_special_tokens=False).input_ids[0] |
| if self.use_activation_checkpointing or self.use_llm_lora or not self.freeze_llm: |
| self.llm.enable_input_require_grads() |
| self.added_special_token = True |
| print("[SEG]: {}, <p>: {}, </p>: {}, <region>: {}, <mark>: {}" \ |
| .format(self.seg_token_idx, self.bop_token_idx, |
| self.eop_token_idx, self.region_token_idx, self.mark_token_idx)) |
| print('****************************Add special tokens ********************************************') |
| return |
|
|
| def _parse_lora_config(self, lora_config): |
| if isinstance(lora_config, dict) or isinstance( |
| lora_config, Config) or isinstance(lora_config, ConfigDict): |
| lora_config = BUILDER.build(lora_config) |
| return lora_config |
|
|
| def _prepare_llm_for_lora(self, |
| lora_config, |
| use_activation_checkpointing=True): |
| lora_config = self._parse_lora_config(lora_config) |
| self.llm = prepare_model_for_kbit_training( |
| self.llm, use_activation_checkpointing) |
| if lora_config.target_modules is None: |
| modules = find_all_linear_names(self.llm) |
| lora_config.target_modules = modules |
| self.llm = get_peft_model(self.llm, lora_config) |
| for name, param in self.named_parameters(): |
| if 'tok_' in name or 'lm_head' in name: |
| print("Unfrozen {} !!!".format(name)) |
| param.requires_grad_(True) |
| if 'output.' in name and 'llm' in name and 'lora' not in name: |
| print("Unfrozen {} !!!".format(name)) |
| param.requires_grad_(True) |
|
|
| def _prepare_visual_encoder_for_lora(self, |
| lora_config, |
| use_activation_checkpointing=True): |
| lora_config = self._parse_lora_config(lora_config) |
| if lora_config.target_modules is None: |
| modules = find_all_linear_names(self.visual_encoder) |
| lora_config.target_modules = modules |
| self.visual_encoder = get_peft_model(self.visual_encoder, lora_config) |
|
|
| def gradient_checkpointing_enable(self): |
| self.activation_checkpointing_enable() |
|
|
| def activation_checkpointing_enable(self): |
| self.llm.gradient_checkpointing_enable() |
| if hasattr(self.visual_encoder, 'gradient_checkpointing_enable'): |
| self.visual_encoder.gradient_checkpointing_enable() |
| elif hasattr(self.visual_encoder, 'clip_model'): |
| if self.visual_encoder.clip_model is not None: |
| self.visual_encoder.clip_model.gradient_checkpointing_enable() |
| if hasattr(self.projector, 'gradient_checkpointing_enable'): |
| self.projector.gradient_checkpointing_enable() |
| if self.text2vision_projector and hasattr(self.projector_text2vision, 'gradient_checkpointing_enable'): |
| self.projector_text2vision.gradient_checkpointing_enable() |
|
|
| def gradient_checkpointing_disable(self): |
| self.activation_checkpointing_disable() |
|
|
| def activation_checkpointing_disable(self): |
| self.llm.gradient_checkpointing_disable() |
| if hasattr(self.visual_encoder, 'gradient_checkpointing_disable'): |
| self.visual_encoder.gradient_checkpointing_disable() |
| if hasattr(self.projector, 'gradient_checkpointing_disable'): |
| self.projector.gradient_checkpointing_disable() |
| if self.text2vision_projector and hasattr(self.projector_text2vision, 'gradient_checkpointing_disable'): |
| self.projector_text2vision.gradient_checkpointing_disable() |
|
|
| def init_weights(self): |
| pass |
|
|
| def state_dict(self, *args, **kwargs): |
| state_dict = super().state_dict(*args, **kwargs) |
|
|
| to_return = OrderedDict() |
|
|
| |
| to_return.update( |
| {k: v for k, v in state_dict.items() if 'tok_' in k or 'embed_tokens' in k} |
| ) |
| |
| to_return.update( |
| {k: v for k, v in state_dict.items() if ('output.' in k or 'lm_head' in k) and 'llm' in k and 'lora' not in k} |
| ) |
|
|
| |
| if self.use_visual_encoder_lora: |
| to_return.update( |
| get_peft_model_state_dict( |
| self.visual_encoder, state_dict=state_dict)) |
| elif not self.freeze_visual_encoder: |
| to_return.update({ |
| k: v |
| for k, v in state_dict.items() if 'visual_encoder.' in k |
| }) |
| |
| if self.use_llm_lora: |
| to_return.update( |
| get_peft_model_state_dict(self.llm, state_dict=state_dict)) |
| elif not self.freeze_llm: |
| to_return.update( |
| {k: v |
| for k, v in state_dict.items() if 'llm.' in k}) |
| |
| to_return.update( |
| {k: v |
| for k, v in state_dict.items() if 'projector.' in k}) |
| |
| to_return.update( |
| {k: v |
| for k, v in state_dict.items() if 'projector_text2vision' in k}) |
|
|
| |
| if self.freeze_visual_encoder: |
| to_return.update( |
| {k: v |
| for k, v in state_dict.items() if 'visual_encoder.adapter_proj' in k}) |
|
|
| |
| if hasattr(self.visual_encoder, 'clip_model'): |
| if self.visual_encoder.clip_lora is not None: |
| to_return.update( |
| get_peft_model_state_dict(self.visual_encoder.clip_model, |
| state_dict=state_dict)) |
| |
| if self.require_omg_decoder: |
| to_return.update( |
| {k: v |
| for k, v in state_dict.items() |
| if 'visual_encoder.panoptic_head.transformer_decoder_llm' in k or |
| 'visual_encoder.panoptic_head.mask_embed_llm' in k or |
| 'visual_encoder.panoptic_head.pixel_decoder_llm' in k or |
| 'visual_encoder.panoptic_head.additional_cross_attn_layers' in k or |
| 'visual_encoder.panoptic_head.additional_ffn' in k or |
| 'visual_encoder.downsample_layer' in k |
| }) |
|
|
| return to_return |
|
|
| def _build_from_cfg_or_module(self, cfg_or_mod): |
| if isinstance(cfg_or_mod, nn.Module): |
| return cfg_or_mod |
| elif isinstance(cfg_or_mod, dict): |
| traverse_dict(cfg_or_mod) |
| return BUILDER.build(cfg_or_mod) |
| else: |
| raise NotImplementedError |
|
|
| def forward(self, data, data_samples=None, mode='loss'): |
| if 'pixel_values' in data: |
| if 'masks' in data: |
| masks = data['masks'] |
| del data['masks'] |
| else: |
| masks = None |
| if 'regions' in data: |
| regions = data['regions'] |
| del data['regions'] |
| else: |
| regions = None |
| if 'points' in data: |
| points = data['points'] |
| del data['points'] |
| else: |
| points = None |
|
|
| visual_outputs = self.visual_encoder( |
| data['pixel_values'].to(self.visual_encoder.dtype), |
| output_hidden_states=True) |
|
|
| if self.add_seg_pretrain: |
| pred_obj_query, gt_obj_query = prepare_seg_pretrain_data( |
| visual_outputs, |
| [self.projector.model.query_proj, self.projector.model.model], |
| self.projector_text2vision.model |
| ) |
|
|
| if isinstance(visual_outputs, list) or isinstance(visual_outputs, tuple)\ |
| or isinstance(visual_outputs, torch.Tensor): |
| pixel_values = self.projector(visual_outputs) |
| else: |
| pixel_values = self.projector( |
| visual_outputs.hidden_states[self.visual_select_layer][:, 1:]) |
|
|
| if regions is not None: |
| region_embeddings, region_success = self.get_region_embeddings( |
| regions, data['input_ids'], |
| ) |
| none_region_embeddings = region_embeddings |
| del regions |
| else: |
| region_success = True |
| region_embeddings = [] |
| none_region_embeddings = self.get_none_region_embeddings( |
| input_ids=data['input_ids'], |
| ) |
|
|
| if points is not None: |
| points_mark_embedding, mark_success = self.get_points_embeddings( |
| points, data['input_ids'], |
| width=data['pixel_values'].shape[-1], |
| height=data['pixel_values'].shape[-2], |
| ) |
| none_points_mark_embedding = points_mark_embedding |
| else: |
| none_points_mark_embedding = self.get_none_points_embeddings( |
| data['input_ids'], |
| width=data['pixel_values'].shape[-1], |
| height=data['pixel_values'].shape[-2], |
| ) |
| points_mark_embedding = [] |
| mark_success = True |
|
|
| data['pixel_values'] = pixel_values |
| data = prepare_inputs_labels_for_multimodal_with_visual_prompts( |
| llm=self.llm, region_id=self.region_token_idx, |
| regions_feats=region_embeddings, |
| mark_id=self.mark_token_idx, |
| mark_feats=points_mark_embedding, |
| **data) |
| else: |
| masks = None |
|
|
| _zero = none_points_mark_embedding.sum() * 0.0 + none_region_embeddings.sum() * 0.0 |
|
|
| if mode == 'loss': |
| if self.add_seg_pretrain: |
| return self.compute_loss(data, data_samples, masks=masks, region_success=region_success, |
| pred_gt_obj_query=(pred_obj_query, gt_obj_query), |
| mark_success=mark_success, _zero=_zero) |
| else: |
| return self.compute_loss(data, data_samples, masks=masks, |
| pred_gt_obj_query=None, |
| region_success=region_success, |
| mark_success=mark_success, |
| _zero=_zero) |
| elif mode == 'predict': |
| return self.predict(data, data_samples) |
| elif mode == 'tensor': |
| return self._forward(data, data_samples) |
| else: |
| raise NotImplementedError |
|
|
| def _forward(self, data, data_samples=None): |
|
|
| outputs = self.llm(**data) |
|
|
| return outputs |
|
|
| def predict(self, data, data_samples=None): |
| outputs = self.llm(**data) |
| logits_dict = [{'logits': logits} for logits in outputs.logits] |
| return logits_dict |
|
|
| def compute_loss(self, data, data_samples=None, masks=None, pred_gt_obj_query=None, |
| region_success=True, mark_success=True, _zero=0): |
| if 'original_labels' in data.keys(): |
| input_ids = data['original_labels'] |
| del data['original_labels'] |
| else: |
| input_ids = data['labels'] |
| outputs = self.llm(**data, output_hidden_states=True) |
|
|
| loss_dice, loss_mask = self.compute_seg_loss( |
| input_ids, outputs.hidden_states[-1], masks) |
|
|
| if pred_gt_obj_query is not None: |
| pred_obj_query, gt_obj_query = pred_gt_obj_query |
| proj_loss = torch.mean((pred_obj_query - gt_obj_query) ** 2) * 10 |
| else: |
| proj_loss = 0 |
|
|
| if not region_success: |
| loss = outputs.loss * 0 |
| else: |
| loss = outputs.loss |
|
|
| if not mark_success: |
| loss = outputs.loss * 0 |
|
|
| |
| loss = loss + _zero |
|
|
| loss_dict = {'loss': loss, 'loss_dice': outputs.loss* 0 + loss_dice * 0.1, |
| 'loss_mask': outputs.loss * 0 + loss_mask * 0.4, |
| 'loss_proj': outputs.loss * 0 + proj_loss} |
| return loss_dict |
|
|
| def __getattr__(self, name: str): |
| try: |
| return super().__getattr__(name) |
| except AttributeError: |
| return getattr(self.llm, name) |
|
|
| def get_region_embeddings(self, regions, input_ids): |
| success = True |
| if regions is None or len(regions) == 0: |
| return [], success |
| else: |
| region_token_mask = input_ids == self.region_token_idx |
| batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to( |
| input_ids.device) |
| batch_idxs = batch_idxs[region_token_mask] |
| if len(regions) != len(batch_idxs): |
| |
| success = False |
| if len(regions) > len(batch_idxs): |
| regions = regions[:len(batch_idxs)] |
| else: |
| n_pad = len(batch_idxs) - len(regions) |
| pad_region = regions[:1].repeat(n_pad, 1, 1) |
| regions = torch.cat([pad_region, regions]) |
|
|
| regions_embeddings = self.visual_encoder.forward_region_sam( |
| regions, batch_idxs |
| )[:, 0] |
|
|
| regions_embeddings = self.projector.model.forward_visual_prompts_embeddings( |
| regions_embeddings, batch_idxs) |
| return regions_embeddings, success |
|
|
| def get_none_region_embeddings(self, input_ids): |
| batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to( |
| input_ids.device) |
| batch_idxs = batch_idxs[0, :1] |
|
|
| regions = torch.ones((1, 50, 50)).to(torch.float32).to(input_ids.device) |
|
|
| regions_embeddings = self.visual_encoder.forward_region_sam( |
| regions, batch_idxs |
| )[:, 0] |
|
|
| regions_embeddings = self.projector.model.forward_visual_prompts_embeddings( |
| regions_embeddings, batch_idxs) |
| return regions_embeddings |
|
|
| def get_points_embeddings(self, points, input_ids, width, height): |
| success = True |
| if points is None or len(points) == 0: |
| return [] |
|
|
| mark_token_mask = input_ids == self.mark_token_idx |
| batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to( |
| input_ids.device) |
| batch_idxs = batch_idxs[mark_token_mask] |
|
|
| if len(points) != len(batch_idxs): |
| |
| success = False |
| if len(points) > len(batch_idxs): |
| points = points[:len(batch_idxs)] |
| else: |
| n_pad = len(batch_idxs) - len(points) |
| pad_region = points[:1].repeat(n_pad, 1, 1) |
| points = torch.cat([pad_region, points]) |
|
|
| marks_embeddings = self.visual_encoder.forward_point_sam( |
| points, batch_idxs, width=width, height=height |
| )[:, 0] |
|
|
| marks_embeddings = self.projector.model.forward_visual_prompts_embeddings( |
| marks_embeddings, batch_idxs) |
| return marks_embeddings, success |
|
|
| def get_none_points_embeddings(self, input_ids, width, height): |
| batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to( |
| input_ids.device) |
| batch_idxs = batch_idxs[0, :1] |
|
|
| marks_embeddings = self.visual_encoder.forward_point_sam( |
| torch.zeros((1, 2)).to(input_ids), batch_idxs, width=width, height=height |
| )[:, 0] |
|
|
| marks_embeddings = self.projector.model.forward_visual_prompts_embeddings( |
| marks_embeddings, batch_idxs) |
| return marks_embeddings |
|
|
| def get_visual_prompts_projector_zero(self): |
| return self.projector.model.visual_prompt_zero |
|
|
| def compute_seg_loss(self, input_ids, hidden_states, gt_masks): |
| if not self.text2vision_projector or self.add_seg_pretrain: |
| return 0.0, 0.0 |
| success = True |
| if gt_masks is None or len(gt_masks) == 0: |
| batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to( |
| input_ids.device) |
| batch_idxs = batch_idxs[0, :1] |
| gt_masks = [None] |
| hidden_states = hidden_states[0, :1] |
| hidden_states = self.projector_text2vision(hidden_states) |
|
|
| pred_masks_list = self.visual_encoder.forward_llm_seg(hidden_states, batch_idxs) |
| dice_loss, mask_loss = self.visual_encoder.loss_llm_seg(pred_masks_list, gt_masks) |
|
|
| return dice_loss * 0.0, mask_loss * 0.0 |
|
|
|
|
| seg_tokens_mask = input_ids == self.seg_token_idx |
| batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to(seg_tokens_mask.device) |
|
|
| ori_hidden_states = hidden_states |
| hidden_states = hidden_states[seg_tokens_mask] |
| batch_idxs = batch_idxs[seg_tokens_mask] |
|
|
| if len(hidden_states) != len(gt_masks) or len(hidden_states) == 0: |
| |
| print("Drop the batch because the number of [SEG] and masks not equal !!!") |
| hidden_states = ori_hidden_states |
| batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to( |
| input_ids.device) |
| batch_idxs = batch_idxs[0, :1] |
| gt_masks = [None] |
| hidden_states = hidden_states[0, :1] |
| hidden_states = self.projector_text2vision(hidden_states) |
|
|
| pred_masks_list = self.visual_encoder.forward_llm_seg(hidden_states, batch_idxs) |
| dice_loss, mask_loss = self.visual_encoder.loss_llm_seg(pred_masks_list, gt_masks) |
|
|
| return dice_loss * 0.0, mask_loss * 0.0 |
|
|
| assert len(hidden_states) == len(gt_masks), "expect [seg] number equal to mask number, but get {} [seg] and {} masks".format(len(hidden_states), len(gt_masks)) |
| hidden_states = self.projector_text2vision(hidden_states) |
|
|
| pred_masks_list = self.visual_encoder.forward_llm_seg(hidden_states, batch_idxs) |
| dice_loss, mask_loss = self.visual_encoder.loss_llm_seg(pred_masks_list, gt_masks) |
|
|
| if not success: |
| return dice_loss * 0.0, mask_loss * 0.0 |
|
|
| return dice_loss, mask_loss |
|
|
| def preparing_for_generation(self, metainfo, **kwargs): |
| |
| assert hasattr(self, 'tokenizer'), "The Model does not have the tokenizer!!!" |
| self.bot_name = 'BOT' |
| if 'template' in metainfo.keys(): |
| template = metainfo['template'] |
| else: |
| template = PROMPT_TEMPLATE['internlm2_chat'] |
| self.template = template |
| stop_words = [] |
| stop_words += template.get('STOP_WORDS', []) |
| stop_criteria = get_stop_criteria( |
| tokenizer=self.tokenizer, stop_words=stop_words) |
| self.stop_criteria = stop_criteria |
|
|
| default_generation_kwargs = dict( |
| max_new_tokens=2048, |
| do_sample=False, |
| eos_token_id=self.tokenizer.eos_token_id, |
| pad_token_id=( |
| self.tokenizer.pad_token_id |
| if self.tokenizer.pad_token_id is not None |
| else self.tokenizer.eos_token_id |
| ), |
| ) |
| default_generation_kwargs.update(metainfo.get('generation_kwargs', {})) |
| self.gen_config = GenerationConfig(**default_generation_kwargs) |
| self.init_prediction_config = True |
|
|
| self.llm.to(self.visual_encoder.dtype) |
| self.visual_encoder.to(self.visual_encoder.dtype) |
| self.projector.to(self.visual_encoder.dtype) |
| self.projector_text2vision.to(self.visual_encoder.dtype) |
| return |
|
|
| def predict_forward( |
| self, pixel_values, text_prompts, |
| ori_image_size=None, |
| box_prompts=None, points_prompts=None, mask_prompts=None, **kwargs): |
| |
| |
| assert self.init_prediction_config, "Please set prediction configs using self.preparing_for_generation()" |
|
|
| ret_predictions = [] |
| ret_masks = [] |
|
|
| image = pixel_values.cuda().unsqueeze(0).to(self.visual_encoder.dtype) |
| visual_outputs = self.visual_encoder(image, output_hidden_states=True) |
| if isinstance(visual_outputs, list) or isinstance(visual_outputs, tuple) \ |
| or isinstance(visual_outputs, torch.Tensor): |
| pixel_values = self.projector(visual_outputs) |
| else: |
| pixel_values = self.projector( |
| visual_outputs.hidden_states[self.visual_select_layer][:, 1:]) |
|
|
| if isinstance(text_prompts, str): |
| text_prompts = [text_prompts] |
| for text_prompt in text_prompts: |
| |
| input_text = '' |
| input_text += self.template['INSTRUCTION'].format( |
| input=text_prompt, round=1, bot_name=self.bot_name) |
|
|
| chunk_encode = [] |
| for idx, chunk in enumerate(input_text.split(DEFAULT_IMAGE_TOKEN)): |
| if idx == 0: |
| cur_encode = self.tokenizer.encode(chunk) |
| else: |
| cur_encode = self.tokenizer.encode(chunk, add_special_tokens=False) |
| chunk_encode.append(cur_encode) |
| assert len(chunk_encode) == 2 |
| ids = [] |
| for idx, cur_chunk_encode in enumerate(chunk_encode): |
| ids.extend(cur_chunk_encode) |
| if idx != len(chunk_encode) - 1: |
| ids.append(IMAGE_TOKEN_INDEX) |
| ids = torch.tensor(ids).cuda().unsqueeze(0) |
|
|
| mm_inputs = prepare_inputs_labels_for_multimodal_with_visual_prompts( |
| llm=self.llm, input_ids=ids, pixel_values=pixel_values, |
| region_id=self.region_token_idx, |
| regions_feats=[], |
| mark_id=self.mark_token_idx, |
| mark_feats=[], |
| ) |
|
|
| generate_output = self.llm.generate( |
| **mm_inputs, |
| generation_config=self.gen_config, |
| streamer=None, |
| bos_token_id=self.tokenizer.bos_token_id, |
| stopping_criteria=self.stop_criteria, |
| output_hidden_states=True, |
| return_dict_in_generate=True |
| ) |
| predict = self.tokenizer.decode( |
| generate_output.sequences[0], skip_special_tokens=True).strip() |
| ret_predictions.append(predict) |
|
|
| if ori_image_size is not None and 'masks' in kwargs.keys(): |
| hidden_states = generate_output.hidden_states |
| last_hidden_states = [item[-1][0] for item in hidden_states] |
| last_hidden_states = torch.cat(last_hidden_states, dim=0) |
| seg_hidden_states = get_seg_hidden_states( |
| last_hidden_states, generate_output.sequences[0][:-1], |
| seg_id=self.seg_token_idx |
| ) |
|
|
| if len(seg_hidden_states) == 0: |
| print("Warning, no [SEG] tokens !!!") |
| ret_masks.append(None) |
| continue |
| elif len(seg_hidden_states) > 1: |
| print("Warning, {} [SEG] tokens !!!".format(len(seg_hidden_states))) |
| seg_hidden_states = seg_hidden_states[:1] |
| seg_hidden_states = self.projector_text2vision(seg_hidden_states) |
| batch_idxs = torch.zeros((seg_hidden_states.shape[0],), |
| dtype=torch.int64).to(seg_hidden_states.device) |
| pred_masks_list = self.visual_encoder.forward_llm_seg(seg_hidden_states, batch_idxs) |
| pred_masks = pred_masks_list[-1] |
| w, h = copy.deepcopy(ori_image_size) |
| masks = F.interpolate(pred_masks, size=(max(w, h), max(w, h)), |
| mode='bilinear', align_corners=False) |
| masks = masks[:, 0] |
| |
| if w == h: |
| pass |
| elif w > h: |
| n_pad = w - h |
| n_pad_1 = n_pad // 2 |
| n_pad_2 = n_pad - n_pad_1 |
| masks = masks[:, n_pad_1: w - n_pad_2] |
| else: |
| n_pad = h - w |
| n_pad_1 = n_pad // 2 |
| n_pad_2 = n_pad - n_pad_1 |
| masks = masks[:, :, n_pad_1: h - n_pad_2] |
| |
| masks = masks.sigmoid() > 0.5 |
| masks = masks.int() |
| ret_masks.append(masks) |
|
|
| if len(ret_predictions) == 1: |
| ret_predictions = ret_predictions[0] |
| if len(ret_masks) == 0: |
| return {'prediction': ret_predictions} |
|
|
| _ret_masks = [] |
| for i, ret_mask in enumerate(ret_masks): |
| if ret_mask is None: |
| _ret_masks.append(None) |
| else: |
| ret_mask = ret_mask.cpu().numpy() |
| _ret_masks.append(mask_to_rle(ret_mask)) |
|
|
| if 'masks' not in kwargs.keys(): |
| gt_masks = None |
| else: |
| gt_masks = mask_to_rle(kwargs['masks'].cpu().numpy()) |
|
|
| return { |
| 'prediction': ret_predictions, 'prediction_masks': _ret_masks, |
| 'gt_masks': gt_masks, |
| } |
|
|
| def prepare_seg_pretrain_data(visual_outputs, |
| query_in_proj, query_out_proj): |
| clip_feature, query_feat, attention_mask = visual_outputs |
| |
| |
| |
| bs, q, _ = query_feat.shape |
| pred_query_embed = [] |
| gt_query_embed = [] |
| for i in range(bs): |
| valid = attention_mask[i].sum(-1) > 0 |
| valid_query_feat = query_feat[i][valid] |
| gt_query_embed.append(valid_query_feat) |
|
|
| if isinstance(query_in_proj, list): |
| llm_query = valid_query_feat |
| for proj in query_in_proj: |
| llm_query = proj(llm_query) |
| else: |
| llm_query = query_in_proj(valid_query_feat) |
|
|
| pred_query_embed.append(query_out_proj(llm_query)) |
|
|
| pred_query_embed = torch.cat(pred_query_embed, dim=0) |
| gt_query_embed = torch.cat(gt_query_embed, dim=0) |
| return pred_query_embed, gt_query_embed |
|
|
| def get_seg_hidden_states(hidden_states, output_ids, seg_id): |
| seg_mask = output_ids == seg_id |
| n_out = len(seg_mask) |
| return hidden_states[-n_out:][seg_mask] |
|
|
|
|
| def mask_to_rle(mask): |
| rle = [] |
| for m in mask: |
| rle.append(_mask.encode(np.asfortranarray(m.astype(np.uint8)))) |
| return rle |
|
|