AuralSAM2 / ref-avs.code /model /mymodel.py
yyliu01's picture
Upload folder using huggingface_hub
c6dfc69 verified
"""End-to-end Ref-AVS: SAM2 visual backbone + AuralFuser fusion + tracking head.
Orchestration follows ``avs.code/v1m.code/model/mymodel.py``.
"""
import torch
from model.visual.sam2.build_sam import build_sam2_visual_predictor
from model.visual.sam2.utils.transforms import SAM2Transforms
from model.aural_fuser import AuralFuser
from transformers import AutoTokenizer, AutoModel
class AVmodel(torch.nn.Module):
"""SAM2 + audio/text fusion (``aural_fuser``) + SAM2 tracking decoder."""
def __init__(self, param, mask_threshold=0.0, max_hole_area=0.0, max_sprinkle_area=0.0):
super().__init__()
self.param = param
self.mask_threshold = mask_threshold
self._bb_feat_sizes = [
(int(self.param.image_size / 4), int(self.param.image_size / 4)),
(int(self.param.image_size / 8), int(self.param.image_size / 8)),
(int(self.param.image_size / 16), int(self.param.image_size / 16)),
]
self.v_model = build_sam2_visual_predictor(
self.param.sam_config_path,
self.param.backbone_weight,
apply_postprocessing=True,
mode='train',
hydra_overrides_extra=["++model.image_size={}".format(self.param.image_size)],
)
self._transforms = SAM2Transforms(
resolution=self.v_model.image_size,
mask_threshold=mask_threshold,
max_hole_area=max_hole_area,
max_sprinkle_area=max_sprinkle_area,
)
self.aural_fuser = AuralFuser(hyp_param=self.param)
self.text_tokenizer = AutoTokenizer.from_pretrained('distilbert/distilroberta-base')
self.t_model = AutoModel.from_pretrained('distilbert/distilroberta-base')
def _encode_text(self, prompts):
"""RoBERTa embeddings for referring expressions (frozen at train time)."""
enc = self.text_tokenizer(
*prompts,
max_length=25,
padding="max_length",
truncation=True,
return_tensors="pt",
)
enc['input_ids'] = enc['input_ids'].cuda(self.param.local_rank, non_blocking=True)
enc['attention_mask'] = enc['attention_mask'].cuda(self.param.local_rank, non_blocking=True)
with torch.no_grad():
return self.t_model(**enc).last_hidden_state
def forward_frame(self, frame_):
"""Single-frame SAM2 image encoder pass (same helper pattern as v1m)."""
frame = torch.nn.functional.interpolate(
frame_, (self.param.image_size, self.param.image_size),
antialias=True, align_corners=False, mode='bilinear',
)
return self.v_model.image_encoder(frame)
def forward(self, frames, spect, prompts, sam_process=False):
"""Fuse audio+text into FPN, then run SAM2 tracking without box/mask prompts."""
text_feats = self._encode_text(prompts)
backbone_feats = self.v_model.forward_image(frames, pre_compute=False)
audio_residual_feats = self.aural_fuser(backbone_feats, spect, text_feats)
visual_resfeats, audio_resfeats, proj_feats = audio_residual_feats
map_res = visual_resfeats[::-1]
vec_res = audio_resfeats[::-1]
av_feats = (map_res, vec_res)
backbone_feats = self.v_model.precompute_high_res_features(backbone_feats)
backbone_feats = self.v_model.dont_prepare_prompt_inputs(
backbone_feats,
num_frames=frames.shape[0],
condition_frame=int(frames.shape[0] / 2),
)
outputs = self.v_model.forward_tracking_wo_prompt(backbone_feats, audio_res=av_feats)
return outputs, proj_feats
@property
def device(self) -> torch.device:
return self.v_model.device
def freeze_sam_parameters(self):
"""Freeze SAM2 and text backbone; only ``aural_fuser`` is trained."""
self.v_model.eval()
self.t_model.eval()
for _, parameter in self.v_model.named_parameters():
parameter.requires_grad = False