| """End-to-end Ref-AVS: SAM2 visual backbone + AuralFuser fusion + tracking head. |
| |
| Orchestration follows ``avs.code/v1m.code/model/mymodel.py``. |
| """ |
| import torch |
|
|
| from model.visual.sam2.build_sam import build_sam2_visual_predictor |
| from model.visual.sam2.utils.transforms import SAM2Transforms |
| from model.aural_fuser import AuralFuser |
| from transformers import AutoTokenizer, AutoModel |
|
|
|
|
| class AVmodel(torch.nn.Module): |
| """SAM2 + audio/text fusion (``aural_fuser``) + SAM2 tracking decoder.""" |
|
|
| def __init__(self, param, mask_threshold=0.0, max_hole_area=0.0, max_sprinkle_area=0.0): |
| super().__init__() |
| self.param = param |
| self.mask_threshold = mask_threshold |
| self._bb_feat_sizes = [ |
| (int(self.param.image_size / 4), int(self.param.image_size / 4)), |
| (int(self.param.image_size / 8), int(self.param.image_size / 8)), |
| (int(self.param.image_size / 16), int(self.param.image_size / 16)), |
| ] |
|
|
| self.v_model = build_sam2_visual_predictor( |
| self.param.sam_config_path, |
| self.param.backbone_weight, |
| apply_postprocessing=True, |
| mode='train', |
| hydra_overrides_extra=["++model.image_size={}".format(self.param.image_size)], |
| ) |
| self._transforms = SAM2Transforms( |
| resolution=self.v_model.image_size, |
| mask_threshold=mask_threshold, |
| max_hole_area=max_hole_area, |
| max_sprinkle_area=max_sprinkle_area, |
| ) |
| self.aural_fuser = AuralFuser(hyp_param=self.param) |
| self.text_tokenizer = AutoTokenizer.from_pretrained('distilbert/distilroberta-base') |
| self.t_model = AutoModel.from_pretrained('distilbert/distilroberta-base') |
|
|
| def _encode_text(self, prompts): |
| """RoBERTa embeddings for referring expressions (frozen at train time).""" |
| enc = self.text_tokenizer( |
| *prompts, |
| max_length=25, |
| padding="max_length", |
| truncation=True, |
| return_tensors="pt", |
| ) |
| enc['input_ids'] = enc['input_ids'].cuda(self.param.local_rank, non_blocking=True) |
| enc['attention_mask'] = enc['attention_mask'].cuda(self.param.local_rank, non_blocking=True) |
| with torch.no_grad(): |
| return self.t_model(**enc).last_hidden_state |
|
|
| def forward_frame(self, frame_): |
| """Single-frame SAM2 image encoder pass (same helper pattern as v1m).""" |
| frame = torch.nn.functional.interpolate( |
| frame_, (self.param.image_size, self.param.image_size), |
| antialias=True, align_corners=False, mode='bilinear', |
| ) |
| return self.v_model.image_encoder(frame) |
|
|
| def forward(self, frames, spect, prompts, sam_process=False): |
| """Fuse audio+text into FPN, then run SAM2 tracking without box/mask prompts.""" |
| text_feats = self._encode_text(prompts) |
| backbone_feats = self.v_model.forward_image(frames, pre_compute=False) |
| audio_residual_feats = self.aural_fuser(backbone_feats, spect, text_feats) |
| visual_resfeats, audio_resfeats, proj_feats = audio_residual_feats |
|
|
| map_res = visual_resfeats[::-1] |
| vec_res = audio_resfeats[::-1] |
| av_feats = (map_res, vec_res) |
|
|
| backbone_feats = self.v_model.precompute_high_res_features(backbone_feats) |
| backbone_feats = self.v_model.dont_prepare_prompt_inputs( |
| backbone_feats, |
| num_frames=frames.shape[0], |
| condition_frame=int(frames.shape[0] / 2), |
| ) |
| outputs = self.v_model.forward_tracking_wo_prompt(backbone_feats, audio_res=av_feats) |
| return outputs, proj_feats |
|
|
| @property |
| def device(self) -> torch.device: |
| return self.v_model.device |
|
|
| def freeze_sam_parameters(self): |
| """Freeze SAM2 and text backbone; only ``aural_fuser`` is trained.""" |
| self.v_model.eval() |
| self.t_model.eval() |
| for _, parameter in self.v_model.named_parameters(): |
| parameter.requires_grad = False |
|
|