| import logging |
|
|
| from typing import List, Optional, Tuple, Union |
|
|
| import numpy |
| import numpy as np |
| import torch |
| from PIL.Image import Image |
|
|
| from model.visual.sam2.modeling.sam2_base import SAM2Base |
|
|
| from model.visual.sam2.modeling.backbones.hieradet import Hiera |
| from model.visual.sam2.modeling.backbones.image_encoder import FpnNeck |
| from model.visual.sam2.modeling.backbones.image_encoder import ImageEncoder |
| from model.visual.sam2.modeling.position_encoding import PositionEmbeddingSine |
|
|
| from model.visual.sam2.modeling.memory_attention import MemoryAttention |
| from model.visual.sam2.modeling.memory_attention import MemoryAttentionLayer |
| from model.visual.sam2.modeling.sam.transformer import RoPEAttention |
| from model.visual.sam2.modeling.memory_encoder import MemoryEncoder |
| from model.visual.sam2.modeling.memory_encoder import MaskDownSampler |
| from model.visual.sam2.modeling.memory_encoder import Fuser |
| from model.visual.sam2.modeling.memory_encoder import CXBlock |
|
|
| from model.visual.sam2.utils.transforms import SAM2Transforms |
| from model.visual.sam2.modeling.backbones.hieradet import do_pool |
| from model.visual.sam2.modeling.backbones.utils import ( |
| PatchEmbed, |
| window_partition, |
| window_unpartition, |
| ) |
|
|
|
|
| class AVmodel(torch.nn.Module): |
| """End-to-end AV segmentation: SAM2 visual backbone + AuralFuser audio-visual fusion + tracking head.""" |
|
|
| def __init__(self, param, mask_threshold=0.0, max_hole_area=0.0, max_sprinkle_area=0.0, ): |
| super().__init__() |
| self.param = param |
| self.mask_threshold = mask_threshold |
| self._bb_feat_sizes = [(int(self.param.image_size / 4), int(self.param.image_size / 4)), |
| (int(self.param.image_size / 8), int(self.param.image_size / 8)), |
| (int(self.param.image_size / 16), int(self.param.image_size / 16))] |
|
|
| from model.visual.sam2.build_sam import build_sam2_visual_predictor |
| self.v_model = build_sam2_visual_predictor(self.param.sam_config_path, self.param.backbone_weight, |
| apply_postprocessing=True, mode='train') |
| self._transforms = SAM2Transforms( |
| resolution=self.v_model.image_size, |
| mask_threshold=mask_threshold, |
| max_hole_area=max_hole_area, |
| max_sprinkle_area=max_sprinkle_area, |
| ) |
| from model.aural_fuser import AuralFuser |
| self.aural_fuser = AuralFuser(hyp_param=self.param) |
|
|
|
|
|
|
| def _prepare_backbone_features(self, backbone_out): |
| """Prepare and flatten visual features.""" |
| backbone_out = backbone_out.copy() |
| assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) |
| assert len(backbone_out["backbone_fpn"]) >= self.v_model.num_feature_levels |
|
|
| feature_maps = backbone_out["backbone_fpn"][-self.v_model.num_feature_levels:] |
| vision_pos_embeds = backbone_out["vision_pos_enc"][-self.v_model.num_feature_levels:] |
|
|
| feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] |
| vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] |
| vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] |
|
|
| return backbone_out, vision_feats, vision_pos_embeds, feat_sizes |
|
|
| def forward_frame(self, frame_): |
| frame = torch.nn.functional.interpolate(frame_, (self.param.image_size, self.param.image_size), |
| antialias=True, align_corners=False, mode='bilinear') |
| return self.v_model.image_encoder(frame) |
|
|
| def forward(self, frames, spect, prompts, sam_process=False): |
| """Fuse audio into FPN features, then run SAM2 tracking. `sam_process` is reserved for prompt path.""" |
| backbone_feats = self.v_model.forward_image(frames, pre_compute=False) |
| audio_residual_feats = self.aural_fuser(backbone_feats, spect) |
| visual_resfeats, audio_resfeats, proj_feats = audio_residual_feats |
|
|
| map_res = visual_resfeats[::-1] |
| vec_res = audio_resfeats[::-1] |
|
|
| av_feats = (map_res, vec_res) |
| backbone_feats = self.v_model.precompute_high_res_features(backbone_feats) |
| backbone_feats = self.v_model.dont_prepare_prompt_inputs(backbone_feats, num_frames=frames.shape[0], |
| cond_frame=int(frames.shape[0]/2) if self.training else 0) |
| outputs = self.v_model.forward_tracking_wo_prompt(backbone_feats, audio_res=av_feats) |
| return outputs, proj_feats |
|
|
| @property |
| def device(self) -> torch.device: |
| return self.v_model.device |
|
|
| def freeze_sam_parameters(self): |
| self.v_model.eval() |
| for name, parameter in self.v_model.named_parameters(): |
| parameter.requires_grad = False |
|
|