| from typing import List, Optional, Tuple, Union |
|
|
| import os |
| import torch |
| import numpy as np |
| import torch.nn as nn |
| import matplotlib.pyplot as plt |
| from PIL import Image |
| import torch.nn.functional as F |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from model.IXC.modeling_internlm_xcomposer2 import InternLMXComposer2ForCausalLM |
| from model.IXC.modeling_internlm2 import InternLM2Model |
| from model.sam2.build_sam import build_sam2_hf |
| from model.sam2.utils.transforms import SAM2Transforms |
| from transformers import TextStreamer |
| try: |
| from transformers.generation.streamers import BaseStreamer |
| except: |
| BaseStreamer = None |
|
|
|
|
| def dice_loss( |
| inputs: torch.Tensor, |
| targets: torch.Tensor, |
| num_masks: float, |
| scale=1000, |
| eps=1e-6, |
| ): |
| """ |
| Compute the DICE loss, similar to generalized IOU for masks |
| Args: |
| inputs: A float tensor of arbitrary shape. |
| The predictions for each example. |
| targets: A float tensor with the same shape as inputs. Stores the binary |
| classification label for each element in inputs |
| (0 for the negative class and 1 for the positive class). |
| """ |
| inputs = inputs.sigmoid() |
| inputs = inputs.flatten(1, 2) |
| targets = targets.flatten(1, 2) |
| numerator = 2 * (inputs / scale * targets).sum(-1) |
| denominator = (inputs / scale).sum(-1) + (targets / scale).sum(-1) |
| loss = 1 - (numerator + eps) / (denominator + eps) |
| loss = loss.sum() / (num_masks + 1e-8) |
| return loss |
|
|
|
|
| def sigmoid_ce_loss( |
| inputs: torch.Tensor, |
| targets: torch.Tensor, |
| num_masks: float, |
| ): |
| """ |
| Args: |
| inputs: A float tensor of arbitrary shape. |
| The predictions for each example. |
| targets: A float tensor with the same shape as inputs. Stores the binary |
| classification label for each element in inputs |
| (0 for the negative class and 1 for the positive class). |
| Returns: |
| Loss tensor |
| """ |
| loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") |
| loss = loss.flatten(1, 2).mean(1).sum() / (num_masks + 1e-8) |
| return loss |
|
|
|
|
| class GeoPixelMetaModel: |
| def __init__( |
| self, |
| config, |
| **kwargs, |
| ): |
| super(GeoPixelMetaModel, self).__init__(config) |
| self.config = config |
| self.config.train_mask_decoder = getattr(self.config, "train_mask_decoder", kwargs.get("train_mask_decoder", False)) |
| self.config.out_dim = getattr(self.config, "out_dim", kwargs.get("out_dim", 256)) |
| self.vision_pretrained = kwargs.get("vision_pretrained", None) |
| self.initialize_geopixel_modules(self.config) |
|
|
| def initialize_geopixel_modules(self, config): |
| |
| self.visual_model = build_sam2_hf(self.vision_pretrained) |
|
|
| self._transform = SAM2Transforms( |
| resolution=self.visual_model.image_size, |
| mask_threshold=0.0, |
| max_hole_area=0.0, |
| max_sprinkle_area=0.0, |
| ) |
| |
| self._bb_feat_sizes = [ |
| (256, 256), |
| (128, 128), |
| (64, 64), |
| ] |
| |
| for param in self.visual_model.parameters(): |
| param.requires_grad = False |
|
|
| if config.train_mask_decoder: |
| self.visual_model.sam_mask_decoder.train() |
| for param in self.visual_model.sam_mask_decoder.parameters(): |
| param.requires_grad = True |
|
|
| |
| in_dim = config.hidden_size |
| out_dim = config.out_dim |
| text_projection_layers = [ |
| nn.Linear(in_dim, in_dim), |
| nn.ReLU(inplace=True), |
| nn.Linear(in_dim, out_dim), |
| nn.Dropout(0.0), |
| ] |
| self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_projection_layers)]) |
| self.text_hidden_fcs.train() |
| for param in self.text_hidden_fcs.parameters(): |
| param.requires_grad = True |
|
|
|
|
| class GeoPixelModel(GeoPixelMetaModel, InternLM2Model): |
| def __init__( |
| self, |
| config, |
| **kwargs, |
| ): |
| super(GeoPixelModel, self).__init__(config, **kwargs) |
| self.config.use_cache = False |
|
|
|
|
| class GeoPixelForCausalLM(InternLMXComposer2ForCausalLM): |
| def __init__(self,config,**kwargs,): |
| |
| self.ce_loss_weight = kwargs.pop("ce_loss_weight", None) |
| self.dice_loss_weight = kwargs.pop("dice_loss_weight", None) |
| self.bce_loss_weight = kwargs.pop("bce_loss_weight", None) |
| self.seg_token_idx = kwargs.pop("seg_token_idx") |
|
|
| super().__init__(config) |
| self.model = GeoPixelModel(config, **kwargs) |
| self.vocab_size = config.vocab_size |
| self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| self.post_init() |
|
|
| def encode_g_img(self, image): |
| """ |
| Calculates the image embeddings for the provided image |
| Arguments: |
| image (np.ndarray or str) |
| """ |
| if image is None: |
| return None |
| if isinstance(image, str): |
| _, ext = os.path.splitext(image) |
| if ext.lower() in {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp','.tif'}: |
| image = Image.open(image) |
| w, h = image.size |
| _orig_hw = [(h, w)] |
| else: |
| print ('Unknow input format', image) |
| return None |
| else: |
| assert isinstance(image, torch.Tensor) |
| _orig_hw = [image.shape[:2]] |
| image = self.model._transform(image) |
| image = image[None, ...].to(self.device) |
| assert ( len(image.shape) == 4 and image.shape[1] == 3), f"image must be of size 1x3xHxW, got {image.shape}" |
| features = self.get_visual_embs(image) |
| return features,_orig_hw |
|
|
| def get_visual_embs(self, img_batch: torch.FloatTensor): |
| with torch.no_grad(): |
| torch.cuda.empty_cache() |
| img_batch = img_batch.to(self.device) |
| batch_size = img_batch.shape[0] |
| assert ( |
| len(img_batch.shape) == 4 and img_batch.shape[1] == 3 |
| ), f"grounding_img_batch must be of size Bx3xHxW, got {img_batch.shape}" |
| backbone_out = self.model.visual_model.forward_image(img_batch) |
| _, vision_feats, _, _ = self.model.visual_model._prepare_backbone_features(backbone_out) |
| if self.model.visual_model.directly_add_no_mem_embed: |
| vision_feats[-1] = vision_feats[-1] + self.model.visual_model.no_mem_embed |
| feats = [ |
| feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) |
| for feat, feat_size in zip(vision_feats[::-1], self.model._bb_feat_sizes[::-1]) |
| ][::-1] |
| features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} |
| return features |
| |
| def forward(self, **kwargs): |
| return super().forward(**kwargs) if "past_key_values" in kwargs else self.model_forward(**kwargs) |
| |
| def model_forward( |
| self, |
| inference: bool = False, |
| **kwargs, |
| ): |
| samples = kwargs.get('samples', None) |
| if samples and samples['data_type'][0] == 'grounding': |
| kwargs['output_hidden_states'] = True |
| kwargs['use_cache'] = False |
|
|
| torch.cuda.empty_cache() |
| outputs = super().forward(**kwargs) |
|
|
| if inference: |
| assert len(samples['text_input']) == 1 and len(samples['image'][0]) == 1 |
| output_hidden_states = [outputs.hidden_states] |
| outputs = None |
| else: |
| output_hidden_states = outputs.hidden_states |
|
|
| hidden_states = [] |
| assert len(self.model.text_hidden_fcs) == 1 |
| hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states[-1])) |
| last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1) |
|
|
| seg_token_mask = outputs.seg_token_mask |
| pred_embeddings = [states[masks] for states, masks in zip(last_hidden_state, seg_token_mask)] |
| image_g_batch = torch.cat(samples['image_g'][0],dim = 0) |
| image_g_features = self.get_visual_embs(image_g_batch) |
| ori_hw = samples['ori_hw'][0] |
| all_pred_masks = [] |
| for i in range(len(pred_embeddings)): |
| if (pred_embeddings[i].numel()== 0): |
| pred_masks.append([]) |
| continue |
| (sparse_embeddings, dense_embeddings,) = self.model.visual_model.sam_prompt_encoder( |
| points=None, |
| boxes=None, |
| masks=None, |
| text_embeds=pred_embeddings[i].unsqueeze(1), |
| ) |
| batch_mode = (pred_embeddings[i].shape[0]>1) |
| high_res_features = [ |
| feat_level[i].unsqueeze(0) |
| for feat_level in image_g_features["high_res_feats"] |
| ] |
| sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype) |
| image_g_embeds = image_g_features['image_embed'][i].unsqueeze(0).to(torch.bfloat16) |
| low_res_masks, _, _ , _ = self.model.visual_model.sam_mask_decoder( |
| image_embeddings=image_g_embeds, |
| image_pe=self.model.visual_model.sam_prompt_encoder.get_dense_pe(), |
| sparse_prompt_embeddings=sparse_embeddings, |
| dense_prompt_embeddings=dense_embeddings, |
| repeat_image=batch_mode, |
| multimask_output=False, |
| high_res_features=high_res_features, |
| ) |
| pred_masks = self.model._transform.postprocess_masks( |
| low_res_masks, |
| ori_hw[i], |
| ) |
| all_pred_masks.append(pred_masks[:, 0]) |
| |
|
|
| model_output = outputs |
| gt_masks = samples['masks'][0] |
| pred_masks = all_pred_masks |
|
|
| if inference: |
| return { |
| "pred_masks": pred_masks, |
| "gt_masks": gt_masks, |
| } |
|
|
| ce_loss = model_output.loss |
| ce_loss = ce_loss * self.ce_loss_weight |
| mask_bce_loss = 0 |
| mask_dice_loss = 0 |
| num_masks = 0 |
|
|
| for batch_idx in range(len(pred_masks)): |
| cur_gt_masks = torch.stack( |
| [ |
| torch.from_numpy(gt_mask).to(dtype=pred_masks[batch_idx].dtype, device=pred_masks[batch_idx].device) |
| for gt_mask in gt_masks[batch_idx] |
| ], |
| dim=0 |
| ) |
| cur_pred_masks = pred_masks[batch_idx] |
| assert ( |
| cur_gt_masks.shape[0] == cur_pred_masks.shape[0] |
| ), "gt_masks.shape: {}, pred_masks.shape: {}".format( |
| cur_gt_masks.shape, cur_pred_masks.shape |
| ) |
| mask_bce_loss += ( |
| sigmoid_ce_loss(cur_pred_masks, cur_gt_masks, num_masks=cur_gt_masks.shape[0]) |
| * cur_gt_masks.shape[0] |
| ) |
| mask_dice_loss += ( |
| dice_loss(cur_pred_masks, cur_gt_masks, num_masks=cur_gt_masks.shape[0]) |
| * cur_gt_masks.shape[0] |
| ) |
| num_masks += cur_gt_masks.shape[0] |
|
|
| mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (num_masks + 1e-8) |
| mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8) |
| mask_loss = mask_bce_loss + mask_dice_loss |
|
|
| loss = ce_loss + mask_loss |
| outputs = CausalLMOutputWithPast( |
| loss=loss, |
| logits=model_output.logits, |
| past_key_values=model_output.past_key_values, |
| hidden_states=output_hidden_states, |
| attentions=model_output.attentions, |
| ) |
| outputs.ce_loss = ce_loss |
| outputs.mask_bce_loss = mask_bce_loss |
| outputs.mask_dice_loss = mask_dice_loss |
| outputs.mask_loss = mask_loss |
| else: |
| outputs = super().forward(**kwargs) |
| return outputs |
|
|
| def evaluate( |
| self, |
| tokenizer, |
| query: str, |
| images: List[Tuple[str, str]] = [], |
| hd_num: int = 9, |
| history: List[Tuple[str, str]] = [], |
| max_new_tokens: int = 1024, |
| stream: bool = False, |
| **kwargs, |
| ): |
| with torch.no_grad(): |
| inputs, im_mask, _ = self.interleav_wrap_chat(query, images, history=history, hd_num=hd_num) |
| inputs = { |
| k: v.to(self.device) |
| for k, v in inputs.items() if torch.is_tensor(v) |
| } |
| eos_token_id = [ |
| tokenizer.eos_token_id, |
| |
| ] |
| all_pred_masks = [] |
| |
| if stream: |
| streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
| else: |
| streamer = None |
|
|
| outputs = self.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| im_mask=im_mask, |
| input_ids = None, |
| streamer= streamer, |
| num_beams=1, |
| do_sample=False, |
| temperature=1.0, |
| top_p= 1.0, |
| top_k = 0, |
| eos_token_id=eos_token_id, |
| repetition_penalty=1.0, |
| infer_mode = 'base', |
| output_hidden_states=True, |
| return_dict_in_generate=True, |
| **kwargs, |
| ) |
| output_ids = outputs['sequences'] |
| response = tokenizer.decode(output_ids[0].cpu().tolist(), skip_special_tokens=True) |
| response = response.replace("[UNUSED_TOKEN_145]","") |
| history = history + [(query, response)] |
| if len(images)==1 and isinstance(images[0], str): |
| output_hidden_states = outputs.hidden_states[-1] |
| seg_token_mask = output_ids[:, 1:-1] == self.seg_token_idx |
| inputs_embeds_len = inputs['inputs_embeds'].size(1) |
| seg_token_mask = torch.cat( |
| [ |
| torch.zeros((seg_token_mask.shape[0], inputs_embeds_len)).bool().cuda(), |
| seg_token_mask, |
| ], |
| dim=1, |
| ) |
| hidden_states = [] |
| assert len(self.model.text_hidden_fcs) == 1 |
| hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states)) |
| last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1) |
| pred_embeddings = [states[masks] for states, masks in zip(last_hidden_state, seg_token_mask)] |
| image_g_features, ori_hw = self.encode_g_img(images[0]) |
|
|
| for i in range(len(pred_embeddings)): |
| if (pred_embeddings[i].numel()== 0): |
| all_pred_masks.append([]) |
| continue |
| (sparse_embeddings,dense_embeddings,) = self.model.visual_model.sam_prompt_encoder( |
| points=None, |
| boxes=None, |
| masks=None, |
| text_embeds=pred_embeddings[i].unsqueeze(1), |
| ) |
| batch_mode = (pred_embeddings[i].shape[0]>1) |
| high_res_features = [ |
| feat_level[i].unsqueeze(0) |
| for feat_level in image_g_features["high_res_feats"] |
| ] |
| sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype) |
| image_g_embeds = image_g_features['image_embed'][i].unsqueeze(0).to(torch.bfloat16) |
|
|
| low_res_masks, _, _ , _ = self.model.visual_model.sam_mask_decoder( |
| image_embeddings=image_g_embeds, |
| image_pe=self.model.visual_model.sam_prompt_encoder.get_dense_pe(), |
| sparse_prompt_embeddings=sparse_embeddings, |
| dense_prompt_embeddings=dense_embeddings, |
| repeat_image=batch_mode, |
| multimask_output=False, |
| high_res_features=high_res_features, |
| ) |
| pred_masks = self.model._transform.postprocess_masks( |
| low_res_masks, |
| ori_hw[i], |
| ) |
| all_pred_masks.append(pred_masks[:, 0]) |
|
|
| return response, all_pred_masks |