| from typing import List |
| import os |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel, AutoConfig, AutoModelForCausalLM |
| from .segment_anything_2.sam2.build_sam import build_sam2, build_sam2_video_predictor |
| from .unilm.beit3.modeling_utils import BEiT3Wrapper, _get_base_config, _get_large_config |
| from .configuration_evf import EvfConfig |
| from .segment_anything_2.sam2.utils.misc import load_video_frames |
| from collections import OrderedDict |
|
|
|
|
| def dice_loss( |
| inputs: torch.Tensor, |
| targets: torch.Tensor, |
| num_masks: float, |
| scale=1000, |
| eps=1e-6, |
| ): |
| """ |
| Compute the DICE loss, similar to generalized IOU for masks |
| Args: |
| inputs: A float tensor of arbitrary shape. |
| The predictions for each example. |
| targets: A float tensor with the same shape as inputs. Stores the binary |
| classification label for each element in inputs |
| (0 for the negative class and 1 for the positive class). |
| """ |
| inputs = inputs.sigmoid() |
| inputs = inputs.flatten(1, 2) |
| targets = targets.flatten(1, 2) |
| numerator = 2 * (inputs / scale * targets).sum(-1) |
| denominator = (inputs / scale).sum(-1) + (targets / scale).sum(-1) |
| loss = 1 - (numerator + eps) / (denominator + eps) |
| loss = loss.sum() / (num_masks + 1e-8) |
| return loss |
|
|
|
|
| def sigmoid_ce_loss( |
| inputs: torch.Tensor, |
| targets: torch.Tensor, |
| num_masks: float, |
| ): |
| """ |
| Args: |
| inputs: A float tensor of arbitrary shape. |
| The predictions for each example. |
| targets: A float tensor with the same shape as inputs. Stores the binary |
| classification label for each element in inputs |
| (0 for the negative class and 1 for the positive class). |
| Returns: |
| Loss tensor |
| """ |
| loss = F.binary_cross_entropy_with_logits(inputs, |
| targets, |
| reduction="none") |
| loss = loss.flatten(1, 2).mean(1).sum() / (num_masks + 1e-8) |
| return loss |
|
|
|
|
| class EvfSam2Model(PreTrainedModel): |
| config_class = EvfConfig |
|
|
| def __init__(self, config, **kwargs): |
| super(EvfSam2Model, self).__init__(config) |
|
|
| self.config = config |
| self.vision_pretrained = kwargs.get("vision_pretrained", None) |
| self.encoder_pretrained = kwargs.get("encoder_pretrained", None) |
| self.dice_loss_weight = kwargs.get("dice_loss_weight", None) |
| self.bce_loss_weight = kwargs.get("bce_loss_weight", None) |
| self.train_mask_decoder = kwargs.get("train_mask_decoder", False) |
| self.train_prompt_encoder = kwargs.get("train_prompt_encoder", False) |
| self.initialize_evf_modules(config) |
| self._bb_feat_sizes = [ |
| (256, 256), |
| (128, 128), |
| (64, 64), |
| ] |
|
|
| def initialize_evf_modules(self, config): |
| |
| if config.sam_scale == "large": |
| self.visual_model = build_sam2_video_predictor( |
| "sam2_hiera_l.yaml", self.vision_pretrained, device=None) |
| elif config.sam_scale == "tiny": |
| self.visual_model = build_sam2_video_predictor( |
| "sam2_hiera_t.yaml", self.vision_pretrained, device=None) |
| else: |
| raise NotImplementedError |
|
|
| for param in self.visual_model.parameters(): |
| param.requires_grad = False |
| if self.train_mask_decoder: |
| self.visual_model.sam_mask_decoder.train() |
| for param in self.visual_model.sam_mask_decoder.parameters(): |
| param.requires_grad = True |
| if self.train_prompt_encoder: |
| self.visual_model.sam_prompt_encoder.no_mask_embed.requires_grad_( |
| True) |
|
|
| |
| if self.config.mm_extractor_scale == "base": |
| beit_config = _get_base_config() |
| elif self.config.mm_extractor_scale == "large": |
| beit_config = _get_large_config() |
| else: |
| raise AttributeError( |
| f"model config should contain key 'mm_extractor_scale', with value 'base' or 'large'." |
| ) |
|
|
| self.mm_extractor = BEiT3Wrapper(beit_config) |
| if self.encoder_pretrained is not None: |
| beit_state_dict = torch.load(self.encoder_pretrained)["model"] |
| self.mm_extractor.load_state_dict(beit_state_dict, strict=False) |
|
|
| for param in self.mm_extractor.parameters(): |
| param.requires_grad = True |
|
|
| |
| in_dim = config.hidden_size |
| assert in_dim==beit_config.encoder_embed_dim, \ |
| f"projection layer dim {in_dim} mismatch with mm_extractor dim {beit_config.encoder_embed_dim}" |
| out_dim = config.out_dim |
| text_fc = [ |
| nn.Linear(in_dim, in_dim), |
| nn.ReLU(), |
| nn.Linear(in_dim, out_dim) |
| ] |
| self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)]) |
| self.text_hidden_fcs.train() |
| for param in self.text_hidden_fcs.parameters(): |
| param.requires_grad = True |
|
|
| def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: |
| """ |
| Perform PostProcessing on output masks. |
| """ |
| masks = masks.float() |
| masks = F.interpolate(masks, |
| orig_hw, |
| mode="bilinear", |
| align_corners=False) |
| return masks |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| def inference( |
| self, |
| video_path, |
| images_evf, |
| input_ids, |
| |
| multimask_output=True, |
| ): |
| predictor = self.visual_model |
| inference_state = predictor.init_state(video_path=video_path) |
| predictor.reset_state(inference_state) |
|
|
| multimask_output = multimask_output |
|
|
| output = self.mm_extractor.beit3( |
| visual_tokens=images_evf, |
| textual_tokens=input_ids, |
| text_padding_position=torch.zeros_like(input_ids)) |
|
|
| feat = output["encoder_out"][:, :1, ...] |
| feat = self.text_hidden_fcs[0](feat) |
|
|
| ann_frame_idx = 0 |
| ann_obj_id = 1 |
|
|
| _, out_obj_ids, out_mask_logits = predictor.add_new_text( |
| inference_state=inference_state, |
| frame_idx=ann_frame_idx, |
| obj_id=ann_obj_id, |
| text=feat) |
|
|
| |
| video_segments = { |
| } |
| for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video( |
| inference_state): |
| video_segments[out_frame_idx] = { |
| out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() |
| for i, out_obj_id in enumerate(out_obj_ids) |
| } |
|
|
| return video_segments |
|
|
|
|
| AutoConfig.register("evf", EvfConfig) |
| AutoModelForCausalLM.register(EvfConfig, EvfSam2Model) |
|
|