diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..e335221ed96726da38fe03a806adde6c5fef7784 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +ckpts/avs/v1s/nohup.out filter=lfs diff=lfs merge=lfs -text +ckpts/avs/v2/nohup.out filter=lfs diff=lfs merge=lfs -text +ckpts/ref-avs/nohup.out filter=lfs diff=lfs merge=lfs -text +docs/overview.png filter=lfs diff=lfs merge=lfs -text diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..25d04987294c80a6d5f74281e0ceab3068303646 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Yuyuan Liu + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 32897cd3e640101ba184f8c4ccd896981de3804a..d0d3340f4bd411c31290ff3f13a3141640d89d02 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,26 @@ ---- -license: mit ---- +# AuralSAM2 +> **[CVPRF'26]** [AuralSAM2: Enabling SAM2 Hear +Through Pyramid Audio-Visual Feature Prompting](#) +> +> by Yuyuan Liu, Yuanhong Chen, Chong Wang, Junlin Han, Junde Wu, Can Peng, Jingkun Chen, Yu Tian and Gustavo Carneiro +> + + +## Installation +please install the dependencies and dataset based on this [***installation***](./docs/installation.md) document. + +## Getting start +please follow this [***instruction***](./docs/before_start.md) document to reproduce our results. + +## Citation +please consider citing our work in your publications if it helps your research. + +```bibtex +@article{liu2025auralsam2, + title={AuralSAM2: Enabling SAM2 Hear Through Pyramid Audio-Visual Feature Prompting}, + author={Liu, Yuyuan and Chen, Yuanhong and Wang, Chong and Han, Junlin and Wu, Junde and Peng, Can and Chen, Jingkun and Tian, Yu and Carneiro, Gustavo}, + journal={arXiv preprint arXiv:2506.01015}, + year={2025} +} +``` + diff --git a/avs.code/v1m.code/configs/__init__.py b/avs.code/v1m.code/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/avs.code/v1m.code/configs/auralfuser/architecture.yaml b/avs.code/v1m.code/configs/auralfuser/architecture.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ab4c3d06ca42335ce6bfc8064bbd5cfd44c8080a --- /dev/null +++ b/avs.code/v1m.code/configs/auralfuser/architecture.yaml @@ -0,0 +1,30 @@ +# @package _global_ + +aural_fuser: + patch_cfgs: + - [4, 4] + - [2, 2] + - [1, 1] + f_depths: [3, 6, 12] + block_kw: + dim: 256 + num_heads: 4 + mlp_ratio: 4 + qkv_bias: true + qk_scale: null + drop: 0.1 + attn_drop: 0.1 + drop_path: 0.0 + sr_ratio: 4 + linear: false + one_d_kw: + dim: 256 + num_heads: 4 + mlp_ratio: 4 + qkv_bias: true + qk_scale: null + drop: 0.1 + attn_drop: 0.1 + drop_path: 0.0 + sr_ratio: 4 + linear: false diff --git a/avs.code/v1m.code/configs/config.py b/avs.code/v1m.code/configs/config.py new file mode 100644 index 0000000000000000000000000000000000000000..32ac10a0c17bc93e745852e7c122c6ca195f2e60 --- /dev/null +++ b/avs.code/v1m.code/configs/config.py @@ -0,0 +1,85 @@ +import os +import numpy +from easydict import EasyDict + +# v1m.code package root (parent of this `configs/` directory) +_CODE_ROOT = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +_WORKSPACE_ROOT = os.path.dirname(os.path.dirname(_CODE_ROOT)) + +C = EasyDict() +config = C +cfg = C + +C.seed = 666 + +C.audio = EasyDict() +C.audio.FREEZE_AUDIO_EXTRACTOR = True +C.audio.PRETRAINED_VGGISH_MODEL_PATH = os.path.join(_WORKSPACE_ROOT, 'ckpts', 'vggish-10086976.pth') +C.audio.PREPROCESS_AUDIO_TO_LOG_MEL = False +C.audio.POSTPROCESS_LOG_MEL_WITH_PCA = False +C.train_vggish = False + +"""Root Directory Config""" +C.repo_name = 'AV' +C.root_dir = _CODE_ROOT + +"""Data Dir and Weight Dir""" +C.data_root_path = os.path.join(_WORKSPACE_ROOT, 'AVSBench') +C.data_name = 'v1m' + +C.backbone_weight = os.path.join(_WORKSPACE_ROOT, 'ckpts', 'sam_ckpts', 'sam2_hiera_large.pt') +C.sam_config_path = os.path.join('sam2', 'sam2_hiera_l.yaml') + +"""Network Config""" +C.fix_bias = True +C.bn_eps = 1e-5 +C.bn_momentum = 0.1 + +"""Image Config""" +C.num_classes = 2 + +C.image_mean = numpy.array([0.485, 0.456, 0.406]) +C.image_std = numpy.array([0.229, 0.224, 0.225]) + + +C.image_size = 1024 +C.image_embedding_size = int(C.image_size / 16) +C.avsbench_size = (224, 224) + +C.scale_list = [.5, .75, 1., 1.25, 1.5] +C.ignore_index = 255 + +"""Train Config""" +C.lr = 7.5e-5 +C.batch_size = 8 +C.energy_weight = .05 + +C.lr_power = 0.9 +C.momentum = 0.9 +C.weight_decay = 0.05 + +C.num_workers = 4 + +"""Display Config""" +C.record_info_iter = 20 +C.display_iter = 50 + +"""Wandb Config""" +# Paste your W&B API key here, or set the WANDB_API_KEY environment variable instead. +C.wandb_key = "" + +# Your project [work_space] name +C.proj_name = "AVS-final-report" + +C.experiment_name = "v1s-hiera-l" + + +# False = no wandb logging (see utils/tensorboard.py) +C.wandb_online = False + +"""Save Config""" +C.saved_dir = os.path.join(_WORKSPACE_ROOT, 'ckpts', C.experiment_name) + +import pathlib + +pathlib.Path(C.saved_dir).mkdir(parents=True, exist_ok=True) diff --git a/avs.code/v1m.code/configs/sam2/sam2_hiera_b+.yaml b/avs.code/v1m.code/configs/sam2/sam2_hiera_b+.yaml new file mode 100644 index 0000000000000000000000000000000000000000..52e0f10732134149f6a994be063d11fd7591c430 --- /dev/null +++ b/avs.code/v1m.code/configs/sam2/sam2_hiera_b+.yaml @@ -0,0 +1,114 @@ +# @package _global_ + +# Model +model: + _target_: model.visual.sam2.organised_sam2_train.SAM2Train + image_encoder: + _target_: model.visual.sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: model.visual.sam2.modeling.backbones.hieradet.Hiera + embed_dim: 112 + num_heads: 2 + neck: + _target_: model.visual.sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [896, 448, 224, 112] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: model.visual.sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: model.visual.sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: model.visual.sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: model.visual.sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: model.visual.sam2.modeling.memory_encoder.Fuser + layer: + _target_: model.visual.sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False + diff --git a/avs.code/v1m.code/configs/sam2/sam2_hiera_l.yaml b/avs.code/v1m.code/configs/sam2/sam2_hiera_l.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8478b3d4b8b16d8b22f6555cf7b1f00231d7fd59 --- /dev/null +++ b/avs.code/v1m.code/configs/sam2/sam2_hiera_l.yaml @@ -0,0 +1,117 @@ +# @package _global_ + +# Model +model: + _target_: model.visual.sam2.organised_sam2_train.SAM2Train + image_encoder: + _target_: model.visual.sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: model.visual.sam2.modeling.backbones.hieradet.Hiera + embed_dim: 144 + num_heads: 2 + stages: [2, 6, 36, 4] + global_att_blocks: [23, 33, 43] + window_pos_embed_bkg_spatial_size: [7, 7] + window_spec: [8, 4, 16, 8] + neck: + _target_: model.visual.sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [1152, 576, 288, 144] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: model.visual.sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: model.visual.sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: model.visual.sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: model.visual.sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: model.visual.sam2.modeling.memory_encoder.Fuser + layer: + _target_: model.visual.sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/avs.code/v1m.code/configs/sam2/sam2_hiera_s.yaml b/avs.code/v1m.code/configs/sam2/sam2_hiera_s.yaml new file mode 100644 index 0000000000000000000000000000000000000000..26e5d4d39f7b2892396106005c37c7ffe6c83bc2 --- /dev/null +++ b/avs.code/v1m.code/configs/sam2/sam2_hiera_s.yaml @@ -0,0 +1,116 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 11, 2] + global_att_blocks: [7, 10, 13] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/avs.code/v1m.code/configs/sam2/sam2_hiera_t.yaml b/avs.code/v1m.code/configs/sam2/sam2_hiera_t.yaml new file mode 100644 index 0000000000000000000000000000000000000000..59e605b73c9777b70942538252d27a55ae8a7e1a --- /dev/null +++ b/avs.code/v1m.code/configs/sam2/sam2_hiera_t.yaml @@ -0,0 +1,118 @@ +# @package _global_ + +# Model +model: + _target_: model.visual.sam2.organised_sam2_train.SAM2Train + image_encoder: + _target_: model.visual.sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: model.visual.sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 7, 2] + global_att_blocks: [5, 7, 9] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: model.visual.sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: model.visual.sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: model.visual.sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: model.visual.sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: model.visual.sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: model.visual.sam2.modeling.memory_encoder.Fuser + layer: + _target_: model.visual.sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 224 # 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + # SAM decoder + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: false + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + # HieraT does not currently support compilation, should always be set to False + compile_image_encoder: False diff --git a/avs.code/v1m.code/configs/training/sam2_training_config.yaml b/avs.code/v1m.code/configs/training/sam2_training_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..55771e7232fe88c4ea445958956eca8174c2e872 --- /dev/null +++ b/avs.code/v1m.code/configs/training/sam2_training_config.yaml @@ -0,0 +1,62 @@ +# @package _global_ + +# Video transforms + +train_transforms: + - _target_: dataloader.sam2_dataset.transforms.ComposeAPI + transforms: + - _target_: dataloader.sam2_dataset.transforms.RandomHorizontalFlip + consistent_transform: True + - _target_: dataloader.sam2_dataset.transforms.RandomAffine + degrees: 25 + shear: 20 + image_interpolation: bilinear + consistent_transform: True + - _target_: dataloader.sam2_dataset.transforms.RandomResizeAPI + sizes: 1024 # ${scratch.resolution} + square: true + consistent_transform: True + - _target_: dataloader.sam2_dataset.transforms.ColorJitter + consistent_transform: True + brightness: 0.1 + contrast: 0.03 + saturation: 0.03 + hue: null + - _target_: dataloader.sam2_dataset.transforms.RandomGrayscale + p: 0.05 + consistent_transform: True + - _target_: dataloader.sam2_dataset.transforms.ColorJitter + consistent_transform: False + brightness: 0.1 + contrast: 0.05 + saturation: 0.05 + hue: null + - _target_: dataloader.sam2_dataset.transforms.ToTensorAPI + - _target_: dataloader.sam2_dataset.transforms.NormalizeAPI + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +loss: + all: + _target_: loss.training.sam2_training_loss.MultiStepMultiMasksAndIous + weight_dict: + loss_mask: 20 # 20 + loss_dice: 1 + loss_iou: 1 + loss_class: 1 + supervise_all_iou: true + iou_use_l1_loss: true + pred_obj_scores: true + focal_gamma_obj_score: 0.0 + focal_alpha_obj_score: -1.0 + gpu_num: 4. + +# Contrastive loss (ContrastLoss); loaded in main.py / inference.py → hyp_param.contrastive_learning +contrastive_learning: + temperature: 0.10 + ignore_idx: 255 + ood_idx: 254 + max_views: 512 + proj_dim: 512 + sample_limits: 128 + total_limits: 15240 diff --git a/avs.code/v1m.code/dataloader/audio/audio_augmentation.py b/avs.code/v1m.code/dataloader/audio/audio_augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..850d1577ea2bca4f8ec209edc201fb54968be928 --- /dev/null +++ b/avs.code/v1m.code/dataloader/audio/audio_augmentation.py @@ -0,0 +1,23 @@ +import numpy + + +class Augmentation(object): + """Audio pre-step used by training/inference: int16 waveform -> float in [-1, 1]. + + The previous audiomentations-based transforms were commented out and never applied; + behavior is unchanged: only scaling by 1/32768. + """ + + def __init__(self, mono=True): + self.mono = mono + + def train_aug(self, x_, sr_): + x_ = x_ / 32768.0 + return x_ + + def test_process(self, x_): + x_ = x_ / 32768.0 + return x_ + + def __call__(self, x, sr, split): + return self.train_aug(x, sr) if split == "train" else self.test_process(x) diff --git a/avs.code/v1m.code/dataloader/audio/audio_dataset.py b/avs.code/v1m.code/dataloader/audio/audio_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5c8e8b276e8545aa55ef56295719a0ad2b167106 --- /dev/null +++ b/avs.code/v1m.code/dataloader/audio/audio_dataset.py @@ -0,0 +1,38 @@ +import torch +import numpy +import os +from dataloader.audio.preprocess_vgg.vggish_input import waveform_to_examples +import soundfile + + +class Audio(torch.utils.data.Dataset): + def __init__(self, augmentation, directory_path, split): + # temporarily set no augmentation. + self.augmentation = augmentation + self.directory_path = directory_path + self.split = split + + def load_audio_wave(self, file_index, file_index_mix): + audio_path = os.path.join(file_index, 'audio.wav') + wav_data, sample_rate = soundfile.read(audio_path, dtype='int16') + assert wav_data.dtype == numpy.int16, 'Bad sample type: %r' % wav_data.dtype + + if file_index_mix is not None: + audio_path2 = os.path.join(file_index_mix, 'audio.wav') + wav_data2, _ = soundfile.read(audio_path2, dtype='int16') + mix_lambda = numpy.random.beta(10, 10) + min_length = min(wav_data.shape[0], wav_data2.shape[0]) + wav_data = wav_data[:min_length] * mix_lambda + wav_data2[:min_length] * (1-mix_lambda) + + wav_data = self.augmentation(wav_data, sample_rate, self.split) + audio_log_mel = torch.cat([waveform_to_examples(wav_data[:, 0], sample_rate, True).detach(), + waveform_to_examples(wav_data[:, 1], sample_rate, True).detach()], dim=1) + + # for the vgg preprocess, we will need 5 seconds audio log. + if audio_log_mel.shape[0] < 5: + audio_log_mel = torch.cat([audio_log_mel, + audio_log_mel[-1].unsqueeze(0).repeat(5-audio_log_mel.shape[0], 1, 1, 1)]) + return audio_log_mel + + def __len__(self): + return len(self.audio_list) diff --git a/avs.code/v1m.code/dataloader/audio/preprocess_vgg/mel_features.py b/avs.code/v1m.code/dataloader/audio/preprocess_vgg/mel_features.py new file mode 100644 index 0000000000000000000000000000000000000000..ac58fb5427f772fcced9cbd3cec3373ffbe5908c --- /dev/null +++ b/avs.code/v1m.code/dataloader/audio/preprocess_vgg/mel_features.py @@ -0,0 +1,223 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Defines routines to compute mel spectrogram features from audio waveform.""" + +import numpy as np + + +def frame(data, window_length, hop_length): + """Convert array into a sequence of successive possibly overlapping frames. + + An n-dimensional array of shape (num_samples, ...) is converted into an + (n+1)-D array of shape (num_frames, window_length, ...), where each frame + starts hop_length points after the preceding one. + + This is accomplished using stride_tricks, so the original data is not + copied. However, there is no zero-padding, so any incomplete frames at the + end are not included. + + Args: + data: np.array of dimension N >= 1. + window_length: Number of samples in each frame. + hop_length: Advance (in samples) between each window. + + Returns: + (N+1)-D np.array with as many rows as there are complete frames that can be + extracted. + """ + num_samples = data.shape[0] + num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length)) + shape = (num_frames, window_length) + data.shape[1:] + strides = (data.strides[0] * hop_length,) + data.strides + return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides) + + +def periodic_hann(window_length): + """Calculate a "periodic" Hann window. + + The classic Hann window is defined as a raised cosine that starts and + ends on zero, and where every value appears twice, except the middle + point for an odd-length window. Matlab calls this a "symmetric" window + and np.hanning() returns it. However, for Fourier analysis, this + actually represents just over one cycle of a period N-1 cosine, and + thus is not compactly expressed on a length-N Fourier basis. Instead, + it's better to use a raised cosine that ends just before the final + zero value - i.e. a complete cycle of a period-N cosine. Matlab + calls this a "periodic" window. This routine calculates it. + + Args: + window_length: The number of points in the returned window. + + Returns: + A 1D np.array containing the periodic hann window. + """ + return 0.5 - (0.5 * np.cos(2 * np.pi / window_length * + np.arange(window_length))) + + +def stft_magnitude(signal, fft_length, + hop_length=None, + window_length=None): + """Calculate the short-time Fourier transform magnitude. + + Args: + signal: 1D np.array of the input time-domain signal. + fft_length: Size of the FFT to apply. + hop_length: Advance (in samples) between each frame passed to FFT. + window_length: Length of each block of samples to pass to FFT. + + Returns: + 2D np.array where each row contains the magnitudes of the fft_length/2+1 + unique values of the FFT for the corresponding frame of input samples. + """ + frames = frame(signal, window_length, hop_length) + # Apply frame window to each frame. We use a periodic Hann (cosine of period + # window_length) instead of the symmetric Hann of np.hanning (period + # window_length-1). + window = periodic_hann(window_length) + windowed_frames = frames * window + return np.abs(np.fft.rfft(windowed_frames, int(fft_length))) + + +# Mel spectrum constants and functions. +_MEL_BREAK_FREQUENCY_HERTZ = 700.0 +_MEL_HIGH_FREQUENCY_Q = 1127.0 + + +def hertz_to_mel(frequencies_hertz): + """Convert frequencies to mel scale using HTK formula. + + Args: + frequencies_hertz: Scalar or np.array of frequencies in hertz. + + Returns: + Object of same size as frequencies_hertz containing corresponding values + on the mel scale. + """ + return _MEL_HIGH_FREQUENCY_Q * np.log( + 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ)) + + +def spectrogram_to_mel_matrix(num_mel_bins=20, + num_spectrogram_bins=129, + audio_sample_rate=8000, + lower_edge_hertz=125.0, + upper_edge_hertz=3800.0): + """Return a matrix that can post-multiply spectrogram rows to make mel. + + Returns a np.array matrix A that can be used to post-multiply a matrix S of + spectrogram values (STFT magnitudes) arranged as frames x bins to generate a + "mel spectrogram" M of frames x num_mel_bins. M = S A. + + The classic HTK algorithm exploits the complementarity of adjacent mel bands + to multiply each FFT bin by only one mel weight, then add it, with positive + and negative signs, to the two adjacent mel bands to which that bin + contributes. Here, by expressing this operation as a matrix multiply, we go + from num_fft multiplies per frame (plus around 2*num_fft adds) to around + num_fft^2 multiplies and adds. However, because these are all presumably + accomplished in a single call to np.dot(), it's not clear which approach is + faster in Python. The matrix multiplication has the attraction of being more + general and flexible, and much easier to read. + + Args: + num_mel_bins: How many bands in the resulting mel spectrum. This is + the number of columns in the output matrix. + num_spectrogram_bins: How many bins there are in the source spectrogram + data, which is understood to be fft_size/2 + 1, i.e. the spectrogram + only contains the nonredundant FFT bins. + audio_sample_rate: Samples per second of the audio at the input to the + spectrogram. We need this to figure out the actual frequencies for + each spectrogram bin, which dictates how they are mapped into mel. + lower_edge_hertz: Lower bound on the frequencies to be included in the mel + spectrum. This corresponds to the lower edge of the lowest triangular + band. + upper_edge_hertz: The desired top edge of the highest frequency band. + + Returns: + An np.array with shape (num_spectrogram_bins, num_mel_bins). + + Raises: + ValueError: if frequency edges are incorrectly ordered or out of range. + """ + nyquist_hertz = audio_sample_rate / 2. + if lower_edge_hertz < 0.0: + raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz) + if lower_edge_hertz >= upper_edge_hertz: + raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" % + (lower_edge_hertz, upper_edge_hertz)) + if upper_edge_hertz > nyquist_hertz: + raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" % + (upper_edge_hertz, nyquist_hertz)) + spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins) + spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz) + # The i'th mel band (starting from i=1) has center frequency + # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge + # band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in + # the band_edges_mel arrays. + band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz), + hertz_to_mel(upper_edge_hertz), num_mel_bins + 2) + # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins + # of spectrogram values. + mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins)) + for i in range(num_mel_bins): + lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3] + # Calculate lower and upper slopes for every spectrogram bin. + # Line segments are linear in the *mel* domain, not hertz. + lower_slope = ((spectrogram_bins_mel - lower_edge_mel) / + (center_mel - lower_edge_mel)) + upper_slope = ((upper_edge_mel - spectrogram_bins_mel) / + (upper_edge_mel - center_mel)) + # .. then intersect them with each other and zero. + mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope, + upper_slope)) + # HTK excludes the spectrogram DC bin; make sure it always gets a zero + # coefficient. + mel_weights_matrix[0, :] = 0.0 + return mel_weights_matrix + + +def log_mel_spectrogram(data, + audio_sample_rate=8000, + log_offset=0.0, + window_length_secs=0.025, + hop_length_secs=0.010, + **kwargs): + """Convert waveform to a log magnitude mel-frequency spectrogram. + + Args: + data: 1D np.array of waveform data. + audio_sample_rate: The sampling rate of data. + log_offset: Add this to values when taking log to avoid -Infs. + window_length_secs: Duration of each window to analyze. + hop_length_secs: Advance between successive analysis windows. + **kwargs: Additional arguments to pass to spectrogram_to_mel_matrix. + + Returns: + 2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank + magnitudes for successive frames. + """ + window_length_samples = int(round(audio_sample_rate * window_length_secs)) + hop_length_samples = int(round(audio_sample_rate * hop_length_secs)) + fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0))) + spectrogram = stft_magnitude( + data, + fft_length=fft_length, + hop_length=hop_length_samples, + window_length=window_length_samples) + mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix( + num_spectrogram_bins=spectrogram.shape[1], + audio_sample_rate=audio_sample_rate, **kwargs)) + return np.log(mel_spectrogram + log_offset) diff --git a/avs.code/v1m.code/dataloader/audio/preprocess_vgg/vggish_input.py b/avs.code/v1m.code/dataloader/audio/preprocess_vgg/vggish_input.py new file mode 100644 index 0000000000000000000000000000000000000000..9d58e81bc70a85138980128e033f271998794605 --- /dev/null +++ b/avs.code/v1m.code/dataloader/audio/preprocess_vgg/vggish_input.py @@ -0,0 +1,98 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Compute input examples for VGGish from audio waveform.""" + +# Modification: Return torch tensors rather than numpy arrays +import torch + +import numpy as np +import resampy + +from dataloader.audio.preprocess_vgg import mel_features +from dataloader.audio.preprocess_vgg import vggish_params + +import soundfile as sf + + +def waveform_to_examples(data, sample_rate, return_tensor=True): + """Converts audio waveform into an array of examples for VGGish. + + Args: + data: np.array of either one dimension (mono) or two dimensions + (multi-channel, with the outer dimension representing channels). + Each sample is generally expected to lie in the range [-1.0, +1.0], + although this is not required. + sample_rate: Sample rate of data. + return_tensor: Return data as a Pytorch tensor ready for VGGish + + Returns: + 3-D np.array of shape [num_examples, num_frames, num_bands] which represents + a sequence of examples, each of which contains a patch of log mel + spectrogram, covering num_frames frames of audio and num_bands mel frequency + bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS. + + """ + # Convert to mono. + if len(data.shape) > 1: + data = np.mean(data, axis=1) + # Resample to the rate assumed by VGGish. + if sample_rate != vggish_params.SAMPLE_RATE: + data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE) + + # Compute log mel spectrogram features. + log_mel = mel_features.log_mel_spectrogram( + data, + audio_sample_rate=vggish_params.SAMPLE_RATE, + log_offset=vggish_params.LOG_OFFSET, + window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS, + hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS, + num_mel_bins=vggish_params.NUM_MEL_BINS, + lower_edge_hertz=vggish_params.MEL_MIN_HZ, + upper_edge_hertz=vggish_params.MEL_MAX_HZ) + + # Frame features into examples. + features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS + example_window_length = int(round( + vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate)) + example_hop_length = int(round( + vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate)) + log_mel_examples = mel_features.frame( + log_mel, + window_length=example_window_length, + hop_length=example_hop_length) + + if return_tensor: + log_mel_examples = torch.tensor( + log_mel_examples, requires_grad=True)[:, None, :, :].float() + + return log_mel_examples + + +def wavfile_to_examples(wav_file, return_tensor=True): + """Convenience wrapper around waveform_to_examples() for a common WAV format. + + Args: + wav_file: String path to a file, or a file-like object. The file + is assumed to contain WAV audio data with signed 16-bit PCM samples. + torch: Return data as a Pytorch tensor ready for VGGish + + Returns: + See waveform_to_examples. + """ + wav_data, sr = sf.read(wav_file, dtype='int16') + assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype + samples = wav_data / 32768.0 # Convert to [-1.0, +1.0] + return waveform_to_examples(samples, sr, return_tensor) diff --git a/avs.code/v1m.code/dataloader/audio/preprocess_vgg/vggish_params.py b/avs.code/v1m.code/dataloader/audio/preprocess_vgg/vggish_params.py new file mode 100644 index 0000000000000000000000000000000000000000..526784bceaa4c9c8b8dc2b8f82e0f3d395d4bec2 --- /dev/null +++ b/avs.code/v1m.code/dataloader/audio/preprocess_vgg/vggish_params.py @@ -0,0 +1,53 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Global parameters for the VGGish model. + +See vggish_slim.py for more information. +""" + +# Architectural constants. +NUM_FRAMES = 96 # Frames in input mel-spectrogram patch. +NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch. +EMBEDDING_SIZE = 128 # Size of embedding layer. + +# Hyperparameters used in feature and example generation. +SAMPLE_RATE = 16000 +STFT_WINDOW_LENGTH_SECONDS = 0.025 +STFT_HOP_LENGTH_SECONDS = 0.010 +NUM_MEL_BINS = NUM_BANDS +MEL_MIN_HZ = 125 +MEL_MAX_HZ = 7500 +LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram. +EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames +EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap. + +# Parameters used for embedding postprocessing. +PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors' +PCA_MEANS_NAME = 'pca_means' +QUANTIZE_MIN_VAL = -2.0 +QUANTIZE_MAX_VAL = +2.0 + +# Hyperparameters used in training. +INIT_STDDEV = 0.01 # Standard deviation used to initialize weights. +LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer. +ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer. + +# Names of ops, tensors, and features. +INPUT_OP_NAME = 'vggish/input_features' +INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0' +OUTPUT_OP_NAME = 'vggish/embedding' +OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0' +AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding' diff --git a/avs.code/v1m.code/dataloader/dataset.py b/avs.code/v1m.code/dataloader/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..066f3639049e8840a67c60078ae7a8d6f38c1fa2 --- /dev/null +++ b/avs.code/v1m.code/dataloader/dataset.py @@ -0,0 +1,67 @@ +"""Fused audio-visual dataset for AVSBench-style indexing.""" +import os +import random +import PIL.Image +import numpy +import torch +from dataloader.visual.visual_dataset import Visual +from dataloader.audio.audio_dataset import Audio +import pandas + + +class AV(torch.utils.data.Dataset): + """Pairs video frames + labels from `Visual` with log-mel spectrograms from `Audio` via `metadata.csv`.""" + + def __init__(self, split, augmentation, param, root_path='', data_name='find'): + self.visual_dataset = Visual(augmentation['visual'], os.path.join(root_path, data_name), split, param.image_size, param.image_embedding_size) + self.audio_dataset = Audio(augmentation['audio'], os.path.join(root_path, data_name), split) + self.augment = augmentation + self.split = split + self.file_path = self.organise_files(self.split, root_path, data_name, csv_name_='avss_index/metadata.csv') + + def __getitem__(self, index): + mixing_prob = 0. # we omit this option. + other_index = random.randint(1, self.__len__()) - 1 if random.random() < mixing_prob and self.split == 'train' else None + frame, label, prompts = self.visual_dataset.load_data(self.file_path[index]) + if other_index is not None: + other_frame, other_label, other_prompts = self.visual_dataset.load_data(self.file_path[other_index]) + frame, label, prompts = self.visual_mix(frame, other_frame, label, other_label, prompts, other_prompts) + audio_mel = self.audio_dataset.load_audio_wave(self.file_path[index], self.file_path[other_index]) + else: + audio_mel = self.audio_dataset.load_audio_wave(self.file_path[index], None) + + assert other_index is None if self.split == 'test' else 1, print('no mix in validation.') + + return {'frame': frame, 'label': label, 'spectrogram': audio_mel, 'id': self.file_path[index], + 'prompts': prompts} + + def __len__(self): + return len(self.file_path) + + @staticmethod + def organise_files(split_, root_path_, data_name_, csv_name_): + """Read rows from `csv_name_` under `root_path_` matching split and dataset label.""" + total_files = pandas.read_csv(os.path.join(root_path_, csv_name_)) + files_info = total_files[(total_files["split"] == split_) & (total_files["label"] == data_name_)]['uid'] + + files_path = [os.path.join(root_path_, data_name_, files_name) for files_name in files_info] + del total_files, files_info + return files_path + + @staticmethod + def visual_mix(frame1, frame2, label1, label2, prompts1, prompts2): + mix_frame = frame1.clone() + mix_label = label1.clone() + bbx1, bby1, bbx2, bby2 = 0, 0, mix_label.shape[1] - 1, mix_label.shape[2] - 1 + + for i in range(0, mix_frame.shape[0]): + label_canvas_foreground = label2[i, bbx1:bbx2, bby1:bby2] > 0. + mix_frame[i, :, bbx1:bbx2, bby1:bby2][:, label_canvas_foreground] = ( + frame2[i, :, bbx1:bbx2, bby1:bby2][:, label_canvas_foreground]) + mix_label[i, bbx1:bbx2, bby1:bby2][label_canvas_foreground] = ( + label2[i, bbx1:bbx2, bby1:bby2][label_canvas_foreground]) + + return mix_frame, mix_label, prompts1 + + + diff --git a/avs.code/v1m.code/dataloader/sam2_dataset/__init__.py b/avs.code/v1m.code/dataloader/sam2_dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/avs.code/v1m.code/dataloader/sam2_dataset/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/avs.code/v1m.code/dataloader/sam2_dataset/transforms.py b/avs.code/v1m.code/dataloader/sam2_dataset/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..7731e59ba98a5465493e3a9c4b785eb4d4420ca2 --- /dev/null +++ b/avs.code/v1m.code/dataloader/sam2_dataset/transforms.py @@ -0,0 +1,528 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Transforms and data augmentation for both image + bbox. +""" + +import logging + +import random +from typing import Iterable + +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as F +import torchvision.transforms.v2.functional as Fv2 +from PIL import Image as PILImage +# from docutils.nodes import label +import numpy +from torchvision.transforms import InterpolationMode + +# from utils.data_utils import VideoDatapoint + + +def hflip(frames, labels, index): + # print(index) + # print(len(frames), frames[index].size, type(frames[index])) + # print(len(labels), labels[index].size, type(labels[index])) + frames[index] = F.hflip(frames[index]) + labels[index] = F.hflip(labels[index]) + # for obj in frames[index].objects: + # if obj.segment is not None: + # obj.segment = F.hflip(obj.segment) + + return frames, labels + + +def get_size_with_aspect_ratio(image_size, size, max_size=None): + w, h = image_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = max_size * min_original_size / max_original_size + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = int(round(size)) + oh = int(round(size * h / w)) + else: + oh = int(round(size)) + ow = int(round(size * w / h)) + + return (oh, ow) + + +def resize(frames, labels, index, size, max_size=None, square=False, v2=False): + # size can be min_size (scalar) or (w, h) tuple + def get_size(image_size, size, max_size=None): + if isinstance(size, (list, tuple)): + return size[::-1] + else: + return get_size_with_aspect_ratio(image_size, size, max_size) + + if square: + size = size, size + else: + raise NotImplementedError + # cur_size = ( + # frames[index].data.size()[-2:][::-1] + # if v2 + # else frames[index].data.size + # ) + # size = get_size(cur_size, size, max_size) + + # old_size = ( + # frames[index].data.size()[-2:][::-1] + # if v2 + # else frames[index].data.size + # ) + if v2: + frames[index].data = Fv2.resize( + frames[index].data, size, antialias=True + ) + else: + frames[index] = F.resize(frames[index], size) + labels[index] = F.resize(labels[index], size) + # new_size = ( + # frames[index].data.size()[-2:][::-1] + # if v2 + # else frames[index].data.size + # ) + + # for obj in frames[index].objects: + # if obj.segment is not None: + # obj.segment = F.resize(obj.segment[None, None], size).squeeze() + + # h, w = size + # frames[index].size = (h, w) + return frames, labels + + +def pad(frames, index, padding, v2=False): + old_h, old_w = frames[index].size + h, w = old_h, old_w + if len(padding) == 2: + # assumes that we only pad on the bottom right corners + frames[index].data = F.pad( + frames[index].data, (0, 0, padding[0], padding[1]) + ) + h += padding[1] + w += padding[0] + else: + # left, top, right, bottom + frames[index].data = F.pad( + frames[index].data, + (padding[0], padding[1], padding[2], padding[3]), + ) + h += padding[1] + padding[3] + w += padding[0] + padding[2] + + frames[index].size = (h, w) + + for obj in frames[index].objects: + if obj.segment is not None: + if v2: + if len(padding) == 2: + obj.segment = Fv2.pad(obj.segment, (0, 0, padding[0], padding[1])) + else: + obj.segment = Fv2.pad(obj.segment, tuple(padding)) + else: + if len(padding) == 2: + obj.segment = F.pad(obj.segment, (0, 0, padding[0], padding[1])) + else: + obj.segment = F.pad(obj.segment, tuple(padding)) + return frames + + +class RandomHorizontalFlip: + def __init__(self, consistent_transform, p=0.5): + self.p = p + self.consistent_transform = consistent_transform + + def __call__(self, frames, labels, **kwargs): + if self.consistent_transform: + if random.random() < self.p: + for i in range(len(frames)): + frames, labels = hflip(frames, labels, i) + return frames, labels + for i in range(len(frames)): + if random.random() < self.p: + frames, labels = hflip(frames, labels, i) + return frames, labels + + +class RandomResizeAPI: + def __init__( + self, sizes, consistent_transform, max_size=None, square=False, v2=False + ): + if isinstance(sizes, int): + sizes = (sizes,) + assert isinstance(sizes, Iterable) + self.sizes = list(sizes) + self.max_size = max_size + self.square = square + self.consistent_transform = consistent_transform + self.v2 = v2 + + def __call__(self, frames, labels): + if self.consistent_transform: + size = random.choice(self.sizes) + for i in range(len(frames)): + frames, labels = resize( + frames, labels, i, size, self.max_size, square=self.square, v2=self.v2 + ) + return frames, labels + for i in range(len(frames)): + size = random.choice(self.sizes) + frames, labels = resize( + frames, labels, i, size, self.max_size, square=self.square, v2=self.v2 + ) + return frames, labels + + +class ToTensorAPI: + def __init__(self, v2=False): + self.v2 = v2 + + def __call__(self, frames, labels, **kwargs): + for img_idx in range(len(frames)): + if self.v2: + raise NotImplementedError + # frames[img_idx] = Fv2.to_tensor(frames[img_idx]) + else: + frames[img_idx] = F.to_tensor(frames[img_idx]) + labels[img_idx] = torch.tensor(numpy.array(labels[img_idx]), dtype=torch.float) + return frames, labels + + +class NormalizeAPI: + def __init__(self, mean, std, v2=False): + self.mean = mean + self.std = std + self.v2 = v2 + + def __call__(self, frames, labels, **kwargs): + for img_idx in range(len(frames)): + # if self.v2: + # img.data = Fv2.convert_image_dtype(img.data, torch.float32) + # img.data = Fv2.normalize(img.data, mean=self.mean, std=self.std) + # else: + frames[img_idx] = F.normalize(frames[img_idx], mean=self.mean, std=self.std) + + return frames, labels + +''' + + + + + + + + +''' +class ComposeAPI: + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, frames, labels, **kwargs): + for t in self.transforms: + frames, labels = t(frames, labels, **kwargs) + return frames, labels + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string + + +class RandomGrayscale: + def __init__(self, consistent_transform, p=0.5): + self.p = p + self.consistent_transform = consistent_transform + self.Grayscale = T.Grayscale(num_output_channels=3) + + def __call__(self, frames, labels, **kwargs): + if self.consistent_transform: + if random.random() < self.p: + for img_idx in range(len(frames)): + frames[img_idx] = self.Grayscale(frames[img_idx]) + return frames, labels + for img_idx in range(len(frames)): + if random.random() < self.p: + frames[img_idx] = self.Grayscale(frames[img_idx]) + return frames, labels + + +class ColorJitter: + def __init__(self, consistent_transform, brightness, contrast, saturation, hue): + self.consistent_transform = consistent_transform + self.brightness = ( + brightness + if isinstance(brightness, list) + else [max(0, 1 - brightness), 1 + brightness] + ) + self.contrast = ( + contrast + if isinstance(contrast, list) + else [max(0, 1 - contrast), 1 + contrast] + ) + self.saturation = ( + saturation + if isinstance(saturation, list) + else [max(0, 1 - saturation), 1 + saturation] + ) + self.hue = hue if isinstance(hue, list) or hue is None else ([-hue, hue]) + + def __call__(self, frames, labels, **kwargs): + if self.consistent_transform: + # Create a color jitter transformation params + ( + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ) = T.ColorJitter.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) + for img in frames: + if not self.consistent_transform: + ( + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ) = T.ColorJitter.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) + for fn_id in fn_idx: + if fn_id == 0 and brightness_factor is not None: + img = F.adjust_brightness(img, brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + img = F.adjust_contrast(img, contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + img = F.adjust_saturation(img, saturation_factor) + elif fn_id == 3 and hue_factor is not None: + img = F.adjust_hue(img, hue_factor) + return frames, labels + + +class RandomAffine: + def __init__( + self, + degrees, + consistent_transform, + scale=None, + translate=None, + shear=None, + image_mean=(123, 116, 103), + label_fill_value=0., + log_warning=True, + num_tentatives=1, + image_interpolation="bicubic", + ): + """ + The mask is required for this transform. + if consistent_transform if True, then the same random affine is applied to all frames and masks. + """ + self.degrees = degrees if isinstance(degrees, list) else ([-degrees, degrees]) + self.scale = scale + self.shear = ( + shear if isinstance(shear, list) else ([-shear, shear] if shear else None) + ) + self.translate = translate + self.fill_img = image_mean + self.fill_label = label_fill_value + self.consistent_transform = consistent_transform + self.log_warning = log_warning + self.num_tentatives = num_tentatives + assert self.num_tentatives >= 1., 'must have at least one if we utilise the augmentation.' + + if image_interpolation == "bicubic": + self.image_interpolation = InterpolationMode.BICUBIC + elif image_interpolation == "bilinear": + self.image_interpolation = InterpolationMode.BILINEAR + else: + raise NotImplementedError + + def __call__(self, frames, labels, **kwargs): + for _tentative in range(self.num_tentatives): + res_img, res_labels = self.transform_frames(frames, labels) + # if res is not None: + return res_img, res_labels + + # raise NotImplementedError + # if self.log_warning: + # logging.warning( + # f"Skip RandomAffine for zero-area mask in first frame after {self.num_tentatives} tentatives" + # ) + # return frames + + def transform_frames(self, frames, labels): + _, height, width = F.get_dimensions(frames[0]) + img_size = [width, height] + + if self.consistent_transform: + # Create a random affine transformation + affine_params = T.RandomAffine.get_params( + degrees=self.degrees, + translate=self.translate, + scale_ranges=self.scale, + shears=self.shear, + img_size=img_size, + ) + + for img_idx, img in enumerate(frames): + if not self.consistent_transform: + # if not consistent we create a new affine params for every frame&mask pair Create a random affine transformation + affine_params = T.RandomAffine.get_params( + degrees=self.degrees, + translate=self.translate, + scale_ranges=self.scale, + shears=self.shear, + img_size=img_size, + ) + frames[img_idx] = F.affine( + img, + *affine_params, + interpolation=self.image_interpolation, + fill=self.fill_img, + ) + labels[img_idx] = F.affine( + labels[img_idx], + *affine_params, + # default: interpolation='nearest', + fill=self.fill_label, + ) + return frames, labels + + +''' +def random_mosaic_frame( + datapoint, + index, + grid_h, + grid_w, + target_grid_y, + target_grid_x, + should_hflip, +): + # Step 1: downsize the images and paste them into a mosaic + image_data = datapoint.frames[index].data + is_pil = isinstance(image_data, PILImage.Image) + if is_pil: + H_im = image_data.height + W_im = image_data.width + image_data_output = PILImage.new("RGB", (W_im, H_im)) + else: + H_im = image_data.size(-2) + W_im = image_data.size(-1) + image_data_output = torch.zeros_like(image_data) + + downsize_cache = {} + for grid_y in range(grid_h): + for grid_x in range(grid_w): + y_offset_b = grid_y * H_im // grid_h + x_offset_b = grid_x * W_im // grid_w + y_offset_e = (grid_y + 1) * H_im // grid_h + x_offset_e = (grid_x + 1) * W_im // grid_w + H_im_downsize = y_offset_e - y_offset_b + W_im_downsize = x_offset_e - x_offset_b + + if (H_im_downsize, W_im_downsize) in downsize_cache: + image_data_downsize = downsize_cache[(H_im_downsize, W_im_downsize)] + else: + image_data_downsize = F.resize( + image_data, + size=(H_im_downsize, W_im_downsize), + interpolation=InterpolationMode.BILINEAR, + antialias=True, # antialiasing for downsizing + ) + downsize_cache[(H_im_downsize, W_im_downsize)] = image_data_downsize + if should_hflip[grid_y, grid_x].item(): + image_data_downsize = F.hflip(image_data_downsize) + + if is_pil: + image_data_output.paste(image_data_downsize, (x_offset_b, y_offset_b)) + else: + image_data_output[:, y_offset_b:y_offset_e, x_offset_b:x_offset_e] = ( + image_data_downsize + ) + + datapoint.frames[index].data = image_data_output + + # Step 2: downsize the masks and paste them into the target grid of the mosaic + for obj in datapoint.frames[index].objects: + if obj.segment is None: + continue + assert obj.segment.shape == (H_im, W_im) and obj.segment.dtype == torch.uint8 + segment_output = torch.zeros_like(obj.segment) + + target_y_offset_b = target_grid_y * H_im // grid_h + target_x_offset_b = target_grid_x * W_im // grid_w + target_y_offset_e = (target_grid_y + 1) * H_im // grid_h + target_x_offset_e = (target_grid_x + 1) * W_im // grid_w + target_H_im_downsize = target_y_offset_e - target_y_offset_b + target_W_im_downsize = target_x_offset_e - target_x_offset_b + + segment_downsize = F.resize( + obj.segment[None, None], + size=(target_H_im_downsize, target_W_im_downsize), + interpolation=InterpolationMode.BILINEAR, + antialias=True, # antialiasing for downsizing + )[0, 0] + if should_hflip[target_grid_y, target_grid_x].item(): + segment_downsize = F.hflip(segment_downsize[None, None])[0, 0] + + segment_output[ + target_y_offset_b:target_y_offset_e, target_x_offset_b:target_x_offset_e + ] = segment_downsize + obj.segment = segment_output + + return datapoint + + +class RandomMosaicVideoAPI: + def __init__(self, prob=0.15, grid_h=2, grid_w=2, use_random_hflip=False): + self.prob = prob + self.grid_h = grid_h + self.grid_w = grid_w + self.use_random_hflip = use_random_hflip + + def __call__(self, frames, **kwargs): + if random.random() > self.prob: + return datapoint + + # select a random location to place the target mask in the mosaic + target_grid_y = random.randint(0, self.grid_h - 1) + target_grid_x = random.randint(0, self.grid_w - 1) + # whether to flip each grid in the mosaic horizontally + if self.use_random_hflip: + should_hflip = torch.rand(self.grid_h, self.grid_w) < 0.5 + else: + should_hflip = torch.zeros(self.grid_h, self.grid_w, dtype=torch.bool) + for i in range(len(datapoint.frames)): + datapoint = random_mosaic_frame( + datapoint, + i, + grid_h=self.grid_h, + grid_w=self.grid_w, + target_grid_y=target_grid_y, + target_grid_x=target_grid_x, + should_hflip=should_hflip, + ) + + return datapoint +''' \ No newline at end of file diff --git a/avs.code/v1m.code/dataloader/visual/visual_augmentation.py b/avs.code/v1m.code/dataloader/visual/visual_augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..5d40aed7c8b8c08d50a46db122e1213bd4878afd --- /dev/null +++ b/avs.code/v1m.code/dataloader/visual/visual_augmentation.py @@ -0,0 +1,140 @@ +import random + +import matplotlib.pyplot as plt +import numpy +import torch +import torchvision.transforms.functional as F +import torchvision.transforms as transforms + + +class Augmentation(object): + def __init__(self, image_mean, image_std, image_width, image_height, scale_list, ignore_index=255): + self.image_size = (image_height, image_width) + # self.image_norm = (image_mean, image_std) + # self.get_crop_pos = transforms.RandomCrop(self.image_size) + self.color_jitter = transforms.ColorJitter(brightness=.5, contrast=.5, saturation=.5, hue=.25) + self.gaussian_blurring = transforms.GaussianBlur((3, 3)) + self.scale_list = scale_list + + self.normalise = transforms.Normalize(mean=image_mean, std=image_std) + self.to_tensor = transforms.ToTensor() + + self.ignore_index = ignore_index + + # self.normalise = transforms.Normalize(mean=image_mean, std=image_std) + + # if setup == "avs" or setup == "avss" or setup == "avss_binary": + # # AVS + # self.scale_list = [.5, .75, 1.] + # self.color_jitter = None + # else: + # # COCO + # # self.scale_list = [.75, 1., 1.25, 1.5, 1.75, 2.] + # self.scale_list = [0.5,0.75,1.0,1.25,1.5,1.75,2.0] + + # def normalise(self, image): + # image = image / 255.0 + # image = image - self.image_norm[0] + # image = image / self.image_norm[1] + # return image + + def resize(self, image_, label_, size=None): + h_, w_ = self.image_size if size is None else size + image_ = F.resize(image_, (h_, w_), transforms.InterpolationMode.BICUBIC) + label_ = F.resize(label_, (h_, w_), transforms.InterpolationMode.NEAREST) + return image_, label_ + + def random_crop_with_padding(self, image_, label_): + w_, h_ = image_.size + if min(h_, w_) < min(self.image_size): + res_w_ = max(self.image_size[0] - w_, 0) + res_h_ = max(self.image_size[1] - h_, 0) + image_ = F.pad(image_, [0, 0, res_w_, res_h_], fill=(numpy.array(self.image_norm[0]) * 255.).tolist()) + # image_ = F.pad(image_, [0, 0, res_w_, res_h_], fill=self.ignore_index) # if error, define the padding value. + label_ = F.pad(label_, [0, 0, res_w_, res_h_], fill=self.ignore_index) + + pos_ = self.get_crop_pos.get_params(image_, self.image_size) + image_ = F.crop(image_, *pos_) + label_ = F.crop(label_, *pos_) + + return image_, label_ + + # @staticmethod + def random_scales(self, image_, label_): + w_, h_ = image_.size + chosen_scale = random.choice(self.scale_list) + w_, h_ = int(w_ * chosen_scale), int(h_ * chosen_scale) + image_ = F.resize(image_, (h_, w_), transforms.InterpolationMode.BICUBIC) + label_ = F.resize(label_, (h_, w_), transforms.InterpolationMode.NEAREST) + return image_, label_ + + @staticmethod + def random_flip_h(image_, label_): + chosen_flip = random.random() > 0.5 + image_ = F.hflip(image_) if chosen_flip else image_ + label_ = F.hflip(label_) if chosen_flip else label_ + return image_, label_ + + def augment_entire_clip(self, x_list, y_list): + degree_ = float(torch.empty(1).uniform_(float(-25.), float(25.)).item()) + shear_ = [float(torch.empty(1).uniform_(float(-20.), float(20.)).item()), + torch.empty(1).uniform_(float(-20.), float(20.)).item()] + dice = random.random() + for index, single_x in enumerate(x_list): + if dice <= 0.1: + single_x = F.rgb_to_grayscale(single_x, num_output_channels=3) + + single_x = F.affine(single_x, angle=degree_, shear=shear_, translate=[0,0], scale=1., + interpolation=transforms.InterpolationMode.BILINEAR, fill=[0., 0., 0.]) + single_y = F.affine(y_list[index], angle=degree_, shear=shear_, translate=[0,0], scale=1., + interpolation=transforms.InterpolationMode.NEAREST, fill=[0.]) + x_list[index] = single_x + y_list[index] = single_y + + return x_list, y_list + + + + + def train_aug(self, x_, y_): + x_, y_ = self.random_flip_h(x_, y_) + # # x, y = self.random_scales(x, y) + x_, y_ = self.resize(x_, y_) + + if self.color_jitter is not None and random.random() < 0.5: + x_ = self.color_jitter(x_) + if self.gaussian_blurring is not None and random.random() < 0.5: + x_ = self.gaussian_blurring(x_) + + # x, y = self.random_crop_with_padding(x, y) + + x_ = self.normalise(self.to_tensor(x_)).type(torch.float32) + # receive pseudo labels. + y_ = torch.tensor(numpy.array(y_)[numpy.newaxis, ...], dtype=torch.float) + return x_, y_ + + def test_process(self, x_, y_): + # x = self.to_tensor(x) + # y = torch.tensor(numpy.asarray(y)).long() + + # following AVSbench setup, we fix image size (224, 224) + x_, y_ = self.resize(x_, y_) + + x_ = self.normalise(self.to_tensor(x_)).type(torch.float32) + y_ = torch.tensor(numpy.array(y_)[numpy.newaxis, ...], dtype=torch.float) + return x_, y_ + + def __call__(self, x, y, split): + return self.train_aug(x, y) if split == "train" \ + else self.test_process(x, y) + + + + + + + + + + + diff --git a/avs.code/v1m.code/dataloader/visual/visual_dataset.py b/avs.code/v1m.code/dataloader/visual/visual_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..74d6965982acd11304d0a4bba31f09fc792fb50e --- /dev/null +++ b/avs.code/v1m.code/dataloader/visual/visual_dataset.py @@ -0,0 +1,127 @@ +import os +import re +import PIL.Image +import matplotlib.pyplot as plt +import numpy +import torch +import pandas +import torchvision + + +class Visual(torch.utils.data.Dataset): + def __init__(self, augmentation, directory_path, split, image_size, image_embedding_size): + self.augment = augmentation + self.directory_path = directory_path + self.split = split + self.image_size = image_size + self.embedding_size = image_embedding_size + + def load_data(self, file_prefix): + frame_path = os.path.join(file_prefix, 'frames') + frame_path = [os.path.join(frame_path, i) for i in os.listdir(frame_path)] + label_path = os.path.join(file_prefix, 'labels_rgb') + label_path = [os.path.join(label_path, i) for i in os.listdir(label_path)] + + # if self.split == 'train': + # label_path += [os.path.join(file_prefix.replace('v1s', 'v1s_sam2_pseudo_labels'), i) for i in + # os.listdir(file_prefix.replace('v1s', 'v1s_sam2_pseudo_labels'))] + + frame_path.sort(key=lambda x: tuple(map(int, x.split('/')[-1].split("_")[-1].split('.jpg')[0]))) + label_path.sort(key=lambda x: tuple(map(int, x.split('/')[-1].split("_")[-1].split('.png')[0]))) + + frame = [PIL.Image.open(i) for i in frame_path] + label = [PIL.Image.open(i).convert('L') for i in label_path] + + # if self.split == 'train': + # label += [PIL.Image.new('L', frame[0].size)] * (len(frame)-len(label)) + + label_idx = torch.tensor(list([1] + [0] * 4), dtype=torch.bool) + # fulfill the empty page. + # we utilise pseudo-labels now. + # label_idx = torch.tensor(list([1] + [0] * (len(frame) - len(label))), dtype=torch.bool) + # label += [PIL.Image.new('L', frame[0].size)] * (len(frame)-len(label)) + + # receive the prompts from the ground truth. + # prompts = {"point_coords": torch.nan, "point_labels": torch.nan, + # "masks": [None]*len(frame), "box_coords": [None]*len(frame)} + + prompts = {} + image_batch = [None]*len(frame) + label_batch = [None]*len(frame) + + if self.split == 'train': + # frame, label = self.augment.augment_entire_clip(frame, label) + frame, label = self.augment(frame, label) + + + for i in range(len(frame)): + if self.split == 'test': + curr_frame, curr_label = self.augment(frame[i], label[i], split=self.split) + else: + curr_frame, curr_label = frame[i], label[i] + # if self.split == 'train' and i > 0: + # curr_label = curr_label / 255. + # curr_label[curr_label > 0.5] = 1 + # curr_label[curr_label < 0.5] = 0 + # # curr_label[(0.05 < curr_label) & (curr_label < 0.95)] = 255 + # # we temporarily make it to be hard mask; + # # curr_label = ((curr_label / 255.) - 0.5) * 2 + # # curr_label[curr_label >= 0.] = 1. + # # curr_label[curr_label < 0.] = 0. + # else: + curr_label[curr_label > 0.] = 1. + image_batch[i], label_batch[i] = curr_frame, curr_label + + # image_batch[i], label_batch[i] = self.augment(frame[i], label[i], split=self.split) + # note: we simply convert the code to binary mask in v1s, v1m; + # to some reason, we failed to load the label in `L' format and had to hardcoding here. + # label_batch[i][label_batch[i] > 0.] = 1. + + # prompts['box_coords'][i], prompts['masks'][i] = self.receive_other_prompts(label_batch[i]) + + # organise the prompts + # prompts.update({'masks': torch.stack(prompts['masks'], dim=0)}) + # prompts.update({'box_coords': torch.stack(prompts['box_coords'], dim=0)}) + # prompts.update({'point_labels': torch.stack(prompts['point_labels'], dim=0)}) + prompts.update({'label_index': label_idx}) + return torch.stack(image_batch, dim=0), torch.stack(label_batch, dim=0), prompts + + def receive_other_prompts(self, y_): + # y_ = torch.zeros_like(y_) + if len(torch.unique(y_)) > 1: + # foreground point + points_foreground = torch.stack(torch.where(y_ > 0)[::-1], dim=0).transpose(1, 0) + + # bbox prompt (left-top corner & right-bottom corner) + bbox_one = torch.min(points_foreground[:, 0]), torch.min(points_foreground[:, 1]) + bbox_fou = torch.max(points_foreground[:, 0]), torch.max(points_foreground[:, 1]) + bbox_coord = torch.tensor(bbox_one + bbox_fou, dtype=torch.float) + bbox_coord = self.transform_coords(bbox_coord, orig_hw=y_.squeeze().shape) + # mask prompt + low_mask = torchvision.transforms.functional.resize(y_.clone(), [self.embedding_size*4, self.embedding_size*4], + torchvision.transforms.InterpolationMode.NEAREST) + else: + # for the pure background situation. + bbox_coord = torch.zeros([4], dtype=torch.float).fill_(float('nan')) + low_mask = torch.zeros([1, self.embedding_size*4, self.embedding_size*4], dtype=torch.float).fill_(float('nan')) + + return bbox_coord, low_mask + + # we transfer the coords to SAM's input resolution (1024, 1024). + def transform_coords(self, coords: torch.Tensor, orig_hw=None) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, + If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + + Returns + Un-normalized coordinates in the range of [0, 1] which is expected by the sam2 model. + """ + h, w = orig_hw + coords = coords.clone().reshape(-1, 2, 2) + coords[..., 0] = coords[..., 0] / w + coords[..., 1] = coords[..., 1] / h + coords = coords * self.image_size # unnormalize coords + return coords.reshape(4) + + + diff --git a/avs.code/v1m.code/inference.py b/avs.code/v1m.code/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..e71db5eb014fb4347f3d653fc5b9643f1da8d935 --- /dev/null +++ b/avs.code/v1m.code/inference.py @@ -0,0 +1,193 @@ +"""Distributed inference on the test set; runs the same three `process` modes as training validation.""" +import os +import pathlib +import torch +import numpy +import random +import argparse +from easydict import EasyDict + +# Avoid import failure when configs.config creates saved_dir without write permission. +_real_mkdir = pathlib.Path.mkdir + + +def _safe_mkdir(self, mode=0o777, parents=False, exist_ok=False): + try: + return _real_mkdir(self, mode, parents=parents, exist_ok=exist_ok) + except PermissionError: + pass + + +pathlib.Path.mkdir = _safe_mkdir + + +def seed_it(seed): + random.seed(seed) + os.environ["PYTHONSEED"] = str(seed) + numpy.random.seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.enabled = True + torch.manual_seed(seed) + + +class _DummyTensorboard: + """Minimal Tensorboard stub so Trainer.valid runs without wandb logging.""" + + def upload_wandb_info(self, info_dict): + pass + + def upload_wandb_image(self, *args, **kwargs): + pass + + +def main(local_rank, ngpus_per_node, hyp_param): + hyp_param.local_rank = local_rank + torch.distributed.init_process_group( + backend='nccl', + init_method='env://', + rank=hyp_param.local_rank, + world_size=hyp_param.gpus * 1 + ) + seed_it(local_rank + hyp_param.seed) + + import model.visual.sam2 # noqa: F401 — registers Hydra `configs` + from hydra import compose + from omegaconf import OmegaConf + + arch_h = compose(config_name='auralfuser/architecture.yaml') + OmegaConf.resolve(arch_h) + hyp_param.aural_fuser = OmegaConf.to_container(arch_h.aural_fuser, resolve=True) + + train_cfg = compose(config_name='training/sam2_training_config.yaml') + OmegaConf.resolve(train_cfg) + hyp_param.contrastive_learning = OmegaConf.to_container(train_cfg.contrastive_learning, resolve=True) + + from model.mymodel import AVmodel + av_model = AVmodel(hyp_param).cuda() + torch.cuda.set_device(hyp_param.local_rank) + ckpt_sd = torch.load(hyp_param.inference_ckpt, map_location="cpu") + if not isinstance(ckpt_sd, dict): + raise TypeError("Checkpoint must be a state_dict dictionary.") + # Same as v1s/v2: full-model ckpt vs train-only aural_fuser ckpt (e.g. keys vgg.*, f_blocks.*). + if any(k.startswith("v_model.") or k.startswith("aural_fuser.") for k in ckpt_sd.keys()): + av_model.load_state_dict(ckpt_sd, strict=True) + else: + av_model.aural_fuser.load_state_dict(ckpt_sd, strict=True) + + av_model = torch.nn.parallel.distributed.DistributedDataParallel(av_model, device_ids=[hyp_param.local_rank], + find_unused_parameters=False) + av_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(av_model) + av_model.eval() + + from dataloader.dataset import AV + from dataloader.visual.visual_augmentation import Augmentation as VisualAugmentation + from dataloader.audio.audio_augmentation import Augmentation as AudioAugmentation + from torch.utils.data import DataLoader, Subset + from torch.utils.data.distributed import DistributedSampler + + visual_augmentation = VisualAugmentation(hyp_param.image_mean, hyp_param.image_std, + hyp_param.image_size, hyp_param.image_size, + hyp_param.scale_list, ignore_index=hyp_param.ignore_index) + audio_augmentation = AudioAugmentation(mono=True) + + dataset = AV(split='test', augmentation={"visual": visual_augmentation, "audio": audio_augmentation}, + param=hyp_param, root_path=hyp_param.data_root_path, data_name=hyp_param.inference_data_name) + + max_batches = getattr(hyp_param, "inference_max_batches", 0) or 0 + if max_batches > 0: + n_samples = min(max_batches * hyp_param.batch_size, len(dataset)) + dataset = Subset(dataset, range(n_samples)) + + sampler = DistributedSampler(dataset, shuffle=False) + test_dataloader = DataLoader(dataset, batch_size=hyp_param.batch_size, sampler=sampler, + num_workers=hyp_param.num_workers) + + from trainer.train import Trainer + from utils.foreground_iou import ForegroundIoU + from utils.foreground_fscore import ForegroundFScore + + metrics = { + "foreground_iou": ForegroundIoU(), + "foreground_f-score": ForegroundFScore(hyp_param.local_rank), + } + trainer = Trainer(hyp_param, loss=None, tensorboard=_DummyTensorboard(), metrics=metrics) + + # Same three modes as main.py validation: default first mask / iou_select / iou_occ_select + runs = [ + ("", "default (logits[:,0])"), + ("iou_select", "iou_select"), + ("iou_occ_select", "iou_occ_select"), + ] + results = [] + for process, label in runs: + fiou, ffscore = trainer.valid(epoch=0, dataloader=test_dataloader, model=av_model, process=process) + results.append((label, fiou, ffscore)) + torch.cuda.empty_cache() + + if hyp_param.local_rank <= 0: + print("\n========== inference (same three process flags as training valid) ==========") + for label, fiou, ffscore in results: + print(" {:32s} f_iou={} f_f-score={}".format(label, fiou, ffscore)) + print("=======================================================\n") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Inference: full test set + three process modes') + + parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N') + + parser.add_argument("--local_rank", type=int, default=-1, + help='multi-process training for DDP') + + parser.add_argument('-g', '--gpus', default=1, type=int, + help='number of gpus per node') + + parser.add_argument('--batch_size', default=1, type=int, + help='Batch size (match training if needed)') + + parser.add_argument('--epochs', default=80, type=int, + help="unused") + + parser.add_argument('--lr', default=1e-5, type=float, + help="unused") + + parser.add_argument('--online', action="store_true", + help='unused') + + parser.add_argument( + '--inference_ckpt', type=str, default=None, + help='Trained AuralSAM2 checkpoint (.pth state_dict). ' + 'SAM2 backbone is loaded from backbone_weight in configs (same path as training: repo_root/ckpts/sam_ckpts/). ' + 'Default if unset: avs.code/training_details/.../hiera_l.pth', + ) + parser.add_argument('--inference_data_name', type=str, default='v1m', + help='AVSBench subset folder label (v1s|v1m|v2); must match training test split') + parser.add_argument('--inference_max_batches', type=int, default=0, + help='0 = full test; >0 = first N batches only (debug)') + + args = parser.parse_args() + + from configs.config import C + + args = EasyDict({**C, **vars(args)}) + + _repo = pathlib.Path(__file__).resolve().parent + # Repo root: .../AuralSAM2 (parent of avs.code) + _workspace = _repo.parent.parent + args.data_root_path = str(_workspace / 'AVSBench') + args.backbone_weight = str(_workspace / 'ckpts' / 'sam_ckpts' / 'sam2_hiera_large.pt') + args.audio.PRETRAINED_VGGISH_MODEL_PATH = str(_workspace / 'ckpts' / 'vggish-10086976.pth') + args.saved_dir = '/tmp/v1m_infer_ckpt' + pathlib.Path(args.saved_dir).mkdir(parents=True, exist_ok=True) + if args.inference_ckpt is None: + args.inference_ckpt = str( + _repo.parent / 'training_details' / 'v1m' / 'hiera_l' / 'hiera_l.pth' + ) + + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '9901' + + torch.multiprocessing.spawn(main, nprocs=args.gpus, args=(args.gpus, args)) diff --git a/avs.code/v1m.code/loss/training/__init__.py b/avs.code/v1m.code/loss/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..da8932ea6e0916b9f16cb514f1e64704440d554e --- /dev/null +++ b/avs.code/v1m.code/loss/training/__init__.py @@ -0,0 +1,2 @@ +"""Training loss modules.""" + diff --git a/avs.code/v1m.code/loss/training/contrastive_learning.py b/avs.code/v1m.code/loss/training/contrastive_learning.py new file mode 100644 index 0000000000000000000000000000000000000000..41287d26dd05dfe367b41c96375770dc6c8afa94 --- /dev/null +++ b/avs.code/v1m.code/loss/training/contrastive_learning.py @@ -0,0 +1,201 @@ +from abc import ABC + +import torch +import torch.nn as nn + + +class ContrastLoss(nn.Module, ABC): + def __init__(self, hyp_param): + super().__init__() + self.param = hyp_param + _defaults = { + "temperature": 0.10, + "ignore_idx": 255, + "ood_idx": 254, + "max_views": 512, + "proj_dim": 512, + "sample_limits": 128, + "total_limits": 15240, + } + _raw = getattr(hyp_param, "contrastive_learning", None) or {} + _cfg = {**_defaults, **_raw} + self.temperature = _cfg["temperature"] + self.ignore_idx = _cfg["ignore_idx"] + self.ood_idx = _cfg["ood_idx"] + self.max_views = _cfg["max_views"] + self.proj_dim = _cfg["proj_dim"] + self.sample_limits = _cfg["sample_limits"] + self.total_limits = _cfg["total_limits"] + + def select_class_wise_samples(self, embeddings, audio_embeddings, predictions, masks, batch_idx): + embedding_sample_list = [] + label_list = [] + embedding_sample_list_a = [] + label_list_a = [] + class_index_list = torch.unique(masks) + + if len(class_index_list) > 1: + for class_index in class_index_list[1:]: + embedding_sample_list_a.append(audio_embeddings.unsqueeze(0)) + label_list_a.append(class_index.unsqueeze(0) + batch_idx * 1e3) + else: + embedding_sample_list_a.append(audio_embeddings.unsqueeze(0)) + label_list_a.append(torch.zeros([1], device=embeddings.device) + batch_idx * 1e3) + + sample_limits = self.sample_limits + embeddings = embeddings.permute(1, 0) + for class_index in class_index_list: + hard_indices = embeddings[((masks != predictions) & (masks == class_index)).nonzero()] + easy_indices = embeddings[((masks == predictions) & (masks == class_index)).nonzero()] + + hard_indices_num, easy_indices_num = hard_indices.shape[0], easy_indices.shape[0] + selective_num_hard = min(sample_limits, hard_indices_num) + selective_num_easy = min(sample_limits, easy_indices_num) + + if (selective_num_hard + selective_num_easy) < sample_limits * 2: + if selective_num_hard > selective_num_easy: + selective_num_hard += sample_limits * 2 - selective_num_easy + else: + selective_num_easy += sample_limits * 2 - selective_num_hard + + hard_chosen_indices = torch.randperm(hard_indices_num)[:selective_num_hard] + embedding_sample_list.append(hard_indices[hard_chosen_indices]) + label_list.append(masks[hard_chosen_indices] + batch_idx * 1e3) + + easy_chosen_indices = torch.randperm(easy_indices_num)[:selective_num_easy] + embedding_sample_list.append(easy_indices[easy_chosen_indices]) + label_list.append(masks[easy_chosen_indices] + batch_idx * 1e3) + return embedding_sample_list, label_list, embedding_sample_list_a, label_list_a + + def forward_audio_visual(self, visual_embeddings, audio_embeddings, masks, predictions): + masks = masks.flatten(start_dim=1) + predictions = predictions.flatten(start_dim=1) + visual_embeddings = visual_embeddings.flatten(start_dim=-2) + + visual_embedding_sample_list = [] + visual_label_list = [] + audio_embedding_sample_list = [] + audio_label_list = [] + + for frame_idx in range(masks.shape[0]): + current_vision_feats = visual_embeddings[frame_idx] + current_masks = masks[frame_idx] + current_predictions = predictions[frame_idx] + current_audio_feats = audio_embeddings[frame_idx] + for layer_idx in range(3): + ( + selected_vision_embeddings, + selected_vision_labels, + selected_audio_embeddings, + selected_audio_labels, + ) = self.select_class_wise_samples( + current_vision_feats[layer_idx], + current_audio_feats[layer_idx], + current_predictions, + current_masks, + 0, + ) + visual_embedding_sample_list += selected_vision_embeddings + visual_label_list += selected_vision_labels + audio_embedding_sample_list += selected_audio_embeddings + audio_label_list += selected_audio_labels + + if len(visual_embedding_sample_list) == 0: + return 0.0 + + visual_embedding_sample_list = torch.cat(visual_embedding_sample_list, dim=0).squeeze() + visual_label_list = torch.cat(visual_label_list, dim=0).unsqueeze(-1) + audio_embedding_sample_list = torch.cat(audio_embedding_sample_list, dim=0).squeeze() + audio_label_list = torch.cat(audio_label_list).unsqueeze(1) + + total_limits = self.total_limits + if visual_embedding_sample_list.shape[0] > total_limits: + rand_index = torch.randperm(visual_embedding_sample_list.shape[0])[total_limits] + visual_embedding_sample_list = visual_embedding_sample_list[:rand_index] + visual_label_list = visual_label_list[:rand_index] + loss = self.info_nce( + visual_embedding_sample_list, + visual_label_list, + audio_embedding_sample_list, + audio_label_list, + ) + return loss + + def forward(self, embeddings, output_dicts, masks): + predictions = torch.cat([i["multistep_pred_masks"] for i in output_dicts]) + predictions = torch.nn.functional.interpolate( + predictions, + size=(int(self.param.image_size / 16), int(self.param.image_size / 16)), + mode="bilinear", + align_corners=False, + ).squeeze() + masks = torch.nn.functional.interpolate( + masks.unsqueeze(1), + size=(int(self.param.image_size / 16), int(self.param.image_size / 16)), + mode="nearest", + ).squeeze() + visual_embeddings, audio_embeddings = embeddings + visual_embeddings = torch.cat( + [ + torch.cat( + [ + visual_embeddings[0][i].unsqueeze(0), + visual_embeddings[1][i].unsqueeze(0), + visual_embeddings[2][i].unsqueeze(0), + ] + ).unsqueeze(0) + for i in range(masks.shape[0]) + ] + ) + audio_embeddings = torch.cat( + [ + torch.cat( + [ + audio_embeddings[0][i].unsqueeze(0), + audio_embeddings[1][i].unsqueeze(0), + audio_embeddings[2][i].unsqueeze(0), + ] + ).unsqueeze(0) + for i in range(masks.shape[0]) + ] + ) + return self.forward_audio_visual( + visual_embeddings, audio_embeddings.squeeze(), masks, predictions + ) + + @staticmethod + def manipulate_cover_mask(a_label, current_mask): + a_label = a_label + 1 + visual_mask = torch.matmul(a_label, torch.transpose(a_label, 0, 1)) + current_mask[: visual_mask.shape[1], : visual_mask.shape[0]][visual_mask == 1.0] = 0 + current_mask[: visual_mask.shape[1], : visual_mask.shape[0]][visual_mask == 4.0] = 0 + return current_mask + + def info_nce(self, anchors_, a_labels_, contras_, c_labels_): + c_labels_ = torch.cat([a_labels_, c_labels_]) + contras_ = torch.cat([anchors_, contras_]) + mask = torch.eq(a_labels_, torch.transpose(c_labels_, 0, 1)).float() + + anchor_dot_contrast = torch.div( + torch.matmul(anchors_, torch.transpose(contras_, 0, 1)), + self.temperature, + ) + + logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) + logits = anchor_dot_contrast - logits_max.detach() + neg_mask = 1 - mask + + mask = self.manipulate_cover_mask(a_label=a_labels_, current_mask=mask) + mask = mask.fill_diagonal_(0.0) + + neg_logits = torch.exp(logits) * neg_mask + neg_logits = neg_logits.sum(1, keepdim=True) + exp_logits = torch.exp(logits) + log_prob = logits - torch.log(exp_logits + neg_logits) + + mask_pos_pairs = mask.sum(1) + mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs) + mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs + assert not torch.isnan(mean_log_prob_pos).any(), print(torch.isnan(log_prob).any()) + return -mean_log_prob_pos.mean() + diff --git a/avs.code/v1m.code/loss/training/sam2_training_loss.py b/avs.code/v1m.code/loss/training/sam2_training_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..f6ce1b02c0dbbf5d7e771b314a4a537145e28978 --- /dev/null +++ b/avs.code/v1m.code/loss/training/sam2_training_loss.py @@ -0,0 +1,220 @@ +from collections import defaultdict +from typing import Dict, List + +import torch +import torch.distributed +import torch.nn as nn +import torch.nn.functional as F + +CORE_LOSS_KEY = "core_loss" + + +def dice_loss(inputs, targets, num_objects, loss_on_multimask=False): + inputs = inputs.sigmoid() + if loss_on_multimask: + assert inputs.dim() == 4 and targets.dim() == 4 + inputs = inputs.flatten(2) + targets = targets.flatten(2) + numerator = 2 * (inputs * targets).sum(-1) + else: + inputs = inputs.flatten(1) + numerator = 2 * (inputs * targets).sum(1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + if loss_on_multimask: + return loss / num_objects + return loss.sum() / num_objects + + +def sigmoid_focal_loss( + inputs, + targets, + num_objects, + alpha: float = 0.25, + gamma: float = 2, + loss_on_multimask=False, +): + prob = inputs.sigmoid() + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + if loss_on_multimask: + assert loss.dim() == 4 + return loss.flatten(2).mean(-1) / num_objects + return loss.mean(1).sum() / num_objects + + +def iou_loss( + inputs, targets, pred_ious, num_objects, loss_on_multimask=False, use_l1_loss=False +): + assert inputs.dim() == 4 and targets.dim() == 4 + pred_mask = inputs.flatten(2) > 0 + gt_mask = targets.flatten(2) > 0 + area_i = torch.sum(pred_mask & gt_mask, dim=-1).float() + area_u = torch.sum(pred_mask | gt_mask, dim=-1).float() + actual_ious = area_i / torch.clamp(area_u, min=1.0) + + if use_l1_loss: + loss = F.l1_loss(pred_ious, actual_ious, reduction="none") + else: + loss = F.mse_loss(pred_ious, actual_ious, reduction="none") + if loss_on_multimask: + return loss / num_objects + return loss.sum() / num_objects + + +class MultiStepMultiMasksAndIous(nn.Module): + def __init__( + self, + weight_dict, + focal_alpha=0.25, + focal_gamma=2, + supervise_all_iou=False, + iou_use_l1_loss=False, + pred_obj_scores=False, + focal_gamma_obj_score=0.0, + focal_alpha_obj_score=-1, + gpu_num=1, + ): + super().__init__() + self.weight_dict = weight_dict + self.focal_alpha = focal_alpha + self.focal_gamma = focal_gamma + self.world_size = gpu_num + assert "loss_mask" in self.weight_dict + assert "loss_dice" in self.weight_dict + assert "loss_iou" in self.weight_dict + if "loss_class" not in self.weight_dict: + self.weight_dict["loss_class"] = 0.0 + + self.focal_alpha_obj_score = focal_alpha_obj_score + self.focal_gamma_obj_score = focal_gamma_obj_score + self.supervise_all_iou = supervise_all_iou + self.iou_use_l1_loss = iou_use_l1_loss + self.pred_obj_scores = pred_obj_scores + + def forward(self, outs_batch: List[Dict], targets_batch: torch.Tensor): + assert len(outs_batch) == len(targets_batch) + num_objects = torch.tensor( + targets_batch.shape[1], device=targets_batch.device, dtype=torch.float + ) + torch.distributed.all_reduce(num_objects) + num_objects = torch.clamp(num_objects / self.world_size, min=1).item() + + losses = defaultdict(int) + for outs, targets in zip(outs_batch, targets_batch): + cur_losses = self._forward(outs, targets, num_objects) + for k, v in cur_losses.items(): + losses[k] += v + return losses + + def _forward(self, outputs: Dict, targets: torch.Tensor, num_objects): + target_masks = targets.unsqueeze(1).float() + assert target_masks.dim() == 4 + + src_masks_list = outputs["multistep_pred_multimasks_high_res"] + ious_list = outputs["multistep_pred_ious"] + object_score_logits_list = outputs["multistep_object_score_logits"] + assert len(src_masks_list) == len(ious_list) + assert len(object_score_logits_list) == len(ious_list) + + losses = {"loss_mask": 0, "loss_dice": 0, "loss_iou": 0, "loss_class": 0} + for src_masks, ious, object_score_logits in zip( + src_masks_list, ious_list, object_score_logits_list + ): + self._update_losses( + losses, src_masks, target_masks, ious, num_objects, object_score_logits + ) + losses[CORE_LOSS_KEY] = self.reduce_loss(losses) + return losses + + def _update_losses( + self, losses, src_masks, target_masks, ious, num_objects, object_score_logits + ): + target_masks = target_masks.expand_as(src_masks) + loss_multimask = sigmoid_focal_loss( + src_masks, + target_masks, + num_objects, + alpha=self.focal_alpha, + gamma=self.focal_gamma, + loss_on_multimask=True, + ) + loss_multidice = dice_loss( + src_masks, target_masks, num_objects, loss_on_multimask=True + ) + if not self.pred_obj_scores: + loss_class = torch.tensor( + 0.0, dtype=loss_multimask.dtype, device=loss_multimask.device + ) + target_obj = torch.ones( + loss_multimask.shape[0], + 1, + dtype=loss_multimask.dtype, + device=loss_multimask.device, + ) + else: + target_obj = torch.any((target_masks[:, 0] > 0).flatten(1), dim=-1)[ + ..., None + ].float() + loss_class = sigmoid_focal_loss( + object_score_logits, + target_obj, + num_objects, + alpha=self.focal_alpha_obj_score, + gamma=self.focal_gamma_obj_score, + ) + + loss_multiiou = iou_loss( + src_masks, + target_masks, + ious, + num_objects, + loss_on_multimask=True, + use_l1_loss=self.iou_use_l1_loss, + ) + assert loss_multimask.dim() == 2 + assert loss_multidice.dim() == 2 + assert loss_multiiou.dim() == 2 + if loss_multimask.size(1) > 1: + loss_combo = ( + loss_multimask * self.weight_dict["loss_mask"] + + loss_multidice * self.weight_dict["loss_dice"] + ) + best_loss_inds = torch.argmin(loss_combo, dim=-1) + batch_inds = torch.arange(loss_combo.size(0), device=loss_combo.device) + + loss_mask = loss_multimask[batch_inds, best_loss_inds].unsqueeze(1) + loss_dice = loss_multidice[batch_inds, best_loss_inds].unsqueeze(1) + if self.supervise_all_iou: + loss_iou = loss_multiiou.mean(dim=-1).unsqueeze(1) + else: + loss_iou = loss_multiiou[batch_inds, best_loss_inds].unsqueeze(1) + else: + loss_mask = loss_multimask + loss_dice = loss_multidice + loss_iou = loss_multiiou + + loss_mask = loss_mask * target_obj + loss_dice = loss_dice * target_obj + loss_iou = loss_iou * target_obj + + losses["loss_mask"] += loss_mask.sum() + losses["loss_dice"] += loss_dice.sum() + losses["loss_iou"] += loss_iou.sum() + losses["loss_class"] += loss_class + + def reduce_loss(self, losses): + reduced_loss = 0.0 + for loss_key, weight in self.weight_dict.items(): + if loss_key not in losses: + raise ValueError(f"{type(self)} doesn't compute {loss_key}") + if weight != 0: + reduced_loss += losses[loss_key] * weight + return reduced_loss + diff --git a/avs.code/v1m.code/main.py b/avs.code/v1m.code/main.py new file mode 100644 index 0000000000000000000000000000000000000000..afccee524b130beaeb05360bb9c9e5935e73d51e --- /dev/null +++ b/avs.code/v1m.code/main.py @@ -0,0 +1,166 @@ +"""DDP training entry: AV model with SAM2 frozen, AuralFuser trainable, Hydra transforms and loss.""" +import os +import torch +import numpy +import random +import argparse +from easydict import EasyDict + + +def seed_it(seed): + """Fix RNGs and cuDNN for reproducible runs (rank offsets seed in DDP).""" + os.environ["PYTHONSEED"] = str(seed) + random.seed(seed) + numpy.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.enabled = True + torch.backends.cudnn.deterministic = True + + torch.backends.cudnn.benchmark = False + + +def main(local_rank, ngpus_per_node, hyp_param): + hyp_param.local_rank = local_rank + # NCCL process group; world size = GPUs on this node + torch.distributed.init_process_group( + backend='nccl', + init_method='env://', + rank=hyp_param.local_rank, + world_size=hyp_param.gpus * 1 + ) + seed_it(local_rank + hyp_param.seed) + + torch.cuda.set_device(hyp_param.local_rank) + + import model.visual.sam2 # noqa: F401 — registers Hydra `configs` (initialize_config_module) + + from hydra import compose + from hydra.utils import instantiate + from omegaconf import OmegaConf + + # Hydra configs under v1m.code/configs (same pattern as training/sam2_training_config.yaml) + transform_config_path = 'training/sam2_training_config.yaml' + + if 'hiera_t' in hyp_param.sam_config_path: + hyp_param.image_size = 224 + hyp_param.image_embedding_size = int(hyp_param.image_size / 16) + print('\n upload image size to be {}x{} \n'.format(224, 224), flush=True) + + cfg = compose(config_name=transform_config_path) + OmegaConf.resolve(cfg) + hyp_param.contrastive_learning = OmegaConf.to_container(cfg.contrastive_learning, resolve=True) + + arch_h = compose(config_name='auralfuser/architecture.yaml') + OmegaConf.resolve(arch_h) + hyp_param.aural_fuser = OmegaConf.to_container(arch_h.aural_fuser, resolve=True) + + from model.mymodel import AVmodel + av_model = AVmodel(hyp_param).cuda(hyp_param.local_rank) + + av_model = torch.nn.parallel.distributed.DistributedDataParallel(av_model, device_ids=[hyp_param.local_rank], + find_unused_parameters=True) + + # Optimizer: parameter groups from AuralFuser only (train_* vs VGG backbone) + from utils.utils import manipulate_params + parameter_list = manipulate_params(hyp_param, av_model.module.aural_fuser) + optimiser = torch.optim.AdamW(parameter_list, betas=(0.9, 0.999)) + + from dataloader.dataset import AV + from dataloader.visual.visual_augmentation import Augmentation as VisualAugmentation + from dataloader.audio.audio_augmentation import Augmentation as AudioAugmentation + from torch.utils.data.distributed import DistributedSampler + + compose_api = instantiate(cfg.train_transforms, _recursive_=True)[0] + + audio_augmentation = AudioAugmentation(mono=True) + train_dataset = AV(split='train', augmentation={"visual": compose_api, "audio": audio_augmentation}, + param=hyp_param, root_path=hyp_param.data_root_path, data_name=hyp_param.data_name) + + + visual_augmentation = VisualAugmentation(hyp_param.image_mean, hyp_param.image_std, + hyp_param.image_size, hyp_param.image_size, + hyp_param.scale_list, ignore_index=hyp_param.ignore_index) + + audio_augmentation = AudioAugmentation(mono=True) + + random_sampler = DistributedSampler(train_dataset, shuffle=True) + train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=hyp_param.batch_size, + sampler=random_sampler, + num_workers=hyp_param.num_workers, drop_last=True) + + test_dataset = AV(split='test', augmentation={"visual": visual_augmentation, "audio": audio_augmentation}, + param=hyp_param, root_path=hyp_param.data_root_path, data_name=hyp_param.data_name) + + order_sampler = DistributedSampler(test_dataset, shuffle=False) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, sampler=order_sampler, + num_workers=hyp_param.num_workers) + + + criterion = instantiate(cfg.loss, _recursive_=True)['all'] + from utils.tensorboard import Tensorboard + tensorboard = Tensorboard(config=hyp_param) if hyp_param.local_rank <= 0 else None + + from trainer.train import Trainer + from utils.foreground_iou import ForegroundIoU + from utils.foreground_fscore import ForegroundFScore + metrics = {"foreground_iou": ForegroundIoU(), "foreground_f-score": ForegroundFScore(0 if hyp_param.local_rank <= 0 else hyp_param.local_rank)} + + trainer = Trainer(hyp_param, loss=criterion, tensorboard=tensorboard, metrics=metrics) + + + curr_best = 0. # checkpoint when IoU (iou_select mode) improves + + for epoch in range(hyp_param.epochs): + av_model.train() + av_model.module.freeze_sam_parameters() + random_sampler.set_epoch(epoch) + trainer.train(epoch=epoch, dataloader=train_dataloader, model=av_model, optimiser=optimiser) + + torch.distributed.barrier() + torch.cuda.empty_cache() + + av_model.eval() + # Three validation modes: default first mask / IoU-selected mask / IoU + objectness gate + curr_results1, _ = trainer.valid(epoch=epoch, dataloader=test_dataloader, model=av_model, process='first_index') + curr_results, _ = trainer.valid(epoch=epoch, dataloader=test_dataloader, model=av_model, process='iou_select') + curr_results3, _ = trainer.valid(epoch=epoch, dataloader=test_dataloader, model=av_model, process='iou_occ_select') + if hyp_param.local_rank <= 0 and curr_results > curr_best: + curr_best = curr_results + torch.save(av_model.module.aural_fuser.state_dict(), os.path.join(hyp_param.saved_dir, str(curr_results) + ".pth")) + torch.distributed.barrier() + torch.cuda.empty_cache() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='PyTorch Training') + parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N') + + parser.add_argument("--local_rank", type=int, default=-1, + help='multi-process training for DDP') + + parser.add_argument('-g', '--gpus', default=1, type=int, + help='number of gpus per node') + + parser.add_argument('--batch_size', default=1, type=int) + + parser.add_argument('--epochs', default=80, type=int, + help="total epochs that used for the training") + + parser.add_argument('--lr', default=1e-4, type=float, + help='Default HEAD Learning rate is same as others, ' + '*Note: in ddp training, lr will automatically times by n_gpu') + + parser.add_argument('--online', action="store_true", + help='switch on for visualization; switch off for debug') + + args = parser.parse_args() + + from configs.config import C + + args = EasyDict({**C, **vars(args)}) + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '9902' + + torch.multiprocessing.spawn(main, nprocs=args.gpus, args=(args.gpus, args)) diff --git a/avs.code/v1m.code/model/audio/torchvggish/mel_features.py b/avs.code/v1m.code/model/audio/torchvggish/mel_features.py new file mode 100644 index 0000000000000000000000000000000000000000..ac58fb5427f772fcced9cbd3cec3373ffbe5908c --- /dev/null +++ b/avs.code/v1m.code/model/audio/torchvggish/mel_features.py @@ -0,0 +1,223 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Defines routines to compute mel spectrogram features from audio waveform.""" + +import numpy as np + + +def frame(data, window_length, hop_length): + """Convert array into a sequence of successive possibly overlapping frames. + + An n-dimensional array of shape (num_samples, ...) is converted into an + (n+1)-D array of shape (num_frames, window_length, ...), where each frame + starts hop_length points after the preceding one. + + This is accomplished using stride_tricks, so the original data is not + copied. However, there is no zero-padding, so any incomplete frames at the + end are not included. + + Args: + data: np.array of dimension N >= 1. + window_length: Number of samples in each frame. + hop_length: Advance (in samples) between each window. + + Returns: + (N+1)-D np.array with as many rows as there are complete frames that can be + extracted. + """ + num_samples = data.shape[0] + num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length)) + shape = (num_frames, window_length) + data.shape[1:] + strides = (data.strides[0] * hop_length,) + data.strides + return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides) + + +def periodic_hann(window_length): + """Calculate a "periodic" Hann window. + + The classic Hann window is defined as a raised cosine that starts and + ends on zero, and where every value appears twice, except the middle + point for an odd-length window. Matlab calls this a "symmetric" window + and np.hanning() returns it. However, for Fourier analysis, this + actually represents just over one cycle of a period N-1 cosine, and + thus is not compactly expressed on a length-N Fourier basis. Instead, + it's better to use a raised cosine that ends just before the final + zero value - i.e. a complete cycle of a period-N cosine. Matlab + calls this a "periodic" window. This routine calculates it. + + Args: + window_length: The number of points in the returned window. + + Returns: + A 1D np.array containing the periodic hann window. + """ + return 0.5 - (0.5 * np.cos(2 * np.pi / window_length * + np.arange(window_length))) + + +def stft_magnitude(signal, fft_length, + hop_length=None, + window_length=None): + """Calculate the short-time Fourier transform magnitude. + + Args: + signal: 1D np.array of the input time-domain signal. + fft_length: Size of the FFT to apply. + hop_length: Advance (in samples) between each frame passed to FFT. + window_length: Length of each block of samples to pass to FFT. + + Returns: + 2D np.array where each row contains the magnitudes of the fft_length/2+1 + unique values of the FFT for the corresponding frame of input samples. + """ + frames = frame(signal, window_length, hop_length) + # Apply frame window to each frame. We use a periodic Hann (cosine of period + # window_length) instead of the symmetric Hann of np.hanning (period + # window_length-1). + window = periodic_hann(window_length) + windowed_frames = frames * window + return np.abs(np.fft.rfft(windowed_frames, int(fft_length))) + + +# Mel spectrum constants and functions. +_MEL_BREAK_FREQUENCY_HERTZ = 700.0 +_MEL_HIGH_FREQUENCY_Q = 1127.0 + + +def hertz_to_mel(frequencies_hertz): + """Convert frequencies to mel scale using HTK formula. + + Args: + frequencies_hertz: Scalar or np.array of frequencies in hertz. + + Returns: + Object of same size as frequencies_hertz containing corresponding values + on the mel scale. + """ + return _MEL_HIGH_FREQUENCY_Q * np.log( + 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ)) + + +def spectrogram_to_mel_matrix(num_mel_bins=20, + num_spectrogram_bins=129, + audio_sample_rate=8000, + lower_edge_hertz=125.0, + upper_edge_hertz=3800.0): + """Return a matrix that can post-multiply spectrogram rows to make mel. + + Returns a np.array matrix A that can be used to post-multiply a matrix S of + spectrogram values (STFT magnitudes) arranged as frames x bins to generate a + "mel spectrogram" M of frames x num_mel_bins. M = S A. + + The classic HTK algorithm exploits the complementarity of adjacent mel bands + to multiply each FFT bin by only one mel weight, then add it, with positive + and negative signs, to the two adjacent mel bands to which that bin + contributes. Here, by expressing this operation as a matrix multiply, we go + from num_fft multiplies per frame (plus around 2*num_fft adds) to around + num_fft^2 multiplies and adds. However, because these are all presumably + accomplished in a single call to np.dot(), it's not clear which approach is + faster in Python. The matrix multiplication has the attraction of being more + general and flexible, and much easier to read. + + Args: + num_mel_bins: How many bands in the resulting mel spectrum. This is + the number of columns in the output matrix. + num_spectrogram_bins: How many bins there are in the source spectrogram + data, which is understood to be fft_size/2 + 1, i.e. the spectrogram + only contains the nonredundant FFT bins. + audio_sample_rate: Samples per second of the audio at the input to the + spectrogram. We need this to figure out the actual frequencies for + each spectrogram bin, which dictates how they are mapped into mel. + lower_edge_hertz: Lower bound on the frequencies to be included in the mel + spectrum. This corresponds to the lower edge of the lowest triangular + band. + upper_edge_hertz: The desired top edge of the highest frequency band. + + Returns: + An np.array with shape (num_spectrogram_bins, num_mel_bins). + + Raises: + ValueError: if frequency edges are incorrectly ordered or out of range. + """ + nyquist_hertz = audio_sample_rate / 2. + if lower_edge_hertz < 0.0: + raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz) + if lower_edge_hertz >= upper_edge_hertz: + raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" % + (lower_edge_hertz, upper_edge_hertz)) + if upper_edge_hertz > nyquist_hertz: + raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" % + (upper_edge_hertz, nyquist_hertz)) + spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins) + spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz) + # The i'th mel band (starting from i=1) has center frequency + # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge + # band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in + # the band_edges_mel arrays. + band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz), + hertz_to_mel(upper_edge_hertz), num_mel_bins + 2) + # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins + # of spectrogram values. + mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins)) + for i in range(num_mel_bins): + lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3] + # Calculate lower and upper slopes for every spectrogram bin. + # Line segments are linear in the *mel* domain, not hertz. + lower_slope = ((spectrogram_bins_mel - lower_edge_mel) / + (center_mel - lower_edge_mel)) + upper_slope = ((upper_edge_mel - spectrogram_bins_mel) / + (upper_edge_mel - center_mel)) + # .. then intersect them with each other and zero. + mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope, + upper_slope)) + # HTK excludes the spectrogram DC bin; make sure it always gets a zero + # coefficient. + mel_weights_matrix[0, :] = 0.0 + return mel_weights_matrix + + +def log_mel_spectrogram(data, + audio_sample_rate=8000, + log_offset=0.0, + window_length_secs=0.025, + hop_length_secs=0.010, + **kwargs): + """Convert waveform to a log magnitude mel-frequency spectrogram. + + Args: + data: 1D np.array of waveform data. + audio_sample_rate: The sampling rate of data. + log_offset: Add this to values when taking log to avoid -Infs. + window_length_secs: Duration of each window to analyze. + hop_length_secs: Advance between successive analysis windows. + **kwargs: Additional arguments to pass to spectrogram_to_mel_matrix. + + Returns: + 2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank + magnitudes for successive frames. + """ + window_length_samples = int(round(audio_sample_rate * window_length_secs)) + hop_length_samples = int(round(audio_sample_rate * hop_length_secs)) + fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0))) + spectrogram = stft_magnitude( + data, + fft_length=fft_length, + hop_length=hop_length_samples, + window_length=window_length_samples) + mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix( + num_spectrogram_bins=spectrogram.shape[1], + audio_sample_rate=audio_sample_rate, **kwargs)) + return np.log(mel_spectrogram + log_offset) diff --git a/avs.code/v1m.code/model/audio/torchvggish/vggish.py b/avs.code/v1m.code/model/audio/torchvggish/vggish.py new file mode 100644 index 0000000000000000000000000000000000000000..f01c22867c713bfd8713eee5665120b92602761d --- /dev/null +++ b/avs.code/v1m.code/model/audio/torchvggish/vggish.py @@ -0,0 +1,193 @@ +import numpy as np +import torch +import torch.nn as nn +from torch import hub + +from . import vggish_input, vggish_params + + +class VGG(nn.Module): + def __init__(self, features): + super(VGG, self).__init__() + self.features = features + self.embeddings = nn.Sequential( + nn.Linear(512 * 4 * 6, 4096), + nn.ReLU(True), + nn.Linear(4096, 4096), + nn.ReLU(True), + nn.Linear(4096, 128), + nn.ReLU(True)) + + def forward(self, x): + x = self.features(x) + + # Transpose the output from features to + # remain compatible with vggish embeddings + x = torch.transpose(x, 1, 3) + x = torch.transpose(x, 1, 2) + x = x.contiguous() + x = x.view(x.size(0), -1) + + return self.embeddings(x) + + +class Postprocessor(nn.Module): + """Post-processes VGGish embeddings. Returns a torch.Tensor instead of a + numpy array in order to preserve the gradient. + + "The initial release of AudioSet included 128-D VGGish embeddings for each + segment of AudioSet. These released embeddings were produced by applying + a PCA transformation (technically, a whitening transform is included as well) + and 8-bit quantization to the raw embedding output from VGGish, in order to + stay compatible with the YouTube-8M project which provides visual embeddings + in the same format for a large set of YouTube videos. This class implements + the same PCA (with whitening) and quantization transformations." + """ + + def __init__(self): + """Constructs a postprocessor.""" + super(Postprocessor, self).__init__() + # Create empty matrix, for user's state_dict to load + self.pca_eigen_vectors = torch.empty( + (vggish_params.EMBEDDING_SIZE, vggish_params.EMBEDDING_SIZE,), + dtype=torch.float, + ) + self.pca_means = torch.empty( + (vggish_params.EMBEDDING_SIZE, 1), dtype=torch.float + ) + + self.pca_eigen_vectors = nn.Parameter(self.pca_eigen_vectors, requires_grad=False) + self.pca_means = nn.Parameter(self.pca_means, requires_grad=False) + + def postprocess(self, embeddings_batch): + """Applies tensor postprocessing to a batch of embeddings. + + Args: + embeddings_batch: An tensor of shape [batch_size, embedding_size] + containing output from the embedding layer of VGGish. + + Returns: + A tensor of the same shape as the input, containing the PCA-transformed, + quantized, and clipped version of the input. + """ + assert len(embeddings_batch.shape) == 2, "Expected 2-d batch, got %r" % ( + embeddings_batch.shape, + ) + assert ( + embeddings_batch.shape[1] == vggish_params.EMBEDDING_SIZE + ), "Bad batch shape: %r" % (embeddings_batch.shape,) + + # Apply PCA. + # - Embeddings come in as [batch_size, embedding_size]. + # - Transpose to [embedding_size, batch_size]. + # - Subtract pca_means column vector from each column. + # - Premultiply by PCA matrix of shape [output_dims, input_dims] + # where both are are equal to embedding_size in our case. + # - Transpose result back to [batch_size, embedding_size]. + pca_applied = torch.mm(self.pca_eigen_vectors, (embeddings_batch.t() - self.pca_means)).t() + + # Quantize by: + # - clipping to [min, max] range + clipped_embeddings = torch.clamp( + pca_applied, vggish_params.QUANTIZE_MIN_VAL, vggish_params.QUANTIZE_MAX_VAL + ) + # - convert to 8-bit in range [0.0, 255.0] + quantized_embeddings = torch.round( + (clipped_embeddings - vggish_params.QUANTIZE_MIN_VAL) + * ( + 255.0 + / (vggish_params.QUANTIZE_MAX_VAL - vggish_params.QUANTIZE_MIN_VAL) + ) + ) + return torch.squeeze(quantized_embeddings) + + def forward(self, x): + return self.postprocess(x) + + +def make_layers(): + layers = [] + in_channels = 1 + for v in [64, "M", 128, "M", 256, 256, "M", 512, 512, "M"]: + if v == "M": + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = v + return nn.Sequential(*layers) + + +def _vgg(): + return VGG(make_layers()) + + +# def _spectrogram(): +# config = dict( +# sr=16000, +# n_fft=400, +# n_mels=64, +# hop_length=160, +# window="hann", +# center=False, +# pad_mode="reflect", +# htk=True, +# fmin=125, +# fmax=7500, +# output_format='Magnitude', +# # device=device, +# ) +# return Spectrogram.MelSpectrogram(**config) + + +class VGGish(VGG): + def __init__(self, cfg, device=None): + super().__init__(make_layers()) + if cfg.FREEZE_AUDIO_EXTRACTOR: + state_dict = torch.load(cfg.PRETRAINED_VGGISH_MODEL_PATH) + super().load_state_dict(state_dict) + print(f'==> Load pretrained VGGish parameters from {cfg.PRETRAINED_VGGISH_MODEL_PATH}') + + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print("device: ", device) + self.device = device + + self.preprocess = cfg.PREPROCESS_AUDIO_TO_LOG_MEL + self.postprocess = cfg.POSTPROCESS_LOG_MEL_WITH_PCA + if self.postprocess: + self.pproc = Postprocessor() + if cfg.FREEZE_AUDIO_EXTRACTOR: + state_dict = torch.load(cfg.PRETRAINED_PCA_PARAMS_PATH) + # TODO: Convert the state_dict to torch + state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME] = torch.as_tensor( + state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME], dtype=torch.float + ) + state_dict[vggish_params.PCA_MEANS_NAME] = torch.as_tensor( + state_dict[vggish_params.PCA_MEANS_NAME].reshape(-1, 1), dtype=torch.float + ) + self.pproc.load_state_dict(state_dict) + self.to(self.device) + + def forward(self, x): + if self.preprocess: + print(">>> pre processing...") + x = self._preprocess(x) + x = x.to(self.device) + x = VGG.forward(self, x) + if self.postprocess: + print(">>> post processing...") + x = self._postprocess(x) + return x + + def _preprocess(self, x): + # if isinstance(x, np.ndarray): + # x = vggish_input.waveform_to_examples(x, fs) + if isinstance(x, str): + x = vggish_input.wavfile_to_examples(x) + else: + raise AttributeError + return x + + def _postprocess(self, x): + return self.pproc(x) diff --git a/avs.code/v1m.code/model/audio/torchvggish/vggish_input.py b/avs.code/v1m.code/model/audio/torchvggish/vggish_input.py new file mode 100644 index 0000000000000000000000000000000000000000..ede228b1fb630180f1f49244355d373fb3300f03 --- /dev/null +++ b/avs.code/v1m.code/model/audio/torchvggish/vggish_input.py @@ -0,0 +1,98 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Compute input examples for VGGish from audio waveform.""" + +# Modification: Return torch tensors rather than numpy arrays +import torch + +import numpy as np +import resampy + +from . import mel_features +from . import vggish_params + +import soundfile as sf + + +def waveform_to_examples(data, sample_rate, return_tensor=True): + """Converts audio waveform into an array of examples for VGGish. + + Args: + data: np.array of either one dimension (mono) or two dimensions + (multi-channel, with the outer dimension representing channels). + Each sample is generally expected to lie in the range [-1.0, +1.0], + although this is not required. + sample_rate: Sample rate of data. + return_tensor: Return data as a Pytorch tensor ready for VGGish + + Returns: + 3-D np.array of shape [num_examples, num_frames, num_bands] which represents + a sequence of examples, each of which contains a patch of log mel + spectrogram, covering num_frames frames of audio and num_bands mel frequency + bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS. + + """ + # Convert to mono. + if len(data.shape) > 1: + data = np.mean(data, axis=1) + # Resample to the rate assumed by VGGish. + if sample_rate != vggish_params.SAMPLE_RATE: + data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE) + + # Compute log mel spectrogram features. + log_mel = mel_features.log_mel_spectrogram( + data, + audio_sample_rate=vggish_params.SAMPLE_RATE, + log_offset=vggish_params.LOG_OFFSET, + window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS, + hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS, + num_mel_bins=vggish_params.NUM_MEL_BINS, + lower_edge_hertz=vggish_params.MEL_MIN_HZ, + upper_edge_hertz=vggish_params.MEL_MAX_HZ) + + # Frame features into examples. + features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS + example_window_length = int(round( + vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate)) + example_hop_length = int(round( + vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate)) + log_mel_examples = mel_features.frame( + log_mel, + window_length=example_window_length, + hop_length=example_hop_length) + + if return_tensor: + log_mel_examples = torch.tensor( + log_mel_examples, requires_grad=True)[:, None, :, :].float() + + return log_mel_examples + + +def wavfile_to_examples(wav_file, return_tensor=True): + """Convenience wrapper around waveform_to_examples() for a common WAV format. + + Args: + wav_file: String path to a file, or a file-like object. The file + is assumed to contain WAV audio data with signed 16-bit PCM samples. + torch: Return data as a Pytorch tensor ready for VGGish + + Returns: + See waveform_to_examples. + """ + wav_data, sr = sf.read(wav_file, dtype='int16') + assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype + samples = wav_data / 32768.0 # Convert to [-1.0, +1.0] + return waveform_to_examples(samples, sr, return_tensor) diff --git a/avs.code/v1m.code/model/audio/torchvggish/vggish_params.py b/avs.code/v1m.code/model/audio/torchvggish/vggish_params.py new file mode 100644 index 0000000000000000000000000000000000000000..526784bceaa4c9c8b8dc2b8f82e0f3d395d4bec2 --- /dev/null +++ b/avs.code/v1m.code/model/audio/torchvggish/vggish_params.py @@ -0,0 +1,53 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Global parameters for the VGGish model. + +See vggish_slim.py for more information. +""" + +# Architectural constants. +NUM_FRAMES = 96 # Frames in input mel-spectrogram patch. +NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch. +EMBEDDING_SIZE = 128 # Size of embedding layer. + +# Hyperparameters used in feature and example generation. +SAMPLE_RATE = 16000 +STFT_WINDOW_LENGTH_SECONDS = 0.025 +STFT_HOP_LENGTH_SECONDS = 0.010 +NUM_MEL_BINS = NUM_BANDS +MEL_MIN_HZ = 125 +MEL_MAX_HZ = 7500 +LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram. +EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames +EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap. + +# Parameters used for embedding postprocessing. +PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors' +PCA_MEANS_NAME = 'pca_means' +QUANTIZE_MIN_VAL = -2.0 +QUANTIZE_MAX_VAL = +2.0 + +# Hyperparameters used in training. +INIT_STDDEV = 0.01 # Standard deviation used to initialize weights. +LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer. +ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer. + +# Names of ops, tensors, and features. +INPUT_OP_NAME = 'vggish/input_features' +INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0' +OUTPUT_OP_NAME = 'vggish/embedding' +OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0' +AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding' diff --git a/avs.code/v1m.code/model/aural_fuser.py b/avs.code/v1m.code/model/aural_fuser.py new file mode 100644 index 0000000000000000000000000000000000000000..924810bfcf8bee5e285cab7d54e477daf254b85a --- /dev/null +++ b/avs.code/v1m.code/model/aural_fuser.py @@ -0,0 +1,567 @@ +import math + +import torch +import torch.nn as nn +from model.audio.torchvggish import vggish +from timm.models.layers import DropPath, trunc_normal_ + +from model.visual.sam2.modeling.position_encoding import PositionEmbeddingSine + + +class ProjectionHead(nn.Module): + def __init__(self, dim_in, proj_dim=256, norm_act=nn.BatchNorm2d, conv_layer=nn.Conv2d): + super().__init__() + self.proj = nn.Sequential( + conv_layer(dim_in, proj_dim, kernel_size=1), + norm_act(proj_dim), + conv_layer(proj_dim, proj_dim, kernel_size=1), + ) + + def forward(self, x): + return torch.nn.functional.normalize(self.proj(x), p=2, dim=1) + +class AuralFuser(torch.nn.Module): + """Fuses VGGish audio with SAM2 FPN maps via patch embeds, fusion blocks, and projection heads.""" + + def __init__(self, hyp_param): + self.hyp_param = hyp_param + super().__init__() + self.vgg = vggish.VGGish(self.hyp_param.audio) + if not getattr(self.hyp_param, "train_vggish", False): + for p in self.vgg.parameters(): + p.requires_grad = False + + self.position_encoding_func = PositionEmbeddingSine(num_pos_feats=256, normalize=True, scale=None, + temperature=10000) + + # Populated in main.py / inference.py via Hydra compose('auralfuser/architecture.yaml') → hyp_param.aural_fuser + if not hasattr(self.hyp_param, "aural_fuser") or self.hyp_param.aural_fuser is None: + raise ValueError( + "hyp_param.aural_fuser is missing; load it with Hydra compose before constructing AuralFuser." + ) + arch_cfg = self.hyp_param.aural_fuser + + _patch_cfgs = [tuple(i) for i in arch_cfg["patch_cfgs"]] + _f_depths = arch_cfg["f_depths"] + _block_kw = dict(arch_cfg["block_kw"]) + _block_kw["norm_layer"] = nn.LayerNorm + _one_d_kw = dict(arch_cfg["one_d_kw"]) + _one_d_kw["norm_layer"] = nn.LayerNorm + self.patch_embeds = nn.ModuleList( + nn.Conv2d(256, 256, kernel_size=k, stride=s) for k, s in _patch_cfgs + ) + + self.f_blocks = nn.ModuleList( + nn.ModuleList([Block(**_block_kw) for _ in range(n)]) for n in _f_depths + ) + + self.a_blocks = nn.ModuleList( + nn.ModuleList([OneDBlock(**_one_d_kw) for _ in range(3)]) for _ in range(3) + ) + + self.fusion_modules = nn.ModuleList( + AudioVisualFusionModule(in_channels=256, mode='dot') for _ in range(3) + ) + self.smooth_convs = nn.ModuleList( + nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0) for _ in range(2) + ) + + self.train_proj_v1 = ProjectionHead(dim_in=256, proj_dim=128) + + self.train_proj_a1 = ProjectionHead(dim_in=256, norm_act=nn.BatchNorm1d, conv_layer=nn.Conv1d, proj_dim=128) + + @staticmethod + def positionalencoding1d(d_model, length): + if d_model % 2 != 0: + raise ValueError("Cannot use sin/cos positional encoding with " + "odd dim (got dim={:d})".format(d_model)) + pe = torch.zeros(length, d_model) + position = torch.arange(0, length).unsqueeze(1) + div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) * + -(math.log(10000.0) / d_model))) + pe[:, 0::2] = torch.sin(position.float() * div_term) + pe[:, 1::2] = torch.cos(position.float() * div_term) + + return pe + + def forward(self, feature_dicts, spect=None): + image_embed_shape = [self.hyp_param.image_embedding_size] * 2 + H, W = image_embed_shape[0], image_embed_shape[1] + d = torch.cat( + [ + self.vgg(spect[:, 0, ...].unsqueeze(1)), + self.vgg(spect[:, 1, ...].unsqueeze(1)), + ], + dim=-1, + ) + length = d.shape[-1] + fix_audio_pos = self.positionalencoding1d(length, 1).squeeze().to(spect.device) + fpn = list(feature_dicts["backbone_fpn"]) + patch_embeds = list(self.patch_embeds) + f_blocks = list(self.f_blocks) + a_blocks = list(self.a_blocks) + tpavi = list(self.fusion_modules) + smooths = [None, self.smooth_convs[0], self.smooth_convs[1]] + + feats = [None, None, None] + d_outputs = [] + + for i in range(3): + x = fpn[i] + x = patch_embeds[i](x) + x_pos = self.position_encoding_func(x) + x = x.flatten(2).permute(0, 2, 1) + x_pos = x_pos.flatten(2).permute(0, 2, 1) + + if i == 0: + x = x + x_pos + d = d + fix_audio_pos + else: + x = x + feats[i - 1] + x = smooths[i]( + x.permute(0, 2, 1).reshape(x.shape[0], 256, H, W) + ).flatten(2).permute(0, 2, 1) + x = x + x_pos + d = d + fix_audio_pos + + for blks in f_blocks[i]: + x = blks(x, H, W, x_pos) + for blks in a_blocks[i]: + d = blks(d, fix_audio_pos) + + x = x + x_pos + d = d + fix_audio_pos + x, d_out, _, _ = tpavi[i](x, H, W, x_pos, d, length) + d = d_out + feats[i] = x + d_outputs.append(d_out) + + a, b, c = feats + d1, d2, d3 = d_outputs + + feature_residual = [a, b, c] + audio_out = [d1, d2, d3] + + proj_feature_out = [ + [ + self.train_proj_v1(a.permute(0, 2, 1).reshape(-1, 256, *image_embed_shape)), + self.train_proj_v1(b.permute(0, 2, 1).reshape(-1, 256, *image_embed_shape)), + self.train_proj_v1(c.permute(0, 2, 1).reshape(-1, 256, *image_embed_shape)), + ], + [ + self.train_proj_a1(d1.unsqueeze(-1)), + self.train_proj_a1(d2.unsqueeze(-1)), + self.train_proj_a1(d3.unsqueeze(-1)), + ], + ] + + return feature_residual, audio_out, proj_feature_out + + +class AudioVisualFusionModule(nn.Module): + def __init__(self, in_channels, inter_channels=None, mode='dot', + dimension=3): + super().__init__() + assert mode == 'dot' + self.mode = mode + self.dimension = dimension + + self.in_channels = in_channels + self.inter_channels = in_channels // 2 + + self.align_channel = nn.Conv1d(256, in_channels, kernel_size=1) + self.align_channel_back = nn.Conv1d(in_channels, 128, kernel_size=1) + + self.norm_layer = nn.LayerNorm(in_channels) + + if dimension == 3: + conv_nd = nn.Conv3d + bn = nn.BatchNorm3d + elif dimension == 2: + conv_nd = nn.Conv2d + bn = nn.BatchNorm2d + else: + conv_nd = nn.Conv1d + bn = nn.BatchNorm1d + + self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) + + self.W_z = nn.Sequential( + conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1), + bn(self.in_channels) + ) + nn.init.constant_(self.W_z[1].weight, 0) + nn.init.constant_(self.W_z[1].bias, 0) + + self.W_z2 = nn.Sequential( + nn.Conv1d(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1), + nn.BatchNorm1d(self.in_channels) + ) + nn.init.constant_(self.W_z2[1].weight, 0) + nn.init.constant_(self.W_z2[1].bias, 0) + self.norm_layer2 = nn.LayerNorm(self.in_channels) + + self.q_frame = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) + self.k_frame = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) + self.v_frame = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) + + self.q_audio = nn.Conv1d(self.in_channels, self.inter_channels, kernel_size=1) + self.k_audio = nn.Conv1d(self.in_channels, self.inter_channels, kernel_size=1) + self.v_audio = nn.Conv1d(self.in_channels, self.inter_channels, kernel_size=1) + + def forward(self, frame, H_x, W_x, tmp1, audio, tmp2): + frame = frame.permute(0, 2, 1) + frame = frame.reshape(frame.shape[0], frame.shape[1], H_x, W_x) + frame = frame.unsqueeze(2) + audio = self.align_channel(audio.unsqueeze(-1)) + + batch_size, _ = frame.size(0), frame.size(1) + q_frame = self.q_frame(frame).reshape(1, -1, self.inter_channels) + k_frame = self.k_frame(frame).reshape(1, -1, self.inter_channels) + v_frame = self.v_frame(frame).reshape(1, -1, self.inter_channels) + q_audio = self.q_audio(audio).reshape(1, -1, self.inter_channels) + k_audio = self.k_audio(audio).reshape(1, -1, self.inter_channels) + v_audio = self.v_audio(audio).reshape(1, -1, self.inter_channels) + f = torch.matmul(q_frame, k_audio.mT) + f_normalise = f / f.size(1) + + frame_attn = torch.matmul(f_normalise, v_audio) + + frame_attn = frame_attn.permute(0, 2, 1).contiguous() + frame_attn = frame_attn.view(batch_size, self.inter_channels, *frame.size()[2:]) + frame_attn = self.W_z(frame_attn) + frame = frame_attn + frame + + frame = frame.permute(0, 2, 3, 4, 1) + frame = self.norm_layer(frame) + frame = frame.permute(0, 4, 1, 2, 3) + frame = frame.squeeze().flatten(start_dim=2).permute(0, 2, 1) + + a = torch.matmul(q_audio, k_frame.mT) + a_normalise = a / a.size(-1) + + audio_attn = torch.matmul(a_normalise, v_frame) + audio_attn = audio_attn.permute(0, 2, 1).contiguous() + + audio_attn = audio_attn.view(batch_size, self.inter_channels).unsqueeze(-1) + audio_attn = self.W_z2(audio_attn) + + audio = audio_attn + audio + + audio = self.norm_layer2(audio.squeeze()).squeeze() + + return frame, audio, frame_attn, audio_attn + + +class OneDBlock(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = OneDAttention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = OneDMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, + linear=linear) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, _pos): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class OneDAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, + linear=False): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.linear = linear + self.sr_ratio = sr_ratio + if not linear: + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.pool = nn.AdaptiveAvgPool2d(7) + self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1) + self.norm = nn.LayerNorm(dim) + self.act = nn.GELU() + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = x.unsqueeze(0) + + B, N, C = x.shape + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + k, v = kv[0], kv[1] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + x = x.squeeze() + return x + + +class OneDMlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + self.linear = linear + + if self.linear: + self.relu = nn.ReLU(inplace=True) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = self.fc1(x) + if self.linear: + x = self.relu(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, linear=linear) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W, _pos): + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) + + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, + linear=False): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.linear = linear + self.sr_ratio = sr_ratio + if not linear: + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.pool = nn.AdaptiveAvgPool2d(7) + self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1) + self.norm = nn.LayerNorm(dim) + self.act = nn.GELU() + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + B, N, C = x.shape + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + if not self.linear: + if self.sr_ratio > 1: + x_ = x.permute(0, 2, 1).reshape(B, C, H, W) + x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) + x_ = self.norm(x_) + kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + else: + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + else: + x_ = x.permute(0, 2, 1).reshape(B, C, H, W) + x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1) + x_ = self.norm(x_) + x_ = self.act(x_) + kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + self.linear = linear + + if self.linear: + self.relu = nn.ReLU(inplace=True) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + if self.linear: + x = self.relu(x) + x = self.dwconv(x, H, W) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2) + return x diff --git a/avs.code/v1m.code/model/mymodel.py b/avs.code/v1m.code/model/mymodel.py new file mode 100644 index 0000000000000000000000000000000000000000..35194cd584a4786f713447829592b15c7a366095 --- /dev/null +++ b/avs.code/v1m.code/model/mymodel.py @@ -0,0 +1,102 @@ +import logging + +from typing import List, Optional, Tuple, Union + +import numpy +import numpy as np +import torch +from PIL.Image import Image + +from model.visual.sam2.modeling.sam2_base import SAM2Base + +from model.visual.sam2.modeling.backbones.hieradet import Hiera +from model.visual.sam2.modeling.backbones.image_encoder import FpnNeck +from model.visual.sam2.modeling.backbones.image_encoder import ImageEncoder +from model.visual.sam2.modeling.position_encoding import PositionEmbeddingSine + +from model.visual.sam2.modeling.memory_attention import MemoryAttention +from model.visual.sam2.modeling.memory_attention import MemoryAttentionLayer +from model.visual.sam2.modeling.sam.transformer import RoPEAttention +from model.visual.sam2.modeling.memory_encoder import MemoryEncoder +from model.visual.sam2.modeling.memory_encoder import MaskDownSampler +from model.visual.sam2.modeling.memory_encoder import Fuser +from model.visual.sam2.modeling.memory_encoder import CXBlock + +from model.visual.sam2.utils.transforms import SAM2Transforms +from model.visual.sam2.modeling.backbones.hieradet import do_pool +from model.visual.sam2.modeling.backbones.utils import ( + PatchEmbed, + window_partition, + window_unpartition, +) + + +class AVmodel(torch.nn.Module): + """End-to-end AV segmentation: SAM2 visual backbone + AuralFuser audio-visual fusion + tracking head.""" + + def __init__(self, param, mask_threshold=0.0, max_hole_area=0.0, max_sprinkle_area=0.0, ): + super().__init__() + self.param = param + self.mask_threshold = mask_threshold + self._bb_feat_sizes = [(int(self.param.image_size / 4), int(self.param.image_size / 4)), + (int(self.param.image_size / 8), int(self.param.image_size / 8)), + (int(self.param.image_size / 16), int(self.param.image_size / 16))] + + from model.visual.sam2.build_sam import build_sam2_visual_predictor + self.v_model = build_sam2_visual_predictor(self.param.sam_config_path, self.param.backbone_weight, + apply_postprocessing=True, mode='train') + self._transforms = SAM2Transforms( + resolution=self.v_model.image_size, + mask_threshold=mask_threshold, + max_hole_area=max_hole_area, + max_sprinkle_area=max_sprinkle_area, + ) + from model.aural_fuser import AuralFuser + self.aural_fuser = AuralFuser(hyp_param=self.param) + + + + def _prepare_backbone_features(self, backbone_out): + """Prepare and flatten visual features.""" + backbone_out = backbone_out.copy() + assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) + assert len(backbone_out["backbone_fpn"]) >= self.v_model.num_feature_levels + + feature_maps = backbone_out["backbone_fpn"][-self.v_model.num_feature_levels:] + vision_pos_embeds = backbone_out["vision_pos_enc"][-self.v_model.num_feature_levels:] + + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] + vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] + vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] + + return backbone_out, vision_feats, vision_pos_embeds, feat_sizes + + def forward_frame(self, frame_): + frame = torch.nn.functional.interpolate(frame_, (self.param.image_size, self.param.image_size), + antialias=True, align_corners=False, mode='bilinear') + return self.v_model.image_encoder(frame) + + def forward(self, frames, spect, prompts, sam_process=False): + """Fuse audio into FPN features, then run SAM2 tracking. `sam_process` is reserved for prompt path.""" + backbone_feats = self.v_model.forward_image(frames, pre_compute=False) + audio_residual_feats = self.aural_fuser(backbone_feats, spect) + visual_resfeats, audio_resfeats, proj_feats = audio_residual_feats + + map_res = visual_resfeats[::-1] + vec_res = audio_resfeats[::-1] + + av_feats = (map_res, vec_res) + backbone_feats = self.v_model.precompute_high_res_features(backbone_feats) + backbone_feats = self.v_model.dont_prepare_prompt_inputs(backbone_feats, num_frames=frames.shape[0], + cond_frame=int(frames.shape[0]/2) if self.training else 0) + outputs = self.v_model.forward_tracking_wo_prompt(backbone_feats, audio_res=av_feats) + return outputs, proj_feats + + @property + def device(self) -> torch.device: + return self.v_model.device + + def freeze_sam_parameters(self): + self.v_model.eval() + for name, parameter in self.v_model.named_parameters(): + parameter.requires_grad = False diff --git a/avs.code/v1m.code/model/visual/sam2/__init__.py b/avs.code/v1m.code/model/visual/sam2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..46a1cecc55b6fd02a5ce6c66d9cc8a77343156db --- /dev/null +++ b/avs.code/v1m.code/model/visual/sam2/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from hydra import initialize_config_module +from hydra.core.global_hydra import GlobalHydra + +if not GlobalHydra.instance().is_initialized(): + initialize_config_module("configs", version_base="1.2") diff --git a/avs.code/v1m.code/model/visual/sam2/build_sam.py b/avs.code/v1m.code/model/visual/sam2/build_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..69f68c2e672d35d925aeb496cac918c1ee913dde --- /dev/null +++ b/avs.code/v1m.code/model/visual/sam2/build_sam.py @@ -0,0 +1,171 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os + +import torch +from hydra import compose +from hydra.utils import instantiate +from omegaconf import OmegaConf +''' +import sam2 + +# Check if the user is running Python from the parent directory of the sam2 repo +# (i.e. the directory where this repo is cloned into) -- this is not supported since +# it could shadow the sam2 package and cause issues. +if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")): + # If the user has "sam2/sam2" in their path, they are likey importing the repo itself + # as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory). + # This typically happens because the user is running Python from the parent directory + # that contains the sam2 repo they cloned. + raise RuntimeError( + "You're likely running Python from the parent directory of the sam2 repository " + "(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). " + "This is not supported since the `sam2` Python package could be shadowed by the " + "repository name (the repository is also named `sam2` and contains the Python package " + "in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir " + "rather than its parent dir, or from your home directory) after installing SAM 2." + ) +''' + +HF_MODEL_ID_TO_FILENAMES = { + "facebook/sam2-hiera-tiny": ( + "sam2/sam2_hiera_t.yaml", + "sam2_hiera_tiny.pt", + ), + "facebook/sam2-hiera-small": ( + "sam2/sam2_hiera_s.yaml", + "sam2_hiera_small.pt", + ), + "facebook/sam2-hiera-base-plus": ( + "sam2/sam2_hiera_b+.yaml", + "sam2_hiera_base_plus.pt", + ), + "facebook/sam2-hiera-large": ( + "sam2/sam2_hiera_l.yaml", + "sam2_hiera_large.pt", + ), + "facebook/sam2.1-hiera-tiny": ( + "sam2.1/sam2.1_hiera_t.yaml", + "sam2.1_hiera_tiny.pt", + ), + "facebook/sam2.1-hiera-small": ( + "sam2.1/sam2.1_hiera_s.yaml", + "sam2.1_hiera_small.pt", + ), + "facebook/sam2.1-hiera-base-plus": ( + "sam2.1/sam2.1_hiera_b+.yaml", + "sam2.1_hiera_base_plus.pt", + ), + "facebook/sam2.1-hiera-large": ( + "sam2.1/sam2.1_hiera_l.yaml", + "sam2.1_hiera_large.pt", + ), +} + + +def build_sam2( + config_file, + ckpt_path=None, + device="cuda", + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, + **kwargs, +): + + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + # dynamically fall back to multi-mask if the single mask is not stable + "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + ] + # Read config and init model + cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) + OmegaConf.resolve(cfg) + model = instantiate(cfg.model, _recursive_=True) + _load_checkpoint(model, ckpt_path) + model = model.to(device) + if mode == "eval": + model.eval() + return model + + +def build_sam2_visual_predictor( + config_file, + ckpt_path=None, + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, + **kwargs, +): + # visual + hydra_overrides = [] + # "++model._target_=model.visual.sam2.organised_sam2_train.SAM2Train", + # ] + # hydra_overrides = [ + # "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", + # ] + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + + # dynamically fall back to multi-mask if the single mask is not stable + # "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + # "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + # "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + + # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking + "++model.binarize_mask_from_pts_for_mem_enc=true", + # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) + # "++model.fill_hole_area=8", + ] + hydra_overrides.extend(hydra_overrides_extra) + + # Read config and init model + cfg = compose(config_name=config_file, overrides=hydra_overrides) + OmegaConf.resolve(cfg) + model = instantiate(cfg.model, _recursive_=True) + _load_checkpoint(model, ckpt_path) + if mode == "eval": + model.eval() + return model + + +def _hf_download(model_id): + from huggingface_hub import hf_hub_download + + config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id] + ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) + return config_name, ckpt_path + + +def build_sam2_hf(model_id, **kwargs): + config_name, ckpt_path = _hf_download(model_id) + return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs) + + +# def build_sam2_video_predictor_hf(model_id, **kwargs): +# config_name, ckpt_path = _hf_download(model_id) +# return build_sam2_video_predictor( +# config_file=config_name, ckpt_path=ckpt_path, **kwargs +# ) + + +def _load_checkpoint(model, ckpt_path): + if ckpt_path is not None: + sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"] + missing_keys, unexpected_keys = model.load_state_dict(sd) + if missing_keys: + logging.error(missing_keys) + raise RuntimeError() + if unexpected_keys: + logging.error(unexpected_keys) + raise RuntimeError() + logging.info("Loaded checkpoint sucessfully") diff --git a/avs.code/v1m.code/model/visual/sam2/modeling/__init__.py b/avs.code/v1m.code/model/visual/sam2/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/avs.code/v1m.code/model/visual/sam2/modeling/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/avs.code/v1m.code/model/visual/sam2/modeling/backbones/__init__.py b/avs.code/v1m.code/model/visual/sam2/modeling/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/avs.code/v1m.code/model/visual/sam2/modeling/backbones/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/avs.code/v1m.code/model/visual/sam2/modeling/backbones/hieradet.py b/avs.code/v1m.code/model/visual/sam2/modeling/backbones/hieradet.py new file mode 100644 index 0000000000000000000000000000000000000000..3fb6633c9c752cbefe2fc6043c81fb79bc659465 --- /dev/null +++ b/avs.code/v1m.code/model/visual/sam2/modeling/backbones/hieradet.py @@ -0,0 +1,317 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from functools import partial +from typing import List, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from iopath.common.file_io import g_pathmgr + +from model.visual.sam2.modeling.backbones.utils import ( + PatchEmbed, + window_partition, + window_unpartition, +) + +from model.visual.sam2.modeling.sam2_utils import DropPath, MLP + + +def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: + if pool is None: + return x + # (B, H, W, C) -> (B, C, H, W) + x = x.permute(0, 3, 1, 2) + x = pool(x) + # (B, C, H', W') -> (B, H', W', C) + x = x.permute(0, 2, 3, 1) + if norm: + x = norm(x) + + return x + + +class MultiScaleAttention(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + q_pool: nn.Module = None, + ): + super().__init__() + + self.dim = dim + self.dim_out = dim_out + self.num_heads = num_heads + self.q_pool = q_pool + self.qkv = nn.Linear(dim, dim_out * 3) + self.proj = nn.Linear(dim_out, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (B, H * W, 3, nHead, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) + # q, k, v with shape (B, H * W, nheads, C) + q, k, v = torch.unbind(qkv, 2) + + # Q pooling (for downsample at stage changes) + if self.q_pool: + q = do_pool(q.reshape(B, H, W, -1), self.q_pool) + H, W = q.shape[1:3] # downsampled shape + q = q.reshape(B, H * W, self.num_heads, -1) + + # Torch's SDPA expects [B, nheads, H*W, C] so we transpose + x = F.scaled_dot_product_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + ) + # Transpose back + x = x.transpose(1, 2) + x = x.reshape(B, H, W, -1) + + x = self.proj(x) + + return x + + +class MultiScaleBlock(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + mlp_ratio: float = 4.0, + drop_path: float = 0.0, + norm_layer: Union[nn.Module, str] = "LayerNorm", + q_stride: Tuple[int, int] = None, + act_layer: nn.Module = nn.GELU, + window_size: int = 0, + ): + super().__init__() + + if isinstance(norm_layer, str): + norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) + + self.dim = dim + self.dim_out = dim_out + self.norm1 = norm_layer(dim) + + self.window_size = window_size + + self.pool, self.q_stride = None, q_stride + if self.q_stride: + self.pool = nn.MaxPool2d( + kernel_size=q_stride, stride=q_stride, ceil_mode=False + ) + + self.attn = MultiScaleAttention( + dim, + dim_out, + num_heads=num_heads, + q_pool=self.pool, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim_out) + self.mlp = MLP( + dim_out, + int(dim_out * mlp_ratio), + dim_out, + num_layers=2, + activation=act_layer, + ) + + if dim != dim_out: + self.proj = nn.Linear(dim, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x # B, H, W, C + x = self.norm1(x) + + # Skip connection + if self.dim != self.dim_out: + shortcut = do_pool(self.proj(x), self.pool) + + # Window partition + window_size = self.window_size + if window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, window_size) + + # Window Attention + Q Pooling (if stage change) + x = self.attn(x) + if self.q_stride: + # Shapes have changed due to Q pooling + window_size = self.window_size // self.q_stride[0] + H, W = shortcut.shape[1:3] + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + pad_hw = (H + pad_h, W + pad_w) + + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, window_size, pad_hw, (H, W)) + + x = shortcut + self.drop_path(x) + # MLP + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Hiera(nn.Module): + """ + Reference: https://arxiv.org/abs/2306.00989 + """ + + def __init__( + self, + embed_dim: int = 96, # initial embed dim + num_heads: int = 1, # initial number of heads + drop_path_rate: float = 0.0, # stochastic depth + q_pool: int = 3, # number of q_pool stages + q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages + stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage + dim_mul: float = 2.0, # dim_mul factor at stage shift + head_mul: float = 2.0, # head_mul factor at stage shift + window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), + # window size per stage, when not using global att. + window_spec: Tuple[int, ...] = ( + 8, + 4, + 14, + 7, + ), + # global attn in these blocks + global_att_blocks: Tuple[int, ...] = ( + 12, + 16, + 20, + ), + weights_path=None, + return_interm_layers=True, # return feats from every stage + ): + super().__init__() + + assert len(stages) == len(window_spec) + self.window_spec = window_spec + + depth = sum(stages) + self.q_stride = q_stride + self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] + assert 0 <= q_pool <= len(self.stage_ends[:-1]) + self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] + self.return_interm_layers = return_interm_layers + + self.patch_embed = PatchEmbed( + embed_dim=embed_dim, + ) + # Which blocks have global att? + self.global_att_blocks = global_att_blocks + + # Windowed positional embedding (https://arxiv.org/abs/2311.05613) + self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size + self.pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size) + ) + self.pos_embed_window = nn.Parameter( + torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]) + ) + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + cur_stage = 1 + self.blocks = nn.ModuleList() + + for i in range(depth): + dim_out = embed_dim + # lags by a block, so first block of + # next stage uses an initial window size + # of previous stage and final window size of current stage + window_size = self.window_spec[cur_stage - 1] + + if self.global_att_blocks is not None: + window_size = 0 if i in self.global_att_blocks else window_size + + if i - 1 in self.stage_ends: + dim_out = int(embed_dim * dim_mul) + num_heads = int(num_heads * head_mul) + cur_stage += 1 + + block = MultiScaleBlock( + dim=embed_dim, + dim_out=dim_out, + num_heads=num_heads, + drop_path=dpr[i], + q_stride=self.q_stride if i in self.q_pool_blocks else None, + window_size=window_size, + ) + + embed_dim = dim_out + self.blocks.append(block) + + self.channel_list = ( + [self.blocks[i].dim_out for i in self.stage_ends[::-1]] + if return_interm_layers + else [self.blocks[-1].dim_out] + ) + + if weights_path is not None: + with g_pathmgr.open(weights_path, "rb") as f: + chkpt = torch.load(f, map_location="cpu") + logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False)) + + def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: + h, w = hw + window_embed = self.pos_embed_window + pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") + pos_embed = pos_embed + window_embed.tile( + [x // y for x, y in zip(pos_embed.shape, window_embed.shape)] + ) + pos_embed = pos_embed.permute(0, 2, 3, 1) + return pos_embed + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + x = self.patch_embed(x) + # x: (B, H, W, C) + + # Add pos embed + x = x + self._get_pos_embed(x.shape[1:3]) + + outputs = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if (i == self.stage_ends[-1]) or ( + i in self.stage_ends and self.return_interm_layers + ): + feats = x.permute(0, 3, 1, 2) + outputs.append(feats) + + return outputs + + def get_layer_id(self, layer_name): + # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 + num_layers = self.get_num_layers() + + if layer_name.find("rel_pos") != -1: + return num_layers + 1 + elif layer_name.find("pos_embed") != -1: + return 0 + elif layer_name.find("patch_embed") != -1: + return 0 + elif layer_name.find("blocks") != -1: + return int(layer_name.split("blocks")[1].split(".")[1]) + 1 + else: + return num_layers + 1 + + def get_num_layers(self) -> int: + return len(self.blocks) diff --git a/avs.code/v1m.code/model/visual/sam2/modeling/backbones/image_encoder.py b/avs.code/v1m.code/model/visual/sam2/modeling/backbones/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..37e9266bc98596e97ca303118c910ed24f6cee2c --- /dev/null +++ b/avs.code/v1m.code/model/visual/sam2/modeling/backbones/image_encoder.py @@ -0,0 +1,134 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ImageEncoder(nn.Module): + def __init__( + self, + trunk: nn.Module, + neck: nn.Module, + scalp: int = 0, + ): + super().__init__() + self.trunk = trunk + self.neck = neck + self.scalp = scalp + assert ( + self.trunk.channel_list == self.neck.backbone_channel_list + ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" + + def forward(self, sample: torch.Tensor): + # Forward through backbone + features, pos = self.neck(self.trunk(sample)) + if self.scalp > 0: + # Discard the lowest resolution features + features, pos = features[: -self.scalp], pos[: -self.scalp] + + src = features[-1] + output = { + "vision_features": src, + "vision_pos_enc": pos, + "backbone_fpn": features, + } + return output + + +class FpnNeck(nn.Module): + """ + A modified variant of Feature Pyramid Network (FPN) neck + (we remove output conv and also do bicubic interpolation similar to ViT + pos embed interpolation) + """ + + def __init__( + self, + position_encoding: nn.Module, + d_model: int, + backbone_channel_list: List[int], + kernel_size: int = 1, + stride: int = 1, + padding: int = 0, + fpn_interp_model: str = "bilinear", + fuse_type: str = "sum", + fpn_top_down_levels: Optional[List[int]] = None, + ): + """Initialize the neck + :param trunk: the backbone + :param position_encoding: the positional encoding to use + :param d_model: the dimension of the model + :param neck_norm: the normalization to use + """ + super().__init__() + self.position_encoding = position_encoding + self.convs = nn.ModuleList() + self.backbone_channel_list = backbone_channel_list + self.d_model = d_model + for dim in backbone_channel_list: + current = nn.Sequential() + current.add_module( + "conv", + nn.Conv2d( + in_channels=dim, + out_channels=d_model, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ), + ) + + self.convs.append(current) + self.fpn_interp_model = fpn_interp_model + assert fuse_type in ["sum", "avg"] + self.fuse_type = fuse_type + + # levels to have top-down features in its outputs + # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 + # have top-down propagation, while outputs of level 0 and level 1 have only + # lateral features from the same backbone level. + if fpn_top_down_levels is None: + # default is to have top-down features on all levels + fpn_top_down_levels = range(len(self.convs)) + self.fpn_top_down_levels = list(fpn_top_down_levels) + + def forward(self, xs: List[torch.Tensor]): + + out = [None] * len(self.convs) + pos = [None] * len(self.convs) + assert len(xs) == len(self.convs) + # fpn forward pass + # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py + prev_features = None + # forward in top-down order (from low to high resolution) + n = len(self.convs) - 1 + for i in range(n, -1, -1): + x = xs[i] + lateral_features = self.convs[n - i](x) + if i in self.fpn_top_down_levels and prev_features is not None: + top_down_features = F.interpolate( + prev_features.to(dtype=torch.float32), + scale_factor=2.0, + mode=self.fpn_interp_model, + align_corners=( + None if self.fpn_interp_model == "nearest" else False + ), + antialias=False, + ) + prev_features = lateral_features + top_down_features + if self.fuse_type == "avg": + prev_features /= 2 + else: + prev_features = lateral_features + x_out = prev_features + out[i] = x_out + pos[i] = self.position_encoding(x_out).to(x_out.dtype) + + return out, pos diff --git a/avs.code/v1m.code/model/visual/sam2/modeling/backbones/utils.py b/avs.code/v1m.code/model/visual/sam2/modeling/backbones/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..32d55c7545f064de133a5ff0200ba1ece9b504b7 --- /dev/null +++ b/avs.code/v1m.code/model/visual/sam2/modeling/backbones/utils.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Some utilities for backbones, in particular for windowing""" + +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def window_partition(x, window_size): + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows, (Hp, Wp) + + +def window_unpartition(windows, window_size, pad_hw, hw): + """ + Window unpartition into original sequences and removing padding. + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view( + B, Hp // window_size, Wp // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, ...] = (7, 7), + stride: Tuple[int, ...] = (4, 4), + padding: Tuple[int, ...] = (3, 3), + in_chans: int = 3, + embed_dim: int = 768, + ): + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/avs.code/v1m.code/model/visual/sam2/modeling/memory_attention.py b/avs.code/v1m.code/model/visual/sam2/modeling/memory_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..11f4ccb1904f022c18f8a02b9590a66bd57bb8f1 --- /dev/null +++ b/avs.code/v1m.code/model/visual/sam2/modeling/memory_attention.py @@ -0,0 +1,169 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from torch import nn, Tensor + +from model.visual.sam2.modeling.sam.transformer import RoPEAttention + +from model.visual.sam2.modeling.sam2_utils import get_activation_fn, get_clones + + +class MemoryAttentionLayer(nn.Module): + + def __init__( + self, + activation: str, + cross_attention: nn.Module, + d_model: int, + dim_feedforward: int, + dropout: float, + pos_enc_at_attn: bool, + pos_enc_at_cross_attn_keys: bool, + pos_enc_at_cross_attn_queries: bool, + self_attention: nn.Module, + ): + super().__init__() + self.d_model = d_model + self.dim_feedforward = dim_feedforward + self.dropout_value = dropout + self.self_attn = self_attention + self.cross_attn_image = cross_attention + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation_str = activation + self.activation = get_activation_fn(activation) + + # Where to add pos enc + self.pos_enc_at_attn = pos_enc_at_attn + self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries + self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys + + def _forward_sa(self, tgt, query_pos): + # Self-Attention + tgt2 = self.norm1(tgt) + q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 + tgt2 = self.self_attn(q, k, v=tgt2) + tgt = tgt + self.dropout1(tgt2) + return tgt + + def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): + kwds = {} + if num_k_exclude_rope > 0: + assert isinstance(self.cross_attn_image, RoPEAttention) + kwds = {"num_k_exclude_rope": num_k_exclude_rope} + + # Cross-Attention + tgt2 = self.norm2(tgt) + tgt2 = self.cross_attn_image( + q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, + k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, + v=memory, + **kwds, + ) + tgt = tgt + self.dropout2(tgt2) + return tgt + + def forward( + self, + tgt, + memory, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + num_k_exclude_rope: int = 0, + ) -> torch.Tensor: + + # Self-Attn, Cross-Attn + tgt = self._forward_sa(tgt, query_pos) + tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) + # MLP + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + +class MemoryAttention(nn.Module): + def __init__( + self, + d_model: int, + pos_enc_at_input: bool, + layer: nn.Module, + num_layers: int, + batch_first: bool = True, # Do layers expect batch first input? + ): + super().__init__() + self.d_model = d_model + self.layers = get_clones(layer, num_layers) + self.num_layers = num_layers + self.norm = nn.LayerNorm(d_model) + self.pos_enc_at_input = pos_enc_at_input + self.batch_first = batch_first + + def forward( + self, + curr: torch.Tensor, # self-attention inputs + memory: torch.Tensor, # cross-attention inputs + curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs + memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs + num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* + ): + if isinstance(curr, list): + assert isinstance(curr_pos, list) + assert len(curr) == len(curr_pos) == 1 + curr, curr_pos = ( + curr[0], + curr_pos[0], + ) + + assert ( + curr.shape[1] == memory.shape[1] + ), "Batch size must be the same for curr and memory" + + output = curr + if self.pos_enc_at_input and curr_pos is not None: + output = output + 0.1 * curr_pos + + if self.batch_first: + # Convert to batch first + output = output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + memory = memory.transpose(0, 1) + memory_pos = memory_pos.transpose(0, 1) + + for layer in self.layers: + kwds = {} + if isinstance(layer.cross_attn_image, RoPEAttention): + kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} + + output = layer( + tgt=output, + memory=memory, + pos=memory_pos, + query_pos=curr_pos, + **kwds, + ) + normed_output = self.norm(output) + + if self.batch_first: + # Convert back to seq first + normed_output = normed_output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + + return normed_output diff --git a/avs.code/v1m.code/model/visual/sam2/modeling/memory_encoder.py b/avs.code/v1m.code/model/visual/sam2/modeling/memory_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..7e1143cc0d5774ff96108203e404f678f14b0a23 --- /dev/null +++ b/avs.code/v1m.code/model/visual/sam2/modeling/memory_encoder.py @@ -0,0 +1,181 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model.visual.sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d + + +class MaskDownSampler(nn.Module): + """ + Progressively downsample a mask by total_stride, each time by stride. + Note that LayerNorm is applied per *token*, like in ViT. + + With each downsample (by a factor stride**2), channel capacity increases by the same factor. + In the end, we linearly project to embed_dim channels. + """ + + def __init__( + self, + embed_dim=256, + kernel_size=4, + stride=4, + padding=0, + total_stride=16, + activation=nn.GELU, + ): + super().__init__() + num_layers = int(math.log2(total_stride) // math.log2(stride)) + assert stride**num_layers == total_stride + self.encoder = nn.Sequential() + mask_in_chans, mask_out_chans = 1, 1 + for _ in range(num_layers): + mask_out_chans = mask_in_chans * (stride**2) + self.encoder.append( + nn.Conv2d( + mask_in_chans, + mask_out_chans, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + ) + self.encoder.append(LayerNorm2d(mask_out_chans)) + self.encoder.append(activation()) + mask_in_chans = mask_out_chans + + self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) + + def forward(self, x): + return self.encoder(x) + + +# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) +class CXBlock(nn.Module): + r"""ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__( + self, + dim, + kernel_size=7, + padding=3, + drop_path=0.0, + layer_scale_init_value=1e-6, + use_dwconv=True, + ): + super().__init__() + self.dwconv = nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=padding, + groups=dim if use_dwconv else 1, + ) # depthwise conv + self.norm = LayerNorm2d(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, 4 * dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = self.norm(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class Fuser(nn.Module): + def __init__(self, layer, num_layers, dim=None, input_projection=False): + super().__init__() + self.proj = nn.Identity() + self.layers = get_clones(layer, num_layers) + + if input_projection: + assert dim is not None + self.proj = nn.Conv2d(dim, dim, kernel_size=1) + + def forward(self, x): + # normally x: (N, C, H, W) + x = self.proj(x) + for layer in self.layers: + x = layer(x) + return x + + +class MemoryEncoder(nn.Module): + def __init__( + self, + out_dim, + mask_downsampler, + fuser, + position_encoding, + in_dim=256, # in_dim of pix_feats + ): + super().__init__() + + self.mask_downsampler = mask_downsampler + + self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) + self.fuser = fuser + self.position_encoding = position_encoding + self.out_proj = nn.Identity() + if out_dim != in_dim: + self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) + + def forward( + self, + pix_feat: torch.Tensor, + masks: torch.Tensor, + skip_mask_sigmoid: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + ## Process masks + # sigmoid, so that less domain shift from gt masks which are bool + if not skip_mask_sigmoid: + masks = F.sigmoid(masks) + masks = self.mask_downsampler(masks) + + ## Fuse pix_feats and downsampled masks + # in case the visual features are on CPU, cast them to CUDA + pix_feat = pix_feat.to(masks.device) + + x = self.pix_feat_proj(pix_feat) + x = x + masks + x = self.fuser(x) + x = self.out_proj(x) + + pos = self.position_encoding(x).to(x.dtype) + + return {"vision_features": x, "vision_pos_enc": [pos]} diff --git a/avs.code/v1m.code/model/visual/sam2/modeling/position_encoding.py b/avs.code/v1m.code/model/visual/sam2/modeling/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..52ac22674d5d4fdd9e83b6bdf034bff56d04bc0d --- /dev/null +++ b/avs.code/v1m.code/model/visual/sam2/modeling/position_encoding.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Any, Optional, Tuple + +import numpy as np + +import torch +from torch import nn + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention Is All You Need paper, generalized to work on images. + """ + + def __init__( + self, + num_pos_feats, + temperature: int = 10000, + normalize: bool = True, + scale: Optional[float] = None, + ): + super().__init__() + assert num_pos_feats % 2 == 0, "Expecting even model width" + self.num_pos_feats = num_pos_feats // 2 + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + self.cache = {} + + def _encode_xy(self, x, y): + # The positions are expected to be normalized + assert len(x) == len(y) and x.ndim == y.ndim == 1 + x_embed = x * self.scale + y_embed = y * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, None] / dim_t + pos_y = y_embed[:, None] / dim_t + pos_x = torch.stack( + (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2 + ).flatten(1) + pos_y = torch.stack( + (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2 + ).flatten(1) + return pos_x, pos_y + + @torch.no_grad() + def encode_boxes(self, x, y, w, h): + pos_x, pos_y = self._encode_xy(x, y) + pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) + return pos + + encode = encode_boxes # Backwards compatibility + + @torch.no_grad() + def encode_points(self, x, y, labels): + (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape + assert bx == by and nx == ny and bx == bl and nx == nl + pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) + pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) + pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) + return pos + + @torch.no_grad() + def forward(self, x: torch.Tensor): + cache_key = (x.shape[-2], x.shape[-1]) + if cache_key in self.cache: + return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) + y_embed = ( + torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) + .view(1, -1, 1) + .repeat(x.shape[0], 1, x.shape[-1]) + ) + x_embed = ( + torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) + .view(1, 1, -1) + .repeat(x.shape[0], x.shape[-2], 1) + ) + + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + self.cache[cache_key] = pos[0] + return pos + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C + + +# Rotary Positional Encoding, adapted from: +# 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py +# 2. https://github.com/naver-ai/rope-vit +# 3. https://github.com/lucidrains/rotary-embedding-torch + + +def init_t_xy(end_x: int, end_y: int): + t = torch.arange(end_x * end_y, dtype=torch.float32) + t_x = (t % end_x).float() + t_y = torch.div(t, end_x, rounding_mode="floor").float() + return t_x, t_y + + +def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): + freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + + t_x, t_y = init_t_xy(end_x, end_y) + freqs_x = torch.outer(t_x, freqs_x) + freqs_y = torch.outer(t_y, freqs_y) + freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) + freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) + return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) + shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_enc( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, + repeat_freqs_k: bool = False, +): + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = ( + torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + if xk.shape[-2] != 0 + else None + ) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + if xk_ is None: + # no keys to rotate, due to dropout + return xq_out.type_as(xq).to(xq.device), xk + # repeat freqs along seq_len dim to match k seq_len + if repeat_freqs_k: + r = xk_.shape[-2] // xq_.shape[-2] + if freqs_cis.is_cuda: + freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) + else: + # torch.repeat on complex numbers may not be supported on non-CUDA devices + # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten + freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) diff --git a/avs.code/v1m.code/model/visual/sam2/modeling/sam/__init__.py b/avs.code/v1m.code/model/visual/sam2/modeling/sam/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/avs.code/v1m.code/model/visual/sam2/modeling/sam/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/avs.code/v1m.code/model/visual/sam2/modeling/sam/mask_decoder.py b/avs.code/v1m.code/model/visual/sam2/modeling/sam/mask_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..078f21cc2ec41805eebec677e6e27771335deaa4 --- /dev/null +++ b/avs.code/v1m.code/model/visual/sam2/modeling/sam/mask_decoder.py @@ -0,0 +1,300 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn + +from model.visual.sam2.modeling.sam2_utils import LayerNorm2d, MLP + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + use_high_res_features: bool = False, + iou_prediction_use_sigmoid=False, + dynamic_multimask_via_stability=False, + dynamic_multimask_stability_delta=0.05, + dynamic_multimask_stability_thresh=0.98, + pred_obj_scores: bool = False, + pred_obj_scores_mlp: bool = False, + use_multimask_token_for_obj_ptr: bool = False, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.pred_obj_scores = pred_obj_scores + if self.pred_obj_scores: + self.obj_score_token = nn.Embedding(1, transformer_dim) + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d( + transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 + ), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d( + transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 + ), + activation(), + ) + self.use_high_res_features = use_high_res_features + if use_high_res_features: + self.conv_s0 = nn.Conv2d( + transformer_dim, transformer_dim // 8, kernel_size=1, stride=1 + ) + self.conv_s1 = nn.Conv2d( + transformer_dim, transformer_dim // 4, kernel_size=1, stride=1 + ) + + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, + iou_head_hidden_dim, + self.num_mask_tokens, + iou_head_depth, + sigmoid_output=iou_prediction_use_sigmoid, + ) + if self.pred_obj_scores: + self.pred_obj_score_head = nn.Linear(transformer_dim, 1) + if pred_obj_scores_mlp: + self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3) + + # When outputting a single mask, optionally we can dynamically fall back to the best + # multimask output token if the single mask output token gives low stability scores. + self.dynamic_multimask_via_stability = dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + audio_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + torch.Tensor: batched SAM token for mask output + """ + masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + repeat_image=repeat_image, + high_res_features=high_res_features, + audio_res_features_=audio_res_features + ) + + # Select the correct mask or masks for output + if multimask_output: + masks = masks[:, 1:, :, :] + iou_pred = iou_pred[:, 1:] + elif self.dynamic_multimask_via_stability and not self.training: + masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) + else: + masks = masks[:, 0:1, :, :] + iou_pred = iou_pred[:, 0:1] + + + if multimask_output and self.use_multimask_token_for_obj_ptr: + sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape + else: + # Take the mask output token. Here we *always* use the token for single mask output. + # At test time, even if we track after 1-click (and using multimask_output=True), + # we still take the single mask token here. The rationale is that we always track + # after multiple clicks during training, so the past tokens seen during training + # are always the single mask token (and we'll let it be the object-memory token). + sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape + + # Prepare output + return masks, iou_pred, sam_tokens_out, object_score_logits + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + audio_res_features_: Optional[List[torch.Tensor]] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + s = 0 + if self.pred_obj_scores: + output_tokens = torch.cat( + [ + self.obj_score_token.weight, + self.iou_token.weight, + self.mask_tokens.weight, + ], + dim=0, + ) + s = 1 + else: + output_tokens = torch.cat( + [self.iou_token.weight, self.mask_tokens.weight], dim=0 + ) + output_tokens = output_tokens.unsqueeze(0).expand( + sparse_prompt_embeddings.size(0), -1, -1 + ) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + if repeat_image: + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + else: + assert image_embeddings.shape[0] == tokens.shape[0] + src = image_embeddings + src = src + dense_prompt_embeddings + assert ( + image_pe.size(0) == 1 + ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens, audio_res_features_) + iou_token_out = hs[:, s, :] + mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + + if not self.use_high_res_features: + upscaled_embedding = self.output_upscaling(src) + else: + dc1, ln1, act1, dc2, act2 = self.output_upscaling + feat_s0, feat_s1 = high_res_features + upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) + upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) + + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append( + self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) + ) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + if self.pred_obj_scores: + assert s == 1 + object_score_logits = self.pred_obj_score_head(hs[:, 0, :]) + else: + # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1 + object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1) + + return masks, iou_pred, mask_tokens_out, object_score_logits + + def _get_stability_scores(self, mask_logits): + """ + Compute stability scores of the mask logits based on the IoU between upper and + lower thresholds. + """ + mask_logits = mask_logits.flatten(-2) + stability_delta = self.dynamic_multimask_stability_delta + area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() + area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() + stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) + return stability_scores + + def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): + """ + When outputting a single mask, if the stability score from the current single-mask + output (based on output token 0) falls below a threshold, we instead select from + multi-mask outputs (based on output token 1~3) the mask with the highest predicted + IoU score. This is intended to ensure a valid mask for both clicking and tracking. + """ + # The best mask from multimask output tokens (1~3) + multimask_logits = all_mask_logits[:, 1:, :, :] + multimask_iou_scores = all_iou_scores[:, 1:] + best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) + batch_inds = torch.arange( + multimask_iou_scores.size(0), device=all_iou_scores.device + ) + best_multimask_logits = multimask_logits[batch_inds, best_scores_inds] + best_multimask_logits = best_multimask_logits.unsqueeze(1) + best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds] + best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1) + + # The mask from singlemask output token 0 and its stability score + singlemask_logits = all_mask_logits[:, 0:1, :, :] + singlemask_iou_scores = all_iou_scores[:, 0:1] + stability_scores = self._get_stability_scores(singlemask_logits) + is_stable = stability_scores >= self.dynamic_multimask_stability_thresh + + # Dynamically fall back to best multimask output upon low stability scores. + mask_logits_out = torch.where( + is_stable[..., None, None].expand_as(singlemask_logits), + singlemask_logits, + best_multimask_logits, + ) + iou_scores_out = torch.where( + is_stable.expand_as(singlemask_iou_scores), + singlemask_iou_scores, + best_multimask_iou_scores, + ) + return mask_logits_out, iou_scores_out diff --git a/avs.code/v1m.code/model/visual/sam2/modeling/sam/prompt_encoder.py b/avs.code/v1m.code/model/visual/sam2/modeling/sam/prompt_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..038cebcc072ae7c0f3f83061061be3edba04d0f8 --- /dev/null +++ b/avs.code/v1m.code/model/visual/sam2/modeling/sam/prompt_encoder.py @@ -0,0 +1,188 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple, Type + +import torch +from torch import nn + +from model.visual.sam2.modeling.position_encoding import PositionEmbeddingRandom + +from model.visual.sam2.modeling.sam2_utils import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [ + nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) + ] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = ( + 4 * image_embedding_size[0], + 4 * image_embedding_size[1], + ) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords( + points, self.input_image_size + ) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + point_embedding[labels == 2] += self.point_embeddings[2].weight + point_embedding[labels == 3] += self.point_embeddings[3].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords( + coords, self.input_image_size + ) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + # we only utilise sounding as prompt. + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty( + (bs, 0, self.embed_dim), device=self._get_device() + ) + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + ''' + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + ''' + return sparse_embeddings, dense_embeddings + diff --git a/avs.code/v1m.code/model/visual/sam2/modeling/sam/transformer.py b/avs.code/v1m.code/model/visual/sam2/modeling/sam/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..31916550afeccb66f4427cee7ec4a7a2d66913a5 --- /dev/null +++ b/avs.code/v1m.code/model/visual/sam2/modeling/sam/transformer.py @@ -0,0 +1,367 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import math +import warnings +from functools import partial +from typing import Tuple, Type + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from model.visual.sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis +from model.visual.sam2.modeling.sam2_utils import MLP +from model.visual.sam2.utils.misc import get_sdpa_settings + +warnings.simplefilter(action="ignore", category=FutureWarning) +# Check whether Flash Attention is available (and use it by default) +OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings() +# A fallback setting to allow all available kernels if Flash Attention fails +ALLOW_ALL_KERNELS = False + + +def sdp_kernel_context(dropout_p): + """ + Get the context for the attention scaled dot-product kernel. We use Flash Attention + by default, but fall back to all available kernels if Flash Attention fails. + """ + if ALLOW_ALL_KERNELS: + return contextlib.nullcontext() + + return torch.backends.cuda.sdp_kernel( + enable_flash=USE_FLASH_ATTN, + # if Flash attention kernel is off, then math kernel needs to be enabled + enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON, + enable_mem_efficient=OLD_GPU, + ) + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + audio_res: [], + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + visual_res, audio_res = audio_res + + # Prepare queries + queries = point_embedding + keys = image_embedding + # Apply transformer blocks and final layernorm + for i, layer in enumerate(self.layers): + keys = keys + visual_res[i] + queries[:, 2:6] = queries[:, 2:6] + audio_res[i] + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + queries[:, 2:6] = queries[:, 2:6] + audio_res[-1] + keys = keys + visual_res[-1] + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLP( + embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation + ) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + dropout: float = 0.0, + kv_in_dim: int = None, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert ( + self.internal_dim % num_heads == 0 + ), "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + self.dropout_p = dropout + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + try: + with sdp_kernel_context(dropout_p): + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + except Exception as e: + # Fall back to all kernels if the Flash attention kernel fails + warnings.warn( + f"Flash Attention kernel failed due to: {e}\nFalling back to all available " + f"kernels for scaled_dot_product_attention (which may have a slower speed).", + category=UserWarning, + stacklevel=2, + ) + global ALLOW_ALL_KERNELS + ALLOW_ALL_KERNELS = True + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out + + +class RoPEAttention(Attention): + """Attention with rotary position encoding.""" + + def __init__( + self, + *args, + rope_theta=10000.0, + # whether to repeat q rope to match k length + # this is needed for cross-attention to memories + rope_k_repeat=False, + feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.compute_cis = partial( + compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta + ) + freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) + self.freqs_cis = freqs_cis + self.rope_k_repeat = rope_k_repeat + + def forward( + self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0 + ) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Apply rotary position encoding + w = h = math.sqrt(q.shape[-2]) + self.freqs_cis = self.freqs_cis.to(q.device) + if self.freqs_cis.shape[0] != q.shape[-2]: + self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) + if q.shape[-2] != k.shape[-2]: + assert self.rope_k_repeat + + num_k_rope = k.size(-2) - num_k_exclude_rope + q, k[:, :, :num_k_rope] = apply_rotary_enc( + q, + k[:, :, :num_k_rope], + freqs_cis=self.freqs_cis, + repeat_freqs_k=self.rope_k_repeat, + ) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + try: + with sdp_kernel_context(dropout_p): + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + except Exception as e: + # Fall back to all kernels if the Flash attention kernel fails + warnings.warn( + f"Flash Attention kernel failed due to: {e}\nFalling back to all available " + f"kernels for scaled_dot_product_attention (which may have a slower speed).", + category=UserWarning, + stacklevel=2, + ) + global ALLOW_ALL_KERNELS + ALLOW_ALL_KERNELS = True + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/avs.code/v1m.code/model/visual/sam2/modeling/sam2_base.py b/avs.code/v1m.code/model/visual/sam2/modeling/sam2_base.py new file mode 100644 index 0000000000000000000000000000000000000000..2ab890394064172b8719e8a06ee0a47d995fd585 --- /dev/null +++ b/avs.code/v1m.code/model/visual/sam2/modeling/sam2_base.py @@ -0,0 +1,940 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed +import torch.nn.functional as F + +from torch.nn.init import trunc_normal_ + +from model.visual.sam2.modeling.sam.mask_decoder import MaskDecoder +from model.visual.sam2.modeling.sam.prompt_encoder import PromptEncoder +from model.visual.sam2.modeling.sam.transformer import TwoWayTransformer +from model.visual.sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames + +# a large negative value as a placeholder score for missing objects +NO_OBJ_SCORE = -1024.0 + + +class SAM2Base(torch.nn.Module): + def __init__( + self, + image_encoder, + memory_attention, + memory_encoder, + num_maskmem=7, # default 1 input frame + 6 previous frames + image_size=512, + backbone_stride=16, # stride of the image backbone output + sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob + sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob + # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks + binarize_mask_from_pts_for_mem_enc=False, + use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder + # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit, + # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model + # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM. + max_cond_frames_in_attn=-1, + # on the first frame, whether to directly add the no-memory embedding to the image feature + # (instead of using the transformer encoder) + directly_add_no_mem_embed=False, + # whether to use high-resolution feature maps in the SAM mask decoder + use_high_res_features_in_sam=False, + # whether to output multiple (3) masks for the first click on initial conditioning frames + multimask_output_in_sam=False, + # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`; + # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points) + multimask_min_pt_num=1, + multimask_max_pt_num=1, + # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`) + multimask_output_for_tracking=False, + # Whether to use multimask tokens for obj ptr; Only relevant when both + # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True + use_multimask_token_for_obj_ptr: bool = False, + # whether to use sigmoid to restrict ious prediction to [0-1] + iou_prediction_use_sigmoid=False, + # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5). + # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of + # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame. + memory_temporal_stride_for_eval=1, + # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks) + non_overlap_masks_for_mem_enc=False, + # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder=False, + # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`) + max_obj_ptrs_in_encoder=16, + # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`) + add_tpos_enc_to_obj_ptrs=True, + # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference + # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) + proj_tpos_enc_in_obj_ptrs=False, + # whether to use signed distance (instead of unsigned absolute distance) in the temporal positional encoding in the object pointers + # (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) + use_signed_tpos_enc_to_obj_ptrs=False, + # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation + # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking) + only_obj_ptrs_in_the_past_for_eval=False, + # Whether to predict if there is an object in the frame + pred_obj_scores: bool = False, + # Whether to use an MLP to predict object scores + pred_obj_scores_mlp: bool = False, + # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True; + # Whether to have a fixed no obj pointer when there is no object present + # or to use it as an additive embedding with obj_ptr produced by decoder + fixed_no_obj_ptr: bool = False, + # Soft no object, i.e. mix in no_obj_ptr softly, + # hope to make recovery easier if there is a mistake and mitigate accumulation of errors + soft_no_obj_ptr: bool = False, + use_mlp_for_obj_ptr_proj: bool = False, + # add no obj embedding to spatial frames + no_obj_embed_spatial: bool = False, + # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class. + sam_mask_decoder_extra_args=None, + compile_image_encoder: bool = False, + ): + super().__init__() + + # Part 1: the image backbone + self.image_encoder = image_encoder + # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting + self.use_high_res_features_in_sam = use_high_res_features_in_sam + self.num_feature_levels = 3 if use_high_res_features_in_sam else 1 + self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder + self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder + if use_obj_ptrs_in_encoder: + # A conv layer to downsample the mask prompt to stride 4 (the same stride as + # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale, + # so that it can be fed into the SAM mask decoder to generate a pointer. + self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4) + self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs + if proj_tpos_enc_in_obj_ptrs: + assert add_tpos_enc_to_obj_ptrs # these options need to be used together + self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs + self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs + self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval + + # Part 2: memory attention to condition current frame's visual features + # with memories (and obj ptrs) from past frames + self.memory_attention = memory_attention + + #### this is for Version 2.0 + # self.hidden_dim = memory_attention.d_model + #### this is for Version 2.1 + # self.hidden_dim = image_encoder.neck.d_model + self.hidden_dim = 256 # well, it is always 256 anyway. + + # Part 3: memory encoder for the previous frame's outputs + self.memory_encoder = memory_encoder + self.mem_dim = self.hidden_dim + if hasattr(self.memory_encoder, "out_proj") and hasattr( + self.memory_encoder.out_proj, "weight" + ): + # if there is compression of memories along channel dim + self.mem_dim = self.memory_encoder.out_proj.weight.shape[0] + self.num_maskmem = num_maskmem # Number of memories accessible + # Temporal encoding of the memories + self.maskmem_tpos_enc = torch.nn.Parameter( + torch.zeros(num_maskmem, 1, 1, self.mem_dim) + ) + trunc_normal_(self.maskmem_tpos_enc, std=0.02) + # a single token to indicate no memory embedding from previous frames + self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + trunc_normal_(self.no_mem_embed, std=0.02) + trunc_normal_(self.no_mem_pos_enc, std=0.02) + self.directly_add_no_mem_embed = directly_add_no_mem_embed + # Apply sigmoid to the output raw mask logits (to turn them from + # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder + self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc + self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc + self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc + self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc + self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval + # On frames with mask input, whether to directly output the input mask without + # using a SAM prompt encoder + mask decoder + self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam + self.multimask_output_in_sam = multimask_output_in_sam + self.multimask_min_pt_num = multimask_min_pt_num + self.multimask_max_pt_num = multimask_max_pt_num + self.multimask_output_for_tracking = multimask_output_for_tracking + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid + + # Part 4: SAM-style prompt encoder (for both mask and point inputs) + # and SAM-style mask decoder for the final mask output + self.image_size = image_size + self.backbone_stride = backbone_stride + self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args + self.pred_obj_scores = pred_obj_scores + self.pred_obj_scores_mlp = pred_obj_scores_mlp + self.fixed_no_obj_ptr = fixed_no_obj_ptr + self.soft_no_obj_ptr = soft_no_obj_ptr + if self.fixed_no_obj_ptr: + assert self.pred_obj_scores + assert self.use_obj_ptrs_in_encoder + if self.pred_obj_scores and self.use_obj_ptrs_in_encoder: + self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) + trunc_normal_(self.no_obj_ptr, std=0.02) + self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj + self.no_obj_embed_spatial = None + if no_obj_embed_spatial: + self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim)) + trunc_normal_(self.no_obj_embed_spatial, std=0.02) + + self._build_sam_heads() + self.max_cond_frames_in_attn = max_cond_frames_in_attn + + # Model compilation + if compile_image_encoder: + # Compile the forward function (not the full module) to allow loading checkpoints. + print( + "Image encoder compilation is enabled. First forward pass will be slow." + ) + self.image_encoder.forward = torch.compile( + self.image_encoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + + ### we fix the use_mask_input_as_output_without_sam to be turned off. + self.use_mask_input_as_output_without_sam = False + + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, *args, **kwargs): + raise NotImplementedError( + "Please use the corresponding methods in SAM2VideoPredictor for inference or SAM2Train for training/fine-tuning" + "See notebooks/video_predictor_example.ipynb for an inference example." + ) + + def _build_sam_heads(self): + """Build SAM-style prompt encoder and mask decoder.""" + self.sam_prompt_embed_dim = self.hidden_dim + self.sam_image_embedding_size = self.image_size // self.backbone_stride + + # build PromptEncoder and MaskDecoder from SAM + # (their hyperparameters like `mask_in_chans=16` are from SAM code) + self.sam_prompt_encoder = PromptEncoder( + embed_dim=self.sam_prompt_embed_dim, + image_embedding_size=( + self.sam_image_embedding_size, + self.sam_image_embedding_size, + ), + input_image_size=(self.image_size, self.image_size), + mask_in_chans=16, + ) + self.sam_mask_decoder = MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=self.sam_prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=self.sam_prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + use_high_res_features=self.use_high_res_features_in_sam, + iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, + pred_obj_scores=self.pred_obj_scores, + pred_obj_scores_mlp=self.pred_obj_scores_mlp, + use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, + **(self.sam_mask_decoder_extra_args or {}), + ) + if self.use_obj_ptrs_in_encoder: + # a linear projection on SAM output tokens to turn them into object pointers + self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim) + if self.use_mlp_for_obj_ptr_proj: + self.obj_ptr_proj = MLP( + self.hidden_dim, self.hidden_dim, self.hidden_dim, 3 + ) + else: + self.obj_ptr_proj = torch.nn.Identity() + if self.proj_tpos_enc_in_obj_ptrs: + # a linear projection on temporal positional encoding in object pointers to + # avoid potential interference with spatial positional encoding + self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim) + else: + self.obj_ptr_tpos_proj = torch.nn.Identity() + + def _forward_sam_heads( + self, + backbone_features, + point_inputs=None, + mask_inputs=None, + high_res_features=None, + multimask_output=False, + audio_res=None + ): + """ + Forward SAM prompt encoders and mask heads. + + Inputs: + - backbone_features: image features of [B, C, H, W] shape + - point_inputs: a dictionary with "point_coords" and "point_labels", where + 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the + absolute pixel-unit coordinate in (x, y) format of the P input points + 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means + positive clicks, 0 means negative clicks, and -1 means padding + - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the + same spatial size as the image. + - high_res_features: either 1) None or 2) or a list of length 2 containing + two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively, + which will be used as high-resolution feature maps for SAM decoder. + - multimask_output: if it's True, we output 3 candidate masks and their 3 + corresponding IoU estimates, and if it's False, we output only 1 mask and + its corresponding IoU estimate. + + Outputs: + - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if + `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM + output mask logits (before sigmoid) for the low-resolution masks, with 4x + the resolution (1/4 stride) of the input backbone_features. + - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3 + if `multimask_output=True` and M = 1 if `multimask_output=False`), + upsampled from the low-resolution masks, with shape size as the image + (stride is 1 pixel). + - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1 + if `multimask_output=False`), the estimated IoU of each output mask. + - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `low_res_multimasks`. + - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `high_res_multimasks`. + - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted + based on the output token from the SAM mask decoder. + """ + B = backbone_features.size(0) + device = backbone_features.device + assert backbone_features.size(1) == self.sam_prompt_embed_dim + assert backbone_features.size(2) == self.sam_image_embedding_size + assert backbone_features.size(3) == self.sam_image_embedding_size + + ''' + # a) Handle point prompts + if point_inputs is not None: + sam_point_coords = point_inputs["point_coords"] + sam_point_labels = point_inputs["point_labels"] + assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B + raise NotImplementedError + else: + # If no points are provide, pad with an empty point (with label -1) + sam_point_coords = torch.zeros(B, 1, 2, device=device) + sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) + + # b) Handle mask prompts + if mask_inputs is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) + if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: + sam_mask_prompt = F.interpolate( + mask_inputs.float(), + size=self.sam_prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + else: + sam_mask_prompt = mask_inputs + raise NotImplementedError + else: + # Otherwise, simply feed None (and SAM's prompt encoder will add + # a learned `no_mask_embed` to indicate no mask input in this case). + sam_mask_prompt = None + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( + points=(sam_point_coords, sam_point_labels), + boxes=None, + masks=sam_mask_prompt, + ) + ''' + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( + points=None, + boxes=None, + masks=None, + ) + + ( + low_res_multimasks, + ious, + sam_output_tokens, + object_score_logits, + ) = self.sam_mask_decoder( + image_embeddings=backbone_features, + image_pe=self.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=False, # the image is already batched + high_res_features=high_res_features, + audio_res_features=audio_res + ) + ''' + if self.pred_obj_scores: + is_obj_appearing = object_score_logits > 0 + + # Mask used for spatial memories is always a *hard* choice between obj and no obj, + # consistent with the actual mask prediction + low_res_multimasks = torch.where( + is_obj_appearing[:, None, None], + low_res_multimasks, + NO_OBJ_SCORE, + ) + ''' + # convert masks from possibly bfloat16 (or float16) to float32 + # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) + low_res_multimasks = low_res_multimasks.float() + high_res_multimasks = F.interpolate( + low_res_multimasks, + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + sam_output_token = sam_output_tokens[:, 0] + if multimask_output: + # comment this line temporarily. + # take the best mask prediction (with the highest IoU estimation) + best_iou_inds = torch.argmax(ious, dim=-1) + batch_inds = torch.arange(B, device=device) + low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + if sam_output_tokens.size(1) > 1: + sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] + else: + low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks + + # Extract object pointer from the SAM output token (with occlusion handling) + obj_ptr = self.obj_ptr_proj(sam_output_token) + + # don't train occlusion at the moment, command temporarily. + if self.pred_obj_scores: + is_obj_appearing = object_score_logits > 0 + # Allow *soft* no obj ptr, unlike for masks + if self.soft_no_obj_ptr: + lambda_is_obj_appearing = object_score_logits.sigmoid() + else: + lambda_is_obj_appearing = is_obj_appearing.float() + + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + return ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): + """ + Directly turn binary `mask_inputs` into a output mask logits without using SAM. + (same input and output shapes as in _forward_sam_heads above). + """ + # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid). + out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05 + mask_inputs_float = mask_inputs.float() + high_res_masks = mask_inputs_float * out_scale + out_bias + low_res_masks = F.interpolate( + high_res_masks, + size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + # a dummy IoU prediction of all 1's under mask input + ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float() + if not self.use_obj_ptrs_in_encoder: + # all zeros as a dummy object pointer (of shape [B, C]) + obj_ptr = torch.zeros( + mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device + ) + else: + # produce an object pointer using the SAM decoder from the mask input + _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads( + backbone_features=backbone_features, + mask_inputs=self.mask_downsample(mask_inputs_float), + high_res_features=high_res_features, + ) + # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem; + # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying + # on the object_scores from the SAM decoder. + is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1) + is_obj_appearing = is_obj_appearing[..., None] + lambda_is_obj_appearing = is_obj_appearing.float() + object_score_logits = out_scale * lambda_is_obj_appearing + out_bias + if self.pred_obj_scores: + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_masks, + high_res_masks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def precompute_high_res_features(self, backbone_out): + if self.use_high_res_features_in_sam: + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0( + backbone_out["backbone_fpn"][0] + ) + backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1( + backbone_out["backbone_fpn"][1] + ) + return backbone_out + + def forward_image(self, img_batch: torch.Tensor, pre_compute=True): + """Get the image feature on the input batch.""" + backbone_out = self.image_encoder(img_batch) + return backbone_out if not pre_compute else self.precompute_high_res_features(backbone_out) + + def _prepare_backbone_features(self, backbone_out): + """Prepare and flatten visual features.""" + backbone_out = backbone_out.copy() + assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) + assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels + + feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :] + vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :] + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] + # flatten NxCxHxW to HWxNxC + vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] + vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] + + return backbone_out, vision_feats, vision_pos_embeds, feat_sizes + + def _prepare_memory_conditioned_features( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + ): + """Fuse the current frame's visual feature map with previous memory.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + device = current_vision_feats[-1].device + # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images. + # In this case, we skip the fusion with any memory. + if self.num_maskmem == 0: # Disable memory and skip fusion + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + return pix_feat + + num_obj_ptr_tokens = 0 + tpos_sign_mul = -1 if track_in_reverse else 1 + # Step 1: condition the visual features of the current frame on previous memories + if not is_init_cond_frame: + # Retrieve the memories encoded with the maskmem backbone + to_cat_memory, to_cat_memory_pos_embed = [], [] + # Add conditioning frames's output first (all cond frames have t_pos=0 for + # when getting temporal positional embedding below) + assert len(output_dict["cond_frame_outputs"]) > 0 + # Select a maximum number of temporally closest cond frames for cross attention + cond_outputs = output_dict["cond_frame_outputs"] + selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames( + frame_idx, cond_outputs, self.max_cond_frames_in_attn + ) + t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()] + # for t_pos in range(1, min(self.num_maskmem, frame_idx)): + # out = output_dict["non_cond_frame_outputs"].get(t_pos, None) + # t_pos_and_prevs.append((t_pos, out)) + # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory + # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 + # We also allow taking the memory frame non-consecutively (with stride>1), in which case + # we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame. + stride = 1 if self.training else self.memory_temporal_stride_for_eval + + for t_pos in range(1, self.num_maskmem): + t_rel = self.num_maskmem - t_pos # how many frames before current frame + if t_rel == 1: + # for t_rel == 1, we take the last frame (regardless of r) + if not track_in_reverse: + # the frame immediately before this frame (i.e. frame_idx - 1) + prev_frame_idx = frame_idx - t_rel + else: + # the frame immediately after this frame (i.e. frame_idx + 1) + prev_frame_idx = frame_idx + t_rel + else: + # for t_rel >= 2, we take the memory frame from every r-th frames + if not track_in_reverse: + # first find the nearest frame among every r-th frames before this frame + # for r=1, this would be (frame_idx - 2) + prev_frame_idx = ((frame_idx - 2) // stride) * stride + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx - (t_rel - 2) * stride + else: + # first find the nearest frame among every r-th frames after this frame + # for r=1, this would be (frame_idx + 2) + prev_frame_idx = -(-(frame_idx + 2) // stride) * stride + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx + (t_rel - 2) * stride + out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) + if out is None: + # If an unselected conditioning frame is among the last (self.num_maskmem - 1) + # frames, we still attend to it as if it's a non-conditioning frame. + out = unselected_cond_outputs.get(prev_frame_idx, None) + t_pos_and_prevs.append((t_pos, out)) + + for t_pos, prev in t_pos_and_prevs: + if prev is None: + continue # skip padding frames + # "maskmem_features" might have been offloaded to CPU in demo use cases, + # so we load it back to GPU (it's a no-op if it's already on GPU). + feats = prev["maskmem_features"].to(device, non_blocking=True) + to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) + # Spatial positional encoding (it might have been offloaded to CPU in eval) + maskmem_enc = prev["maskmem_pos_enc"][-1].to(device) + maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) + # Temporal positional encoding + maskmem_enc = ( + maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1] + ) + to_cat_memory_pos_embed.append(maskmem_enc) + # Construct the list of past object pointers + if self.use_obj_ptrs_in_encoder: + max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder) + # First add those object pointers from selected conditioning frames + # (optionally, only include object pointers in the past during evaluation) + if not self.training and self.only_obj_ptrs_in_the_past_for_eval: + ptr_cond_outputs = { + t: out + for t, out in selected_cond_outputs.items() + if (t >= frame_idx if track_in_reverse else t <= frame_idx) + } + else: + ptr_cond_outputs = selected_cond_outputs + pos_and_ptrs = [ + # Temporal pos encoding contains how far away each pointer is from current frame + ( + ( + (frame_idx - t) * tpos_sign_mul + if self.use_signed_tpos_enc_to_obj_ptrs + else abs(frame_idx - t) + ), + out["obj_ptr"], + ) + for t, out in ptr_cond_outputs.items() + ] + # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame + for t_diff in range(1, max_obj_ptrs_in_encoder): + t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff + if t < 0 or (num_frames is not None and t >= num_frames): + break + out = output_dict["non_cond_frame_outputs"].get( + t, unselected_cond_outputs.get(t, None) + ) + if out is not None: + pos_and_ptrs.append((t_diff, out["obj_ptr"])) + # If we have at least one object pointer, add them to the across attention + if len(pos_and_ptrs) > 0: + pos_list, ptrs_list = zip(*pos_and_ptrs) + # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape + obj_ptrs = torch.stack(ptrs_list, dim=0) + # a temporal positional embedding based on how far each object pointer is from + # the current frame (sine embedding normalized by the max pointer num). + # default false. + if self.add_tpos_enc_to_obj_ptrs: + t_diff_max = max_obj_ptrs_in_encoder - 1 + tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim + obj_pos = torch.tensor(pos_list, device=device) + obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) + obj_pos = self.obj_ptr_tpos_proj(obj_pos) + obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) + else: + obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim) + if self.mem_dim < C: + # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C + obj_ptrs = obj_ptrs.reshape( + -1, B, C // self.mem_dim, self.mem_dim + ) + obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1) + obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0) + to_cat_memory.append(obj_ptrs) + to_cat_memory_pos_embed.append(obj_pos) + num_obj_ptr_tokens = obj_ptrs.shape[0] + else: + num_obj_ptr_tokens = 0 + else: + # for initial conditioning frames, encode them without using any previous memory + if self.directly_add_no_mem_embed: + # directly add no-mem embedding (instead of using the transformer encoder) + pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + # Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder) + # the Following lines will never be triggered. + raise NotImplementedError + to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] + to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] + + # Step 2: Concatenate the memories and forward through the transformer encoder + memory = torch.cat(to_cat_memory, dim=0) + memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0) + + pix_feat_with_mem = self.memory_attention( + curr=current_vision_feats, + curr_pos=current_vision_pos_embeds, + memory=memory, + memory_pos=memory_pos_embed, + num_obj_ptr_tokens=num_obj_ptr_tokens, + ) + # reshape the output (HW)BC => BCHW + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + def _encode_new_memory( + self, + current_vision_feats, + feat_sizes, + pred_masks_high_res, + object_score_logits, + is_mask_from_pts, + ): + """Encode the current image and its prediction into a memory feature.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + # top-level feature, (HW)BC => BCHW + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + if self.non_overlap_masks_for_mem_enc and not self.training: + # optionally, apply non-overlapping constraints to the masks (it's applied + # in the batch dimension and should only be used during eval, where all + # the objects come from the same video under batch size 1). + pred_masks_high_res = self._apply_non_overlapping_constraints( + pred_masks_high_res + ) + raise NotImplementedError + # scale the raw mask logits with a temperature before applying sigmoid + binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts + if binarize and not self.training: + mask_for_mem = (pred_masks_high_res > 0).float() + else: + # apply sigmoid on the raw mask logits to turn them into range (0, 1) + mask_for_mem = torch.sigmoid(pred_masks_high_res) + # apply scale and bias terms to the sigmoid probabilities + if self.sigmoid_scale_for_mem_enc != 1.0: + mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc + if self.sigmoid_bias_for_mem_enc != 0.0: + mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc + maskmem_out = self.memory_encoder( + pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied + ) + maskmem_features = maskmem_out["vision_features"] + maskmem_pos_enc = maskmem_out["vision_pos_enc"] + # add a no-object embedding to the spatial memory to indicate that the frame + # is predicted to be occluded (i.e. no object is appearing in the frame) + if self.no_obj_embed_spatial is not None: + is_obj_appearing = (object_score_logits > 0).float() + maskmem_features += ( + 1 - is_obj_appearing[..., None, None] + ) * self.no_obj_embed_spatial[..., None, None].expand( + *maskmem_features.shape + ) + # it will be used in sam2.1 + # raise NotImplementedError + + return maskmem_features, maskmem_pos_enc + + def _track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ): + current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(current_vision_feats) > 1: + high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) + ] + else: + high_res_features = None + if mask_inputs is not None and self.use_mask_input_as_output_without_sam: + # When use_mask_input_as_output_without_sam=True, we directly output the mask input + # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. + pix_feat = current_vision_feats[-1].permute(1, 2, 0) + pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + sam_outputs = self._use_mask_as_output( + pix_feat, high_res_features, mask_inputs + ) + else: + # fused the visual feature with previous memory features in the memory bank + pix_feat = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats[-1:], + current_vision_pos_embeds=current_vision_pos_embeds[-1:], + feat_sizes=feat_sizes[-1:], + output_dict=output_dict, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + ) + # apply SAM-style segmentation head + # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, + # e.g. in demo where such logits come from earlier interaction instead of correction sampling + # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) + if prev_sam_mask_logits is not None: + assert point_inputs is not None and mask_inputs is None + mask_inputs = prev_sam_mask_logits + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + + return current_out, sam_outputs, high_res_features, pix_feat + + def _encode_memory_in_output( + self, + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ): + if run_mem_encoder and self.num_maskmem > 0: + high_res_masks_for_mem_enc = high_res_masks + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks_for_mem_enc, + object_score_logits=object_score_logits, + is_mask_from_pts=(point_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + + def track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + # Whether to run the memory encoder on the predicted masks. Sometimes we might want + # to skip the memory encoder with `run_mem_encoder=False`. For example, + # in demo we might call `track_step` multiple times for each user click, + # and only encode the memory when the user finalizes their clicks. And in ablation + # settings like SAM training on static images, we don't need the memory encoder. + run_mem_encoder=True, + # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). + prev_sam_mask_logits=None, + ): + current_out, sam_outputs, _, _ = self._track_step( + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ) + + ( + _, + _, + _, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = sam_outputs + + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + if not self.training: + # Only add this in inference (to avoid unused param in activation checkpointing; + # it's mainly used in the demo to encode spatial memories w/ consolidated masks) + current_out["object_score_logits"] = object_score_logits + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + self._encode_memory_in_output( + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ) + + return current_out + + def _use_multimask(self, is_init_cond_frame, point_inputs): + """Whether to use multimask output in the SAM head.""" + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) + multimask_output = ( + self.multimask_output_in_sam + and (is_init_cond_frame or self.multimask_output_for_tracking) + and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) + ) + return multimask_output + + def _apply_non_overlapping_constraints(self, pred_masks): + """ + Apply non-overlapping constraints to the object scores in pred_masks. Here we + keep only the highest scoring object at each spatial location in pred_masks. + """ + batch_size = pred_masks.size(0) + if batch_size == 1: + return pred_masks + + device = pred_masks.device + # "max_obj_inds": object index of the object with the highest score at each location + max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) + # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks` + batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] + keep = max_obj_inds == batch_obj_inds + # suppress overlapping regions' scores below -10.0 so that the foreground regions + # don't overlap (here sigmoid(-10.0)=4.5398e-05) + pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) + return pred_masks diff --git a/avs.code/v1m.code/model/visual/sam2/modeling/sam2_utils.py b/avs.code/v1m.code/model/visual/sam2/modeling/sam2_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..19133558dd657bbcf67f851011d45bd4999cab0a --- /dev/null +++ b/avs.code/v1m.code/model/visual/sam2/modeling/sam2_utils.py @@ -0,0 +1,323 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import copy +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model.visual.sam2.utils.misc import mask_to_box + + +def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): + """ + Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs` + that are temporally closest to the current frame at `frame_idx`. Here, we take + - a) the closest conditioning frame before `frame_idx` (if any); + - b) the closest conditioning frame after `frame_idx` (if any); + - c) any other temporally closest conditioning frames until reaching a total + of `max_cond_frame_num` conditioning frames. + + Outputs: + - selected_outputs: selected items (keys & values) from `cond_frame_outputs`. + - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`. + """ + if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: + selected_outputs = cond_frame_outputs + unselected_outputs = {} + else: + assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" + selected_outputs = {} + + # the closest conditioning frame before `frame_idx` (if any) + idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None) + if idx_before is not None: + selected_outputs[idx_before] = cond_frame_outputs[idx_before] + + # the closest conditioning frame after `frame_idx` (if any) + idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None) + if idx_after is not None: + selected_outputs[idx_after] = cond_frame_outputs[idx_after] + + # add other temporally closest conditioning frames until reaching a total + # of `max_cond_frame_num` conditioning frames. + num_remain = max_cond_frame_num - len(selected_outputs) + inds_remain = sorted( + (t for t in cond_frame_outputs if t not in selected_outputs), + key=lambda x: abs(x - frame_idx), + )[:num_remain] + selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) + unselected_outputs = { + t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs + } + + return selected_outputs, unselected_outputs + + +def get_1d_sine_pe(pos_inds, dim, temperature=10000): + """ + Get 1D sine positional embedding as in the original Transformer paper. + """ + pe_dim = dim // 2 + dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) + dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) + + pos_embed = pos_inds.unsqueeze(-1) / dim_t + pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) + return pos_embed + + +def get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") + + +def get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +class DropPath(nn.Module): + # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py + def __init__(self, drop_prob=0.0, scale_by_keep=True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and self.scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + activation: nn.Module = nn.ReLU, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + self.act = activation() + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +def sample_box_points( + masks: torch.Tensor, + noise: float = 0.1, # SAM default + noise_bound: int = 20, # SAM default + top_left_label: int = 2, + bottom_right_label: int = 3, +) -> Tuple[np.array, np.array]: + """ + Sample a noised version of the top left and bottom right corners of a given `bbox` + + Inputs: + - masks: [B, 1, H,W] boxes, dtype=torch.Tensor + - noise: noise as a fraction of box width and height, dtype=float + - noise_bound: maximum amount of noise (in pure pixesl), dtype=int + + Returns: + - box_coords: [B, num_pt, 2], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.float + - box_labels: [B, num_pt], label 2 is reserverd for top left and 3 for bottom right corners, dtype=torch.int32 + """ + device = masks.device + box_coords = mask_to_box(masks) + B, _, H, W = masks.shape + box_labels = torch.tensor( + [top_left_label, bottom_right_label], dtype=torch.int, device=device + ).repeat(B) + if noise > 0.0: + if not isinstance(noise_bound, torch.Tensor): + noise_bound = torch.tensor(noise_bound, device=device) + bbox_w = box_coords[..., 2] - box_coords[..., 0] + bbox_h = box_coords[..., 3] - box_coords[..., 1] + max_dx = torch.min(bbox_w * noise, noise_bound) + max_dy = torch.min(bbox_h * noise, noise_bound) + box_noise = 2 * torch.rand(B, 1, 4, device=device) - 1 + box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1) + + box_coords = box_coords + box_noise + img_bounds = ( + torch.tensor([W, H, W, H], device=device) - 1 + ) # uncentered pixel coords + box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping + + box_coords = box_coords.reshape(-1, 2, 2) # always 2 points + box_labels = box_labels.reshape(-1, 2) + return box_coords, box_labels + + +def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1): + """ + Sample `num_pt` random points (along with their labels) independently from the error regions. + + Inputs: + - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool + - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None + - num_pt: int, number of points to sample independently for each of the B error maps + + Outputs: + - points: [B, num_pt, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point + - labels: [B, num_pt], dtype=torch.int32, where 1 means positive clicks and 0 means + negative clicks + """ + if pred_masks is None: # if pred_masks is not provided, treat it as empty + pred_masks = torch.zeros_like(gt_masks) + assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1 + assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape + assert num_pt >= 0 + + B, _, H_im, W_im = gt_masks.shape + device = gt_masks.device + + # false positive region, a new point sampled in this region should have + # negative label to correct the FP error + fp_masks = ~gt_masks & pred_masks + # false negative region, a new point sampled in this region should have + # positive label to correct the FN error + fn_masks = gt_masks & ~pred_masks + # whether the prediction completely match the ground-truth on each mask + all_correct = torch.all((gt_masks == pred_masks).flatten(2), dim=2) + all_correct = all_correct[..., None, None] + + # channel 0 is FP map, while channel 1 is FN map + pts_noise = torch.rand(B, num_pt, H_im, W_im, 2, device=device) + # sample a negative new click from FP region or a positive new click + # from FN region, depend on where the maximum falls, + # and in case the predictions are all correct (no FP or FN), we just + # sample a negative click from the background region + pts_noise[..., 0] *= fp_masks | (all_correct & ~gt_masks) + pts_noise[..., 1] *= fn_masks + pts_idx = pts_noise.flatten(2).argmax(dim=2) + labels = (pts_idx % 2).to(torch.int32) + pts_idx = pts_idx // 2 + pts_x = pts_idx % W_im + pts_y = pts_idx // W_im + points = torch.stack([pts_x, pts_y], dim=2).to(torch.float) + return points, labels + + +def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True): + """ + Sample 1 random point (along with its label) from the center of each error region, + that is, the point with the largest distance to the boundary of each error region. + This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py + + Inputs: + - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool + - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None + - padding: if True, pad with boundary of 1 px for distance transform + + Outputs: + - points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point + - labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks + """ + import cv2 + + if pred_masks is None: + pred_masks = torch.zeros_like(gt_masks) + assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1 + assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape + + B, _, _, W_im = gt_masks.shape + device = gt_masks.device + + # false positive region, a new point sampled in this region should have + # negative label to correct the FP error + fp_masks = ~gt_masks & pred_masks + # false negative region, a new point sampled in this region should have + # positive label to correct the FN error + fn_masks = gt_masks & ~pred_masks + + fp_masks = fp_masks.cpu().numpy() + fn_masks = fn_masks.cpu().numpy() + points = torch.zeros(B, 1, 2, dtype=torch.float) + labels = torch.ones(B, 1, dtype=torch.int32) + for b in range(B): + fn_mask = fn_masks[b, 0] + fp_mask = fp_masks[b, 0] + if padding: + fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant") + fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant") + # compute the distance of each point in FN/FP region to its boundary + fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0) + fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0) + if padding: + fn_mask_dt = fn_mask_dt[1:-1, 1:-1] + fp_mask_dt = fp_mask_dt[1:-1, 1:-1] + + # take the point in FN/FP region with the largest distance to its boundary + fn_mask_dt_flat = fn_mask_dt.reshape(-1) + fp_mask_dt_flat = fp_mask_dt.reshape(-1) + fn_argmax = np.argmax(fn_mask_dt_flat) + fp_argmax = np.argmax(fp_mask_dt_flat) + is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax] + pt_idx = fn_argmax if is_positive else fp_argmax + points[b, 0, 0] = pt_idx % W_im # x + points[b, 0, 1] = pt_idx // W_im # y + labels[b, 0] = int(is_positive) + + points = points.to(device) + labels = labels.to(device) + return points, labels + + +def get_next_point(gt_masks, pred_masks, method): + if method == "uniform": + return sample_random_points_from_errors(gt_masks, pred_masks) + elif method == "center": + return sample_one_point_from_error_center(gt_masks, pred_masks) + else: + raise ValueError(f"unknown sampling method {method}") diff --git a/avs.code/v1m.code/model/visual/sam2/organised_sam2_train.py b/avs.code/v1m.code/model/visual/sam2/organised_sam2_train.py new file mode 100644 index 0000000000000000000000000000000000000000..607c3ad22ba7dcb7eb74c30e1283f68c4808450e --- /dev/null +++ b/avs.code/v1m.code/model/visual/sam2/organised_sam2_train.py @@ -0,0 +1,811 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import numpy as np +import torch +import torch.distributed +from model.visual.sam2.modeling.sam2_base import SAM2Base +from model.visual.sam2.modeling.sam2_utils import ( + get_1d_sine_pe, + get_next_point, + sample_box_points, + select_closest_cond_frames, +) + +from utils.misc import concat_points + +from utils.data_utils import BatchedVideoDatapoint + + +class SAM2Train(SAM2Base): + def __init__( + self, + image_encoder, + memory_attention=None, + memory_encoder=None, + prob_to_use_pt_input_for_train=0.0, + prob_to_use_pt_input_for_eval=0.0, + prob_to_use_box_input_for_train=0.0, + prob_to_use_box_input_for_eval=0.0, + # if it is greater than 1, we interactive point sampling in the 1st frame and other randomly selected frames + num_frames_to_correct_for_train=1, # default: only iteratively sample on first frame + num_frames_to_correct_for_eval=1, # default: only iteratively sample on first frame + rand_frames_to_correct_for_train=False, + rand_frames_to_correct_for_eval=False, + # how many frames to use as initial conditioning frames (for both point input and mask input; the first frame is always used as an initial conditioning frame) + # - if `rand_init_cond_frames` below is True, we randomly sample 1~num_init_cond_frames initial conditioning frames + # - otherwise we sample a fixed number of num_init_cond_frames initial conditioning frames + # note: for point input, we sample correction points on all such initial conditioning frames, and we require that `num_frames_to_correct` >= `num_init_cond_frames`; + # these are initial conditioning frames because as we track the video, more conditioning frames might be added + # when a frame receives correction clicks under point input if `add_all_frames_to_correct_as_cond=True` + num_init_cond_frames_for_train=1, # default: only use the first frame as initial conditioning frame + num_init_cond_frames_for_eval=1, # default: only use the first frame as initial conditioning frame + rand_init_cond_frames_for_train=True, # default: random 1~num_init_cond_frames_for_train cond frames (to be constent w/ previous TA data loader) + rand_init_cond_frames_for_eval=False, + # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click + # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames + add_all_frames_to_correct_as_cond=False, + # how many additional correction points to sample (on each frame selected to be corrected) + # note that the first frame receives an initial input click (in addition to any correction clicks) + num_correction_pt_per_frame=7, + # method for point sampling during evaluation + # "uniform" (sample uniformly from error region) or "center" (use the point with the largest distance to error region boundary) + # default to "center" to be consistent with evaluation in the SAM paper + pt_sampling_for_eval="center", + # During training, we optionally allow sampling the correction points from GT regions + # instead of the prediction error regions with a small probability. This might allow the + # model to overfit less to the error regions in training datasets + prob_to_sample_from_gt_for_train=0.0, + use_act_ckpt_iterative_pt_sampling=False, + # whether to forward image features per frame (as it's being tracked) during evaluation, instead of forwarding image features + # of all frames at once. This avoids backbone OOM errors on very long videos in evaluation, but could be slightly slower. + forward_backbone_per_frame_for_eval=False, + freeze_image_encoder=False, + **kwargs, + ): + super().__init__(image_encoder, memory_attention, memory_encoder, **kwargs) + self.use_act_ckpt_iterative_pt_sampling = use_act_ckpt_iterative_pt_sampling + self.forward_backbone_per_frame_for_eval = forward_backbone_per_frame_for_eval + + # Point sampler and conditioning frames + self.prob_to_use_pt_input_for_train = prob_to_use_pt_input_for_train + self.prob_to_use_box_input_for_train = prob_to_use_box_input_for_train + self.prob_to_use_pt_input_for_eval = prob_to_use_pt_input_for_eval + self.prob_to_use_box_input_for_eval = prob_to_use_box_input_for_eval + if prob_to_use_pt_input_for_train > 0 or prob_to_use_pt_input_for_eval > 0: + logging.info( + f"Training with points (sampled from masks) as inputs with p={prob_to_use_pt_input_for_train}" + ) + assert num_frames_to_correct_for_train >= num_init_cond_frames_for_train + assert num_frames_to_correct_for_eval >= num_init_cond_frames_for_eval + + self.num_frames_to_correct_for_train = num_frames_to_correct_for_train + self.num_frames_to_correct_for_eval = num_frames_to_correct_for_eval + self.rand_frames_to_correct_for_train = rand_frames_to_correct_for_train + self.rand_frames_to_correct_for_eval = rand_frames_to_correct_for_eval + # Initial multi-conditioning frames + self.num_init_cond_frames_for_train = num_init_cond_frames_for_train + self.num_init_cond_frames_for_eval = num_init_cond_frames_for_eval + self.rand_init_cond_frames_for_train = rand_init_cond_frames_for_train + self.rand_init_cond_frames_for_eval = rand_init_cond_frames_for_eval + self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond + self.num_correction_pt_per_frame = num_correction_pt_per_frame + self.pt_sampling_for_eval = pt_sampling_for_eval + self.prob_to_sample_from_gt_for_train = prob_to_sample_from_gt_for_train + # A random number generator with a fixed initial seed across GPUs + self.rng = np.random.default_rng(seed=42) + if freeze_image_encoder: + for p in self.image_encoder.parameters(): + p.requires_grad = False + + + def forward(self, input: BatchedVideoDatapoint): + if self.training or not self.forward_backbone_per_frame_for_eval: + # precompute image features on all frames before tracking + backbone_out = self.forward_image(input.flat_img_batch) + else: + # defer image feature computation on a frame until it's being tracked + backbone_out = {"backbone_fpn": None, "vision_pos_enc": None} + backbone_out = self.prepare_prompt_inputs(backbone_out, input) + previous_stages_out = self.forward_tracking(backbone_out, input) + + return previous_stages_out + + def _prepare_backbone_features_per_frame(self, img_batch, img_ids): + """Compute the image backbone features on the fly for the given img_ids.""" + # Only forward backbone on unique image ids to avoid repetitive computation + # (if `img_ids` has only one element, it's already unique so we skip this step). + if img_ids.numel() > 1: + unique_img_ids, inv_ids = torch.unique(img_ids, return_inverse=True) + else: + unique_img_ids, inv_ids = img_ids, None + + # Compute the image features on those unique image ids + image = img_batch[unique_img_ids] + backbone_out = self.forward_image(image) + ( + _, + vision_feats, + vision_pos_embeds, + feat_sizes, + ) = self._prepare_backbone_features(backbone_out) + ''' + vision_feats + torch.Size([65536, 5, 32]) + torch.Size([16384, 5, 64]) + torch.Size([4096, 5, 256]) + ''' + # Inverse-map image features for `unique_img_ids` to the final image features + # for the original input `img_ids`. + if inv_ids is not None: + image = image[inv_ids] + vision_feats = [x[:, inv_ids] for x in vision_feats] + vision_pos_embeds = [x[:, inv_ids] for x in vision_pos_embeds] + + return image, vision_feats, vision_pos_embeds, feat_sizes + + @staticmethod + def dont_prepare_prompt_inputs(backbone_out, num_frames=5, cond_frame=0): + backbone_out["gt_masks_per_frame"] = {} + backbone_out["num_frames"] = num_frames + backbone_out["use_pt_input"] = False + # always start from the first frame. + backbone_out["init_cond_frames"] = [cond_frame] + backbone_out["frames_not_in_init_cond"] = [i for i in range(0, num_frames) if i != cond_frame] + # backbone_out["init_cond_frames"] = [] + # backbone_out["frames_not_in_init_cond"] = [i for i in range(0, num_frames)] + + backbone_out["mask_inputs_per_frame"] = {} + backbone_out["point_inputs_per_frame"] = {} + backbone_out["frames_to_add_correction_pt"] = [] + return backbone_out + + def prepare_prompt_inputs(self, backbone_out, input, start_frame_idx=0): + """ + Prepare input mask, point or box prompts. Optionally, we allow tracking from + a custom `start_frame_idx` to the end of the video (for evaluation purposes). + """ + # Load the ground-truth masks on all frames (so that we can later + # sample correction points from them) + # gt_masks_per_frame = { + # stage_id: targets.segments.unsqueeze(1) # [B, 1, H_im, W_im] + # for stage_id, targets in enumerate(input.find_targets) + # } + gt_masks_per_frame = { + stage_id: masks.unsqueeze(1) # [B, 1, H_im, W_im] + for stage_id, masks in enumerate(input.masks) + } + # gt_masks_per_frame = input.masks.unsqueeze(2) # [T,B,1,H_im,W_im] keep everything in tensor form + backbone_out["gt_masks_per_frame"] = gt_masks_per_frame + num_frames = input.num_frames + backbone_out["num_frames"] = num_frames + + # Randomly decide whether to use point inputs or mask inputs + if self.training: + prob_to_use_pt_input = self.prob_to_use_pt_input_for_train + prob_to_use_box_input = self.prob_to_use_box_input_for_train + num_frames_to_correct = self.num_frames_to_correct_for_train + rand_frames_to_correct = self.rand_frames_to_correct_for_train + num_init_cond_frames = self.num_init_cond_frames_for_train + rand_init_cond_frames = self.rand_init_cond_frames_for_train + else: + prob_to_use_pt_input = self.prob_to_use_pt_input_for_eval + prob_to_use_box_input = self.prob_to_use_box_input_for_eval + num_frames_to_correct = self.num_frames_to_correct_for_eval + rand_frames_to_correct = self.rand_frames_to_correct_for_eval + num_init_cond_frames = self.num_init_cond_frames_for_eval + rand_init_cond_frames = self.rand_init_cond_frames_for_eval + if num_frames == 1: + # here we handle a special case for mixing video + SAM on image training, + # where we force using point input for the SAM task on static images + prob_to_use_pt_input = 1.0 + num_frames_to_correct = 1 + num_init_cond_frames = 1 + assert num_init_cond_frames >= 1 + # (here `self.rng.random()` returns value in range 0.0 <= X < 1.0) + use_pt_input = self.rng.random() < prob_to_use_pt_input + if rand_init_cond_frames and num_init_cond_frames > 1: + # randomly select 1 to `num_init_cond_frames` frames as initial conditioning frames + num_init_cond_frames = self.rng.integers( + 1, num_init_cond_frames, endpoint=True + ) + if ( + use_pt_input + and rand_frames_to_correct + and num_frames_to_correct > num_init_cond_frames + ): + # randomly select `num_init_cond_frames` to `num_frames_to_correct` frames to sample + # correction clicks (only for the case of point input) + num_frames_to_correct = self.rng.integers( + num_init_cond_frames, num_frames_to_correct, endpoint=True + ) + backbone_out["use_pt_input"] = use_pt_input + + # Sample initial conditioning frames + if num_init_cond_frames == 1: + init_cond_frames = [start_frame_idx] # starting frame + else: + # starting frame + randomly selected remaining frames (without replacement) + init_cond_frames = [start_frame_idx] + self.rng.choice( + range(start_frame_idx + 1, num_frames), + num_init_cond_frames - 1, + replace=False, + ).tolist() + backbone_out["init_cond_frames"] = init_cond_frames + backbone_out["frames_not_in_init_cond"] = [ + t for t in range(start_frame_idx, num_frames) if t not in init_cond_frames + ] + # Prepare mask or point inputs on initial conditioning frames + backbone_out["mask_inputs_per_frame"] = {} # {frame_idx: } + backbone_out["point_inputs_per_frame"] = {} # {frame_idx: } + for t in init_cond_frames: + if not use_pt_input: + backbone_out["mask_inputs_per_frame"][t] = gt_masks_per_frame[t] + else: + # During training # P(box) = prob_to_use_pt_input * prob_to_use_box_input + use_box_input = self.rng.random() < prob_to_use_box_input + if use_box_input: + points, labels = sample_box_points( + gt_masks_per_frame[t], + ) + else: + # (here we only sample **one initial point** on initial conditioning frames from the + # ground-truth mask; we may sample more correction points on the fly) + points, labels = get_next_point( + gt_masks=gt_masks_per_frame[t], + pred_masks=None, + method=( + "uniform" if self.training else self.pt_sampling_for_eval + ), + ) + + point_inputs = {"point_coords": points, "point_labels": labels} + backbone_out["point_inputs_per_frame"][t] = point_inputs + + # Sample frames where we will add correction clicks on the fly + # based on the error between prediction and ground-truth masks + if not use_pt_input: + # no correction points will be sampled when using mask inputs + frames_to_add_correction_pt = [] + elif num_frames_to_correct == num_init_cond_frames: + frames_to_add_correction_pt = init_cond_frames + else: + assert num_frames_to_correct > num_init_cond_frames + # initial cond frame + randomly selected remaining frames (without replacement) + extra_num = num_frames_to_correct - num_init_cond_frames + frames_to_add_correction_pt = ( + init_cond_frames + + self.rng.choice( + backbone_out["frames_not_in_init_cond"], extra_num, replace=False + ).tolist() + ) + backbone_out["frames_to_add_correction_pt"] = frames_to_add_correction_pt + + return backbone_out + + def forward_tracking_wo_prompt(self, backbone_out, audio_res=None, return_dict=False): + # img_feats_already_computed = True. + """Forward video tracking on each frame (and sample correction clicks).""" + # Prepare the backbone features + # - vision_feats and vision_pos_embeds are in (HW)BC format + ( + _, + vision_feats, + vision_pos_embeds, + feat_sizes, + ) = self._prepare_backbone_features(backbone_out) + + # Starting the stage loop + num_frames = backbone_out["num_frames"] + init_cond_frames = backbone_out["init_cond_frames"] + frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"] + # first process all the initial conditioning frames to encode them as memory, + # and then conditioning on them to track the remaining frames + processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"] + output_dict = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + + av_v_feats, av_a_feats = audio_res + for stage_id in processing_order: + # Get the image features for the current frames + img_ids = stage_id + # Retrieve image features according to img_ids (if they are already computed). + current_vision_feats = [x[:, img_ids].unsqueeze(1) for x in vision_feats] # add unsqueeze to maintain single sample. + current_vision_pos_embeds = [x[:, img_ids].unsqueeze(1) for x in vision_pos_embeds] # add unsqueeze to maintain single sample. + current_av_v_feats = [x[img_ids] for x in av_v_feats] + current_av_a_feats = [x[img_ids] for x in av_a_feats] + + # Get output masks based on this frame's prompts and previous memory + current_out = self.track_step_wo_prompt( + frame_idx=stage_id, + is_init_cond_frame=stage_id in init_cond_frames, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=None, # backbone_out["point_inputs_per_frame"].get(stage_id, None), + mask_inputs=None, # backbone_out["mask_inputs_per_frame"].get(stage_id, None), + gt_masks=None, # backbone_out["gt_masks_per_frame"].get(stage_id, None), + frames_to_add_correction_pt=None, # frames_to_add_correction_pt, + output_dict=output_dict, + num_frames=num_frames, + audio_res=(current_av_v_feats, current_av_a_feats), + ) + # Append the output, depending on whether it's a conditioning frame + add_output_as_cond_frame = stage_id in init_cond_frames or ( + self.add_all_frames_to_correct_as_cond + and stage_id in frames_to_add_correction_pt + ) + if add_output_as_cond_frame: + output_dict["cond_frame_outputs"][stage_id] = current_out + else: + output_dict["non_cond_frame_outputs"][stage_id] = current_out + + if return_dict: + return output_dict + # turn `output_dict` into a list for loss function + all_frame_outputs = {} + all_frame_outputs.update(output_dict["cond_frame_outputs"]) + all_frame_outputs.update(output_dict["non_cond_frame_outputs"]) + all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)] + # Make DDP happy with activation checkpointing by removing unused keys + all_frame_outputs = [ + {k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs + ] + + + return all_frame_outputs + + def track_step_wo_prompt( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + run_mem_encoder=True, # Whether to run the memory encoder on the predicted masks. + prev_sam_mask_logits=None, # The previously predicted SAM mask logits. + frames_to_add_correction_pt=None, + gt_masks=None, + audio_res=None, + ): + if frames_to_add_correction_pt is None: + frames_to_add_correction_pt = [] + + current_out, sam_outputs, high_res_features, pix_feat = self._track_step_wo_prompt( + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + audio_res + ) + + ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = sam_outputs + current_out["multistep_pred_masks"] = low_res_masks + current_out["multistep_pred_masks_high_res"] = high_res_masks + current_out["multistep_pred_multimasks"] = [low_res_multimasks] + current_out["multistep_pred_multimasks_high_res"] = [high_res_multimasks] + current_out["multistep_pred_ious"] = [ious] + current_out["multistep_point_inputs"] = [point_inputs] + current_out["multistep_object_score_logits"] = [object_score_logits] + + ''' + # Optionally, sample correction points iteratively to correct the mask + if frame_idx in frames_to_add_correction_pt: + point_inputs, final_sam_outputs = self._iter_correct_pt_sampling( + is_init_cond_frame, + point_inputs, + gt_masks, + high_res_features, + pix_feat, + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + object_score_logits, + current_out, + ) + ( + _, + _, + _, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = final_sam_outputs + ''' + # Use the final prediction (after all correction steps for output and eval) + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + + self._encode_memory_in_output( + current_vision_feats, + feat_sizes, + 666., # point_inputs, + run_mem_encoder, + # we follow SAM2 predictor, if we have multiple masks output, we only utilise the first one to perform + # the memory rope attention. + high_res_masks, + object_score_logits, + current_out, + ) + return current_out + + def _track_step_wo_prompt( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + audio_res=None + ): + current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(current_vision_feats) > 1: + high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) + ] + else: + high_res_features = None + if mask_inputs is not None and self.use_mask_input_as_output_without_sam: # False + # When use_mask_input_as_output_without_sam=True, we directly output the mask input + # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. + pix_feat = current_vision_feats[-1].permute(1, 2, 0) + pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + sam_outputs = self._use_mask_as_output( + pix_feat, high_res_features, mask_inputs + ) + else: + # fused the visual feature with previous memory features in the memory bank + pix_feat = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats[-1:], + current_vision_pos_embeds=current_vision_pos_embeds[-1:], + feat_sizes=feat_sizes[-1:], + output_dict=output_dict, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + ) + # current_vision_feats[-1] = current_vision_feats[-1] + self.no_mem_embed + # pix_feat = current_vision_feats[-1].permute(1, 2, 0) + # pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + + # we do not apply any prompts except audio. + ''' + # apply SAM-style segmentation head + # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, + # e.g. in demo where such logits come from earlier interaction instead of correction sampling + # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) + # if prev_sam_mask_logits is not None: + # assert point_inputs is not None and mask_inputs is None + # mask_inputs = prev_sam_mask_logits + + ## comment this line, as we don't use points as prompts. + # multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + ''' + + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=True, + audio_res=audio_res + ) + + return current_out, sam_outputs, high_res_features, pix_feat + + def forward_tracking( + self, backbone_out, input: BatchedVideoDatapoint, return_dict=False + ): + """Forward video tracking on each frame (and sample correction clicks).""" + img_feats_already_computed = backbone_out["backbone_fpn"] is not None + if img_feats_already_computed: + # Prepare the backbone features + # - vision_feats and vision_pos_embeds are in (HW)BC format + ( + _, + vision_feats, + vision_pos_embeds, + feat_sizes, + ) = self._prepare_backbone_features(backbone_out) + + # Starting the stage loop + num_frames = backbone_out["num_frames"] + init_cond_frames = backbone_out["init_cond_frames"] + frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"] + # first process all the initial conditioning frames to encode them as memory, + # and then conditioning on them to track the remaining frames + processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"] + output_dict = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + for stage_id in processing_order: + # Get the image features for the current frames + # img_ids = input.find_inputs[stage_id].img_ids + img_ids = input.flat_obj_to_img_idx[stage_id] + if img_feats_already_computed: + # Retrieve image features according to img_ids (if they are already computed). + current_vision_feats = [x[:, img_ids] for x in vision_feats] + current_vision_pos_embeds = [x[:, img_ids] for x in vision_pos_embeds] + else: + # Otherwise, compute the image features on the fly for the given img_ids + # (this might be used for evaluation on long videos to avoid backbone OOM). + ( + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._prepare_backbone_features_per_frame( + input.flat_img_batch, img_ids + ) + + # Get output masks based on this frame's prompts and previous memory + current_out = self.track_step( + frame_idx=stage_id, + is_init_cond_frame=stage_id in init_cond_frames, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=backbone_out["point_inputs_per_frame"].get(stage_id, None), + mask_inputs=backbone_out["mask_inputs_per_frame"].get(stage_id, None), + gt_masks=backbone_out["gt_masks_per_frame"].get(stage_id, None), + frames_to_add_correction_pt=frames_to_add_correction_pt, + output_dict=output_dict, + num_frames=num_frames, + ) + # Append the output, depending on whether it's a conditioning frame + add_output_as_cond_frame = stage_id in init_cond_frames or ( + self.add_all_frames_to_correct_as_cond + and stage_id in frames_to_add_correction_pt + ) + if add_output_as_cond_frame: + output_dict["cond_frame_outputs"][stage_id] = current_out + else: + output_dict["non_cond_frame_outputs"][stage_id] = current_out + + if return_dict: + return output_dict + # turn `output_dict` into a list for loss function + all_frame_outputs = {} + all_frame_outputs.update(output_dict["cond_frame_outputs"]) + all_frame_outputs.update(output_dict["non_cond_frame_outputs"]) + all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)] + # Make DDP happy with activation checkpointing by removing unused keys + all_frame_outputs = [ + {k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs + ] + + return all_frame_outputs + + def track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + run_mem_encoder=True, # Whether to run the memory encoder on the predicted masks. + prev_sam_mask_logits=None, # The previously predicted SAM mask logits. + frames_to_add_correction_pt=None, + gt_masks=None, + ): + if frames_to_add_correction_pt is None: + frames_to_add_correction_pt = [] + current_out, sam_outputs, high_res_features, pix_feat = self._track_step( + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ) + + ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = sam_outputs + + current_out["multistep_pred_masks"] = low_res_masks + current_out["multistep_pred_masks_high_res"] = high_res_masks + current_out["multistep_pred_multimasks"] = [low_res_multimasks] + current_out["multistep_pred_multimasks_high_res"] = [high_res_multimasks] + current_out["multistep_pred_ious"] = [ious] + current_out["multistep_point_inputs"] = [point_inputs] + current_out["multistep_object_score_logits"] = [object_score_logits] + + # Optionally, sample correction points iteratively to correct the mask + if frame_idx in frames_to_add_correction_pt: + point_inputs, final_sam_outputs = self._iter_correct_pt_sampling( + is_init_cond_frame, + point_inputs, + gt_masks, + high_res_features, + pix_feat, + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + object_score_logits, + current_out, + ) + ( + _, + _, + _, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = final_sam_outputs + + # Use the final prediction (after all correction steps for output and eval) + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + self._encode_memory_in_output( + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ) + return current_out + + def _iter_correct_pt_sampling( + self, + is_init_cond_frame, + point_inputs, + gt_masks, + high_res_features, + pix_feat_with_mem, + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + object_score_logits, + current_out, + ): + + assert gt_masks is not None + all_pred_masks = [low_res_masks] + all_pred_high_res_masks = [high_res_masks] + all_pred_multimasks = [low_res_multimasks] + all_pred_high_res_multimasks = [high_res_multimasks] + all_pred_ious = [ious] + all_point_inputs = [point_inputs] + all_object_score_logits = [object_score_logits] + for _ in range(self.num_correction_pt_per_frame): + # sample a new point from the error between prediction and ground-truth + # (with a small probability, directly sample from GT masks instead of errors) + if self.training and self.prob_to_sample_from_gt_for_train > 0: + sample_from_gt = ( + self.rng.random() < self.prob_to_sample_from_gt_for_train + ) + else: + sample_from_gt = False + # if `pred_for_new_pt` is None, only GT masks will be used for point sampling + pred_for_new_pt = None if sample_from_gt else (high_res_masks > 0) + new_points, new_labels = get_next_point( + gt_masks=gt_masks, + pred_masks=pred_for_new_pt, + method="uniform" if self.training else self.pt_sampling_for_eval, + ) + point_inputs = concat_points(point_inputs, new_points, new_labels) + # Feed the mask logits of the previous SAM outputs in the next SAM decoder step. + # For tracking, this means that when the user adds a correction click, we also feed + # the tracking output mask logits along with the click as input to the SAM decoder. + mask_inputs = low_res_masks + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + if self.use_act_ckpt_iterative_pt_sampling and not multimask_output: + sam_outputs = torch.utils.checkpoint.checkpoint( + self._forward_sam_heads, + backbone_features=pix_feat_with_mem, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + use_reentrant=False, + ) + else: + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat_with_mem, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + _, + object_score_logits, + ) = sam_outputs + all_pred_masks.append(low_res_masks) + all_pred_high_res_masks.append(high_res_masks) + all_pred_multimasks.append(low_res_multimasks) + all_pred_high_res_multimasks.append(high_res_multimasks) + all_pred_ious.append(ious) + all_point_inputs.append(point_inputs) + all_object_score_logits.append(object_score_logits) + + # Concatenate the masks along channel (to compute losses on all of them, + # using `MultiStepIteractiveMasks`) + current_out["multistep_pred_masks"] = torch.cat(all_pred_masks, dim=1) + current_out["multistep_pred_masks_high_res"] = torch.cat( + all_pred_high_res_masks, dim=1 + ) + current_out["multistep_pred_multimasks"] = all_pred_multimasks + current_out["multistep_pred_multimasks_high_res"] = all_pred_high_res_multimasks + current_out["multistep_pred_ious"] = all_pred_ious + current_out["multistep_point_inputs"] = all_point_inputs + current_out["multistep_object_score_logits"] = all_object_score_logits + + return point_inputs, sam_outputs diff --git a/avs.code/v1m.code/model/visual/sam2/utils/__init__.py b/avs.code/v1m.code/model/visual/sam2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/avs.code/v1m.code/model/visual/sam2/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/avs.code/v1m.code/model/visual/sam2/utils/misc.py b/avs.code/v1m.code/model/visual/sam2/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..b65ee825732ff85137805be650edd4cbe8e6f6d4 --- /dev/null +++ b/avs.code/v1m.code/model/visual/sam2/utils/misc.py @@ -0,0 +1,349 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import warnings +from threading import Thread + +import numpy as np +import torch +from PIL import Image +from tqdm import tqdm + + +def get_sdpa_settings(): + if torch.cuda.is_available(): + old_gpu = torch.cuda.get_device_properties(0).major < 7 + # only use Flash Attention on Ampere (8.0) or newer GPUs + use_flash_attn = torch.cuda.get_device_properties(0).major >= 8 + if not use_flash_attn: + warnings.warn( + "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.", + category=UserWarning, + stacklevel=2, + ) + # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only + # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases) + pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2]) + if pytorch_version < (2, 2): + warnings.warn( + f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. " + "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).", + category=UserWarning, + stacklevel=2, + ) + math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn + else: + old_gpu = True + use_flash_attn = False + math_kernel_on = True + + return old_gpu, use_flash_attn, math_kernel_on + + +def get_connected_components(mask): + """ + Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W). + + Inputs: + - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is + background. + + Outputs: + - labels: A tensor of shape (N, 1, H, W) containing the connected component labels + for foreground pixels and 0 for background pixels. + - counts: A tensor of shape (N, 1, H, W) containing the area of the connected + components for foreground pixels and 0 for background pixels. + """ + from sam2 import _C + + return _C.get_connected_componnets(mask.to(torch.uint8).contiguous()) + + +def mask_to_box(masks: torch.Tensor): + """ + compute bounding box given an input mask + + Inputs: + - masks: [B, 1, H, W] masks, dtype=torch.Tensor + + Returns: + - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor + """ + B, _, h, w = masks.shape + device = masks.device + xs = torch.arange(w, device=device, dtype=torch.int32) + ys = torch.arange(h, device=device, dtype=torch.int32) + grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy") + grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w) + grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w) + min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1) + max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1) + min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1) + max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1) + bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1) + + return bbox_coords + + +def _load_img_as_tensor(img_path, image_size): + img_pil = Image.open(img_path) + img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) + if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images + img_np = img_np / 255.0 + else: + raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}") + img = torch.from_numpy(img_np).permute(2, 0, 1) + video_width, video_height = img_pil.size # the original video size + return img, video_height, video_width + + +class AsyncVideoFrameLoader: + """ + A list of video frames to be load asynchronously without blocking session start. + """ + + def __init__( + self, + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ): + self.img_paths = img_paths + self.image_size = image_size + self.offload_video_to_cpu = offload_video_to_cpu + self.img_mean = img_mean + self.img_std = img_std + # items in `self.images` will be loaded asynchronously + self.images = [None] * len(img_paths) + # catch and raise any exceptions in the async loading thread + self.exception = None + # video_height and video_width be filled when loading the first image + self.video_height = None + self.video_width = None + self.compute_device = compute_device + + # load the first frame to fill video_height and video_width and also + # to cache it (since it's most likely where the user will click) + self.__getitem__(0) + + # load the rest of frames asynchronously without blocking the session start + def _load_frames(): + try: + for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"): + self.__getitem__(n) + except Exception as e: + self.exception = e + + self.thread = Thread(target=_load_frames, daemon=True) + self.thread.start() + + def __getitem__(self, index): + if self.exception is not None: + raise RuntimeError("Failure in frame loading thread") from self.exception + + img = self.images[index] + if img is not None: + return img + + img, video_height, video_width = _load_img_as_tensor( + self.img_paths[index], self.image_size + ) + self.video_height = video_height + self.video_width = video_width + # normalize by mean and std + img -= self.img_mean + img /= self.img_std + if not self.offload_video_to_cpu: + img = img.to(self.compute_device, non_blocking=True) + self.images[index] = img + return img + + def __len__(self): + return len(self.images) + + +def load_video_frames( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + compute_device=torch.device("cuda"), +): + """ + Load the video frames from video_path. The frames are resized to image_size as in + the model and are loaded to GPU if offload_video_to_cpu=False. This is used by the demo. + """ + is_bytes = isinstance(video_path, bytes) + is_str = isinstance(video_path, str) + is_mp4_path = is_str and os.path.splitext(video_path)[-1] in [".mp4", ".MP4"] + if is_bytes or is_mp4_path: + return load_video_frames_from_video_file( + video_path=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + compute_device=compute_device, + ) + elif is_str and os.path.isdir(video_path): + return load_video_frames_from_jpg_images( + video_path=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + async_loading_frames=async_loading_frames, + compute_device=compute_device, + ) + else: + raise NotImplementedError( + "Only MP4 video and JPEG folder are supported at this moment" + ) + + +def load_video_frames_from_jpg_images( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + compute_device=torch.device("cuda"), +): + """ + Load the video frames from a directory of JPEG files (".jpg" format). + + The frames are resized to image_size x image_size and are loaded to GPU if + `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. + + You can load a frame asynchronously by setting `async_loading_frames` to `True`. + """ + if isinstance(video_path, str) and os.path.isdir(video_path): + jpg_folder = video_path + else: + raise NotImplementedError( + "Only JPEG frames are supported at this moment. For video files, you may use " + "ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n" + "```\n" + "ffmpeg -i .mp4 -q:v 2 -start_number 0 /'%05d.jpg'\n" + "```\n" + "where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks " + "ffmpeg to start the JPEG file from 00000.jpg." + ) + + frame_names = [ + p + for p in os.listdir(jpg_folder) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + ] + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + num_frames = len(frame_names) + if num_frames == 0: + raise RuntimeError(f"no images found in {jpg_folder}") + img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + + if async_loading_frames: + lazy_images = AsyncVideoFrameLoader( + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ) + return lazy_images, lazy_images.video_height, lazy_images.video_width + + images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) + for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): + images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) + if not offload_video_to_cpu: + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def load_video_frames_from_video_file( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + compute_device=torch.device("cuda"), +): + """Load the video frames from a video file.""" + import decord + + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + # Get the original video height and width + decord.bridge.set_bridge("torch") + video_height, video_width, _ = decord.VideoReader(video_path).next().shape + # Iterate over all frames in the video + images = [] + for frame in decord.VideoReader(video_path, width=image_size, height=image_size): + images.append(frame.permute(2, 0, 1)) + + images = torch.stack(images, dim=0).float() / 255.0 + if not offload_video_to_cpu: + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def fill_holes_in_mask_scores(mask, max_area): + """ + A post processor to fill small holes in mask scores with area under `max_area`. + """ + # Holes are those connected components in background with area <= self.max_area + # (background regions are those with mask scores <= 0) + assert max_area > 0, "max_area must be positive" + + input_mask = mask + try: + labels, areas = get_connected_components(mask <= 0) + is_hole = (labels > 0) & (areas <= max_area) + # We fill holes with a small positive mask score (0.1) to change them to foreground. + mask = torch.where(is_hole, 0.1, mask) + except Exception as e: + # Skip the post-processing step on removing small holes if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. You can " + "still use SAM 2 and it's OK to ignore the error above, although some post-processing " + "functionality may be limited (which doesn't affect the results in most cases; see " + "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + mask = input_mask + + return mask + + +def concat_points(old_point_inputs, new_points, new_labels): + """Add new points and labels to previous point inputs (add at the end).""" + if old_point_inputs is None: + points, labels = new_points, new_labels + else: + points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1) + labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1) + + return {"point_coords": points, "point_labels": labels} diff --git a/avs.code/v1m.code/model/visual/sam2/utils/transforms.py b/avs.code/v1m.code/model/visual/sam2/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6d4fa6a3e4d2e2a0dde7f87e4991daff338467c4 --- /dev/null +++ b/avs.code/v1m.code/model/visual/sam2/utils/transforms.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import Normalize, Resize, ToTensor + + +class SAM2Transforms(nn.Module): + def __init__( + self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0 + ): + """ + Transforms for SAM2. + """ + super().__init__() + self.resolution = resolution + self.mask_threshold = mask_threshold + self.max_hole_area = max_hole_area + self.max_sprinkle_area = max_sprinkle_area + self.mean = [0.485, 0.456, 0.406] + self.std = [0.229, 0.224, 0.225] + self.to_tensor = ToTensor() + self.transforms = torch.jit.script( + nn.Sequential( + Resize((self.resolution, self.resolution)), + Normalize(self.mean, self.std), + ) + ) + + def __call__(self, x): + x = self.to_tensor(x) + return self.transforms(x) + + def forward_batch(self, img_list): + img_batch = [self.transforms(self.to_tensor(img)) for img in img_list] + img_batch = torch.stack(img_batch, dim=0) + return img_batch + + def transform_coords( + self, coords: torch.Tensor, normalize=False, orig_hw=None + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, + If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + + Returns + Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model. + """ + if normalize: + assert orig_hw is not None + h, w = orig_hw + coords = coords.clone() + coords[..., 0] = coords[..., 0] / w + coords[..., 1] = coords[..., 1] / h + + coords = coords * self.resolution # unnormalize coords + return coords + + def transform_boxes( + self, boxes: torch.Tensor, normalize=False, orig_hw=None + ) -> torch.Tensor: + """ + Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates, + if the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + """ + boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) + return boxes + + def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: + """ + Perform PostProcessing on output masks. + """ + from model.visual.sam2.utils.misc import get_connected_components + + masks = masks.float() + input_masks = masks + mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image + try: + if self.max_hole_area > 0: + # Holes are those connected components in background with area <= self.fill_hole_area + # (background regions are those with mask scores <= self.mask_threshold) + labels, areas = get_connected_components( + mask_flat <= self.mask_threshold + ) + is_hole = (labels > 0) & (areas <= self.max_hole_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with a small positive mask score (10.0) to change them to foreground. + masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) + + if self.max_sprinkle_area > 0: + labels, areas = get_connected_components( + mask_flat > self.mask_threshold + ) + is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with negative mask score (-10.0) to change them to background. + masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) + except Exception as e: + # Skip the post-processing step if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. You can " + "still use SAM 2 and it's OK to ignore the error above, although some post-processing " + "functionality may be limited (which doesn't affect the results in most cases; see " + "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + masks = input_masks + + masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) + return masks diff --git a/avs.code/v1m.code/tools/remap_aural_ckpt_keys.py b/avs.code/v1m.code/tools/remap_aural_ckpt_keys.py new file mode 100644 index 0000000000000000000000000000000000000000..cbb8d6086a854b1b0ab011542eafd99f4bf8a3bf --- /dev/null +++ b/avs.code/v1m.code/tools/remap_aural_ckpt_keys.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 +""" +Remap legacy checkpoint keys: rename audio_prompter.* to the current AuralFuser layout (aural_fuser.*), +and drop duplicate weights under training_layers / finetuning_layers. + +Usage: + python tools/remap_aural_ckpt_keys.py /path/to/model.pth [--in-place] [--no-backup] + +By default writes _remapped.pth; --in-place overwrites the input (after a .bak backup unless --no-backup). +""" +from __future__ import annotations + +import argparse +import shutil +from pathlib import Path + +import torch + +# Matches AuralFuser ModuleList names (old train_* indices start at 1; new indices are 0-based). +_REPLACEMENTS: list[tuple[str, str]] = [ + ("train_f_patch_embed1", "patch_embeds.0"), + ("train_f_patch_embed2", "patch_embeds.1"), + ("train_f_patch_embed3", "patch_embeds.2"), + ("train_f_a_block1", "fusion_modules.0"), + ("train_f_a_block2", "fusion_modules.1"), + ("train_f_a_block3", "fusion_modules.2"), + ("train_f_block1", "f_blocks.0"), + ("train_f_block2", "f_blocks.1"), + ("train_f_block3", "f_blocks.2"), + ("train_a_block1", "a_blocks.0"), + ("train_a_block2", "a_blocks.1"), + ("train_a_block3", "a_blocks.2"), + ("train_smooth1", "smooth_convs.0"), + ("train_smooth2", "smooth_convs.1"), +] + + +def remap_state_dict(sd: dict) -> dict: + out: dict = {} + dropped = 0 + for k, v in sd.items(): + if k.startswith("audio_prompter."): + if ".training_layers." in k or ".finetuning_layers." in k: + dropped += 1 + continue + nk = k.replace("audio_prompter.", "aural_fuser.", 1) + for old, new in _REPLACEMENTS: + nk = nk.replace(old, new) + out[nk] = v + else: + out[k] = v + if dropped: + print(f"Dropped duplicate keys: {dropped} (training_layers / finetuning_layers)") + return out + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("ckpt", type=Path, help="Input .pth (full-model state_dict)") + ap.add_argument( + "-o", "--output", type=Path, default=None, + help="Output path; default _remapped.pth", + ) + ap.add_argument("--in-place", action="store_true", help="Overwrite input file") + ap.add_argument("--no-backup", action="store_true", help="Skip .bak when using --in-place") + args = ap.parse_args() + + ckpt_path: Path = args.ckpt.resolve() + if not ckpt_path.is_file(): + raise SystemExit(f"File not found: {ckpt_path}") + + print(f"Loading: {ckpt_path}") + sd = torch.load(ckpt_path, map_location="cpu") + if not isinstance(sd, dict): + raise SystemExit("Expected top-level checkpoint to be a state_dict dict") + + n_old_ap = sum(1 for k in sd if k.startswith("audio_prompter.")) + if n_old_ap == 0: + print("Warning: no audio_prompter.* keys found; checkpoint may already be remapped.") + + new_sd = remap_state_dict(sd) + n_af = sum(1 for k in new_sd if k.startswith("aural_fuser.")) + print(f"aural_fuser key count: {n_af}") + + if args.in_place: + out = ckpt_path + if not args.no_backup: + bak = ckpt_path.with_suffix(ckpt_path.suffix + ".bak") + print(f"Backup -> {bak}") + shutil.copy2(ckpt_path, bak) + else: + out = args.output or ckpt_path.with_name(ckpt_path.stem + "_remapped.pth") + + torch.save(new_sd, out) + print(f"Saved: {out} ({len(new_sd)} tensor keys)") + + +if __name__ == "__main__": + main() diff --git a/avs.code/v1m.code/trainer/train.py b/avs.code/v1m.code/trainer/train.py new file mode 100644 index 0000000000000000000000000000000000000000..b6dea236f976d28ad72711ee50bbefa21522334c --- /dev/null +++ b/avs.code/v1m.code/trainer/train.py @@ -0,0 +1,157 @@ +"""Training and validation loop for the AV segmentation model.""" +import numpy +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + + +class Trainer: + """Wraps train/valid steps with optional loss, metrics, and logging.""" + + def __init__(self, hyp_param, loss, tensorboard, metrics): + self.param = hyp_param + self.loss = loss + self.tensorboard = tensorboard + self.metrics = metrics + from loss.training.contrastive_learning import ContrastLoss + self.cl = ContrastLoss(self.param) + + @torch.no_grad() + def valid(self, epoch, dataloader, model, process=''): + """Evaluate foreground IoU / F-score. `process` selects SAM multimask decoding (see branch below).""" + if not isinstance(dataloader, DataLoader): + raise TypeError( + "valid() expects a torch.utils.data.DataLoader (do not pass iter(dataloader) first)." + ) + self.metrics['foreground_iou'].reset() + self.metrics['foreground_f-score'].reset() + dataloader_length = len(dataloader) + tbar = range(dataloader_length) + tbar = tqdm(tbar, ncols=135) if self.param.local_rank <= 0 else tbar + iou_pool = [None] * self.param.gpus + fscore_pool = [None] * self.param.gpus + + data_iter = iter(dataloader) + for batch_index in tbar: + items = next(data_iter) + frame, spect, label, prompt_dicts = items['frame'], items['spectrogram'], items['label'], items['prompts'] + + frame = torch.flatten(frame, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) + spect = torch.flatten(spect, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) + label = torch.flatten(label, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) + + with torch.autocast("cuda", dtype=torch.bfloat16): + outputs, _ = model.module(frame, spect, prompt_dicts, sam_process=True) + logits = torch.cat([torch.cat(i['multistep_pred_multimasks_high_res']) for i in outputs]) + ious_scores = torch.cat([torch.cat(i['multistep_pred_ious']) for i in outputs]) + occ_scores = torch.cat([torch.cat(i['multistep_object_score_logits']) for i in outputs]) + # process: '' = first multimask; iou_select = argmax IoU head; iou_occ_select = + objectness gate + if process == 'iou_select': + ious_scores = torch.argmax(ious_scores, dim=1) + logits = logits[torch.arange(0, frame.shape[0]), ious_scores, ...] + elif process == 'iou_occ_select': + ious_scores = torch.argmax(ious_scores, dim=1) + logits = logits[torch.arange(0, frame.shape[0]), ious_scores, ...] + logits[occ_scores.squeeze() < 0, ...] = 0. + else: + logits = logits[:, 0, ...] + + masks = logits > 0. + foreground_iou_rank = self.metrics['foreground_iou'].calculate_iou(masks.squeeze().long(), + label.squeeze().long(), + get_entire_list=True) + + foreground_f_score_rank = self.metrics['foreground_f-score'].calculate_f_score(logits.squeeze(), + label.squeeze(), + get_entire_list=True) + torch.distributed.all_gather_object(iou_pool, foreground_iou_rank) + torch.distributed.all_gather_object(fscore_pool, foreground_f_score_rank) + foreground_iou = sum([i['foreground_iou'][0].cpu() for i in iou_pool]) / sum( + [i['foreground_iou'][1] for i in iou_pool]) + foreground_f_score = sum([i['foreground_f-score'][0] for i in fscore_pool]) / sum( + [i['foreground_f-score'][1] for i in fscore_pool]) + + if self.param.local_rank <= 0: + tbar.set_description('epoch {} | valid.f_iou {}, valid.f_f-score {}'.format(epoch, + numpy.round( + foreground_iou.cpu().numpy(), + 5), + numpy.round( + foreground_f_score, + 5))) + torch.cuda.empty_cache() + + final_iou = foreground_iou + final_fscore = foreground_f_score + if self.param.local_rank <= 0 and self.tensorboard is not None: + self.tensorboard.upload_wandb_info({"valid.f_iou/{}".format(process): final_iou, + "valid.f_f-score/{}".format(process): final_fscore}) + + def _to_float(x): + if isinstance(x, torch.Tensor): + return float(x.detach().cpu().item()) + return float(x) + + return numpy.round(_to_float(final_iou), 5), numpy.round(_to_float(final_fscore), 5) + + def train(self, epoch, dataloader, model, optimiser): + """One epoch: SAM frozen, AuralFuser + heads trained with composite loss + contrastive term.""" + if not isinstance(dataloader, DataLoader): + raise TypeError( + "train() expects a torch.utils.data.DataLoader (do not pass iter(dataloader) first)." + ) + self.metrics['foreground_iou'].reset() + self.metrics['foreground_f-score'].reset() + + dataloader_length = len(dataloader) + tbar = range(dataloader_length) + tbar = tqdm(tbar, ncols=135) if self.param.local_rank <= 0 else tbar + + data_iter = iter(dataloader) + for batch_index in tbar: + current_index = dataloader_length * epoch + batch_index + items = next(data_iter) + + frame, spect, label, prompt_dicts = items['frame'], items['spectrogram'], items['label'], items['prompts'] + frame = torch.flatten(frame, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) + spect = torch.flatten(spect, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) + label = torch.flatten(label, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) + with torch.autocast("cuda", dtype=torch.bfloat16): + outputs, proj_feats = model(frame, spect, prompt_dicts, sam_process=False) + + loss_dict = self.loss(outputs, label.unsqueeze(1)) + cl_loss = self.cl(proj_feats, outputs, label) + + optimiser.zero_grad() + (loss_dict['core_loss'] + cl_loss).backward() + optimiser.step() + + current_lr = self.param.lr * (1 - current_index / (dataloader_length * self.param.epochs)) ** 0.9 + for params_lr in optimiser.param_groups: + names = params_lr.get("name", []) + if names and any("vgg" in n for n in names): + params_lr['lr'] = current_lr * 0.1 + else: + params_lr['lr'] = current_lr + + if self.param.local_rank <= 0: + logits = torch.cat([i['multistep_pred_multimasks_high_res'][0] for i in outputs]) + foreground_iou = self.metrics['foreground_iou'].calculate_iou((logits > 0)[:, 0, ...].long(), + label.long()) + + self.tensorboard.upload_wandb_info({"loss": loss_dict['core_loss'].item(), "f_iou": foreground_iou.item(), + "lr": optimiser.param_groups[0]['lr'], + "loss_dice": loss_dict['loss_dice'], + "loss_focal": loss_dict['loss_mask'], + "loss_contras": cl_loss.item()}) + tbar.set_description('epoch {} | loss {}, f_iou {}'.format(epoch, loss_dict['core_loss'].item(), + foreground_iou.item())) + ''' + if batch_index % 200 == 0: + pred_mask = (logits > 0)[:, 0, ...].long() + n_vis = min(4, frame.shape[0], pred_mask.shape[0], label.shape[0]) + self.tensorboard.upload_wandb_image( + frame[:n_vis], pred_mask[:n_vis], label[:n_vis].long() + ) + ''' + return diff --git a/avs.code/v1m.code/utils/data_utils.py b/avs.code/v1m.code/utils/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7e7a98f8ec73e6e5dafd1e395b48a98575e5afb1 --- /dev/null +++ b/avs.code/v1m.code/utils/data_utils.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch + +from PIL import Image as PILImage + + +class BatchedVideoMetaData: + """ + This class represents metadata about a batch of videos. + Attributes: + unique_objects_identifier: A tensor of shape Bx3 containing unique identifiers for each object in the batch. Index consists of (video_id, obj_id, frame_id) + frame_orig_size: A tensor of shape Bx2 containing the original size of each frame in the batch. + """ + + unique_objects_identifier: torch.LongTensor + frame_orig_size: torch.LongTensor + + +class BatchedVideoDatapoint: + """ + This class represents a batch of videos with associated annotations and metadata. + Attributes: + img_batch: A [TxBxCxHxW] tensor containing the image data for each frame in the batch, where T is the number of frames per video, and B is the number of videos in the batch. + obj_to_frame_idx: A [TxOx2] tensor containing the image_batch index which the object belongs to. O is the number of objects in the batch. + masks: A [TxOxHxW] tensor containing binary masks for each object in the batch. + metadata: An instance of BatchedVideoMetaData containing metadata about the batch. + dict_key: A string key used to identify the batch. + """ + + img_batch: torch.FloatTensor + obj_to_frame_idx: torch.IntTensor + masks: torch.BoolTensor + metadata: BatchedVideoMetaData + + dict_key: str + + def pin_memory(self, device=None): + return self.apply(torch.Tensor.pin_memory, device=device) + + @property + def num_frames(self) -> int: + """ + Returns the number of frames per video. + """ + return self.batch_size[0] + + @property + def num_videos(self) -> int: + """ + Returns the number of videos in the batch. + """ + return self.img_batch.shape[1] + + @property + def flat_obj_to_img_idx(self) -> torch.IntTensor: + """ + Returns a flattened tensor containing the object to img index. + The flat index can be used to access a flattened img_batch of shape [(T*B)xCxHxW] + """ + frame_idx, video_idx = self.obj_to_frame_idx.unbind(dim=-1) + flat_idx = video_idx * self.num_frames + frame_idx + return flat_idx + + @property + def flat_img_batch(self) -> torch.FloatTensor: + """ + Returns a flattened img_batch_tensor of shape [(B*T)xCxHxW] + """ + + return self.img_batch.transpose(0, 1).flatten(0, 1) + + +@dataclass +class Object: + # Id of the object in the media + object_id: int + # Index of the frame in the media (0 if single image) + frame_index: int + segment: Union[torch.Tensor, dict] # RLE dict or binary mask + + +@dataclass +class Frame: + data: Union[torch.Tensor, PILImage.Image] + objects: List[Object] + + +@dataclass +class VideoDatapoint: + """Refers to an image/video and all its annotations""" + + frames: List[Frame] + video_id: int + size: Tuple[int, int] + + +def collate_fn( + batch: List[VideoDatapoint], + dict_key, +) -> BatchedVideoDatapoint: + """ + Args: + batch: A list of VideoDatapoint instances. + dict_key (str): A string key used to identify the batch. + """ + img_batch = [] + for video in batch: + img_batch += [torch.stack([frame.data for frame in video.frames], dim=0)] + + img_batch = torch.stack(img_batch, dim=0).permute((1, 0, 2, 3, 4)) + T = img_batch.shape[0] + # Prepare data structures for sequential processing. Per-frame processing but batched across videos. + step_t_objects_identifier = [[] for _ in range(T)] + step_t_frame_orig_size = [[] for _ in range(T)] + + step_t_masks = [[] for _ in range(T)] + step_t_obj_to_frame_idx = [ + [] for _ in range(T) + ] # List to store frame indices for each time step + + for video_idx, video in enumerate(batch): + orig_video_id = video.video_id + orig_frame_size = video.size + for t, frame in enumerate(video.frames): + objects = frame.objects + for obj in objects: + orig_obj_id = obj.object_id + orig_frame_idx = obj.frame_index + step_t_obj_to_frame_idx[t].append( + torch.tensor([t, video_idx], dtype=torch.int) + ) + step_t_masks[t].append(obj.segment.to(torch.bool)) + step_t_objects_identifier[t].append( + torch.tensor([orig_video_id, orig_obj_id, orig_frame_idx]) + ) + step_t_frame_orig_size[t].append(torch.tensor(orig_frame_size)) + + obj_to_frame_idx = torch.stack( + [ + torch.stack(obj_to_frame_idx, dim=0) + for obj_to_frame_idx in step_t_obj_to_frame_idx + ], + dim=0, + ) + masks = torch.stack([torch.stack(masks, dim=0) for masks in step_t_masks], dim=0) + objects_identifier = torch.stack( + [torch.stack(id, dim=0) for id in step_t_objects_identifier], dim=0 + ) + frame_orig_size = torch.stack( + [torch.stack(id, dim=0) for id in step_t_frame_orig_size], dim=0 + ) + return BatchedVideoDatapoint( + img_batch=img_batch, + obj_to_frame_idx=obj_to_frame_idx, + masks=masks, + metadata=BatchedVideoMetaData( + unique_objects_identifier=objects_identifier, + frame_orig_size=frame_orig_size, + ), + dict_key=dict_key, + batch_size=[T], + ) diff --git a/avs.code/v1m.code/utils/foreground_fscore.py b/avs.code/v1m.code/utils/foreground_fscore.py new file mode 100644 index 0000000000000000000000000000000000000000..ea20b84d2304ca0bd9981fd1a3c254111e3d0ac4 --- /dev/null +++ b/avs.code/v1m.code/utils/foreground_fscore.py @@ -0,0 +1,90 @@ +import numpy +import torch + + +class AverageMeter: + def __init__(self, *keys): + self.__data = dict() + for k in keys: + self.__data[k] = [0.0, 0] + + def add(self, dict): + for k, v in dict.items(): + self.__data[k][0] += v + self.__data[k][1] += 1 + + def get(self, *keys): + if len(keys) == 1: + return self.__data[keys[0]][0] / self.__data[keys[0]][1] + else: + v_list = [self.__data[k][0] / self.__data[k][1] for k in keys] + return tuple(v_list) + + def get_entire_dict_for_ddp_calculation(self): + return self.__data + + def pop(self, key=None): + if key is None: + for k in self.__data.keys(): + self.__data[k] = [0.0, 0] + else: + v = self.get(key) + self.__data[key] = [0.0, 0] + return v + + +class ForegroundFScore(AverageMeter): + def __init__(self, rank): + self.local_rank = rank + super(ForegroundFScore, self).__init__('foreground_f-score') + + def _eval_pr(self, y_pred, y, num, cuda_flag=True): + if cuda_flag: + prec, recall = torch.zeros(num).cuda(self.local_rank), torch.zeros(num).cuda(self.local_rank) + thlist = torch.linspace(0, 1 - 1e-10, num).cuda(self.local_rank) + else: + prec, recall = torch.zeros(num), torch.zeros(num) + thlist = torch.linspace(0, 1 - 1e-10, num) + for i in range(num): + y_temp = (y_pred >= thlist[i]).float() + tp = (y_temp * y).sum() + prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() + 1e-20) + return prec, recall + + def calculate_f_score(self, pred, gt, pr_num=255, get_entire_list=False): + + r""" + param: + pred: size [N x H x W] + gt: size [N x H x W] + output: + iou: size [1] (size_average=True) or [N] (size_average=False) + """ + # print('=> eval [FMeasure]..') + pred = torch.sigmoid(pred) # =======================================[important] + N = pred.size(0) + beta2 = 0.3 + avg_f, img_num = 0.0, 0 + score = torch.zeros(pr_num) + # fLog = open(os.path.join(measure_path, 'FMeasure.txt'), 'w') + # print("{} videos in this batch".format(N)) + + for img_id in range(N): + # examples with totally black GTs are out of consideration + if torch.mean(gt[img_id].float()) == 0.0: + continue + prec, recall = self._eval_pr(pred[img_id], gt[img_id], pr_num) + f_score = (1 + beta2) * prec * recall / (beta2 * prec + recall) + f_score[f_score != f_score] = 0 # for Nan + avg_f += f_score + img_num += 1 + score = avg_f / img_num + # print('score: ', score) + # fLog.close() + self.add({'foreground_f-score': score.max().item()}) + return self.get('foreground_f-score') if not get_entire_list else self.get_entire_dict_for_ddp_calculation() + + def reset(self,): + super(ForegroundFScore, self).__init__('foreground_f-score') + + diff --git a/avs.code/v1m.code/utils/foreground_iou.py b/avs.code/v1m.code/utils/foreground_iou.py new file mode 100644 index 0000000000000000000000000000000000000000..e01eeb081eee8ebfa1fcb6618d05b9d57c02f817 --- /dev/null +++ b/avs.code/v1m.code/utils/foreground_iou.py @@ -0,0 +1,69 @@ +import numpy +import torch + + +class AverageMeter: + def __init__(self, *keys): + self.__data = dict() + for k in keys: + self.__data[k] = [0.0, 0] + + def add(self, dict): + for k, v in dict.items(): + self.__data[k][0] += v + self.__data[k][1] += 1 + + def get(self, *keys): + if len(keys) == 1: + return self.__data[keys[0]][0] / self.__data[keys[0]][1] + else: + v_list = [self.__data[k][0] / self.__data[k][1] for k in keys] + return tuple(v_list) + + def get_entire_dict_for_ddp_calculation(self): + return self.__data + + def pop(self, key=None): + if key is None: + for k in self.__data.keys(): + self.__data[k] = [0.0, 0] + else: + v = self.get(key) + self.__data[key] = [0.0, 0] + return v + + +class ForegroundIoU(AverageMeter): + def __init__(self): + super(ForegroundIoU, self).__init__('foreground_iou') + + def calculate_iou(self, pred, target, eps=1e-7, get_entire_list=False): + r""" + param (both hard mask): + pred: size [N x H x W], type: int + target: size [N x H x W], type: int + output: + iou: size [1] (size_average=True) or [N] (size_average=False) + """ + assert len(pred.shape) == 3 and pred.shape == target.shape, 'shape mismatch.' + assert pred.dtype is torch.long and target.dtype is torch.long, 'type mismatch.' + + N = pred.size(0) + num_pixels = pred.size(-1) * pred.size(-2) + no_obj_flag = (target.sum(2).sum(1) == 0) + + inter = (pred * target).sum(2).sum(1) + union = torch.max(pred, target).sum(2).sum(1) + + inter_no_obj = ((1 - target) * (1 - pred)).sum(2).sum(1) + inter[no_obj_flag] = inter_no_obj[no_obj_flag] + union[no_obj_flag] = num_pixels + + iou = torch.sum(inter / (union+eps)) / N + + self.add({'foreground_iou': iou}) + return self.get('foreground_iou') if not get_entire_list else self.get_entire_dict_for_ddp_calculation() + + def reset(self,): + super(ForegroundIoU, self).__init__('foreground_iou') + diff --git a/avs.code/v1m.code/utils/iou.py b/avs.code/v1m.code/utils/iou.py new file mode 100644 index 0000000000000000000000000000000000000000..211488b780887a8efd84361bafc6b09bfad4c345 --- /dev/null +++ b/avs.code/v1m.code/utils/iou.py @@ -0,0 +1,76 @@ +import torch +import numpy + + +class BinaryMIoU(object): + def __init__(self, ignore_index): + self.num_classes = 2 + self.ignore_index = ignore_index + self.inter, self.union = 0, 0 + self.correct, self.label = 0, 0 + self.iou = numpy.array([0 for _ in range(self.num_classes)]) + self.acc = 0.0 + + def get_metric_results(self, curr_correct_, curr_label_, curr_inter_, curr_union_): + # calculates the overall miou and acc + self.correct = self.correct + curr_correct_ + self.label = self.label + curr_label_ + self.inter = self.inter + curr_inter_ + self.union = self.union + curr_union_ + self.acc = 1.0 * self.correct / (numpy.spacing(1) + self.label) + self.iou = 1.0 * self.inter / (numpy.spacing(1) + self.union) + return numpy.round(self.iou, 4), numpy.round(self.acc, 4) + # if class_list is None: + # return numpy.round(self.iou.mean().item(), 4), \ + # numpy.round(self.acc, 4) + # else: + # return numpy.round(self.iou[class_list].mean().item(), 4), \ + # numpy.round(self.acc, 4) + + @staticmethod + def get_current_image_results(curr_correct_, curr_label_, curr_inter_, curr_union_): + curr_acc = 1.0 * curr_correct_ / (numpy.spacing(1) + curr_label_) + curr_iou = 1.0 * curr_inter_ / (numpy.spacing(1) + curr_union_) + return curr_iou, curr_acc + + def __call__(self, x, y): + curr_correct, curr_label, curr_inter, curr_union = self.calculate_current_sample(x, y) + return (self.get_metric_results(curr_correct, curr_label, curr_inter, curr_union), + self.get_current_image_results(curr_correct, curr_label, curr_inter, curr_union)) + + def calculate_current_sample(self, output, target): + # output => BxCxHxW (logits) + # target => Bx1xHxW + target[target == self.ignore_index] = -1 + correct, labeled = self.batch_pix_accuracy(output.data, target) + inter, union = self.batch_intersection_union(output.data, target, self.num_classes) + return [numpy.round(correct, 5), numpy.round(labeled, 5), numpy.round(inter, 5), numpy.round(union, 5)] + + @ staticmethod + def batch_pix_accuracy(predict, target): + # _, predict = torch.max(output, 1) + + predict = predict.int() + 1 + target = target.int() + 1 + + pixel_labeled = (target > 0).sum() + pixel_correct = ((predict == target) * (target > 0)).sum() + assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled" + return pixel_correct.cpu().numpy(), pixel_labeled.cpu().numpy() + + @ staticmethod + def batch_intersection_union(predict, target, num_class): + # _, predict = torch.max(output, 1) + predict = predict + 1 + target = target + 1 + + predict = predict * (target > 0).long() + intersection = predict * (predict == target).long() + + area_inter = torch.histc(intersection.float(), bins=num_class, max=num_class, min=1) + area_pred = torch.histc(predict.float(), bins=num_class, max=num_class, min=1) + area_lab = torch.histc(target.float(), bins=num_class, max=num_class, min=1) + area_union = area_pred + area_lab - area_inter + assert (area_inter <= area_union).all(), "Intersection area should be smaller than Union area" + return area_inter.cpu().numpy(), area_union.cpu().numpy() + diff --git a/avs.code/v1m.code/utils/misc.py b/avs.code/v1m.code/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..5eb9d66c31a4b9209b81a5b615386d29f246135c --- /dev/null +++ b/avs.code/v1m.code/utils/misc.py @@ -0,0 +1,350 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import warnings +from threading import Thread + +import numpy as np +import torch +from PIL import Image +from tqdm import tqdm + + +def get_sdpa_settings(): + if torch.cuda.is_available(): + old_gpu = torch.cuda.get_device_properties(0).major < 7 + # only use Flash Attention on Ampere (8.0) or newer GPUs + use_flash_attn = torch.cuda.get_device_properties(0).major >= 8 + if not use_flash_attn: + warnings.warn( + "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.", + category=UserWarning, + stacklevel=2, + ) + # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only + # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases) + pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2]) + if pytorch_version < (2, 2): + warnings.warn( + f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. " + "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).", + category=UserWarning, + stacklevel=2, + ) + math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn + else: + old_gpu = True + use_flash_attn = False + math_kernel_on = True + + return old_gpu, use_flash_attn, math_kernel_on + + +def get_connected_components(mask): + """ + Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W). + + Inputs: + - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is + background. + + Outputs: + - labels: A tensor of shape (N, 1, H, W) containing the connected component labels + for foreground pixels and 0 for background pixels. + - counts: A tensor of shape (N, 1, H, W) containing the area of the connected + components for foreground pixels and 0 for background pixels. + """ + from sam2 import _C + + return _C.get_connected_componnets(mask.to(torch.uint8).contiguous()) + + +def mask_to_box(masks: torch.Tensor): + """ + compute bounding box given an input mask + + Inputs: + - masks: [B, 1, H, W] masks, dtype=torch.Tensor + + Returns: + - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor + """ + B, _, h, w = masks.shape + device = masks.device + xs = torch.arange(w, device=device, dtype=torch.int32) + ys = torch.arange(h, device=device, dtype=torch.int32) + grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy") + grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w) + grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w) + min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1) + max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1) + min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1) + max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1) + bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1) + + return bbox_coords + + +def _load_img_as_tensor(img_path, image_size): + img_pil = Image.open(img_path) + img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) + if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images + img_np = img_np / 255.0 + else: + raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}") + img = torch.from_numpy(img_np).permute(2, 0, 1) + video_width, video_height = img_pil.size # the original video size + return img, video_height, video_width + + +class AsyncVideoFrameLoader: + """ + A list of video frames to be load asynchronously without blocking session start. + """ + + def __init__( + self, + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ): + self.img_paths = img_paths + self.image_size = image_size + self.offload_video_to_cpu = offload_video_to_cpu + self.img_mean = img_mean + self.img_std = img_std + # items in `self.images` will be loaded asynchronously + self.images = [None] * len(img_paths) + # catch and raise any exceptions in the async loading thread + self.exception = None + # video_height and video_width be filled when loading the first image + self.video_height = None + self.video_width = None + self.compute_device = compute_device + + # load the first frame to fill video_height and video_width and also + # to cache it (since it's most likely where the user will click) + self.__getitem__(0) + + # load the rest of frames asynchronously without blocking the session start + def _load_frames(): + try: + for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"): + self.__getitem__(n) + except Exception as e: + self.exception = e + + self.thread = Thread(target=_load_frames, daemon=True) + self.thread.start() + + def __getitem__(self, index): + if self.exception is not None: + raise RuntimeError("Failure in frame loading thread") from self.exception + + img = self.images[index] + if img is not None: + return img + + img, video_height, video_width = _load_img_as_tensor( + self.img_paths[index], self.image_size + ) + self.video_height = video_height + self.video_width = video_width + # normalize by mean and std + img -= self.img_mean + img /= self.img_std + if not self.offload_video_to_cpu: + img = img.to(self.compute_device, non_blocking=True) + self.images[index] = img + return img + + def __len__(self): + return len(self.images) + + +def load_video_frames( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + compute_device=torch.device("cuda"), +): + """ + Load the video frames from video_path. The frames are resized to image_size as in + the model and are loaded to GPU if offload_video_to_cpu=False. This is used by the demo. + """ + is_bytes = isinstance(video_path, bytes) + is_str = isinstance(video_path, str) + is_mp4_path = is_str and os.path.splitext(video_path)[-1] in [".mp4", ".MP4"] + if is_bytes or is_mp4_path: + return load_video_frames_from_video_file( + video_path=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + compute_device=compute_device, + ) + elif is_str and os.path.isdir(video_path): + return load_video_frames_from_jpg_images( + video_path=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + async_loading_frames=async_loading_frames, + compute_device=compute_device, + ) + else: + raise NotImplementedError( + "Only MP4 video and JPEG folder are supported at this moment" + ) + + +def load_video_frames_from_jpg_images( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + compute_device=torch.device("cuda"), +): + """ + Load the video frames from a directory of JPEG files (".jpg" format). + + The frames are resized to image_size x image_size and are loaded to GPU if + `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. + + You can load a frame asynchronously by setting `async_loading_frames` to `True`. + """ + if isinstance(video_path, str) and os.path.isdir(video_path): + jpg_folder = video_path + else: + raise NotImplementedError( + "Only JPEG frames are supported at this moment. For video files, you may use " + "ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n" + "```\n" + "ffmpeg -i .mp4 -q:v 2 -start_number 0 /'%05d.jpg'\n" + "```\n" + "where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks " + "ffmpeg to start the JPEG file from 00000.jpg." + ) + + frame_names = [ + p + for p in os.listdir(jpg_folder) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + ] + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + num_frames = len(frame_names) + if num_frames == 0: + raise RuntimeError(f"no images found in {jpg_folder}") + img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + + if async_loading_frames: + lazy_images = AsyncVideoFrameLoader( + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ) + return lazy_images, lazy_images.video_height, lazy_images.video_width + + images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) + for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): + images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) + if not offload_video_to_cpu: + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def load_video_frames_from_video_file( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + compute_device=torch.device("cuda"), +): + """Load the video frames from a video file.""" + import decord + + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + # Get the original video height and width + decord.bridge.set_bridge("torch") + video_height, video_width, _ = decord.VideoReader(video_path).next().shape + # Iterate over all frames in the video + images = [] + for frame in decord.VideoReader(video_path, width=image_size, height=image_size): + images.append(frame.permute(2, 0, 1)) + + images = torch.stack(images, dim=0).float() / 255.0 + if not offload_video_to_cpu: + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def fill_holes_in_mask_scores(mask, max_area): + """ + A post processor to fill small holes in mask scores with area under `max_area`. + """ + # Holes are those connected components in background with area <= self.max_area + # (background regions are those with mask scores <= 0) + assert max_area > 0, "max_area must be positive" + + input_mask = mask + try: + labels, areas = get_connected_components(mask <= 0) + is_hole = (labels > 0) & (areas <= max_area) + # We fill holes with a small positive mask score (0.1) to change them to foreground. + mask = torch.where(is_hole, 0.1, mask) + except Exception as e: + # Skip the post-processing step on removing small holes if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. You can " + "still use SAM 2 and it's OK to ignore the error above, although some post-processing " + "functionality may be limited (which doesn't affect the results in most cases; see " + "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + mask = input_mask + + return mask + + +def concat_points(old_point_inputs, new_points, new_labels): + """Add new points and labels to previous point inputs (add at the end).""" + if old_point_inputs is None: + points, labels = new_points, new_labels + else: + points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1) + labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1) + + return {"point_coords": points, "point_labels": labels} + diff --git a/avs.code/v1m.code/utils/tensorboard.py b/avs.code/v1m.code/utils/tensorboard.py new file mode 100644 index 0000000000000000000000000000000000000000..3519131463bb3f279eb97b2b44974a402482af42 --- /dev/null +++ b/avs.code/v1m.code/utils/tensorboard.py @@ -0,0 +1,135 @@ +import os + +import PIL +import matplotlib.pyplot as plt +import numpy +import torch +import torchvision +import wandb + +# from utils.visualize import show_img + + +color_map = {"background": (0, 0, 0), "longitudinal": (128, 0, 0), "pothole": (0, 128, 0), + "alligator": (128, 128, 0), "transverse": (128, 0, 128), "ignore": (255, 255, 255)} + + +class Tensorboard: + def __init__(self, config): + if config.get('wandb_online', False): + key = config.get('wandb_key') or os.environ.get('WANDB_API_KEY', '') + if key: + os.environ['WANDB_API_KEY'] = key + wandb.login(key=key, relogin=False) + self.tensor_board = wandb.init(project=config['proj_name'], name=config['experiment_name'], + config=config, settings=wandb.Settings(code_dir="")) + else: + os.environ.setdefault("WANDB_MODE", "disabled") + self.tensor_board = wandb.init(project=config['proj_name'], name=config['experiment_name'], + config=config, mode="disabled", + settings=wandb.Settings(code_dir="")) + + self._log_images = bool(config.get('wandb_online', False)) + + self.restore_transform = torchvision.transforms.Compose([ + DeNormalize(config['image_mean'], config['image_std']), + torchvision.transforms.ToPILImage()]) + + def upload_wandb_info(self, info_dict): + for i, info in enumerate(info_dict): + self.tensor_board.log({info: info_dict[info]}) + return + + + def upload_wandb_image(self, frames, pseudo_label_from_pred, pseudo_label_from_sam, img_number=4): + if not self._log_images: + return + + def _batched_rgb(t): + """[N,C,H,W] or [C,H,W] float tensor on CPU.""" + if not isinstance(t, torch.Tensor): + t = torch.as_tensor(t) + t = t.detach().cpu().float() + if t.dim() == 3: + return t.unsqueeze(0) + if t.dim() == 4: + return t + raise ValueError("frames must be [C,H,W] or [N,C,H,W], got shape {}".format(tuple(t.shape))) + + def _batched_mask(t): + """[N,H,W] or [N,1,H,W] or [H,W].""" + if not isinstance(t, torch.Tensor): + t = torch.as_tensor(t) + t = t.detach().cpu().float() + while t.dim() > 3: + t = t.squeeze(1) + if t.dim() == 2: + t = t.unsqueeze(0) + if t.dim() != 3: + raise ValueError("masks must be [H,W], [N,H,W] or [N,1,H,W], got shape {}".format(tuple(t.shape))) + return t + + frames = _batched_rgb(frames) + pseudo_label_from_pred = _batched_mask(pseudo_label_from_pred) + pseudo_label_from_sam = _batched_mask(pseudo_label_from_sam) + + n = min(frames.shape[0], pseudo_label_from_pred.shape[0], pseudo_label_from_sam.shape[0], img_number) + frames = frames[:n] + pseudo_label_from_pred = pseudo_label_from_pred[:n] + pseudo_label_from_sam = pseudo_label_from_sam[:n] + + pseudo_label_from_sam = pseudo_label_from_sam.clone() + pseudo_label_from_pred = pseudo_label_from_pred.clone() + pseudo_label_from_sam[pseudo_label_from_sam == 255.] = 0.5 + pseudo_label_from_pred[pseudo_label_from_pred == 255.] = 0.5 + + denorm = self.restore_transform.transforms[0] + image_list = [] + label_list = [] + logits_list = [] + for i in range(n): + fi = frames[i].clone() + if fi.shape[0] == 3: + denorm(fi) + fi.clamp_(0.0, 1.0) + image_list.append(wandb.Image(fi, caption="id {}".format(str(i)))) + # wandb.Image expects torch tensors as [C, H, W] (it permutes CHW→HWC) + ms = pseudo_label_from_sam[i].squeeze() + mp = pseudo_label_from_pred[i].squeeze() + if ms.dim() == 2: + ms = ms.unsqueeze(0) + if mp.dim() == 2: + mp = mp.unsqueeze(0) + label_list.append(wandb.Image(ms, caption="id {}".format(str(i)))) + logits_list.append(wandb.Image(mp, caption="id {}".format(str(i)))) + + self.tensor_board.log({"image": image_list, "label": label_list, "logits": logits_list}) + + def de_normalize(self, image): + return [self.restore_transform(i.detach().cpu()) if (isinstance(i, torch.Tensor) and len(i.shape) == 3) + else colorize_mask(i.detach().cpu().numpy(), self.palette) + for i in image] + + def finish(self): + self.tensor_board.finish() + + +class DeNormalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, tensor): + for t, m, s in zip(tensor, self.mean, self.std): + t.mul_(s).add_(m) + return tensor + + +def colorize_mask(mask, palette): + zero_pad = 256 * 3 - len(palette) + for i in range(zero_pad): + palette.append(0) + # palette[-6:-3] = [183, 65, 14] + new_mask = PIL.Image.fromarray(mask.astype(numpy.uint8)).convert('P') + new_mask.putpalette(palette) + return new_mask diff --git a/avs.code/v1m.code/utils/utils.py b/avs.code/v1m.code/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e72f27a7e2be77cea271001230195ef79f685351 --- /dev/null +++ b/avs.code/v1m.code/utils/utils.py @@ -0,0 +1,119 @@ +"""Optimizer helpers: split learning rates for AuralFuser train_* vs VGG backbone.""" +import torch +import copy +from typing import List, Dict, Set, Any + + +def manipulate_params(cfg, model): + weight_decay_norm = 0 + weight_decay_embed = 0 + defaults = {} + defaults["lr"] = cfg.lr + defaults["weight_decay"] = cfg.weight_decay + + norm_module_types = ( + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + torch.nn.SyncBatchNorm, + torch.nn.GroupNorm, + torch.nn.InstanceNorm1d, + torch.nn.InstanceNorm2d, + torch.nn.InstanceNorm3d, + torch.nn.LayerNorm, + torch.nn.LocalResponseNorm, + ) + + params_training: List[Dict[str, Any]] = [] + params_finetuning: List[Dict[str, Any]] = [] + memo: Set[torch.nn.parameter.Parameter] = set() + + train_prefixes = ( + "patch_embeds", + "f_blocks", + "a_blocks", + "fusion_modules", + "smooth_convs", + "train_proj_v1", + "train_proj_a1", + ) + + for module_name, module in model.named_modules(): + for module_param_name, value in module.named_parameters(recurse=False): + if not value.requires_grad: + continue + # Avoid duplicating parameters + if value in memo: + continue + memo.add(value) + hyperparams = copy.copy(defaults) + if 'vgg' in module_name or 'vgg' in module_param_name: + hyperparams['lr'] *= 0.1 + params_finetuning.append({"params": [value], "name": [module_name], **hyperparams}) + elif ( + 'train' in module_name + or 'train' in module_param_name + or module_name.startswith(train_prefixes) + ): + if ( + "relative_position_bias_table" in module_param_name + or "pos_embed" in module_param_name + ): + hyperparams["weight_decay"] = 0.0 + if isinstance(module, norm_module_types): + hyperparams["weight_decay"] = 0.0 + if isinstance(module, torch.nn.Embedding): + hyperparams["weight_decay"] = 0.0 + params_training.append({"params": [value], "name": [module_name], **hyperparams}) + else: + print('undefined layer type.') + raise NotImplementedError + final_list = params_training + params_finetuning + assert len([p for p in model.parameters() if p.requires_grad]) == len(final_list), 'checksum confirmed not pass.' + return final_list + + +def group_weight(weight_group, module, weight_decay_value, lr): + group_decay = [] + group_no_decay = [] + norm_module_types = ( + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + torch.nn.SyncBatchNorm, + torch.nn.GroupNorm, + torch.nn.InstanceNorm1d, + torch.nn.InstanceNorm2d, + torch.nn.InstanceNorm3d, + torch.nn.LayerNorm, + torch.nn.LocalResponseNorm, + ) + + for m in module.modules(): + if isinstance(m, torch.nn.Linear): + group_decay.append(m.weight) + if m.bias is not None: + group_no_decay.append(m.bias) + elif isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d)): + group_decay.append(m.weight) + if m.bias is not None: + group_no_decay.append(m.bias) + elif isinstance(m, norm_module_types): + if m.weight is not None: + group_no_decay.append(m.weight) + if m.bias is not None: + group_no_decay.append(m.bias) + elif isinstance(m, torch.nn.Parameter): + group_no_decay.append(m) + elif isinstance(m, torch.nn.Embedding): + group_no_decay.append(m) + else: + print('undefined layer type find.') + raise NotImplementedError + + assert len(list(module.parameters())) == len(group_decay) + len( + group_no_decay) + weight_group.append(dict(params=group_decay, weight_deacy=weight_decay_value, lr=lr)) + weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr)) + return weight_group \ No newline at end of file diff --git a/avs.code/v1s.code/configs/__init__.py b/avs.code/v1s.code/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/avs.code/v1s.code/configs/auralfuser/architecture.yaml b/avs.code/v1s.code/configs/auralfuser/architecture.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ab4c3d06ca42335ce6bfc8064bbd5cfd44c8080a --- /dev/null +++ b/avs.code/v1s.code/configs/auralfuser/architecture.yaml @@ -0,0 +1,30 @@ +# @package _global_ + +aural_fuser: + patch_cfgs: + - [4, 4] + - [2, 2] + - [1, 1] + f_depths: [3, 6, 12] + block_kw: + dim: 256 + num_heads: 4 + mlp_ratio: 4 + qkv_bias: true + qk_scale: null + drop: 0.1 + attn_drop: 0.1 + drop_path: 0.0 + sr_ratio: 4 + linear: false + one_d_kw: + dim: 256 + num_heads: 4 + mlp_ratio: 4 + qkv_bias: true + qk_scale: null + drop: 0.1 + attn_drop: 0.1 + drop_path: 0.0 + sr_ratio: 4 + linear: false diff --git a/avs.code/v1s.code/configs/config.py b/avs.code/v1s.code/configs/config.py new file mode 100644 index 0000000000000000000000000000000000000000..3d75fa447314845c0671a494110cb2bedc2b8420 --- /dev/null +++ b/avs.code/v1s.code/configs/config.py @@ -0,0 +1,84 @@ +import os +import numpy +from easydict import EasyDict + +# v1m.code package root (parent of this `configs/` directory) +_CODE_ROOT = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +# workspace root (parent of avs.code) +_WORKSPACE_ROOT = os.path.dirname(os.path.dirname(_CODE_ROOT)) + +C = EasyDict() +config = C +cfg = C + +C.seed = 666 + +C.audio = EasyDict() +C.audio.FREEZE_AUDIO_EXTRACTOR = True +C.audio.PRETRAINED_VGGISH_MODEL_PATH = os.path.join(_WORKSPACE_ROOT, 'ckpts', 'vggish-10086976.pth') +C.audio.PREPROCESS_AUDIO_TO_LOG_MEL = False +C.audio.POSTPROCESS_LOG_MEL_WITH_PCA = False +C.train_vggish = False + +"""Root Directory Config""" +C.repo_name = 'AV' +C.root_dir = _CODE_ROOT + +"""Data Dir and Weight Dir""" +C.data_root_path = os.path.join(_WORKSPACE_ROOT, 'AVSBench') +C.backbone_weight = os.path.join(_WORKSPACE_ROOT, 'ckpts', 'sam_ckpts', 'sam2_hiera_large.pt') +C.sam_config_path = os.path.join('sam2', 'sam2_hiera_l.yaml') + +"""Network Config""" +C.fix_bias = True +C.bn_eps = 1e-5 +C.bn_momentum = 0.1 + +"""Image Config""" +C.num_classes = 2 + +C.image_mean = numpy.array([0.485, 0.456, 0.406]) +C.image_std = numpy.array([0.229, 0.224, 0.225]) + + +C.image_size = 1024 +C.image_embedding_size = int(C.image_size / 16) +C.avsbench_size = (224, 224) + +C.scale_list = [.5, .75, 1., 1.25, 1.5] +C.ignore_index = 255 + +"""Train Config""" +C.lr = 7.5e-5 +C.batch_size = 8 +C.energy_weight = .05 + +C.lr_power = 0.9 +C.momentum = 0.9 +C.weight_decay = 0.05 + +C.num_workers = 4 + +"""Display Config""" +C.record_info_iter = 20 +C.display_iter = 50 + +"""Wandb Config""" +# Paste your W&B API key here, or set the WANDB_API_KEY environment variable instead. +C.wandb_key = "" + +# Your project [work_space] name +C.proj_name = "AVS-final-report" + +C.experiment_name = "v1s-hiera-l" + + +# False = no wandb logging (see utils/tensorboard.py) +C.wandb_online = False + +"""Save Config""" +C.saved_dir = os.path.join(_WORKSPACE_ROOT, 'ckpts', C.experiment_name) + +import pathlib + +pathlib.Path(C.saved_dir).mkdir(parents=True, exist_ok=True) diff --git a/avs.code/v1s.code/configs/sam2/sam2_hiera_b+.yaml b/avs.code/v1s.code/configs/sam2/sam2_hiera_b+.yaml new file mode 100644 index 0000000000000000000000000000000000000000..52e0f10732134149f6a994be063d11fd7591c430 --- /dev/null +++ b/avs.code/v1s.code/configs/sam2/sam2_hiera_b+.yaml @@ -0,0 +1,114 @@ +# @package _global_ + +# Model +model: + _target_: model.visual.sam2.organised_sam2_train.SAM2Train + image_encoder: + _target_: model.visual.sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: model.visual.sam2.modeling.backbones.hieradet.Hiera + embed_dim: 112 + num_heads: 2 + neck: + _target_: model.visual.sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [896, 448, 224, 112] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: model.visual.sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: model.visual.sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: model.visual.sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: model.visual.sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: model.visual.sam2.modeling.memory_encoder.Fuser + layer: + _target_: model.visual.sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False + diff --git a/avs.code/v1s.code/configs/sam2/sam2_hiera_l.yaml b/avs.code/v1s.code/configs/sam2/sam2_hiera_l.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8478b3d4b8b16d8b22f6555cf7b1f00231d7fd59 --- /dev/null +++ b/avs.code/v1s.code/configs/sam2/sam2_hiera_l.yaml @@ -0,0 +1,117 @@ +# @package _global_ + +# Model +model: + _target_: model.visual.sam2.organised_sam2_train.SAM2Train + image_encoder: + _target_: model.visual.sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: model.visual.sam2.modeling.backbones.hieradet.Hiera + embed_dim: 144 + num_heads: 2 + stages: [2, 6, 36, 4] + global_att_blocks: [23, 33, 43] + window_pos_embed_bkg_spatial_size: [7, 7] + window_spec: [8, 4, 16, 8] + neck: + _target_: model.visual.sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [1152, 576, 288, 144] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: model.visual.sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: model.visual.sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: model.visual.sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: model.visual.sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: model.visual.sam2.modeling.memory_encoder.Fuser + layer: + _target_: model.visual.sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/avs.code/v1s.code/configs/sam2/sam2_hiera_s.yaml b/avs.code/v1s.code/configs/sam2/sam2_hiera_s.yaml new file mode 100644 index 0000000000000000000000000000000000000000..26e5d4d39f7b2892396106005c37c7ffe6c83bc2 --- /dev/null +++ b/avs.code/v1s.code/configs/sam2/sam2_hiera_s.yaml @@ -0,0 +1,116 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 11, 2] + global_att_blocks: [7, 10, 13] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/avs.code/v1s.code/configs/sam2/sam2_hiera_t.yaml b/avs.code/v1s.code/configs/sam2/sam2_hiera_t.yaml new file mode 100644 index 0000000000000000000000000000000000000000..59e605b73c9777b70942538252d27a55ae8a7e1a --- /dev/null +++ b/avs.code/v1s.code/configs/sam2/sam2_hiera_t.yaml @@ -0,0 +1,118 @@ +# @package _global_ + +# Model +model: + _target_: model.visual.sam2.organised_sam2_train.SAM2Train + image_encoder: + _target_: model.visual.sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: model.visual.sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 7, 2] + global_att_blocks: [5, 7, 9] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: model.visual.sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: model.visual.sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: model.visual.sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: model.visual.sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: model.visual.sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: model.visual.sam2.modeling.memory_encoder.Fuser + layer: + _target_: model.visual.sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 224 # 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + # SAM decoder + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: false + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + # HieraT does not currently support compilation, should always be set to False + compile_image_encoder: False diff --git a/avs.code/v1s.code/configs/training/sam2_training_config.yaml b/avs.code/v1s.code/configs/training/sam2_training_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..55771e7232fe88c4ea445958956eca8174c2e872 --- /dev/null +++ b/avs.code/v1s.code/configs/training/sam2_training_config.yaml @@ -0,0 +1,62 @@ +# @package _global_ + +# Video transforms + +train_transforms: + - _target_: dataloader.sam2_dataset.transforms.ComposeAPI + transforms: + - _target_: dataloader.sam2_dataset.transforms.RandomHorizontalFlip + consistent_transform: True + - _target_: dataloader.sam2_dataset.transforms.RandomAffine + degrees: 25 + shear: 20 + image_interpolation: bilinear + consistent_transform: True + - _target_: dataloader.sam2_dataset.transforms.RandomResizeAPI + sizes: 1024 # ${scratch.resolution} + square: true + consistent_transform: True + - _target_: dataloader.sam2_dataset.transforms.ColorJitter + consistent_transform: True + brightness: 0.1 + contrast: 0.03 + saturation: 0.03 + hue: null + - _target_: dataloader.sam2_dataset.transforms.RandomGrayscale + p: 0.05 + consistent_transform: True + - _target_: dataloader.sam2_dataset.transforms.ColorJitter + consistent_transform: False + brightness: 0.1 + contrast: 0.05 + saturation: 0.05 + hue: null + - _target_: dataloader.sam2_dataset.transforms.ToTensorAPI + - _target_: dataloader.sam2_dataset.transforms.NormalizeAPI + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +loss: + all: + _target_: loss.training.sam2_training_loss.MultiStepMultiMasksAndIous + weight_dict: + loss_mask: 20 # 20 + loss_dice: 1 + loss_iou: 1 + loss_class: 1 + supervise_all_iou: true + iou_use_l1_loss: true + pred_obj_scores: true + focal_gamma_obj_score: 0.0 + focal_alpha_obj_score: -1.0 + gpu_num: 4. + +# Contrastive loss (ContrastLoss); loaded in main.py / inference.py → hyp_param.contrastive_learning +contrastive_learning: + temperature: 0.10 + ignore_idx: 255 + ood_idx: 254 + max_views: 512 + proj_dim: 512 + sample_limits: 128 + total_limits: 15240 diff --git a/avs.code/v1s.code/dataloader/audio/audio_augmentation.py b/avs.code/v1s.code/dataloader/audio/audio_augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..850d1577ea2bca4f8ec209edc201fb54968be928 --- /dev/null +++ b/avs.code/v1s.code/dataloader/audio/audio_augmentation.py @@ -0,0 +1,23 @@ +import numpy + + +class Augmentation(object): + """Audio pre-step used by training/inference: int16 waveform -> float in [-1, 1]. + + The previous audiomentations-based transforms were commented out and never applied; + behavior is unchanged: only scaling by 1/32768. + """ + + def __init__(self, mono=True): + self.mono = mono + + def train_aug(self, x_, sr_): + x_ = x_ / 32768.0 + return x_ + + def test_process(self, x_): + x_ = x_ / 32768.0 + return x_ + + def __call__(self, x, sr, split): + return self.train_aug(x, sr) if split == "train" else self.test_process(x) diff --git a/avs.code/v1s.code/dataloader/audio/audio_dataset.py b/avs.code/v1s.code/dataloader/audio/audio_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5c8e8b276e8545aa55ef56295719a0ad2b167106 --- /dev/null +++ b/avs.code/v1s.code/dataloader/audio/audio_dataset.py @@ -0,0 +1,38 @@ +import torch +import numpy +import os +from dataloader.audio.preprocess_vgg.vggish_input import waveform_to_examples +import soundfile + + +class Audio(torch.utils.data.Dataset): + def __init__(self, augmentation, directory_path, split): + # temporarily set no augmentation. + self.augmentation = augmentation + self.directory_path = directory_path + self.split = split + + def load_audio_wave(self, file_index, file_index_mix): + audio_path = os.path.join(file_index, 'audio.wav') + wav_data, sample_rate = soundfile.read(audio_path, dtype='int16') + assert wav_data.dtype == numpy.int16, 'Bad sample type: %r' % wav_data.dtype + + if file_index_mix is not None: + audio_path2 = os.path.join(file_index_mix, 'audio.wav') + wav_data2, _ = soundfile.read(audio_path2, dtype='int16') + mix_lambda = numpy.random.beta(10, 10) + min_length = min(wav_data.shape[0], wav_data2.shape[0]) + wav_data = wav_data[:min_length] * mix_lambda + wav_data2[:min_length] * (1-mix_lambda) + + wav_data = self.augmentation(wav_data, sample_rate, self.split) + audio_log_mel = torch.cat([waveform_to_examples(wav_data[:, 0], sample_rate, True).detach(), + waveform_to_examples(wav_data[:, 1], sample_rate, True).detach()], dim=1) + + # for the vgg preprocess, we will need 5 seconds audio log. + if audio_log_mel.shape[0] < 5: + audio_log_mel = torch.cat([audio_log_mel, + audio_log_mel[-1].unsqueeze(0).repeat(5-audio_log_mel.shape[0], 1, 1, 1)]) + return audio_log_mel + + def __len__(self): + return len(self.audio_list) diff --git a/avs.code/v1s.code/dataloader/audio/preprocess_vgg/mel_features.py b/avs.code/v1s.code/dataloader/audio/preprocess_vgg/mel_features.py new file mode 100644 index 0000000000000000000000000000000000000000..ac58fb5427f772fcced9cbd3cec3373ffbe5908c --- /dev/null +++ b/avs.code/v1s.code/dataloader/audio/preprocess_vgg/mel_features.py @@ -0,0 +1,223 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Defines routines to compute mel spectrogram features from audio waveform.""" + +import numpy as np + + +def frame(data, window_length, hop_length): + """Convert array into a sequence of successive possibly overlapping frames. + + An n-dimensional array of shape (num_samples, ...) is converted into an + (n+1)-D array of shape (num_frames, window_length, ...), where each frame + starts hop_length points after the preceding one. + + This is accomplished using stride_tricks, so the original data is not + copied. However, there is no zero-padding, so any incomplete frames at the + end are not included. + + Args: + data: np.array of dimension N >= 1. + window_length: Number of samples in each frame. + hop_length: Advance (in samples) between each window. + + Returns: + (N+1)-D np.array with as many rows as there are complete frames that can be + extracted. + """ + num_samples = data.shape[0] + num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length)) + shape = (num_frames, window_length) + data.shape[1:] + strides = (data.strides[0] * hop_length,) + data.strides + return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides) + + +def periodic_hann(window_length): + """Calculate a "periodic" Hann window. + + The classic Hann window is defined as a raised cosine that starts and + ends on zero, and where every value appears twice, except the middle + point for an odd-length window. Matlab calls this a "symmetric" window + and np.hanning() returns it. However, for Fourier analysis, this + actually represents just over one cycle of a period N-1 cosine, and + thus is not compactly expressed on a length-N Fourier basis. Instead, + it's better to use a raised cosine that ends just before the final + zero value - i.e. a complete cycle of a period-N cosine. Matlab + calls this a "periodic" window. This routine calculates it. + + Args: + window_length: The number of points in the returned window. + + Returns: + A 1D np.array containing the periodic hann window. + """ + return 0.5 - (0.5 * np.cos(2 * np.pi / window_length * + np.arange(window_length))) + + +def stft_magnitude(signal, fft_length, + hop_length=None, + window_length=None): + """Calculate the short-time Fourier transform magnitude. + + Args: + signal: 1D np.array of the input time-domain signal. + fft_length: Size of the FFT to apply. + hop_length: Advance (in samples) between each frame passed to FFT. + window_length: Length of each block of samples to pass to FFT. + + Returns: + 2D np.array where each row contains the magnitudes of the fft_length/2+1 + unique values of the FFT for the corresponding frame of input samples. + """ + frames = frame(signal, window_length, hop_length) + # Apply frame window to each frame. We use a periodic Hann (cosine of period + # window_length) instead of the symmetric Hann of np.hanning (period + # window_length-1). + window = periodic_hann(window_length) + windowed_frames = frames * window + return np.abs(np.fft.rfft(windowed_frames, int(fft_length))) + + +# Mel spectrum constants and functions. +_MEL_BREAK_FREQUENCY_HERTZ = 700.0 +_MEL_HIGH_FREQUENCY_Q = 1127.0 + + +def hertz_to_mel(frequencies_hertz): + """Convert frequencies to mel scale using HTK formula. + + Args: + frequencies_hertz: Scalar or np.array of frequencies in hertz. + + Returns: + Object of same size as frequencies_hertz containing corresponding values + on the mel scale. + """ + return _MEL_HIGH_FREQUENCY_Q * np.log( + 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ)) + + +def spectrogram_to_mel_matrix(num_mel_bins=20, + num_spectrogram_bins=129, + audio_sample_rate=8000, + lower_edge_hertz=125.0, + upper_edge_hertz=3800.0): + """Return a matrix that can post-multiply spectrogram rows to make mel. + + Returns a np.array matrix A that can be used to post-multiply a matrix S of + spectrogram values (STFT magnitudes) arranged as frames x bins to generate a + "mel spectrogram" M of frames x num_mel_bins. M = S A. + + The classic HTK algorithm exploits the complementarity of adjacent mel bands + to multiply each FFT bin by only one mel weight, then add it, with positive + and negative signs, to the two adjacent mel bands to which that bin + contributes. Here, by expressing this operation as a matrix multiply, we go + from num_fft multiplies per frame (plus around 2*num_fft adds) to around + num_fft^2 multiplies and adds. However, because these are all presumably + accomplished in a single call to np.dot(), it's not clear which approach is + faster in Python. The matrix multiplication has the attraction of being more + general and flexible, and much easier to read. + + Args: + num_mel_bins: How many bands in the resulting mel spectrum. This is + the number of columns in the output matrix. + num_spectrogram_bins: How many bins there are in the source spectrogram + data, which is understood to be fft_size/2 + 1, i.e. the spectrogram + only contains the nonredundant FFT bins. + audio_sample_rate: Samples per second of the audio at the input to the + spectrogram. We need this to figure out the actual frequencies for + each spectrogram bin, which dictates how they are mapped into mel. + lower_edge_hertz: Lower bound on the frequencies to be included in the mel + spectrum. This corresponds to the lower edge of the lowest triangular + band. + upper_edge_hertz: The desired top edge of the highest frequency band. + + Returns: + An np.array with shape (num_spectrogram_bins, num_mel_bins). + + Raises: + ValueError: if frequency edges are incorrectly ordered or out of range. + """ + nyquist_hertz = audio_sample_rate / 2. + if lower_edge_hertz < 0.0: + raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz) + if lower_edge_hertz >= upper_edge_hertz: + raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" % + (lower_edge_hertz, upper_edge_hertz)) + if upper_edge_hertz > nyquist_hertz: + raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" % + (upper_edge_hertz, nyquist_hertz)) + spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins) + spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz) + # The i'th mel band (starting from i=1) has center frequency + # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge + # band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in + # the band_edges_mel arrays. + band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz), + hertz_to_mel(upper_edge_hertz), num_mel_bins + 2) + # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins + # of spectrogram values. + mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins)) + for i in range(num_mel_bins): + lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3] + # Calculate lower and upper slopes for every spectrogram bin. + # Line segments are linear in the *mel* domain, not hertz. + lower_slope = ((spectrogram_bins_mel - lower_edge_mel) / + (center_mel - lower_edge_mel)) + upper_slope = ((upper_edge_mel - spectrogram_bins_mel) / + (upper_edge_mel - center_mel)) + # .. then intersect them with each other and zero. + mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope, + upper_slope)) + # HTK excludes the spectrogram DC bin; make sure it always gets a zero + # coefficient. + mel_weights_matrix[0, :] = 0.0 + return mel_weights_matrix + + +def log_mel_spectrogram(data, + audio_sample_rate=8000, + log_offset=0.0, + window_length_secs=0.025, + hop_length_secs=0.010, + **kwargs): + """Convert waveform to a log magnitude mel-frequency spectrogram. + + Args: + data: 1D np.array of waveform data. + audio_sample_rate: The sampling rate of data. + log_offset: Add this to values when taking log to avoid -Infs. + window_length_secs: Duration of each window to analyze. + hop_length_secs: Advance between successive analysis windows. + **kwargs: Additional arguments to pass to spectrogram_to_mel_matrix. + + Returns: + 2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank + magnitudes for successive frames. + """ + window_length_samples = int(round(audio_sample_rate * window_length_secs)) + hop_length_samples = int(round(audio_sample_rate * hop_length_secs)) + fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0))) + spectrogram = stft_magnitude( + data, + fft_length=fft_length, + hop_length=hop_length_samples, + window_length=window_length_samples) + mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix( + num_spectrogram_bins=spectrogram.shape[1], + audio_sample_rate=audio_sample_rate, **kwargs)) + return np.log(mel_spectrogram + log_offset) diff --git a/avs.code/v1s.code/dataloader/audio/preprocess_vgg/vggish_input.py b/avs.code/v1s.code/dataloader/audio/preprocess_vgg/vggish_input.py new file mode 100644 index 0000000000000000000000000000000000000000..9d58e81bc70a85138980128e033f271998794605 --- /dev/null +++ b/avs.code/v1s.code/dataloader/audio/preprocess_vgg/vggish_input.py @@ -0,0 +1,98 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Compute input examples for VGGish from audio waveform.""" + +# Modification: Return torch tensors rather than numpy arrays +import torch + +import numpy as np +import resampy + +from dataloader.audio.preprocess_vgg import mel_features +from dataloader.audio.preprocess_vgg import vggish_params + +import soundfile as sf + + +def waveform_to_examples(data, sample_rate, return_tensor=True): + """Converts audio waveform into an array of examples for VGGish. + + Args: + data: np.array of either one dimension (mono) or two dimensions + (multi-channel, with the outer dimension representing channels). + Each sample is generally expected to lie in the range [-1.0, +1.0], + although this is not required. + sample_rate: Sample rate of data. + return_tensor: Return data as a Pytorch tensor ready for VGGish + + Returns: + 3-D np.array of shape [num_examples, num_frames, num_bands] which represents + a sequence of examples, each of which contains a patch of log mel + spectrogram, covering num_frames frames of audio and num_bands mel frequency + bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS. + + """ + # Convert to mono. + if len(data.shape) > 1: + data = np.mean(data, axis=1) + # Resample to the rate assumed by VGGish. + if sample_rate != vggish_params.SAMPLE_RATE: + data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE) + + # Compute log mel spectrogram features. + log_mel = mel_features.log_mel_spectrogram( + data, + audio_sample_rate=vggish_params.SAMPLE_RATE, + log_offset=vggish_params.LOG_OFFSET, + window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS, + hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS, + num_mel_bins=vggish_params.NUM_MEL_BINS, + lower_edge_hertz=vggish_params.MEL_MIN_HZ, + upper_edge_hertz=vggish_params.MEL_MAX_HZ) + + # Frame features into examples. + features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS + example_window_length = int(round( + vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate)) + example_hop_length = int(round( + vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate)) + log_mel_examples = mel_features.frame( + log_mel, + window_length=example_window_length, + hop_length=example_hop_length) + + if return_tensor: + log_mel_examples = torch.tensor( + log_mel_examples, requires_grad=True)[:, None, :, :].float() + + return log_mel_examples + + +def wavfile_to_examples(wav_file, return_tensor=True): + """Convenience wrapper around waveform_to_examples() for a common WAV format. + + Args: + wav_file: String path to a file, or a file-like object. The file + is assumed to contain WAV audio data with signed 16-bit PCM samples. + torch: Return data as a Pytorch tensor ready for VGGish + + Returns: + See waveform_to_examples. + """ + wav_data, sr = sf.read(wav_file, dtype='int16') + assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype + samples = wav_data / 32768.0 # Convert to [-1.0, +1.0] + return waveform_to_examples(samples, sr, return_tensor) diff --git a/avs.code/v1s.code/dataloader/audio/preprocess_vgg/vggish_params.py b/avs.code/v1s.code/dataloader/audio/preprocess_vgg/vggish_params.py new file mode 100644 index 0000000000000000000000000000000000000000..526784bceaa4c9c8b8dc2b8f82e0f3d395d4bec2 --- /dev/null +++ b/avs.code/v1s.code/dataloader/audio/preprocess_vgg/vggish_params.py @@ -0,0 +1,53 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Global parameters for the VGGish model. + +See vggish_slim.py for more information. +""" + +# Architectural constants. +NUM_FRAMES = 96 # Frames in input mel-spectrogram patch. +NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch. +EMBEDDING_SIZE = 128 # Size of embedding layer. + +# Hyperparameters used in feature and example generation. +SAMPLE_RATE = 16000 +STFT_WINDOW_LENGTH_SECONDS = 0.025 +STFT_HOP_LENGTH_SECONDS = 0.010 +NUM_MEL_BINS = NUM_BANDS +MEL_MIN_HZ = 125 +MEL_MAX_HZ = 7500 +LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram. +EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames +EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap. + +# Parameters used for embedding postprocessing. +PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors' +PCA_MEANS_NAME = 'pca_means' +QUANTIZE_MIN_VAL = -2.0 +QUANTIZE_MAX_VAL = +2.0 + +# Hyperparameters used in training. +INIT_STDDEV = 0.01 # Standard deviation used to initialize weights. +LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer. +ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer. + +# Names of ops, tensors, and features. +INPUT_OP_NAME = 'vggish/input_features' +INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0' +OUTPUT_OP_NAME = 'vggish/embedding' +OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0' +AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding' diff --git a/avs.code/v1s.code/dataloader/dataset.py b/avs.code/v1s.code/dataloader/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..066f3639049e8840a67c60078ae7a8d6f38c1fa2 --- /dev/null +++ b/avs.code/v1s.code/dataloader/dataset.py @@ -0,0 +1,67 @@ +"""Fused audio-visual dataset for AVSBench-style indexing.""" +import os +import random +import PIL.Image +import numpy +import torch +from dataloader.visual.visual_dataset import Visual +from dataloader.audio.audio_dataset import Audio +import pandas + + +class AV(torch.utils.data.Dataset): + """Pairs video frames + labels from `Visual` with log-mel spectrograms from `Audio` via `metadata.csv`.""" + + def __init__(self, split, augmentation, param, root_path='', data_name='find'): + self.visual_dataset = Visual(augmentation['visual'], os.path.join(root_path, data_name), split, param.image_size, param.image_embedding_size) + self.audio_dataset = Audio(augmentation['audio'], os.path.join(root_path, data_name), split) + self.augment = augmentation + self.split = split + self.file_path = self.organise_files(self.split, root_path, data_name, csv_name_='avss_index/metadata.csv') + + def __getitem__(self, index): + mixing_prob = 0. # we omit this option. + other_index = random.randint(1, self.__len__()) - 1 if random.random() < mixing_prob and self.split == 'train' else None + frame, label, prompts = self.visual_dataset.load_data(self.file_path[index]) + if other_index is not None: + other_frame, other_label, other_prompts = self.visual_dataset.load_data(self.file_path[other_index]) + frame, label, prompts = self.visual_mix(frame, other_frame, label, other_label, prompts, other_prompts) + audio_mel = self.audio_dataset.load_audio_wave(self.file_path[index], self.file_path[other_index]) + else: + audio_mel = self.audio_dataset.load_audio_wave(self.file_path[index], None) + + assert other_index is None if self.split == 'test' else 1, print('no mix in validation.') + + return {'frame': frame, 'label': label, 'spectrogram': audio_mel, 'id': self.file_path[index], + 'prompts': prompts} + + def __len__(self): + return len(self.file_path) + + @staticmethod + def organise_files(split_, root_path_, data_name_, csv_name_): + """Read rows from `csv_name_` under `root_path_` matching split and dataset label.""" + total_files = pandas.read_csv(os.path.join(root_path_, csv_name_)) + files_info = total_files[(total_files["split"] == split_) & (total_files["label"] == data_name_)]['uid'] + + files_path = [os.path.join(root_path_, data_name_, files_name) for files_name in files_info] + del total_files, files_info + return files_path + + @staticmethod + def visual_mix(frame1, frame2, label1, label2, prompts1, prompts2): + mix_frame = frame1.clone() + mix_label = label1.clone() + bbx1, bby1, bbx2, bby2 = 0, 0, mix_label.shape[1] - 1, mix_label.shape[2] - 1 + + for i in range(0, mix_frame.shape[0]): + label_canvas_foreground = label2[i, bbx1:bbx2, bby1:bby2] > 0. + mix_frame[i, :, bbx1:bbx2, bby1:bby2][:, label_canvas_foreground] = ( + frame2[i, :, bbx1:bbx2, bby1:bby2][:, label_canvas_foreground]) + mix_label[i, bbx1:bbx2, bby1:bby2][label_canvas_foreground] = ( + label2[i, bbx1:bbx2, bby1:bby2][label_canvas_foreground]) + + return mix_frame, mix_label, prompts1 + + + diff --git a/avs.code/v1s.code/dataloader/sam2_dataset/__init__.py b/avs.code/v1s.code/dataloader/sam2_dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/avs.code/v1s.code/dataloader/sam2_dataset/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/avs.code/v1s.code/dataloader/sam2_dataset/transforms.py b/avs.code/v1s.code/dataloader/sam2_dataset/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..7731e59ba98a5465493e3a9c4b785eb4d4420ca2 --- /dev/null +++ b/avs.code/v1s.code/dataloader/sam2_dataset/transforms.py @@ -0,0 +1,528 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Transforms and data augmentation for both image + bbox. +""" + +import logging + +import random +from typing import Iterable + +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as F +import torchvision.transforms.v2.functional as Fv2 +from PIL import Image as PILImage +# from docutils.nodes import label +import numpy +from torchvision.transforms import InterpolationMode + +# from utils.data_utils import VideoDatapoint + + +def hflip(frames, labels, index): + # print(index) + # print(len(frames), frames[index].size, type(frames[index])) + # print(len(labels), labels[index].size, type(labels[index])) + frames[index] = F.hflip(frames[index]) + labels[index] = F.hflip(labels[index]) + # for obj in frames[index].objects: + # if obj.segment is not None: + # obj.segment = F.hflip(obj.segment) + + return frames, labels + + +def get_size_with_aspect_ratio(image_size, size, max_size=None): + w, h = image_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = max_size * min_original_size / max_original_size + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = int(round(size)) + oh = int(round(size * h / w)) + else: + oh = int(round(size)) + ow = int(round(size * w / h)) + + return (oh, ow) + + +def resize(frames, labels, index, size, max_size=None, square=False, v2=False): + # size can be min_size (scalar) or (w, h) tuple + def get_size(image_size, size, max_size=None): + if isinstance(size, (list, tuple)): + return size[::-1] + else: + return get_size_with_aspect_ratio(image_size, size, max_size) + + if square: + size = size, size + else: + raise NotImplementedError + # cur_size = ( + # frames[index].data.size()[-2:][::-1] + # if v2 + # else frames[index].data.size + # ) + # size = get_size(cur_size, size, max_size) + + # old_size = ( + # frames[index].data.size()[-2:][::-1] + # if v2 + # else frames[index].data.size + # ) + if v2: + frames[index].data = Fv2.resize( + frames[index].data, size, antialias=True + ) + else: + frames[index] = F.resize(frames[index], size) + labels[index] = F.resize(labels[index], size) + # new_size = ( + # frames[index].data.size()[-2:][::-1] + # if v2 + # else frames[index].data.size + # ) + + # for obj in frames[index].objects: + # if obj.segment is not None: + # obj.segment = F.resize(obj.segment[None, None], size).squeeze() + + # h, w = size + # frames[index].size = (h, w) + return frames, labels + + +def pad(frames, index, padding, v2=False): + old_h, old_w = frames[index].size + h, w = old_h, old_w + if len(padding) == 2: + # assumes that we only pad on the bottom right corners + frames[index].data = F.pad( + frames[index].data, (0, 0, padding[0], padding[1]) + ) + h += padding[1] + w += padding[0] + else: + # left, top, right, bottom + frames[index].data = F.pad( + frames[index].data, + (padding[0], padding[1], padding[2], padding[3]), + ) + h += padding[1] + padding[3] + w += padding[0] + padding[2] + + frames[index].size = (h, w) + + for obj in frames[index].objects: + if obj.segment is not None: + if v2: + if len(padding) == 2: + obj.segment = Fv2.pad(obj.segment, (0, 0, padding[0], padding[1])) + else: + obj.segment = Fv2.pad(obj.segment, tuple(padding)) + else: + if len(padding) == 2: + obj.segment = F.pad(obj.segment, (0, 0, padding[0], padding[1])) + else: + obj.segment = F.pad(obj.segment, tuple(padding)) + return frames + + +class RandomHorizontalFlip: + def __init__(self, consistent_transform, p=0.5): + self.p = p + self.consistent_transform = consistent_transform + + def __call__(self, frames, labels, **kwargs): + if self.consistent_transform: + if random.random() < self.p: + for i in range(len(frames)): + frames, labels = hflip(frames, labels, i) + return frames, labels + for i in range(len(frames)): + if random.random() < self.p: + frames, labels = hflip(frames, labels, i) + return frames, labels + + +class RandomResizeAPI: + def __init__( + self, sizes, consistent_transform, max_size=None, square=False, v2=False + ): + if isinstance(sizes, int): + sizes = (sizes,) + assert isinstance(sizes, Iterable) + self.sizes = list(sizes) + self.max_size = max_size + self.square = square + self.consistent_transform = consistent_transform + self.v2 = v2 + + def __call__(self, frames, labels): + if self.consistent_transform: + size = random.choice(self.sizes) + for i in range(len(frames)): + frames, labels = resize( + frames, labels, i, size, self.max_size, square=self.square, v2=self.v2 + ) + return frames, labels + for i in range(len(frames)): + size = random.choice(self.sizes) + frames, labels = resize( + frames, labels, i, size, self.max_size, square=self.square, v2=self.v2 + ) + return frames, labels + + +class ToTensorAPI: + def __init__(self, v2=False): + self.v2 = v2 + + def __call__(self, frames, labels, **kwargs): + for img_idx in range(len(frames)): + if self.v2: + raise NotImplementedError + # frames[img_idx] = Fv2.to_tensor(frames[img_idx]) + else: + frames[img_idx] = F.to_tensor(frames[img_idx]) + labels[img_idx] = torch.tensor(numpy.array(labels[img_idx]), dtype=torch.float) + return frames, labels + + +class NormalizeAPI: + def __init__(self, mean, std, v2=False): + self.mean = mean + self.std = std + self.v2 = v2 + + def __call__(self, frames, labels, **kwargs): + for img_idx in range(len(frames)): + # if self.v2: + # img.data = Fv2.convert_image_dtype(img.data, torch.float32) + # img.data = Fv2.normalize(img.data, mean=self.mean, std=self.std) + # else: + frames[img_idx] = F.normalize(frames[img_idx], mean=self.mean, std=self.std) + + return frames, labels + +''' + + + + + + + + +''' +class ComposeAPI: + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, frames, labels, **kwargs): + for t in self.transforms: + frames, labels = t(frames, labels, **kwargs) + return frames, labels + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string + + +class RandomGrayscale: + def __init__(self, consistent_transform, p=0.5): + self.p = p + self.consistent_transform = consistent_transform + self.Grayscale = T.Grayscale(num_output_channels=3) + + def __call__(self, frames, labels, **kwargs): + if self.consistent_transform: + if random.random() < self.p: + for img_idx in range(len(frames)): + frames[img_idx] = self.Grayscale(frames[img_idx]) + return frames, labels + for img_idx in range(len(frames)): + if random.random() < self.p: + frames[img_idx] = self.Grayscale(frames[img_idx]) + return frames, labels + + +class ColorJitter: + def __init__(self, consistent_transform, brightness, contrast, saturation, hue): + self.consistent_transform = consistent_transform + self.brightness = ( + brightness + if isinstance(brightness, list) + else [max(0, 1 - brightness), 1 + brightness] + ) + self.contrast = ( + contrast + if isinstance(contrast, list) + else [max(0, 1 - contrast), 1 + contrast] + ) + self.saturation = ( + saturation + if isinstance(saturation, list) + else [max(0, 1 - saturation), 1 + saturation] + ) + self.hue = hue if isinstance(hue, list) or hue is None else ([-hue, hue]) + + def __call__(self, frames, labels, **kwargs): + if self.consistent_transform: + # Create a color jitter transformation params + ( + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ) = T.ColorJitter.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) + for img in frames: + if not self.consistent_transform: + ( + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ) = T.ColorJitter.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) + for fn_id in fn_idx: + if fn_id == 0 and brightness_factor is not None: + img = F.adjust_brightness(img, brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + img = F.adjust_contrast(img, contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + img = F.adjust_saturation(img, saturation_factor) + elif fn_id == 3 and hue_factor is not None: + img = F.adjust_hue(img, hue_factor) + return frames, labels + + +class RandomAffine: + def __init__( + self, + degrees, + consistent_transform, + scale=None, + translate=None, + shear=None, + image_mean=(123, 116, 103), + label_fill_value=0., + log_warning=True, + num_tentatives=1, + image_interpolation="bicubic", + ): + """ + The mask is required for this transform. + if consistent_transform if True, then the same random affine is applied to all frames and masks. + """ + self.degrees = degrees if isinstance(degrees, list) else ([-degrees, degrees]) + self.scale = scale + self.shear = ( + shear if isinstance(shear, list) else ([-shear, shear] if shear else None) + ) + self.translate = translate + self.fill_img = image_mean + self.fill_label = label_fill_value + self.consistent_transform = consistent_transform + self.log_warning = log_warning + self.num_tentatives = num_tentatives + assert self.num_tentatives >= 1., 'must have at least one if we utilise the augmentation.' + + if image_interpolation == "bicubic": + self.image_interpolation = InterpolationMode.BICUBIC + elif image_interpolation == "bilinear": + self.image_interpolation = InterpolationMode.BILINEAR + else: + raise NotImplementedError + + def __call__(self, frames, labels, **kwargs): + for _tentative in range(self.num_tentatives): + res_img, res_labels = self.transform_frames(frames, labels) + # if res is not None: + return res_img, res_labels + + # raise NotImplementedError + # if self.log_warning: + # logging.warning( + # f"Skip RandomAffine for zero-area mask in first frame after {self.num_tentatives} tentatives" + # ) + # return frames + + def transform_frames(self, frames, labels): + _, height, width = F.get_dimensions(frames[0]) + img_size = [width, height] + + if self.consistent_transform: + # Create a random affine transformation + affine_params = T.RandomAffine.get_params( + degrees=self.degrees, + translate=self.translate, + scale_ranges=self.scale, + shears=self.shear, + img_size=img_size, + ) + + for img_idx, img in enumerate(frames): + if not self.consistent_transform: + # if not consistent we create a new affine params for every frame&mask pair Create a random affine transformation + affine_params = T.RandomAffine.get_params( + degrees=self.degrees, + translate=self.translate, + scale_ranges=self.scale, + shears=self.shear, + img_size=img_size, + ) + frames[img_idx] = F.affine( + img, + *affine_params, + interpolation=self.image_interpolation, + fill=self.fill_img, + ) + labels[img_idx] = F.affine( + labels[img_idx], + *affine_params, + # default: interpolation='nearest', + fill=self.fill_label, + ) + return frames, labels + + +''' +def random_mosaic_frame( + datapoint, + index, + grid_h, + grid_w, + target_grid_y, + target_grid_x, + should_hflip, +): + # Step 1: downsize the images and paste them into a mosaic + image_data = datapoint.frames[index].data + is_pil = isinstance(image_data, PILImage.Image) + if is_pil: + H_im = image_data.height + W_im = image_data.width + image_data_output = PILImage.new("RGB", (W_im, H_im)) + else: + H_im = image_data.size(-2) + W_im = image_data.size(-1) + image_data_output = torch.zeros_like(image_data) + + downsize_cache = {} + for grid_y in range(grid_h): + for grid_x in range(grid_w): + y_offset_b = grid_y * H_im // grid_h + x_offset_b = grid_x * W_im // grid_w + y_offset_e = (grid_y + 1) * H_im // grid_h + x_offset_e = (grid_x + 1) * W_im // grid_w + H_im_downsize = y_offset_e - y_offset_b + W_im_downsize = x_offset_e - x_offset_b + + if (H_im_downsize, W_im_downsize) in downsize_cache: + image_data_downsize = downsize_cache[(H_im_downsize, W_im_downsize)] + else: + image_data_downsize = F.resize( + image_data, + size=(H_im_downsize, W_im_downsize), + interpolation=InterpolationMode.BILINEAR, + antialias=True, # antialiasing for downsizing + ) + downsize_cache[(H_im_downsize, W_im_downsize)] = image_data_downsize + if should_hflip[grid_y, grid_x].item(): + image_data_downsize = F.hflip(image_data_downsize) + + if is_pil: + image_data_output.paste(image_data_downsize, (x_offset_b, y_offset_b)) + else: + image_data_output[:, y_offset_b:y_offset_e, x_offset_b:x_offset_e] = ( + image_data_downsize + ) + + datapoint.frames[index].data = image_data_output + + # Step 2: downsize the masks and paste them into the target grid of the mosaic + for obj in datapoint.frames[index].objects: + if obj.segment is None: + continue + assert obj.segment.shape == (H_im, W_im) and obj.segment.dtype == torch.uint8 + segment_output = torch.zeros_like(obj.segment) + + target_y_offset_b = target_grid_y * H_im // grid_h + target_x_offset_b = target_grid_x * W_im // grid_w + target_y_offset_e = (target_grid_y + 1) * H_im // grid_h + target_x_offset_e = (target_grid_x + 1) * W_im // grid_w + target_H_im_downsize = target_y_offset_e - target_y_offset_b + target_W_im_downsize = target_x_offset_e - target_x_offset_b + + segment_downsize = F.resize( + obj.segment[None, None], + size=(target_H_im_downsize, target_W_im_downsize), + interpolation=InterpolationMode.BILINEAR, + antialias=True, # antialiasing for downsizing + )[0, 0] + if should_hflip[target_grid_y, target_grid_x].item(): + segment_downsize = F.hflip(segment_downsize[None, None])[0, 0] + + segment_output[ + target_y_offset_b:target_y_offset_e, target_x_offset_b:target_x_offset_e + ] = segment_downsize + obj.segment = segment_output + + return datapoint + + +class RandomMosaicVideoAPI: + def __init__(self, prob=0.15, grid_h=2, grid_w=2, use_random_hflip=False): + self.prob = prob + self.grid_h = grid_h + self.grid_w = grid_w + self.use_random_hflip = use_random_hflip + + def __call__(self, frames, **kwargs): + if random.random() > self.prob: + return datapoint + + # select a random location to place the target mask in the mosaic + target_grid_y = random.randint(0, self.grid_h - 1) + target_grid_x = random.randint(0, self.grid_w - 1) + # whether to flip each grid in the mosaic horizontally + if self.use_random_hflip: + should_hflip = torch.rand(self.grid_h, self.grid_w) < 0.5 + else: + should_hflip = torch.zeros(self.grid_h, self.grid_w, dtype=torch.bool) + for i in range(len(datapoint.frames)): + datapoint = random_mosaic_frame( + datapoint, + i, + grid_h=self.grid_h, + grid_w=self.grid_w, + target_grid_y=target_grid_y, + target_grid_x=target_grid_x, + should_hflip=should_hflip, + ) + + return datapoint +''' \ No newline at end of file diff --git a/avs.code/v1s.code/dataloader/visual/visual_augmentation.py b/avs.code/v1s.code/dataloader/visual/visual_augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..5d40aed7c8b8c08d50a46db122e1213bd4878afd --- /dev/null +++ b/avs.code/v1s.code/dataloader/visual/visual_augmentation.py @@ -0,0 +1,140 @@ +import random + +import matplotlib.pyplot as plt +import numpy +import torch +import torchvision.transforms.functional as F +import torchvision.transforms as transforms + + +class Augmentation(object): + def __init__(self, image_mean, image_std, image_width, image_height, scale_list, ignore_index=255): + self.image_size = (image_height, image_width) + # self.image_norm = (image_mean, image_std) + # self.get_crop_pos = transforms.RandomCrop(self.image_size) + self.color_jitter = transforms.ColorJitter(brightness=.5, contrast=.5, saturation=.5, hue=.25) + self.gaussian_blurring = transforms.GaussianBlur((3, 3)) + self.scale_list = scale_list + + self.normalise = transforms.Normalize(mean=image_mean, std=image_std) + self.to_tensor = transforms.ToTensor() + + self.ignore_index = ignore_index + + # self.normalise = transforms.Normalize(mean=image_mean, std=image_std) + + # if setup == "avs" or setup == "avss" or setup == "avss_binary": + # # AVS + # self.scale_list = [.5, .75, 1.] + # self.color_jitter = None + # else: + # # COCO + # # self.scale_list = [.75, 1., 1.25, 1.5, 1.75, 2.] + # self.scale_list = [0.5,0.75,1.0,1.25,1.5,1.75,2.0] + + # def normalise(self, image): + # image = image / 255.0 + # image = image - self.image_norm[0] + # image = image / self.image_norm[1] + # return image + + def resize(self, image_, label_, size=None): + h_, w_ = self.image_size if size is None else size + image_ = F.resize(image_, (h_, w_), transforms.InterpolationMode.BICUBIC) + label_ = F.resize(label_, (h_, w_), transforms.InterpolationMode.NEAREST) + return image_, label_ + + def random_crop_with_padding(self, image_, label_): + w_, h_ = image_.size + if min(h_, w_) < min(self.image_size): + res_w_ = max(self.image_size[0] - w_, 0) + res_h_ = max(self.image_size[1] - h_, 0) + image_ = F.pad(image_, [0, 0, res_w_, res_h_], fill=(numpy.array(self.image_norm[0]) * 255.).tolist()) + # image_ = F.pad(image_, [0, 0, res_w_, res_h_], fill=self.ignore_index) # if error, define the padding value. + label_ = F.pad(label_, [0, 0, res_w_, res_h_], fill=self.ignore_index) + + pos_ = self.get_crop_pos.get_params(image_, self.image_size) + image_ = F.crop(image_, *pos_) + label_ = F.crop(label_, *pos_) + + return image_, label_ + + # @staticmethod + def random_scales(self, image_, label_): + w_, h_ = image_.size + chosen_scale = random.choice(self.scale_list) + w_, h_ = int(w_ * chosen_scale), int(h_ * chosen_scale) + image_ = F.resize(image_, (h_, w_), transforms.InterpolationMode.BICUBIC) + label_ = F.resize(label_, (h_, w_), transforms.InterpolationMode.NEAREST) + return image_, label_ + + @staticmethod + def random_flip_h(image_, label_): + chosen_flip = random.random() > 0.5 + image_ = F.hflip(image_) if chosen_flip else image_ + label_ = F.hflip(label_) if chosen_flip else label_ + return image_, label_ + + def augment_entire_clip(self, x_list, y_list): + degree_ = float(torch.empty(1).uniform_(float(-25.), float(25.)).item()) + shear_ = [float(torch.empty(1).uniform_(float(-20.), float(20.)).item()), + torch.empty(1).uniform_(float(-20.), float(20.)).item()] + dice = random.random() + for index, single_x in enumerate(x_list): + if dice <= 0.1: + single_x = F.rgb_to_grayscale(single_x, num_output_channels=3) + + single_x = F.affine(single_x, angle=degree_, shear=shear_, translate=[0,0], scale=1., + interpolation=transforms.InterpolationMode.BILINEAR, fill=[0., 0., 0.]) + single_y = F.affine(y_list[index], angle=degree_, shear=shear_, translate=[0,0], scale=1., + interpolation=transforms.InterpolationMode.NEAREST, fill=[0.]) + x_list[index] = single_x + y_list[index] = single_y + + return x_list, y_list + + + + + def train_aug(self, x_, y_): + x_, y_ = self.random_flip_h(x_, y_) + # # x, y = self.random_scales(x, y) + x_, y_ = self.resize(x_, y_) + + if self.color_jitter is not None and random.random() < 0.5: + x_ = self.color_jitter(x_) + if self.gaussian_blurring is not None and random.random() < 0.5: + x_ = self.gaussian_blurring(x_) + + # x, y = self.random_crop_with_padding(x, y) + + x_ = self.normalise(self.to_tensor(x_)).type(torch.float32) + # receive pseudo labels. + y_ = torch.tensor(numpy.array(y_)[numpy.newaxis, ...], dtype=torch.float) + return x_, y_ + + def test_process(self, x_, y_): + # x = self.to_tensor(x) + # y = torch.tensor(numpy.asarray(y)).long() + + # following AVSbench setup, we fix image size (224, 224) + x_, y_ = self.resize(x_, y_) + + x_ = self.normalise(self.to_tensor(x_)).type(torch.float32) + y_ = torch.tensor(numpy.array(y_)[numpy.newaxis, ...], dtype=torch.float) + return x_, y_ + + def __call__(self, x, y, split): + return self.train_aug(x, y) if split == "train" \ + else self.test_process(x, y) + + + + + + + + + + + diff --git a/avs.code/v1s.code/dataloader/visual/visual_dataset.py b/avs.code/v1s.code/dataloader/visual/visual_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..62673359e084483b51f41dc2204cbc9dedf288be --- /dev/null +++ b/avs.code/v1s.code/dataloader/visual/visual_dataset.py @@ -0,0 +1,138 @@ +import os +import re +import PIL.Image +import matplotlib.pyplot as plt +import numpy +import torch +import pandas +import torchvision + + +class Visual(torch.utils.data.Dataset): + def __init__(self, augmentation, directory_path, split, image_size, image_embedding_size): + self.augment = augmentation + self.directory_path = directory_path + self.split = split + self.image_size = image_size + self.embedding_size = image_embedding_size + + def load_data(self, file_prefix): + frame_path = os.path.join(file_prefix, 'frames') + frame_path = [os.path.join(frame_path, i) for i in os.listdir(frame_path)] + label_path = os.path.join(file_prefix, 'labels_rgb') + label_path = [os.path.join(label_path, i) for i in os.listdir(label_path)] + + # if self.split == 'train': + # label_path += [os.path.join(file_prefix.replace('v1s', 'v1s_sam2_pseudo_labels'), i) for i in + # os.listdir(file_prefix.replace('v1s', 'v1s_sam2_pseudo_labels'))] + + frame_path.sort(key=lambda x: tuple(map(int, x.split('/')[-1].split("_")[-1].split('.jpg')[0]))) + label_path.sort(key=lambda x: tuple(map(int, x.split('/')[-1].split("_")[-1].split('.png')[0]))) + + frame = [PIL.Image.open(i) for i in frame_path] + label = [PIL.Image.open(i).convert('L') for i in label_path] + + # Keep full clip length. If labels are fewer than frames, pad missing labels + # with ignore-index masks so those positions are skipped in loss. + if len(label) < len(frame): + label += [PIL.Image.new('L', frame[0].size, color=255)] * (len(frame) - len(label)) + elif len(label) > len(frame): + label = label[:len(frame)] + + # if self.split == 'train': + # label += [PIL.Image.new('L', frame[0].size)] * (len(frame)-len(label)) + + label_idx = torch.zeros(len(frame), dtype=torch.bool) + if len(frame) > 0: + # Keep prior behavior: only the first frame is marked as labeled. + label_idx[0] = True + # fulfill the empty page. + # we utilise pseudo-labels now. + # label_idx = torch.tensor(list([1] + [0] * (len(frame) - len(label))), dtype=torch.bool) + # label += [PIL.Image.new('L', frame[0].size)] * (len(frame)-len(label)) + + # receive the prompts from the ground truth. + # prompts = {"point_coords": torch.nan, "point_labels": torch.nan, + # "masks": [None]*len(frame), "box_coords": [None]*len(frame)} + + prompts = {} + image_batch = [None]*len(frame) + label_batch = [None]*len(frame) + + if self.split == 'train': + # frame, label = self.augment.augment_entire_clip(frame, label) + frame, label = self.augment(frame, label) + + + for i in range(len(frame)): + if self.split == 'test': + curr_frame, curr_label = self.augment(frame[i], label[i], split=self.split) + else: + curr_frame, curr_label = frame[i], label[i] + # if self.split == 'train' and i > 0: + # curr_label = curr_label / 255. + # curr_label[curr_label > 0.5] = 1 + # curr_label[curr_label < 0.5] = 0 + # # curr_label[(0.05 < curr_label) & (curr_label < 0.95)] = 255 + # # we temporarily make it to be hard mask; + # # curr_label = ((curr_label / 255.) - 0.5) * 2 + # # curr_label[curr_label >= 0.] = 1. + # # curr_label[curr_label < 0.] = 0. + # else: + # Keep ignore-index (255) untouched; binarize only valid foreground labels. + curr_label[(curr_label > 0.) & (curr_label < 255.)] = 1. + image_batch[i], label_batch[i] = curr_frame, curr_label + + # image_batch[i], label_batch[i] = self.augment(frame[i], label[i], split=self.split) + # note: we simply convert the code to binary mask in v1s, v1m; + # to some reason, we failed to load the label in `L' format and had to hardcoding here. + # label_batch[i][label_batch[i] > 0.] = 1. + + # prompts['box_coords'][i], prompts['masks'][i] = self.receive_other_prompts(label_batch[i]) + + # organise the prompts + # prompts.update({'masks': torch.stack(prompts['masks'], dim=0)}) + # prompts.update({'box_coords': torch.stack(prompts['box_coords'], dim=0)}) + # prompts.update({'point_labels': torch.stack(prompts['point_labels'], dim=0)}) + prompts.update({'label_index': label_idx}) + return torch.stack(image_batch, dim=0), torch.stack(label_batch, dim=0), prompts + + def receive_other_prompts(self, y_): + # y_ = torch.zeros_like(y_) + if len(torch.unique(y_)) > 1: + # foreground point + points_foreground = torch.stack(torch.where(y_ > 0)[::-1], dim=0).transpose(1, 0) + + # bbox prompt (left-top corner & right-bottom corner) + bbox_one = torch.min(points_foreground[:, 0]), torch.min(points_foreground[:, 1]) + bbox_fou = torch.max(points_foreground[:, 0]), torch.max(points_foreground[:, 1]) + bbox_coord = torch.tensor(bbox_one + bbox_fou, dtype=torch.float) + bbox_coord = self.transform_coords(bbox_coord, orig_hw=y_.squeeze().shape) + # mask prompt + low_mask = torchvision.transforms.functional.resize(y_.clone(), [self.embedding_size*4, self.embedding_size*4], + torchvision.transforms.InterpolationMode.NEAREST) + else: + # for the pure background situation. + bbox_coord = torch.zeros([4], dtype=torch.float).fill_(float('nan')) + low_mask = torch.zeros([1, self.embedding_size*4, self.embedding_size*4], dtype=torch.float).fill_(float('nan')) + + return bbox_coord, low_mask + + # we transfer the coords to SAM's input resolution (1024, 1024). + def transform_coords(self, coords: torch.Tensor, orig_hw=None) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, + If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + + Returns + Un-normalized coordinates in the range of [0, 1] which is expected by the sam2 model. + """ + h, w = orig_hw + coords = coords.clone().reshape(-1, 2, 2) + coords[..., 0] = coords[..., 0] / w + coords[..., 1] = coords[..., 1] / h + coords = coords * self.image_size # unnormalize coords + return coords.reshape(4) + + + diff --git a/avs.code/v1s.code/inference.py b/avs.code/v1s.code/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..daa177eee39a941151e3fc5dace4296506fdec7a --- /dev/null +++ b/avs.code/v1s.code/inference.py @@ -0,0 +1,195 @@ +"""Distributed inference on the test set; runs the same three `process` modes as training validation.""" +import os +import pathlib +import torch +import numpy +import random +import argparse +from easydict import EasyDict + +# Avoid import failure when configs.config creates saved_dir without write permission. +_real_mkdir = pathlib.Path.mkdir + + +def _safe_mkdir(self, mode=0o777, parents=False, exist_ok=False): + try: + return _real_mkdir(self, mode, parents=parents, exist_ok=exist_ok) + except PermissionError: + pass + + +pathlib.Path.mkdir = _safe_mkdir + + +def seed_it(seed): + random.seed(seed) + os.environ["PYTHONSEED"] = str(seed) + numpy.random.seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.enabled = True + torch.manual_seed(seed) + + +class _DummyTensorboard: + """Minimal Tensorboard stub so Trainer.valid runs without wandb logging.""" + + def upload_wandb_info(self, info_dict): + pass + + def upload_wandb_image(self, *args, **kwargs): + pass + + +def main(local_rank, ngpus_per_node, hyp_param): + hyp_param.local_rank = local_rank + torch.distributed.init_process_group( + backend='nccl', + init_method='env://', + rank=hyp_param.local_rank, + world_size=hyp_param.gpus * 1 + ) + seed_it(local_rank + hyp_param.seed) + + import model.visual.sam2 # noqa: F401 — registers Hydra `configs` + from hydra import compose + from omegaconf import OmegaConf + + arch_h = compose(config_name='auralfuser/architecture.yaml') + OmegaConf.resolve(arch_h) + hyp_param.aural_fuser = OmegaConf.to_container(arch_h.aural_fuser, resolve=True) + + train_cfg = compose(config_name='training/sam2_training_config.yaml') + OmegaConf.resolve(train_cfg) + hyp_param.contrastive_learning = OmegaConf.to_container(train_cfg.contrastive_learning, resolve=True) + + from model.mymodel import AVmodel + av_model = AVmodel(hyp_param).cuda() + torch.cuda.set_device(hyp_param.local_rank) + ckpt_sd = torch.load(hyp_param.inference_ckpt, map_location="cpu") + if not isinstance(ckpt_sd, dict): + raise TypeError("Checkpoint must be a state_dict dictionary.") + # Support both formats: + # 1) full-model checkpoint (keys like `v_model.*`, `aural_fuser.*`) + # 2) train-only checkpoint for aural_fuser (keys without `aural_fuser.` prefix) + if any(k.startswith("v_model.") or k.startswith("aural_fuser.") for k in ckpt_sd.keys()): + av_model.load_state_dict(ckpt_sd, strict=True) + else: + av_model.aural_fuser.load_state_dict(ckpt_sd, strict=True) + + av_model = torch.nn.parallel.distributed.DistributedDataParallel(av_model, device_ids=[hyp_param.local_rank], + find_unused_parameters=False) + av_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(av_model) + av_model.eval() + + from dataloader.dataset import AV + from dataloader.visual.visual_augmentation import Augmentation as VisualAugmentation + from dataloader.audio.audio_augmentation import Augmentation as AudioAugmentation + from torch.utils.data import DataLoader, Subset + from torch.utils.data.distributed import DistributedSampler + + visual_augmentation = VisualAugmentation(hyp_param.image_mean, hyp_param.image_std, + hyp_param.image_size, hyp_param.image_size, + hyp_param.scale_list, ignore_index=hyp_param.ignore_index) + audio_augmentation = AudioAugmentation(mono=True) + + dataset = AV(split='test', augmentation={"visual": visual_augmentation, "audio": audio_augmentation}, + param=hyp_param, root_path=hyp_param.data_root_path, data_name=hyp_param.inference_data_name) + + max_batches = getattr(hyp_param, "inference_max_batches", 0) or 0 + if max_batches > 0: + n_samples = min(max_batches * hyp_param.batch_size, len(dataset)) + dataset = Subset(dataset, range(n_samples)) + + sampler = DistributedSampler(dataset, shuffle=False) + test_dataloader = DataLoader(dataset, batch_size=hyp_param.batch_size, sampler=sampler, + num_workers=hyp_param.num_workers) + + from trainer.train import Trainer + from utils.foreground_iou import ForegroundIoU + from utils.foreground_fscore import ForegroundFScore + + metrics = { + "foreground_iou": ForegroundIoU(), + "foreground_f-score": ForegroundFScore(hyp_param.local_rank), + } + trainer = Trainer(hyp_param, loss=None, tensorboard=_DummyTensorboard(), metrics=metrics) + + # Same three modes as main.py validation: default first mask / iou_select / iou_occ_select + runs = [ + ("", "default (logits[:,0])"), + ("iou_select", "iou_select"), + ("iou_occ_select", "iou_occ_select"), + ] + results = [] + for process, label in runs: + fiou, ffscore = trainer.valid(epoch=0, dataloader=test_dataloader, model=av_model, process=process) + results.append((label, fiou, ffscore)) + torch.cuda.empty_cache() + + if hyp_param.local_rank <= 0: + print("\n========== inference (same three process flags as training valid) ==========") + for label, fiou, ffscore in results: + print(" {:32s} f_iou={} f_f-score={}".format(label, fiou, ffscore)) + print("=======================================================\n") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Inference: full test set + three process modes') + + parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N') + + parser.add_argument("--local_rank", type=int, default=-1, + help='multi-process training for DDP') + + parser.add_argument('-g', '--gpus', default=1, type=int, + help='number of gpus per node') + + parser.add_argument('--batch_size', default=1, type=int, + help='Batch size (match training if needed)') + + parser.add_argument('--epochs', default=80, type=int, + help="unused") + + parser.add_argument('--lr', default=1e-5, type=float, + help="unused") + + parser.add_argument('--online', action="store_true", + help='unused') + + parser.add_argument( + '--inference_ckpt', type=str, default=None, + help='Trained AuralSAM2 checkpoint (.pth state_dict: full model or aural_fuser-only). ' + 'SAM2 backbone is loaded from backbone_weight in configs (same path as training: repo_root/ckpts/sam_ckpts/). ' + 'Default if unset: avs.code/training_details/.../hiera_l.pth', + ) + parser.add_argument('--inference_data_name', type=str, default='v1s', + help='AVSBench subset folder label (v1s|v1m|v2); must match training test split') + parser.add_argument('--inference_max_batches', type=int, default=0, + help='0 = full test; >0 = first N batches only (debug)') + + args = parser.parse_args() + + from configs.config import C + + args = EasyDict({**C, **vars(args)}) + + _repo = pathlib.Path(__file__).resolve().parent + # Repo root: .../AuralSAM2 (parent of avs.code) + _workspace = _repo.parent.parent + args.data_root_path = str(_workspace / 'AVSBench') + args.backbone_weight = str(_workspace / 'ckpts' / 'sam_ckpts' / 'sam2_hiera_large.pt') + args.audio.PRETRAINED_VGGISH_MODEL_PATH = str(_workspace / 'ckpts' / 'vggish-10086976.pth') + args.saved_dir = '/tmp/v1s_infer_ckpt' + pathlib.Path(args.saved_dir).mkdir(parents=True, exist_ok=True) + if args.inference_ckpt is None: + args.inference_ckpt = str( + _repo.parent / 'training_details' / 'v1s' / 'hiera_l' / 'hiera_l.pth' + ) + + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '9901' + + torch.multiprocessing.spawn(main, nprocs=args.gpus, args=(args.gpus, args)) diff --git a/avs.code/v1s.code/loss/training/__init__.py b/avs.code/v1s.code/loss/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..da8932ea6e0916b9f16cb514f1e64704440d554e --- /dev/null +++ b/avs.code/v1s.code/loss/training/__init__.py @@ -0,0 +1,2 @@ +"""Training loss modules.""" + diff --git a/avs.code/v1s.code/loss/training/contrastive_learning.py b/avs.code/v1s.code/loss/training/contrastive_learning.py new file mode 100644 index 0000000000000000000000000000000000000000..82cde07ec5815910688daa4622d61feaacb453c6 --- /dev/null +++ b/avs.code/v1s.code/loss/training/contrastive_learning.py @@ -0,0 +1,205 @@ +from abc import ABC + +import torch +import torch.nn as nn + + +class ContrastLoss(nn.Module, ABC): + def __init__(self, hyp_param): + super().__init__() + self.param = hyp_param + _defaults = { + "temperature": 0.10, + "ignore_idx": 255, + "ood_idx": 254, + "max_views": 512, + "proj_dim": 512, + "sample_limits": 128, + "total_limits": 15240, + } + _raw = getattr(hyp_param, "contrastive_learning", None) or {} + _cfg = {**_defaults, **_raw} + self.temperature = _cfg["temperature"] + self.ignore_idx = _cfg["ignore_idx"] + self.ood_idx = _cfg["ood_idx"] + self.max_views = _cfg["max_views"] + self.proj_dim = _cfg["proj_dim"] + self.sample_limits = _cfg["sample_limits"] + self.total_limits = _cfg["total_limits"] + + def select_class_wise_samples(self, embeddings, audio_embeddings, predictions, masks, batch_idx): + embedding_sample_list = [] + label_list = [] + embedding_sample_list_a = [] + label_list_a = [] + class_index_list = torch.unique(masks) + + if len(class_index_list) > 1: + for class_index in class_index_list[1:]: + embedding_sample_list_a.append(audio_embeddings.unsqueeze(0)) + label_list_a.append(class_index.unsqueeze(0) + batch_idx * 1e3) + else: + embedding_sample_list_a.append(audio_embeddings.unsqueeze(0)) + label_list_a.append(torch.zeros([1], device=embeddings.device) + batch_idx * 1e3) + + sample_limits = self.sample_limits + embeddings = embeddings.permute(1, 0) + for class_index in class_index_list: + hard_indices = embeddings[((masks != predictions) & (masks == class_index)).nonzero()] + easy_indices = embeddings[((masks == predictions) & (masks == class_index)).nonzero()] + + hard_indices_num, easy_indices_num = hard_indices.shape[0], easy_indices.shape[0] + selective_num_hard = min(sample_limits, hard_indices_num) + selective_num_easy = min(sample_limits, easy_indices_num) + + if (selective_num_hard + selective_num_easy) < sample_limits * 2: + if selective_num_hard > selective_num_easy: + selective_num_hard += sample_limits * 2 - selective_num_easy + else: + selective_num_easy += sample_limits * 2 - selective_num_hard + + hard_chosen_indices = torch.randperm(hard_indices_num)[:selective_num_hard] + embedding_sample_list.append(hard_indices[hard_chosen_indices]) + label_list.append(masks[hard_chosen_indices] + batch_idx * 1e3) + + easy_chosen_indices = torch.randperm(easy_indices_num)[:selective_num_easy] + embedding_sample_list.append(easy_indices[easy_chosen_indices]) + label_list.append(masks[easy_chosen_indices] + batch_idx * 1e3) + return embedding_sample_list, label_list, embedding_sample_list_a, label_list_a + + def forward_audio_visual(self, visual_embeddings, audio_embeddings, masks, predictions): + masks = masks.flatten(start_dim=1) + predictions = predictions.flatten(start_dim=1) + visual_embeddings = visual_embeddings.flatten(start_dim=-2) + + visual_embedding_sample_list = [] + visual_label_list = [] + audio_embedding_sample_list = [] + audio_label_list = [] + + for frame_idx in range(masks.shape[0]): + current_vision_feats = visual_embeddings[frame_idx] + current_masks = masks[frame_idx] + current_predictions = predictions[frame_idx] + current_audio_feats = audio_embeddings[frame_idx] + for layer_idx in range(3): + ( + selected_vision_embeddings, + selected_vision_labels, + selected_audio_embeddings, + selected_audio_labels, + ) = self.select_class_wise_samples( + current_vision_feats[layer_idx], + current_audio_feats[layer_idx], + current_predictions, + current_masks, + 0, + ) + visual_embedding_sample_list += selected_vision_embeddings + visual_label_list += selected_vision_labels + audio_embedding_sample_list += selected_audio_embeddings + audio_label_list += selected_audio_labels + + if len(visual_embedding_sample_list) == 0: + return 0.0 + + visual_embedding_sample_list = torch.cat(visual_embedding_sample_list, dim=0).squeeze() + if visual_embedding_sample_list.dim() == 1: + visual_embedding_sample_list = visual_embedding_sample_list.unsqueeze(0) + visual_label_list = torch.cat(visual_label_list, dim=0).unsqueeze(-1) + audio_embedding_sample_list = torch.cat(audio_embedding_sample_list, dim=0).squeeze() + if audio_embedding_sample_list.dim() == 1: + audio_embedding_sample_list = audio_embedding_sample_list.unsqueeze(0) + audio_label_list = torch.cat(audio_label_list).unsqueeze(1) + + total_limits = self.total_limits + if visual_embedding_sample_list.shape[0] > total_limits: + rand_index = torch.randperm(visual_embedding_sample_list.shape[0])[total_limits] + visual_embedding_sample_list = visual_embedding_sample_list[:rand_index] + visual_label_list = visual_label_list[:rand_index] + loss = self.info_nce( + visual_embedding_sample_list, + visual_label_list, + audio_embedding_sample_list, + audio_label_list, + ) + return loss + + def forward(self, embeddings, output_dicts, masks): + predictions = torch.cat([i["multistep_pred_masks"] for i in output_dicts]) + predictions = torch.nn.functional.interpolate( + predictions, + size=(int(self.param.image_size / 16), int(self.param.image_size / 16)), + mode="bilinear", + align_corners=False, + ).squeeze(1) + masks = torch.nn.functional.interpolate( + masks.unsqueeze(1), + size=(int(self.param.image_size / 16), int(self.param.image_size / 16)), + mode="nearest", + ).squeeze(1) + visual_embeddings, audio_embeddings = embeddings + visual_embeddings = torch.cat( + [ + torch.cat( + [ + visual_embeddings[0][i].unsqueeze(0), + visual_embeddings[1][i].unsqueeze(0), + visual_embeddings[2][i].unsqueeze(0), + ] + ).unsqueeze(0) + for i in range(masks.shape[0]) + ] + ) + audio_embeddings = torch.cat( + [ + torch.cat( + [ + audio_embeddings[0][i].unsqueeze(0), + audio_embeddings[1][i].unsqueeze(0), + audio_embeddings[2][i].unsqueeze(0), + ] + ).unsqueeze(0) + for i in range(masks.shape[0]) + ] + ) + return self.forward_audio_visual( + visual_embeddings, audio_embeddings.squeeze(-1), masks, predictions + ) + + @staticmethod + def manipulate_cover_mask(a_label, current_mask): + a_label = a_label + 1 + visual_mask = torch.matmul(a_label, torch.transpose(a_label, 0, 1)) + current_mask[: visual_mask.shape[1], : visual_mask.shape[0]][visual_mask == 1.0] = 0 + current_mask[: visual_mask.shape[1], : visual_mask.shape[0]][visual_mask == 4.0] = 0 + return current_mask + + def info_nce(self, anchors_, a_labels_, contras_, c_labels_): + c_labels_ = torch.cat([a_labels_, c_labels_]) + contras_ = torch.cat([anchors_, contras_]) + mask = torch.eq(a_labels_, torch.transpose(c_labels_, 0, 1)).float() + + anchor_dot_contrast = torch.div( + torch.matmul(anchors_, torch.transpose(contras_, 0, 1)), + self.temperature, + ) + + logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) + logits = anchor_dot_contrast - logits_max.detach() + neg_mask = 1 - mask + + mask = self.manipulate_cover_mask(a_label=a_labels_, current_mask=mask) + mask = mask.fill_diagonal_(0.0) + + neg_logits = torch.exp(logits) * neg_mask + neg_logits = neg_logits.sum(1, keepdim=True) + exp_logits = torch.exp(logits) + log_prob = logits - torch.log(exp_logits + neg_logits) + + mask_pos_pairs = mask.sum(1) + mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs) + mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs + assert not torch.isnan(mean_log_prob_pos).any(), print(torch.isnan(log_prob).any()) + return -mean_log_prob_pos.mean() + diff --git a/avs.code/v1s.code/loss/training/sam2_training_loss.py b/avs.code/v1s.code/loss/training/sam2_training_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..f6ce1b02c0dbbf5d7e771b314a4a537145e28978 --- /dev/null +++ b/avs.code/v1s.code/loss/training/sam2_training_loss.py @@ -0,0 +1,220 @@ +from collections import defaultdict +from typing import Dict, List + +import torch +import torch.distributed +import torch.nn as nn +import torch.nn.functional as F + +CORE_LOSS_KEY = "core_loss" + + +def dice_loss(inputs, targets, num_objects, loss_on_multimask=False): + inputs = inputs.sigmoid() + if loss_on_multimask: + assert inputs.dim() == 4 and targets.dim() == 4 + inputs = inputs.flatten(2) + targets = targets.flatten(2) + numerator = 2 * (inputs * targets).sum(-1) + else: + inputs = inputs.flatten(1) + numerator = 2 * (inputs * targets).sum(1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + if loss_on_multimask: + return loss / num_objects + return loss.sum() / num_objects + + +def sigmoid_focal_loss( + inputs, + targets, + num_objects, + alpha: float = 0.25, + gamma: float = 2, + loss_on_multimask=False, +): + prob = inputs.sigmoid() + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + if loss_on_multimask: + assert loss.dim() == 4 + return loss.flatten(2).mean(-1) / num_objects + return loss.mean(1).sum() / num_objects + + +def iou_loss( + inputs, targets, pred_ious, num_objects, loss_on_multimask=False, use_l1_loss=False +): + assert inputs.dim() == 4 and targets.dim() == 4 + pred_mask = inputs.flatten(2) > 0 + gt_mask = targets.flatten(2) > 0 + area_i = torch.sum(pred_mask & gt_mask, dim=-1).float() + area_u = torch.sum(pred_mask | gt_mask, dim=-1).float() + actual_ious = area_i / torch.clamp(area_u, min=1.0) + + if use_l1_loss: + loss = F.l1_loss(pred_ious, actual_ious, reduction="none") + else: + loss = F.mse_loss(pred_ious, actual_ious, reduction="none") + if loss_on_multimask: + return loss / num_objects + return loss.sum() / num_objects + + +class MultiStepMultiMasksAndIous(nn.Module): + def __init__( + self, + weight_dict, + focal_alpha=0.25, + focal_gamma=2, + supervise_all_iou=False, + iou_use_l1_loss=False, + pred_obj_scores=False, + focal_gamma_obj_score=0.0, + focal_alpha_obj_score=-1, + gpu_num=1, + ): + super().__init__() + self.weight_dict = weight_dict + self.focal_alpha = focal_alpha + self.focal_gamma = focal_gamma + self.world_size = gpu_num + assert "loss_mask" in self.weight_dict + assert "loss_dice" in self.weight_dict + assert "loss_iou" in self.weight_dict + if "loss_class" not in self.weight_dict: + self.weight_dict["loss_class"] = 0.0 + + self.focal_alpha_obj_score = focal_alpha_obj_score + self.focal_gamma_obj_score = focal_gamma_obj_score + self.supervise_all_iou = supervise_all_iou + self.iou_use_l1_loss = iou_use_l1_loss + self.pred_obj_scores = pred_obj_scores + + def forward(self, outs_batch: List[Dict], targets_batch: torch.Tensor): + assert len(outs_batch) == len(targets_batch) + num_objects = torch.tensor( + targets_batch.shape[1], device=targets_batch.device, dtype=torch.float + ) + torch.distributed.all_reduce(num_objects) + num_objects = torch.clamp(num_objects / self.world_size, min=1).item() + + losses = defaultdict(int) + for outs, targets in zip(outs_batch, targets_batch): + cur_losses = self._forward(outs, targets, num_objects) + for k, v in cur_losses.items(): + losses[k] += v + return losses + + def _forward(self, outputs: Dict, targets: torch.Tensor, num_objects): + target_masks = targets.unsqueeze(1).float() + assert target_masks.dim() == 4 + + src_masks_list = outputs["multistep_pred_multimasks_high_res"] + ious_list = outputs["multistep_pred_ious"] + object_score_logits_list = outputs["multistep_object_score_logits"] + assert len(src_masks_list) == len(ious_list) + assert len(object_score_logits_list) == len(ious_list) + + losses = {"loss_mask": 0, "loss_dice": 0, "loss_iou": 0, "loss_class": 0} + for src_masks, ious, object_score_logits in zip( + src_masks_list, ious_list, object_score_logits_list + ): + self._update_losses( + losses, src_masks, target_masks, ious, num_objects, object_score_logits + ) + losses[CORE_LOSS_KEY] = self.reduce_loss(losses) + return losses + + def _update_losses( + self, losses, src_masks, target_masks, ious, num_objects, object_score_logits + ): + target_masks = target_masks.expand_as(src_masks) + loss_multimask = sigmoid_focal_loss( + src_masks, + target_masks, + num_objects, + alpha=self.focal_alpha, + gamma=self.focal_gamma, + loss_on_multimask=True, + ) + loss_multidice = dice_loss( + src_masks, target_masks, num_objects, loss_on_multimask=True + ) + if not self.pred_obj_scores: + loss_class = torch.tensor( + 0.0, dtype=loss_multimask.dtype, device=loss_multimask.device + ) + target_obj = torch.ones( + loss_multimask.shape[0], + 1, + dtype=loss_multimask.dtype, + device=loss_multimask.device, + ) + else: + target_obj = torch.any((target_masks[:, 0] > 0).flatten(1), dim=-1)[ + ..., None + ].float() + loss_class = sigmoid_focal_loss( + object_score_logits, + target_obj, + num_objects, + alpha=self.focal_alpha_obj_score, + gamma=self.focal_gamma_obj_score, + ) + + loss_multiiou = iou_loss( + src_masks, + target_masks, + ious, + num_objects, + loss_on_multimask=True, + use_l1_loss=self.iou_use_l1_loss, + ) + assert loss_multimask.dim() == 2 + assert loss_multidice.dim() == 2 + assert loss_multiiou.dim() == 2 + if loss_multimask.size(1) > 1: + loss_combo = ( + loss_multimask * self.weight_dict["loss_mask"] + + loss_multidice * self.weight_dict["loss_dice"] + ) + best_loss_inds = torch.argmin(loss_combo, dim=-1) + batch_inds = torch.arange(loss_combo.size(0), device=loss_combo.device) + + loss_mask = loss_multimask[batch_inds, best_loss_inds].unsqueeze(1) + loss_dice = loss_multidice[batch_inds, best_loss_inds].unsqueeze(1) + if self.supervise_all_iou: + loss_iou = loss_multiiou.mean(dim=-1).unsqueeze(1) + else: + loss_iou = loss_multiiou[batch_inds, best_loss_inds].unsqueeze(1) + else: + loss_mask = loss_multimask + loss_dice = loss_multidice + loss_iou = loss_multiiou + + loss_mask = loss_mask * target_obj + loss_dice = loss_dice * target_obj + loss_iou = loss_iou * target_obj + + losses["loss_mask"] += loss_mask.sum() + losses["loss_dice"] += loss_dice.sum() + losses["loss_iou"] += loss_iou.sum() + losses["loss_class"] += loss_class + + def reduce_loss(self, losses): + reduced_loss = 0.0 + for loss_key, weight in self.weight_dict.items(): + if loss_key not in losses: + raise ValueError(f"{type(self)} doesn't compute {loss_key}") + if weight != 0: + reduced_loss += losses[loss_key] * weight + return reduced_loss + diff --git a/avs.code/v1s.code/main.py b/avs.code/v1s.code/main.py new file mode 100644 index 0000000000000000000000000000000000000000..bca501accb6ccfbc0e394e231899d3e2e7a40eb5 --- /dev/null +++ b/avs.code/v1s.code/main.py @@ -0,0 +1,166 @@ +"""DDP training entry: AV model with SAM2 frozen, AuralFuser trainable, Hydra transforms and loss.""" +import os +import torch +import numpy +import random +import argparse +from easydict import EasyDict + + +def seed_it(seed): + """Fix RNGs and cuDNN for reproducible runs (rank offsets seed in DDP).""" + os.environ["PYTHONSEED"] = str(seed) + random.seed(seed) + numpy.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.enabled = True + torch.backends.cudnn.deterministic = True + + torch.backends.cudnn.benchmark = False + + +def main(local_rank, ngpus_per_node, hyp_param): + hyp_param.local_rank = local_rank + # NCCL process group; world size = GPUs on this node + torch.distributed.init_process_group( + backend='nccl', + init_method='env://', + rank=hyp_param.local_rank, + world_size=hyp_param.gpus * 1 + ) + seed_it(local_rank + hyp_param.seed) + + torch.cuda.set_device(hyp_param.local_rank) + + import model.visual.sam2 # noqa: F401 — registers Hydra `configs` (initialize_config_module) + + from hydra import compose + from hydra.utils import instantiate + from omegaconf import OmegaConf + + # Hydra configs under v1m.code/configs (same pattern as training/sam2_training_config.yaml) + transform_config_path = 'training/sam2_training_config.yaml' + + if 'hiera_t' in hyp_param.sam_config_path: + hyp_param.image_size = 224 + hyp_param.image_embedding_size = int(hyp_param.image_size / 16) + print('\n upload image size to be {}x{} \n'.format(224, 224), flush=True) + + cfg = compose(config_name=transform_config_path) + OmegaConf.resolve(cfg) + hyp_param.contrastive_learning = OmegaConf.to_container(cfg.contrastive_learning, resolve=True) + + arch_h = compose(config_name='auralfuser/architecture.yaml') + OmegaConf.resolve(arch_h) + hyp_param.aural_fuser = OmegaConf.to_container(arch_h.aural_fuser, resolve=True) + + from model.mymodel import AVmodel + av_model = AVmodel(hyp_param).cuda(hyp_param.local_rank) + + av_model = torch.nn.parallel.distributed.DistributedDataParallel(av_model, device_ids=[hyp_param.local_rank], + find_unused_parameters=True) + + # Optimizer: parameter groups from AuralFuser only (train_* vs VGG backbone) + from utils.utils import manipulate_params + parameter_list = manipulate_params(hyp_param, av_model.module.aural_fuser) + optimiser = torch.optim.AdamW(parameter_list, betas=(0.9, 0.999)) + + from dataloader.dataset import AV + from dataloader.visual.visual_augmentation import Augmentation as VisualAugmentation + from dataloader.audio.audio_augmentation import Augmentation as AudioAugmentation + from torch.utils.data.distributed import DistributedSampler + + compose_api = instantiate(cfg.train_transforms, _recursive_=True)[0] + + audio_augmentation = AudioAugmentation(mono=True) + train_dataset = AV(split='train', augmentation={"visual": compose_api, "audio": audio_augmentation}, + param=hyp_param, root_path=hyp_param.data_root_path, data_name='v1s') + + + visual_augmentation = VisualAugmentation(hyp_param.image_mean, hyp_param.image_std, + hyp_param.image_size, hyp_param.image_size, + hyp_param.scale_list, ignore_index=hyp_param.ignore_index) + + audio_augmentation = AudioAugmentation(mono=True) + + random_sampler = DistributedSampler(train_dataset, shuffle=True) + train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=hyp_param.batch_size, + sampler=random_sampler, + num_workers=hyp_param.num_workers, drop_last=True) + + test_dataset = AV(split='test', augmentation={"visual": visual_augmentation, "audio": audio_augmentation}, + param=hyp_param, root_path=hyp_param.data_root_path, data_name='v1s') + + order_sampler = DistributedSampler(test_dataset, shuffle=False) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, sampler=order_sampler, + num_workers=hyp_param.num_workers) + + + criterion = instantiate(cfg.loss, _recursive_=True)['all'] + from utils.tensorboard import Tensorboard + tensorboard = Tensorboard(config=hyp_param) if hyp_param.local_rank <= 0 else None + + from trainer.train import Trainer + from utils.foreground_iou import ForegroundIoU + from utils.foreground_fscore import ForegroundFScore + metrics = {"foreground_iou": ForegroundIoU(), "foreground_f-score": ForegroundFScore(0 if hyp_param.local_rank <= 0 else hyp_param.local_rank)} + + trainer = Trainer(hyp_param, loss=criterion, tensorboard=tensorboard, metrics=metrics) + + + curr_best = 0. # checkpoint when IoU (iou_select mode) improves + + for epoch in range(hyp_param.epochs): + av_model.train() + av_model.module.freeze_sam_parameters() + random_sampler.set_epoch(epoch) + trainer.train(epoch=epoch, dataloader=train_dataloader, model=av_model, optimiser=optimiser) + + torch.distributed.barrier() + torch.cuda.empty_cache() + + av_model.eval() + # Three validation modes: default first mask / IoU-selected mask / IoU + objectness gate + curr_results1, _ = trainer.valid(epoch=epoch, dataloader=test_dataloader, model=av_model, process='first_index') + curr_results, _ = trainer.valid(epoch=epoch, dataloader=test_dataloader, model=av_model, process='iou_select') + curr_results3, _ = trainer.valid(epoch=epoch, dataloader=test_dataloader, model=av_model, process='iou_occ_select') + if hyp_param.local_rank <= 0 and curr_results > curr_best: + curr_best = curr_results + torch.save(av_model.module.aural_fuser.state_dict(), os.path.join(hyp_param.saved_dir, str(curr_results) + ".pth")) + torch.distributed.barrier() + torch.cuda.empty_cache() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='PyTorch Training') + parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N') + + parser.add_argument("--local_rank", type=int, default=-1, + help='multi-process training for DDP') + + parser.add_argument('-g', '--gpus', default=1, type=int, + help='number of gpus per node') + + parser.add_argument('--batch_size', default=1, type=int) + + parser.add_argument('--epochs', default=80, type=int, + help="total epochs that used for the training") + + parser.add_argument('--lr', default=1e-4, type=float, + help='Default HEAD Learning rate is same as others, ' + '*Note: in ddp training, lr will automatically times by n_gpu') + + parser.add_argument('--online', action="store_true", + help='switch on for visualization; switch off for debug') + + args = parser.parse_args() + + from configs.config import C + + args = EasyDict({**C, **vars(args)}) + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '9902' + + torch.multiprocessing.spawn(main, nprocs=args.gpus, args=(args.gpus, args)) diff --git a/avs.code/v1s.code/model/audio/torchvggish/mel_features.py b/avs.code/v1s.code/model/audio/torchvggish/mel_features.py new file mode 100644 index 0000000000000000000000000000000000000000..ac58fb5427f772fcced9cbd3cec3373ffbe5908c --- /dev/null +++ b/avs.code/v1s.code/model/audio/torchvggish/mel_features.py @@ -0,0 +1,223 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Defines routines to compute mel spectrogram features from audio waveform.""" + +import numpy as np + + +def frame(data, window_length, hop_length): + """Convert array into a sequence of successive possibly overlapping frames. + + An n-dimensional array of shape (num_samples, ...) is converted into an + (n+1)-D array of shape (num_frames, window_length, ...), where each frame + starts hop_length points after the preceding one. + + This is accomplished using stride_tricks, so the original data is not + copied. However, there is no zero-padding, so any incomplete frames at the + end are not included. + + Args: + data: np.array of dimension N >= 1. + window_length: Number of samples in each frame. + hop_length: Advance (in samples) between each window. + + Returns: + (N+1)-D np.array with as many rows as there are complete frames that can be + extracted. + """ + num_samples = data.shape[0] + num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length)) + shape = (num_frames, window_length) + data.shape[1:] + strides = (data.strides[0] * hop_length,) + data.strides + return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides) + + +def periodic_hann(window_length): + """Calculate a "periodic" Hann window. + + The classic Hann window is defined as a raised cosine that starts and + ends on zero, and where every value appears twice, except the middle + point for an odd-length window. Matlab calls this a "symmetric" window + and np.hanning() returns it. However, for Fourier analysis, this + actually represents just over one cycle of a period N-1 cosine, and + thus is not compactly expressed on a length-N Fourier basis. Instead, + it's better to use a raised cosine that ends just before the final + zero value - i.e. a complete cycle of a period-N cosine. Matlab + calls this a "periodic" window. This routine calculates it. + + Args: + window_length: The number of points in the returned window. + + Returns: + A 1D np.array containing the periodic hann window. + """ + return 0.5 - (0.5 * np.cos(2 * np.pi / window_length * + np.arange(window_length))) + + +def stft_magnitude(signal, fft_length, + hop_length=None, + window_length=None): + """Calculate the short-time Fourier transform magnitude. + + Args: + signal: 1D np.array of the input time-domain signal. + fft_length: Size of the FFT to apply. + hop_length: Advance (in samples) between each frame passed to FFT. + window_length: Length of each block of samples to pass to FFT. + + Returns: + 2D np.array where each row contains the magnitudes of the fft_length/2+1 + unique values of the FFT for the corresponding frame of input samples. + """ + frames = frame(signal, window_length, hop_length) + # Apply frame window to each frame. We use a periodic Hann (cosine of period + # window_length) instead of the symmetric Hann of np.hanning (period + # window_length-1). + window = periodic_hann(window_length) + windowed_frames = frames * window + return np.abs(np.fft.rfft(windowed_frames, int(fft_length))) + + +# Mel spectrum constants and functions. +_MEL_BREAK_FREQUENCY_HERTZ = 700.0 +_MEL_HIGH_FREQUENCY_Q = 1127.0 + + +def hertz_to_mel(frequencies_hertz): + """Convert frequencies to mel scale using HTK formula. + + Args: + frequencies_hertz: Scalar or np.array of frequencies in hertz. + + Returns: + Object of same size as frequencies_hertz containing corresponding values + on the mel scale. + """ + return _MEL_HIGH_FREQUENCY_Q * np.log( + 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ)) + + +def spectrogram_to_mel_matrix(num_mel_bins=20, + num_spectrogram_bins=129, + audio_sample_rate=8000, + lower_edge_hertz=125.0, + upper_edge_hertz=3800.0): + """Return a matrix that can post-multiply spectrogram rows to make mel. + + Returns a np.array matrix A that can be used to post-multiply a matrix S of + spectrogram values (STFT magnitudes) arranged as frames x bins to generate a + "mel spectrogram" M of frames x num_mel_bins. M = S A. + + The classic HTK algorithm exploits the complementarity of adjacent mel bands + to multiply each FFT bin by only one mel weight, then add it, with positive + and negative signs, to the two adjacent mel bands to which that bin + contributes. Here, by expressing this operation as a matrix multiply, we go + from num_fft multiplies per frame (plus around 2*num_fft adds) to around + num_fft^2 multiplies and adds. However, because these are all presumably + accomplished in a single call to np.dot(), it's not clear which approach is + faster in Python. The matrix multiplication has the attraction of being more + general and flexible, and much easier to read. + + Args: + num_mel_bins: How many bands in the resulting mel spectrum. This is + the number of columns in the output matrix. + num_spectrogram_bins: How many bins there are in the source spectrogram + data, which is understood to be fft_size/2 + 1, i.e. the spectrogram + only contains the nonredundant FFT bins. + audio_sample_rate: Samples per second of the audio at the input to the + spectrogram. We need this to figure out the actual frequencies for + each spectrogram bin, which dictates how they are mapped into mel. + lower_edge_hertz: Lower bound on the frequencies to be included in the mel + spectrum. This corresponds to the lower edge of the lowest triangular + band. + upper_edge_hertz: The desired top edge of the highest frequency band. + + Returns: + An np.array with shape (num_spectrogram_bins, num_mel_bins). + + Raises: + ValueError: if frequency edges are incorrectly ordered or out of range. + """ + nyquist_hertz = audio_sample_rate / 2. + if lower_edge_hertz < 0.0: + raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz) + if lower_edge_hertz >= upper_edge_hertz: + raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" % + (lower_edge_hertz, upper_edge_hertz)) + if upper_edge_hertz > nyquist_hertz: + raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" % + (upper_edge_hertz, nyquist_hertz)) + spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins) + spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz) + # The i'th mel band (starting from i=1) has center frequency + # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge + # band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in + # the band_edges_mel arrays. + band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz), + hertz_to_mel(upper_edge_hertz), num_mel_bins + 2) + # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins + # of spectrogram values. + mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins)) + for i in range(num_mel_bins): + lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3] + # Calculate lower and upper slopes for every spectrogram bin. + # Line segments are linear in the *mel* domain, not hertz. + lower_slope = ((spectrogram_bins_mel - lower_edge_mel) / + (center_mel - lower_edge_mel)) + upper_slope = ((upper_edge_mel - spectrogram_bins_mel) / + (upper_edge_mel - center_mel)) + # .. then intersect them with each other and zero. + mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope, + upper_slope)) + # HTK excludes the spectrogram DC bin; make sure it always gets a zero + # coefficient. + mel_weights_matrix[0, :] = 0.0 + return mel_weights_matrix + + +def log_mel_spectrogram(data, + audio_sample_rate=8000, + log_offset=0.0, + window_length_secs=0.025, + hop_length_secs=0.010, + **kwargs): + """Convert waveform to a log magnitude mel-frequency spectrogram. + + Args: + data: 1D np.array of waveform data. + audio_sample_rate: The sampling rate of data. + log_offset: Add this to values when taking log to avoid -Infs. + window_length_secs: Duration of each window to analyze. + hop_length_secs: Advance between successive analysis windows. + **kwargs: Additional arguments to pass to spectrogram_to_mel_matrix. + + Returns: + 2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank + magnitudes for successive frames. + """ + window_length_samples = int(round(audio_sample_rate * window_length_secs)) + hop_length_samples = int(round(audio_sample_rate * hop_length_secs)) + fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0))) + spectrogram = stft_magnitude( + data, + fft_length=fft_length, + hop_length=hop_length_samples, + window_length=window_length_samples) + mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix( + num_spectrogram_bins=spectrogram.shape[1], + audio_sample_rate=audio_sample_rate, **kwargs)) + return np.log(mel_spectrogram + log_offset) diff --git a/avs.code/v1s.code/model/audio/torchvggish/vggish.py b/avs.code/v1s.code/model/audio/torchvggish/vggish.py new file mode 100644 index 0000000000000000000000000000000000000000..f01c22867c713bfd8713eee5665120b92602761d --- /dev/null +++ b/avs.code/v1s.code/model/audio/torchvggish/vggish.py @@ -0,0 +1,193 @@ +import numpy as np +import torch +import torch.nn as nn +from torch import hub + +from . import vggish_input, vggish_params + + +class VGG(nn.Module): + def __init__(self, features): + super(VGG, self).__init__() + self.features = features + self.embeddings = nn.Sequential( + nn.Linear(512 * 4 * 6, 4096), + nn.ReLU(True), + nn.Linear(4096, 4096), + nn.ReLU(True), + nn.Linear(4096, 128), + nn.ReLU(True)) + + def forward(self, x): + x = self.features(x) + + # Transpose the output from features to + # remain compatible with vggish embeddings + x = torch.transpose(x, 1, 3) + x = torch.transpose(x, 1, 2) + x = x.contiguous() + x = x.view(x.size(0), -1) + + return self.embeddings(x) + + +class Postprocessor(nn.Module): + """Post-processes VGGish embeddings. Returns a torch.Tensor instead of a + numpy array in order to preserve the gradient. + + "The initial release of AudioSet included 128-D VGGish embeddings for each + segment of AudioSet. These released embeddings were produced by applying + a PCA transformation (technically, a whitening transform is included as well) + and 8-bit quantization to the raw embedding output from VGGish, in order to + stay compatible with the YouTube-8M project which provides visual embeddings + in the same format for a large set of YouTube videos. This class implements + the same PCA (with whitening) and quantization transformations." + """ + + def __init__(self): + """Constructs a postprocessor.""" + super(Postprocessor, self).__init__() + # Create empty matrix, for user's state_dict to load + self.pca_eigen_vectors = torch.empty( + (vggish_params.EMBEDDING_SIZE, vggish_params.EMBEDDING_SIZE,), + dtype=torch.float, + ) + self.pca_means = torch.empty( + (vggish_params.EMBEDDING_SIZE, 1), dtype=torch.float + ) + + self.pca_eigen_vectors = nn.Parameter(self.pca_eigen_vectors, requires_grad=False) + self.pca_means = nn.Parameter(self.pca_means, requires_grad=False) + + def postprocess(self, embeddings_batch): + """Applies tensor postprocessing to a batch of embeddings. + + Args: + embeddings_batch: An tensor of shape [batch_size, embedding_size] + containing output from the embedding layer of VGGish. + + Returns: + A tensor of the same shape as the input, containing the PCA-transformed, + quantized, and clipped version of the input. + """ + assert len(embeddings_batch.shape) == 2, "Expected 2-d batch, got %r" % ( + embeddings_batch.shape, + ) + assert ( + embeddings_batch.shape[1] == vggish_params.EMBEDDING_SIZE + ), "Bad batch shape: %r" % (embeddings_batch.shape,) + + # Apply PCA. + # - Embeddings come in as [batch_size, embedding_size]. + # - Transpose to [embedding_size, batch_size]. + # - Subtract pca_means column vector from each column. + # - Premultiply by PCA matrix of shape [output_dims, input_dims] + # where both are are equal to embedding_size in our case. + # - Transpose result back to [batch_size, embedding_size]. + pca_applied = torch.mm(self.pca_eigen_vectors, (embeddings_batch.t() - self.pca_means)).t() + + # Quantize by: + # - clipping to [min, max] range + clipped_embeddings = torch.clamp( + pca_applied, vggish_params.QUANTIZE_MIN_VAL, vggish_params.QUANTIZE_MAX_VAL + ) + # - convert to 8-bit in range [0.0, 255.0] + quantized_embeddings = torch.round( + (clipped_embeddings - vggish_params.QUANTIZE_MIN_VAL) + * ( + 255.0 + / (vggish_params.QUANTIZE_MAX_VAL - vggish_params.QUANTIZE_MIN_VAL) + ) + ) + return torch.squeeze(quantized_embeddings) + + def forward(self, x): + return self.postprocess(x) + + +def make_layers(): + layers = [] + in_channels = 1 + for v in [64, "M", 128, "M", 256, 256, "M", 512, 512, "M"]: + if v == "M": + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = v + return nn.Sequential(*layers) + + +def _vgg(): + return VGG(make_layers()) + + +# def _spectrogram(): +# config = dict( +# sr=16000, +# n_fft=400, +# n_mels=64, +# hop_length=160, +# window="hann", +# center=False, +# pad_mode="reflect", +# htk=True, +# fmin=125, +# fmax=7500, +# output_format='Magnitude', +# # device=device, +# ) +# return Spectrogram.MelSpectrogram(**config) + + +class VGGish(VGG): + def __init__(self, cfg, device=None): + super().__init__(make_layers()) + if cfg.FREEZE_AUDIO_EXTRACTOR: + state_dict = torch.load(cfg.PRETRAINED_VGGISH_MODEL_PATH) + super().load_state_dict(state_dict) + print(f'==> Load pretrained VGGish parameters from {cfg.PRETRAINED_VGGISH_MODEL_PATH}') + + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print("device: ", device) + self.device = device + + self.preprocess = cfg.PREPROCESS_AUDIO_TO_LOG_MEL + self.postprocess = cfg.POSTPROCESS_LOG_MEL_WITH_PCA + if self.postprocess: + self.pproc = Postprocessor() + if cfg.FREEZE_AUDIO_EXTRACTOR: + state_dict = torch.load(cfg.PRETRAINED_PCA_PARAMS_PATH) + # TODO: Convert the state_dict to torch + state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME] = torch.as_tensor( + state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME], dtype=torch.float + ) + state_dict[vggish_params.PCA_MEANS_NAME] = torch.as_tensor( + state_dict[vggish_params.PCA_MEANS_NAME].reshape(-1, 1), dtype=torch.float + ) + self.pproc.load_state_dict(state_dict) + self.to(self.device) + + def forward(self, x): + if self.preprocess: + print(">>> pre processing...") + x = self._preprocess(x) + x = x.to(self.device) + x = VGG.forward(self, x) + if self.postprocess: + print(">>> post processing...") + x = self._postprocess(x) + return x + + def _preprocess(self, x): + # if isinstance(x, np.ndarray): + # x = vggish_input.waveform_to_examples(x, fs) + if isinstance(x, str): + x = vggish_input.wavfile_to_examples(x) + else: + raise AttributeError + return x + + def _postprocess(self, x): + return self.pproc(x) diff --git a/avs.code/v1s.code/model/audio/torchvggish/vggish_input.py b/avs.code/v1s.code/model/audio/torchvggish/vggish_input.py new file mode 100644 index 0000000000000000000000000000000000000000..ede228b1fb630180f1f49244355d373fb3300f03 --- /dev/null +++ b/avs.code/v1s.code/model/audio/torchvggish/vggish_input.py @@ -0,0 +1,98 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Compute input examples for VGGish from audio waveform.""" + +# Modification: Return torch tensors rather than numpy arrays +import torch + +import numpy as np +import resampy + +from . import mel_features +from . import vggish_params + +import soundfile as sf + + +def waveform_to_examples(data, sample_rate, return_tensor=True): + """Converts audio waveform into an array of examples for VGGish. + + Args: + data: np.array of either one dimension (mono) or two dimensions + (multi-channel, with the outer dimension representing channels). + Each sample is generally expected to lie in the range [-1.0, +1.0], + although this is not required. + sample_rate: Sample rate of data. + return_tensor: Return data as a Pytorch tensor ready for VGGish + + Returns: + 3-D np.array of shape [num_examples, num_frames, num_bands] which represents + a sequence of examples, each of which contains a patch of log mel + spectrogram, covering num_frames frames of audio and num_bands mel frequency + bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS. + + """ + # Convert to mono. + if len(data.shape) > 1: + data = np.mean(data, axis=1) + # Resample to the rate assumed by VGGish. + if sample_rate != vggish_params.SAMPLE_RATE: + data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE) + + # Compute log mel spectrogram features. + log_mel = mel_features.log_mel_spectrogram( + data, + audio_sample_rate=vggish_params.SAMPLE_RATE, + log_offset=vggish_params.LOG_OFFSET, + window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS, + hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS, + num_mel_bins=vggish_params.NUM_MEL_BINS, + lower_edge_hertz=vggish_params.MEL_MIN_HZ, + upper_edge_hertz=vggish_params.MEL_MAX_HZ) + + # Frame features into examples. + features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS + example_window_length = int(round( + vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate)) + example_hop_length = int(round( + vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate)) + log_mel_examples = mel_features.frame( + log_mel, + window_length=example_window_length, + hop_length=example_hop_length) + + if return_tensor: + log_mel_examples = torch.tensor( + log_mel_examples, requires_grad=True)[:, None, :, :].float() + + return log_mel_examples + + +def wavfile_to_examples(wav_file, return_tensor=True): + """Convenience wrapper around waveform_to_examples() for a common WAV format. + + Args: + wav_file: String path to a file, or a file-like object. The file + is assumed to contain WAV audio data with signed 16-bit PCM samples. + torch: Return data as a Pytorch tensor ready for VGGish + + Returns: + See waveform_to_examples. + """ + wav_data, sr = sf.read(wav_file, dtype='int16') + assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype + samples = wav_data / 32768.0 # Convert to [-1.0, +1.0] + return waveform_to_examples(samples, sr, return_tensor) diff --git a/avs.code/v1s.code/model/audio/torchvggish/vggish_params.py b/avs.code/v1s.code/model/audio/torchvggish/vggish_params.py new file mode 100644 index 0000000000000000000000000000000000000000..526784bceaa4c9c8b8dc2b8f82e0f3d395d4bec2 --- /dev/null +++ b/avs.code/v1s.code/model/audio/torchvggish/vggish_params.py @@ -0,0 +1,53 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Global parameters for the VGGish model. + +See vggish_slim.py for more information. +""" + +# Architectural constants. +NUM_FRAMES = 96 # Frames in input mel-spectrogram patch. +NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch. +EMBEDDING_SIZE = 128 # Size of embedding layer. + +# Hyperparameters used in feature and example generation. +SAMPLE_RATE = 16000 +STFT_WINDOW_LENGTH_SECONDS = 0.025 +STFT_HOP_LENGTH_SECONDS = 0.010 +NUM_MEL_BINS = NUM_BANDS +MEL_MIN_HZ = 125 +MEL_MAX_HZ = 7500 +LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram. +EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames +EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap. + +# Parameters used for embedding postprocessing. +PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors' +PCA_MEANS_NAME = 'pca_means' +QUANTIZE_MIN_VAL = -2.0 +QUANTIZE_MAX_VAL = +2.0 + +# Hyperparameters used in training. +INIT_STDDEV = 0.01 # Standard deviation used to initialize weights. +LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer. +ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer. + +# Names of ops, tensors, and features. +INPUT_OP_NAME = 'vggish/input_features' +INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0' +OUTPUT_OP_NAME = 'vggish/embedding' +OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0' +AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding' diff --git a/avs.code/v1s.code/model/aural_fuser.py b/avs.code/v1s.code/model/aural_fuser.py new file mode 100644 index 0000000000000000000000000000000000000000..924810bfcf8bee5e285cab7d54e477daf254b85a --- /dev/null +++ b/avs.code/v1s.code/model/aural_fuser.py @@ -0,0 +1,567 @@ +import math + +import torch +import torch.nn as nn +from model.audio.torchvggish import vggish +from timm.models.layers import DropPath, trunc_normal_ + +from model.visual.sam2.modeling.position_encoding import PositionEmbeddingSine + + +class ProjectionHead(nn.Module): + def __init__(self, dim_in, proj_dim=256, norm_act=nn.BatchNorm2d, conv_layer=nn.Conv2d): + super().__init__() + self.proj = nn.Sequential( + conv_layer(dim_in, proj_dim, kernel_size=1), + norm_act(proj_dim), + conv_layer(proj_dim, proj_dim, kernel_size=1), + ) + + def forward(self, x): + return torch.nn.functional.normalize(self.proj(x), p=2, dim=1) + +class AuralFuser(torch.nn.Module): + """Fuses VGGish audio with SAM2 FPN maps via patch embeds, fusion blocks, and projection heads.""" + + def __init__(self, hyp_param): + self.hyp_param = hyp_param + super().__init__() + self.vgg = vggish.VGGish(self.hyp_param.audio) + if not getattr(self.hyp_param, "train_vggish", False): + for p in self.vgg.parameters(): + p.requires_grad = False + + self.position_encoding_func = PositionEmbeddingSine(num_pos_feats=256, normalize=True, scale=None, + temperature=10000) + + # Populated in main.py / inference.py via Hydra compose('auralfuser/architecture.yaml') → hyp_param.aural_fuser + if not hasattr(self.hyp_param, "aural_fuser") or self.hyp_param.aural_fuser is None: + raise ValueError( + "hyp_param.aural_fuser is missing; load it with Hydra compose before constructing AuralFuser." + ) + arch_cfg = self.hyp_param.aural_fuser + + _patch_cfgs = [tuple(i) for i in arch_cfg["patch_cfgs"]] + _f_depths = arch_cfg["f_depths"] + _block_kw = dict(arch_cfg["block_kw"]) + _block_kw["norm_layer"] = nn.LayerNorm + _one_d_kw = dict(arch_cfg["one_d_kw"]) + _one_d_kw["norm_layer"] = nn.LayerNorm + self.patch_embeds = nn.ModuleList( + nn.Conv2d(256, 256, kernel_size=k, stride=s) for k, s in _patch_cfgs + ) + + self.f_blocks = nn.ModuleList( + nn.ModuleList([Block(**_block_kw) for _ in range(n)]) for n in _f_depths + ) + + self.a_blocks = nn.ModuleList( + nn.ModuleList([OneDBlock(**_one_d_kw) for _ in range(3)]) for _ in range(3) + ) + + self.fusion_modules = nn.ModuleList( + AudioVisualFusionModule(in_channels=256, mode='dot') for _ in range(3) + ) + self.smooth_convs = nn.ModuleList( + nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0) for _ in range(2) + ) + + self.train_proj_v1 = ProjectionHead(dim_in=256, proj_dim=128) + + self.train_proj_a1 = ProjectionHead(dim_in=256, norm_act=nn.BatchNorm1d, conv_layer=nn.Conv1d, proj_dim=128) + + @staticmethod + def positionalencoding1d(d_model, length): + if d_model % 2 != 0: + raise ValueError("Cannot use sin/cos positional encoding with " + "odd dim (got dim={:d})".format(d_model)) + pe = torch.zeros(length, d_model) + position = torch.arange(0, length).unsqueeze(1) + div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) * + -(math.log(10000.0) / d_model))) + pe[:, 0::2] = torch.sin(position.float() * div_term) + pe[:, 1::2] = torch.cos(position.float() * div_term) + + return pe + + def forward(self, feature_dicts, spect=None): + image_embed_shape = [self.hyp_param.image_embedding_size] * 2 + H, W = image_embed_shape[0], image_embed_shape[1] + d = torch.cat( + [ + self.vgg(spect[:, 0, ...].unsqueeze(1)), + self.vgg(spect[:, 1, ...].unsqueeze(1)), + ], + dim=-1, + ) + length = d.shape[-1] + fix_audio_pos = self.positionalencoding1d(length, 1).squeeze().to(spect.device) + fpn = list(feature_dicts["backbone_fpn"]) + patch_embeds = list(self.patch_embeds) + f_blocks = list(self.f_blocks) + a_blocks = list(self.a_blocks) + tpavi = list(self.fusion_modules) + smooths = [None, self.smooth_convs[0], self.smooth_convs[1]] + + feats = [None, None, None] + d_outputs = [] + + for i in range(3): + x = fpn[i] + x = patch_embeds[i](x) + x_pos = self.position_encoding_func(x) + x = x.flatten(2).permute(0, 2, 1) + x_pos = x_pos.flatten(2).permute(0, 2, 1) + + if i == 0: + x = x + x_pos + d = d + fix_audio_pos + else: + x = x + feats[i - 1] + x = smooths[i]( + x.permute(0, 2, 1).reshape(x.shape[0], 256, H, W) + ).flatten(2).permute(0, 2, 1) + x = x + x_pos + d = d + fix_audio_pos + + for blks in f_blocks[i]: + x = blks(x, H, W, x_pos) + for blks in a_blocks[i]: + d = blks(d, fix_audio_pos) + + x = x + x_pos + d = d + fix_audio_pos + x, d_out, _, _ = tpavi[i](x, H, W, x_pos, d, length) + d = d_out + feats[i] = x + d_outputs.append(d_out) + + a, b, c = feats + d1, d2, d3 = d_outputs + + feature_residual = [a, b, c] + audio_out = [d1, d2, d3] + + proj_feature_out = [ + [ + self.train_proj_v1(a.permute(0, 2, 1).reshape(-1, 256, *image_embed_shape)), + self.train_proj_v1(b.permute(0, 2, 1).reshape(-1, 256, *image_embed_shape)), + self.train_proj_v1(c.permute(0, 2, 1).reshape(-1, 256, *image_embed_shape)), + ], + [ + self.train_proj_a1(d1.unsqueeze(-1)), + self.train_proj_a1(d2.unsqueeze(-1)), + self.train_proj_a1(d3.unsqueeze(-1)), + ], + ] + + return feature_residual, audio_out, proj_feature_out + + +class AudioVisualFusionModule(nn.Module): + def __init__(self, in_channels, inter_channels=None, mode='dot', + dimension=3): + super().__init__() + assert mode == 'dot' + self.mode = mode + self.dimension = dimension + + self.in_channels = in_channels + self.inter_channels = in_channels // 2 + + self.align_channel = nn.Conv1d(256, in_channels, kernel_size=1) + self.align_channel_back = nn.Conv1d(in_channels, 128, kernel_size=1) + + self.norm_layer = nn.LayerNorm(in_channels) + + if dimension == 3: + conv_nd = nn.Conv3d + bn = nn.BatchNorm3d + elif dimension == 2: + conv_nd = nn.Conv2d + bn = nn.BatchNorm2d + else: + conv_nd = nn.Conv1d + bn = nn.BatchNorm1d + + self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) + + self.W_z = nn.Sequential( + conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1), + bn(self.in_channels) + ) + nn.init.constant_(self.W_z[1].weight, 0) + nn.init.constant_(self.W_z[1].bias, 0) + + self.W_z2 = nn.Sequential( + nn.Conv1d(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1), + nn.BatchNorm1d(self.in_channels) + ) + nn.init.constant_(self.W_z2[1].weight, 0) + nn.init.constant_(self.W_z2[1].bias, 0) + self.norm_layer2 = nn.LayerNorm(self.in_channels) + + self.q_frame = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) + self.k_frame = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) + self.v_frame = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) + + self.q_audio = nn.Conv1d(self.in_channels, self.inter_channels, kernel_size=1) + self.k_audio = nn.Conv1d(self.in_channels, self.inter_channels, kernel_size=1) + self.v_audio = nn.Conv1d(self.in_channels, self.inter_channels, kernel_size=1) + + def forward(self, frame, H_x, W_x, tmp1, audio, tmp2): + frame = frame.permute(0, 2, 1) + frame = frame.reshape(frame.shape[0], frame.shape[1], H_x, W_x) + frame = frame.unsqueeze(2) + audio = self.align_channel(audio.unsqueeze(-1)) + + batch_size, _ = frame.size(0), frame.size(1) + q_frame = self.q_frame(frame).reshape(1, -1, self.inter_channels) + k_frame = self.k_frame(frame).reshape(1, -1, self.inter_channels) + v_frame = self.v_frame(frame).reshape(1, -1, self.inter_channels) + q_audio = self.q_audio(audio).reshape(1, -1, self.inter_channels) + k_audio = self.k_audio(audio).reshape(1, -1, self.inter_channels) + v_audio = self.v_audio(audio).reshape(1, -1, self.inter_channels) + f = torch.matmul(q_frame, k_audio.mT) + f_normalise = f / f.size(1) + + frame_attn = torch.matmul(f_normalise, v_audio) + + frame_attn = frame_attn.permute(0, 2, 1).contiguous() + frame_attn = frame_attn.view(batch_size, self.inter_channels, *frame.size()[2:]) + frame_attn = self.W_z(frame_attn) + frame = frame_attn + frame + + frame = frame.permute(0, 2, 3, 4, 1) + frame = self.norm_layer(frame) + frame = frame.permute(0, 4, 1, 2, 3) + frame = frame.squeeze().flatten(start_dim=2).permute(0, 2, 1) + + a = torch.matmul(q_audio, k_frame.mT) + a_normalise = a / a.size(-1) + + audio_attn = torch.matmul(a_normalise, v_frame) + audio_attn = audio_attn.permute(0, 2, 1).contiguous() + + audio_attn = audio_attn.view(batch_size, self.inter_channels).unsqueeze(-1) + audio_attn = self.W_z2(audio_attn) + + audio = audio_attn + audio + + audio = self.norm_layer2(audio.squeeze()).squeeze() + + return frame, audio, frame_attn, audio_attn + + +class OneDBlock(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = OneDAttention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = OneDMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, + linear=linear) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, _pos): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class OneDAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, + linear=False): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.linear = linear + self.sr_ratio = sr_ratio + if not linear: + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.pool = nn.AdaptiveAvgPool2d(7) + self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1) + self.norm = nn.LayerNorm(dim) + self.act = nn.GELU() + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = x.unsqueeze(0) + + B, N, C = x.shape + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + k, v = kv[0], kv[1] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + x = x.squeeze() + return x + + +class OneDMlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + self.linear = linear + + if self.linear: + self.relu = nn.ReLU(inplace=True) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = self.fc1(x) + if self.linear: + x = self.relu(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, linear=linear) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W, _pos): + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) + + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, + linear=False): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.linear = linear + self.sr_ratio = sr_ratio + if not linear: + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.pool = nn.AdaptiveAvgPool2d(7) + self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1) + self.norm = nn.LayerNorm(dim) + self.act = nn.GELU() + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + B, N, C = x.shape + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + if not self.linear: + if self.sr_ratio > 1: + x_ = x.permute(0, 2, 1).reshape(B, C, H, W) + x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) + x_ = self.norm(x_) + kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + else: + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + else: + x_ = x.permute(0, 2, 1).reshape(B, C, H, W) + x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1) + x_ = self.norm(x_) + x_ = self.act(x_) + kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + self.linear = linear + + if self.linear: + self.relu = nn.ReLU(inplace=True) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + if self.linear: + x = self.relu(x) + x = self.dwconv(x, H, W) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2) + return x diff --git a/avs.code/v1s.code/model/mymodel.py b/avs.code/v1s.code/model/mymodel.py new file mode 100644 index 0000000000000000000000000000000000000000..35194cd584a4786f713447829592b15c7a366095 --- /dev/null +++ b/avs.code/v1s.code/model/mymodel.py @@ -0,0 +1,102 @@ +import logging + +from typing import List, Optional, Tuple, Union + +import numpy +import numpy as np +import torch +from PIL.Image import Image + +from model.visual.sam2.modeling.sam2_base import SAM2Base + +from model.visual.sam2.modeling.backbones.hieradet import Hiera +from model.visual.sam2.modeling.backbones.image_encoder import FpnNeck +from model.visual.sam2.modeling.backbones.image_encoder import ImageEncoder +from model.visual.sam2.modeling.position_encoding import PositionEmbeddingSine + +from model.visual.sam2.modeling.memory_attention import MemoryAttention +from model.visual.sam2.modeling.memory_attention import MemoryAttentionLayer +from model.visual.sam2.modeling.sam.transformer import RoPEAttention +from model.visual.sam2.modeling.memory_encoder import MemoryEncoder +from model.visual.sam2.modeling.memory_encoder import MaskDownSampler +from model.visual.sam2.modeling.memory_encoder import Fuser +from model.visual.sam2.modeling.memory_encoder import CXBlock + +from model.visual.sam2.utils.transforms import SAM2Transforms +from model.visual.sam2.modeling.backbones.hieradet import do_pool +from model.visual.sam2.modeling.backbones.utils import ( + PatchEmbed, + window_partition, + window_unpartition, +) + + +class AVmodel(torch.nn.Module): + """End-to-end AV segmentation: SAM2 visual backbone + AuralFuser audio-visual fusion + tracking head.""" + + def __init__(self, param, mask_threshold=0.0, max_hole_area=0.0, max_sprinkle_area=0.0, ): + super().__init__() + self.param = param + self.mask_threshold = mask_threshold + self._bb_feat_sizes = [(int(self.param.image_size / 4), int(self.param.image_size / 4)), + (int(self.param.image_size / 8), int(self.param.image_size / 8)), + (int(self.param.image_size / 16), int(self.param.image_size / 16))] + + from model.visual.sam2.build_sam import build_sam2_visual_predictor + self.v_model = build_sam2_visual_predictor(self.param.sam_config_path, self.param.backbone_weight, + apply_postprocessing=True, mode='train') + self._transforms = SAM2Transforms( + resolution=self.v_model.image_size, + mask_threshold=mask_threshold, + max_hole_area=max_hole_area, + max_sprinkle_area=max_sprinkle_area, + ) + from model.aural_fuser import AuralFuser + self.aural_fuser = AuralFuser(hyp_param=self.param) + + + + def _prepare_backbone_features(self, backbone_out): + """Prepare and flatten visual features.""" + backbone_out = backbone_out.copy() + assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) + assert len(backbone_out["backbone_fpn"]) >= self.v_model.num_feature_levels + + feature_maps = backbone_out["backbone_fpn"][-self.v_model.num_feature_levels:] + vision_pos_embeds = backbone_out["vision_pos_enc"][-self.v_model.num_feature_levels:] + + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] + vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] + vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] + + return backbone_out, vision_feats, vision_pos_embeds, feat_sizes + + def forward_frame(self, frame_): + frame = torch.nn.functional.interpolate(frame_, (self.param.image_size, self.param.image_size), + antialias=True, align_corners=False, mode='bilinear') + return self.v_model.image_encoder(frame) + + def forward(self, frames, spect, prompts, sam_process=False): + """Fuse audio into FPN features, then run SAM2 tracking. `sam_process` is reserved for prompt path.""" + backbone_feats = self.v_model.forward_image(frames, pre_compute=False) + audio_residual_feats = self.aural_fuser(backbone_feats, spect) + visual_resfeats, audio_resfeats, proj_feats = audio_residual_feats + + map_res = visual_resfeats[::-1] + vec_res = audio_resfeats[::-1] + + av_feats = (map_res, vec_res) + backbone_feats = self.v_model.precompute_high_res_features(backbone_feats) + backbone_feats = self.v_model.dont_prepare_prompt_inputs(backbone_feats, num_frames=frames.shape[0], + cond_frame=int(frames.shape[0]/2) if self.training else 0) + outputs = self.v_model.forward_tracking_wo_prompt(backbone_feats, audio_res=av_feats) + return outputs, proj_feats + + @property + def device(self) -> torch.device: + return self.v_model.device + + def freeze_sam_parameters(self): + self.v_model.eval() + for name, parameter in self.v_model.named_parameters(): + parameter.requires_grad = False diff --git a/avs.code/v1s.code/model/visual/sam2/__init__.py b/avs.code/v1s.code/model/visual/sam2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..46a1cecc55b6fd02a5ce6c66d9cc8a77343156db --- /dev/null +++ b/avs.code/v1s.code/model/visual/sam2/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from hydra import initialize_config_module +from hydra.core.global_hydra import GlobalHydra + +if not GlobalHydra.instance().is_initialized(): + initialize_config_module("configs", version_base="1.2") diff --git a/avs.code/v1s.code/model/visual/sam2/build_sam.py b/avs.code/v1s.code/model/visual/sam2/build_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..69f68c2e672d35d925aeb496cac918c1ee913dde --- /dev/null +++ b/avs.code/v1s.code/model/visual/sam2/build_sam.py @@ -0,0 +1,171 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os + +import torch +from hydra import compose +from hydra.utils import instantiate +from omegaconf import OmegaConf +''' +import sam2 + +# Check if the user is running Python from the parent directory of the sam2 repo +# (i.e. the directory where this repo is cloned into) -- this is not supported since +# it could shadow the sam2 package and cause issues. +if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")): + # If the user has "sam2/sam2" in their path, they are likey importing the repo itself + # as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory). + # This typically happens because the user is running Python from the parent directory + # that contains the sam2 repo they cloned. + raise RuntimeError( + "You're likely running Python from the parent directory of the sam2 repository " + "(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). " + "This is not supported since the `sam2` Python package could be shadowed by the " + "repository name (the repository is also named `sam2` and contains the Python package " + "in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir " + "rather than its parent dir, or from your home directory) after installing SAM 2." + ) +''' + +HF_MODEL_ID_TO_FILENAMES = { + "facebook/sam2-hiera-tiny": ( + "sam2/sam2_hiera_t.yaml", + "sam2_hiera_tiny.pt", + ), + "facebook/sam2-hiera-small": ( + "sam2/sam2_hiera_s.yaml", + "sam2_hiera_small.pt", + ), + "facebook/sam2-hiera-base-plus": ( + "sam2/sam2_hiera_b+.yaml", + "sam2_hiera_base_plus.pt", + ), + "facebook/sam2-hiera-large": ( + "sam2/sam2_hiera_l.yaml", + "sam2_hiera_large.pt", + ), + "facebook/sam2.1-hiera-tiny": ( + "sam2.1/sam2.1_hiera_t.yaml", + "sam2.1_hiera_tiny.pt", + ), + "facebook/sam2.1-hiera-small": ( + "sam2.1/sam2.1_hiera_s.yaml", + "sam2.1_hiera_small.pt", + ), + "facebook/sam2.1-hiera-base-plus": ( + "sam2.1/sam2.1_hiera_b+.yaml", + "sam2.1_hiera_base_plus.pt", + ), + "facebook/sam2.1-hiera-large": ( + "sam2.1/sam2.1_hiera_l.yaml", + "sam2.1_hiera_large.pt", + ), +} + + +def build_sam2( + config_file, + ckpt_path=None, + device="cuda", + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, + **kwargs, +): + + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + # dynamically fall back to multi-mask if the single mask is not stable + "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + ] + # Read config and init model + cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) + OmegaConf.resolve(cfg) + model = instantiate(cfg.model, _recursive_=True) + _load_checkpoint(model, ckpt_path) + model = model.to(device) + if mode == "eval": + model.eval() + return model + + +def build_sam2_visual_predictor( + config_file, + ckpt_path=None, + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, + **kwargs, +): + # visual + hydra_overrides = [] + # "++model._target_=model.visual.sam2.organised_sam2_train.SAM2Train", + # ] + # hydra_overrides = [ + # "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", + # ] + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + + # dynamically fall back to multi-mask if the single mask is not stable + # "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + # "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + # "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + + # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking + "++model.binarize_mask_from_pts_for_mem_enc=true", + # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) + # "++model.fill_hole_area=8", + ] + hydra_overrides.extend(hydra_overrides_extra) + + # Read config and init model + cfg = compose(config_name=config_file, overrides=hydra_overrides) + OmegaConf.resolve(cfg) + model = instantiate(cfg.model, _recursive_=True) + _load_checkpoint(model, ckpt_path) + if mode == "eval": + model.eval() + return model + + +def _hf_download(model_id): + from huggingface_hub import hf_hub_download + + config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id] + ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) + return config_name, ckpt_path + + +def build_sam2_hf(model_id, **kwargs): + config_name, ckpt_path = _hf_download(model_id) + return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs) + + +# def build_sam2_video_predictor_hf(model_id, **kwargs): +# config_name, ckpt_path = _hf_download(model_id) +# return build_sam2_video_predictor( +# config_file=config_name, ckpt_path=ckpt_path, **kwargs +# ) + + +def _load_checkpoint(model, ckpt_path): + if ckpt_path is not None: + sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"] + missing_keys, unexpected_keys = model.load_state_dict(sd) + if missing_keys: + logging.error(missing_keys) + raise RuntimeError() + if unexpected_keys: + logging.error(unexpected_keys) + raise RuntimeError() + logging.info("Loaded checkpoint sucessfully") diff --git a/avs.code/v1s.code/model/visual/sam2/modeling/__init__.py b/avs.code/v1s.code/model/visual/sam2/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/avs.code/v1s.code/model/visual/sam2/modeling/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/avs.code/v1s.code/model/visual/sam2/modeling/backbones/__init__.py b/avs.code/v1s.code/model/visual/sam2/modeling/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/avs.code/v1s.code/model/visual/sam2/modeling/backbones/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/avs.code/v1s.code/model/visual/sam2/modeling/backbones/hieradet.py b/avs.code/v1s.code/model/visual/sam2/modeling/backbones/hieradet.py new file mode 100644 index 0000000000000000000000000000000000000000..3fb6633c9c752cbefe2fc6043c81fb79bc659465 --- /dev/null +++ b/avs.code/v1s.code/model/visual/sam2/modeling/backbones/hieradet.py @@ -0,0 +1,317 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from functools import partial +from typing import List, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from iopath.common.file_io import g_pathmgr + +from model.visual.sam2.modeling.backbones.utils import ( + PatchEmbed, + window_partition, + window_unpartition, +) + +from model.visual.sam2.modeling.sam2_utils import DropPath, MLP + + +def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: + if pool is None: + return x + # (B, H, W, C) -> (B, C, H, W) + x = x.permute(0, 3, 1, 2) + x = pool(x) + # (B, C, H', W') -> (B, H', W', C) + x = x.permute(0, 2, 3, 1) + if norm: + x = norm(x) + + return x + + +class MultiScaleAttention(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + q_pool: nn.Module = None, + ): + super().__init__() + + self.dim = dim + self.dim_out = dim_out + self.num_heads = num_heads + self.q_pool = q_pool + self.qkv = nn.Linear(dim, dim_out * 3) + self.proj = nn.Linear(dim_out, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (B, H * W, 3, nHead, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) + # q, k, v with shape (B, H * W, nheads, C) + q, k, v = torch.unbind(qkv, 2) + + # Q pooling (for downsample at stage changes) + if self.q_pool: + q = do_pool(q.reshape(B, H, W, -1), self.q_pool) + H, W = q.shape[1:3] # downsampled shape + q = q.reshape(B, H * W, self.num_heads, -1) + + # Torch's SDPA expects [B, nheads, H*W, C] so we transpose + x = F.scaled_dot_product_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + ) + # Transpose back + x = x.transpose(1, 2) + x = x.reshape(B, H, W, -1) + + x = self.proj(x) + + return x + + +class MultiScaleBlock(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + mlp_ratio: float = 4.0, + drop_path: float = 0.0, + norm_layer: Union[nn.Module, str] = "LayerNorm", + q_stride: Tuple[int, int] = None, + act_layer: nn.Module = nn.GELU, + window_size: int = 0, + ): + super().__init__() + + if isinstance(norm_layer, str): + norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) + + self.dim = dim + self.dim_out = dim_out + self.norm1 = norm_layer(dim) + + self.window_size = window_size + + self.pool, self.q_stride = None, q_stride + if self.q_stride: + self.pool = nn.MaxPool2d( + kernel_size=q_stride, stride=q_stride, ceil_mode=False + ) + + self.attn = MultiScaleAttention( + dim, + dim_out, + num_heads=num_heads, + q_pool=self.pool, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim_out) + self.mlp = MLP( + dim_out, + int(dim_out * mlp_ratio), + dim_out, + num_layers=2, + activation=act_layer, + ) + + if dim != dim_out: + self.proj = nn.Linear(dim, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x # B, H, W, C + x = self.norm1(x) + + # Skip connection + if self.dim != self.dim_out: + shortcut = do_pool(self.proj(x), self.pool) + + # Window partition + window_size = self.window_size + if window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, window_size) + + # Window Attention + Q Pooling (if stage change) + x = self.attn(x) + if self.q_stride: + # Shapes have changed due to Q pooling + window_size = self.window_size // self.q_stride[0] + H, W = shortcut.shape[1:3] + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + pad_hw = (H + pad_h, W + pad_w) + + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, window_size, pad_hw, (H, W)) + + x = shortcut + self.drop_path(x) + # MLP + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Hiera(nn.Module): + """ + Reference: https://arxiv.org/abs/2306.00989 + """ + + def __init__( + self, + embed_dim: int = 96, # initial embed dim + num_heads: int = 1, # initial number of heads + drop_path_rate: float = 0.0, # stochastic depth + q_pool: int = 3, # number of q_pool stages + q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages + stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage + dim_mul: float = 2.0, # dim_mul factor at stage shift + head_mul: float = 2.0, # head_mul factor at stage shift + window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), + # window size per stage, when not using global att. + window_spec: Tuple[int, ...] = ( + 8, + 4, + 14, + 7, + ), + # global attn in these blocks + global_att_blocks: Tuple[int, ...] = ( + 12, + 16, + 20, + ), + weights_path=None, + return_interm_layers=True, # return feats from every stage + ): + super().__init__() + + assert len(stages) == len(window_spec) + self.window_spec = window_spec + + depth = sum(stages) + self.q_stride = q_stride + self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] + assert 0 <= q_pool <= len(self.stage_ends[:-1]) + self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] + self.return_interm_layers = return_interm_layers + + self.patch_embed = PatchEmbed( + embed_dim=embed_dim, + ) + # Which blocks have global att? + self.global_att_blocks = global_att_blocks + + # Windowed positional embedding (https://arxiv.org/abs/2311.05613) + self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size + self.pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size) + ) + self.pos_embed_window = nn.Parameter( + torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]) + ) + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + cur_stage = 1 + self.blocks = nn.ModuleList() + + for i in range(depth): + dim_out = embed_dim + # lags by a block, so first block of + # next stage uses an initial window size + # of previous stage and final window size of current stage + window_size = self.window_spec[cur_stage - 1] + + if self.global_att_blocks is not None: + window_size = 0 if i in self.global_att_blocks else window_size + + if i - 1 in self.stage_ends: + dim_out = int(embed_dim * dim_mul) + num_heads = int(num_heads * head_mul) + cur_stage += 1 + + block = MultiScaleBlock( + dim=embed_dim, + dim_out=dim_out, + num_heads=num_heads, + drop_path=dpr[i], + q_stride=self.q_stride if i in self.q_pool_blocks else None, + window_size=window_size, + ) + + embed_dim = dim_out + self.blocks.append(block) + + self.channel_list = ( + [self.blocks[i].dim_out for i in self.stage_ends[::-1]] + if return_interm_layers + else [self.blocks[-1].dim_out] + ) + + if weights_path is not None: + with g_pathmgr.open(weights_path, "rb") as f: + chkpt = torch.load(f, map_location="cpu") + logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False)) + + def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: + h, w = hw + window_embed = self.pos_embed_window + pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") + pos_embed = pos_embed + window_embed.tile( + [x // y for x, y in zip(pos_embed.shape, window_embed.shape)] + ) + pos_embed = pos_embed.permute(0, 2, 3, 1) + return pos_embed + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + x = self.patch_embed(x) + # x: (B, H, W, C) + + # Add pos embed + x = x + self._get_pos_embed(x.shape[1:3]) + + outputs = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if (i == self.stage_ends[-1]) or ( + i in self.stage_ends and self.return_interm_layers + ): + feats = x.permute(0, 3, 1, 2) + outputs.append(feats) + + return outputs + + def get_layer_id(self, layer_name): + # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 + num_layers = self.get_num_layers() + + if layer_name.find("rel_pos") != -1: + return num_layers + 1 + elif layer_name.find("pos_embed") != -1: + return 0 + elif layer_name.find("patch_embed") != -1: + return 0 + elif layer_name.find("blocks") != -1: + return int(layer_name.split("blocks")[1].split(".")[1]) + 1 + else: + return num_layers + 1 + + def get_num_layers(self) -> int: + return len(self.blocks) diff --git a/avs.code/v1s.code/model/visual/sam2/modeling/backbones/image_encoder.py b/avs.code/v1s.code/model/visual/sam2/modeling/backbones/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..37e9266bc98596e97ca303118c910ed24f6cee2c --- /dev/null +++ b/avs.code/v1s.code/model/visual/sam2/modeling/backbones/image_encoder.py @@ -0,0 +1,134 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ImageEncoder(nn.Module): + def __init__( + self, + trunk: nn.Module, + neck: nn.Module, + scalp: int = 0, + ): + super().__init__() + self.trunk = trunk + self.neck = neck + self.scalp = scalp + assert ( + self.trunk.channel_list == self.neck.backbone_channel_list + ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" + + def forward(self, sample: torch.Tensor): + # Forward through backbone + features, pos = self.neck(self.trunk(sample)) + if self.scalp > 0: + # Discard the lowest resolution features + features, pos = features[: -self.scalp], pos[: -self.scalp] + + src = features[-1] + output = { + "vision_features": src, + "vision_pos_enc": pos, + "backbone_fpn": features, + } + return output + + +class FpnNeck(nn.Module): + """ + A modified variant of Feature Pyramid Network (FPN) neck + (we remove output conv and also do bicubic interpolation similar to ViT + pos embed interpolation) + """ + + def __init__( + self, + position_encoding: nn.Module, + d_model: int, + backbone_channel_list: List[int], + kernel_size: int = 1, + stride: int = 1, + padding: int = 0, + fpn_interp_model: str = "bilinear", + fuse_type: str = "sum", + fpn_top_down_levels: Optional[List[int]] = None, + ): + """Initialize the neck + :param trunk: the backbone + :param position_encoding: the positional encoding to use + :param d_model: the dimension of the model + :param neck_norm: the normalization to use + """ + super().__init__() + self.position_encoding = position_encoding + self.convs = nn.ModuleList() + self.backbone_channel_list = backbone_channel_list + self.d_model = d_model + for dim in backbone_channel_list: + current = nn.Sequential() + current.add_module( + "conv", + nn.Conv2d( + in_channels=dim, + out_channels=d_model, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ), + ) + + self.convs.append(current) + self.fpn_interp_model = fpn_interp_model + assert fuse_type in ["sum", "avg"] + self.fuse_type = fuse_type + + # levels to have top-down features in its outputs + # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 + # have top-down propagation, while outputs of level 0 and level 1 have only + # lateral features from the same backbone level. + if fpn_top_down_levels is None: + # default is to have top-down features on all levels + fpn_top_down_levels = range(len(self.convs)) + self.fpn_top_down_levels = list(fpn_top_down_levels) + + def forward(self, xs: List[torch.Tensor]): + + out = [None] * len(self.convs) + pos = [None] * len(self.convs) + assert len(xs) == len(self.convs) + # fpn forward pass + # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py + prev_features = None + # forward in top-down order (from low to high resolution) + n = len(self.convs) - 1 + for i in range(n, -1, -1): + x = xs[i] + lateral_features = self.convs[n - i](x) + if i in self.fpn_top_down_levels and prev_features is not None: + top_down_features = F.interpolate( + prev_features.to(dtype=torch.float32), + scale_factor=2.0, + mode=self.fpn_interp_model, + align_corners=( + None if self.fpn_interp_model == "nearest" else False + ), + antialias=False, + ) + prev_features = lateral_features + top_down_features + if self.fuse_type == "avg": + prev_features /= 2 + else: + prev_features = lateral_features + x_out = prev_features + out[i] = x_out + pos[i] = self.position_encoding(x_out).to(x_out.dtype) + + return out, pos diff --git a/avs.code/v1s.code/model/visual/sam2/modeling/backbones/utils.py b/avs.code/v1s.code/model/visual/sam2/modeling/backbones/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..32d55c7545f064de133a5ff0200ba1ece9b504b7 --- /dev/null +++ b/avs.code/v1s.code/model/visual/sam2/modeling/backbones/utils.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Some utilities for backbones, in particular for windowing""" + +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def window_partition(x, window_size): + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows, (Hp, Wp) + + +def window_unpartition(windows, window_size, pad_hw, hw): + """ + Window unpartition into original sequences and removing padding. + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view( + B, Hp // window_size, Wp // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, ...] = (7, 7), + stride: Tuple[int, ...] = (4, 4), + padding: Tuple[int, ...] = (3, 3), + in_chans: int = 3, + embed_dim: int = 768, + ): + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/avs.code/v1s.code/model/visual/sam2/modeling/memory_attention.py b/avs.code/v1s.code/model/visual/sam2/modeling/memory_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..11f4ccb1904f022c18f8a02b9590a66bd57bb8f1 --- /dev/null +++ b/avs.code/v1s.code/model/visual/sam2/modeling/memory_attention.py @@ -0,0 +1,169 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from torch import nn, Tensor + +from model.visual.sam2.modeling.sam.transformer import RoPEAttention + +from model.visual.sam2.modeling.sam2_utils import get_activation_fn, get_clones + + +class MemoryAttentionLayer(nn.Module): + + def __init__( + self, + activation: str, + cross_attention: nn.Module, + d_model: int, + dim_feedforward: int, + dropout: float, + pos_enc_at_attn: bool, + pos_enc_at_cross_attn_keys: bool, + pos_enc_at_cross_attn_queries: bool, + self_attention: nn.Module, + ): + super().__init__() + self.d_model = d_model + self.dim_feedforward = dim_feedforward + self.dropout_value = dropout + self.self_attn = self_attention + self.cross_attn_image = cross_attention + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation_str = activation + self.activation = get_activation_fn(activation) + + # Where to add pos enc + self.pos_enc_at_attn = pos_enc_at_attn + self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries + self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys + + def _forward_sa(self, tgt, query_pos): + # Self-Attention + tgt2 = self.norm1(tgt) + q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 + tgt2 = self.self_attn(q, k, v=tgt2) + tgt = tgt + self.dropout1(tgt2) + return tgt + + def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): + kwds = {} + if num_k_exclude_rope > 0: + assert isinstance(self.cross_attn_image, RoPEAttention) + kwds = {"num_k_exclude_rope": num_k_exclude_rope} + + # Cross-Attention + tgt2 = self.norm2(tgt) + tgt2 = self.cross_attn_image( + q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, + k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, + v=memory, + **kwds, + ) + tgt = tgt + self.dropout2(tgt2) + return tgt + + def forward( + self, + tgt, + memory, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + num_k_exclude_rope: int = 0, + ) -> torch.Tensor: + + # Self-Attn, Cross-Attn + tgt = self._forward_sa(tgt, query_pos) + tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) + # MLP + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + +class MemoryAttention(nn.Module): + def __init__( + self, + d_model: int, + pos_enc_at_input: bool, + layer: nn.Module, + num_layers: int, + batch_first: bool = True, # Do layers expect batch first input? + ): + super().__init__() + self.d_model = d_model + self.layers = get_clones(layer, num_layers) + self.num_layers = num_layers + self.norm = nn.LayerNorm(d_model) + self.pos_enc_at_input = pos_enc_at_input + self.batch_first = batch_first + + def forward( + self, + curr: torch.Tensor, # self-attention inputs + memory: torch.Tensor, # cross-attention inputs + curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs + memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs + num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* + ): + if isinstance(curr, list): + assert isinstance(curr_pos, list) + assert len(curr) == len(curr_pos) == 1 + curr, curr_pos = ( + curr[0], + curr_pos[0], + ) + + assert ( + curr.shape[1] == memory.shape[1] + ), "Batch size must be the same for curr and memory" + + output = curr + if self.pos_enc_at_input and curr_pos is not None: + output = output + 0.1 * curr_pos + + if self.batch_first: + # Convert to batch first + output = output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + memory = memory.transpose(0, 1) + memory_pos = memory_pos.transpose(0, 1) + + for layer in self.layers: + kwds = {} + if isinstance(layer.cross_attn_image, RoPEAttention): + kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} + + output = layer( + tgt=output, + memory=memory, + pos=memory_pos, + query_pos=curr_pos, + **kwds, + ) + normed_output = self.norm(output) + + if self.batch_first: + # Convert back to seq first + normed_output = normed_output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + + return normed_output diff --git a/avs.code/v1s.code/model/visual/sam2/modeling/memory_encoder.py b/avs.code/v1s.code/model/visual/sam2/modeling/memory_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..7e1143cc0d5774ff96108203e404f678f14b0a23 --- /dev/null +++ b/avs.code/v1s.code/model/visual/sam2/modeling/memory_encoder.py @@ -0,0 +1,181 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model.visual.sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d + + +class MaskDownSampler(nn.Module): + """ + Progressively downsample a mask by total_stride, each time by stride. + Note that LayerNorm is applied per *token*, like in ViT. + + With each downsample (by a factor stride**2), channel capacity increases by the same factor. + In the end, we linearly project to embed_dim channels. + """ + + def __init__( + self, + embed_dim=256, + kernel_size=4, + stride=4, + padding=0, + total_stride=16, + activation=nn.GELU, + ): + super().__init__() + num_layers = int(math.log2(total_stride) // math.log2(stride)) + assert stride**num_layers == total_stride + self.encoder = nn.Sequential() + mask_in_chans, mask_out_chans = 1, 1 + for _ in range(num_layers): + mask_out_chans = mask_in_chans * (stride**2) + self.encoder.append( + nn.Conv2d( + mask_in_chans, + mask_out_chans, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + ) + self.encoder.append(LayerNorm2d(mask_out_chans)) + self.encoder.append(activation()) + mask_in_chans = mask_out_chans + + self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) + + def forward(self, x): + return self.encoder(x) + + +# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) +class CXBlock(nn.Module): + r"""ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__( + self, + dim, + kernel_size=7, + padding=3, + drop_path=0.0, + layer_scale_init_value=1e-6, + use_dwconv=True, + ): + super().__init__() + self.dwconv = nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=padding, + groups=dim if use_dwconv else 1, + ) # depthwise conv + self.norm = LayerNorm2d(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, 4 * dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = self.norm(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class Fuser(nn.Module): + def __init__(self, layer, num_layers, dim=None, input_projection=False): + super().__init__() + self.proj = nn.Identity() + self.layers = get_clones(layer, num_layers) + + if input_projection: + assert dim is not None + self.proj = nn.Conv2d(dim, dim, kernel_size=1) + + def forward(self, x): + # normally x: (N, C, H, W) + x = self.proj(x) + for layer in self.layers: + x = layer(x) + return x + + +class MemoryEncoder(nn.Module): + def __init__( + self, + out_dim, + mask_downsampler, + fuser, + position_encoding, + in_dim=256, # in_dim of pix_feats + ): + super().__init__() + + self.mask_downsampler = mask_downsampler + + self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) + self.fuser = fuser + self.position_encoding = position_encoding + self.out_proj = nn.Identity() + if out_dim != in_dim: + self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) + + def forward( + self, + pix_feat: torch.Tensor, + masks: torch.Tensor, + skip_mask_sigmoid: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + ## Process masks + # sigmoid, so that less domain shift from gt masks which are bool + if not skip_mask_sigmoid: + masks = F.sigmoid(masks) + masks = self.mask_downsampler(masks) + + ## Fuse pix_feats and downsampled masks + # in case the visual features are on CPU, cast them to CUDA + pix_feat = pix_feat.to(masks.device) + + x = self.pix_feat_proj(pix_feat) + x = x + masks + x = self.fuser(x) + x = self.out_proj(x) + + pos = self.position_encoding(x).to(x.dtype) + + return {"vision_features": x, "vision_pos_enc": [pos]} diff --git a/avs.code/v1s.code/model/visual/sam2/modeling/position_encoding.py b/avs.code/v1s.code/model/visual/sam2/modeling/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..52ac22674d5d4fdd9e83b6bdf034bff56d04bc0d --- /dev/null +++ b/avs.code/v1s.code/model/visual/sam2/modeling/position_encoding.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Any, Optional, Tuple + +import numpy as np + +import torch +from torch import nn + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention Is All You Need paper, generalized to work on images. + """ + + def __init__( + self, + num_pos_feats, + temperature: int = 10000, + normalize: bool = True, + scale: Optional[float] = None, + ): + super().__init__() + assert num_pos_feats % 2 == 0, "Expecting even model width" + self.num_pos_feats = num_pos_feats // 2 + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + self.cache = {} + + def _encode_xy(self, x, y): + # The positions are expected to be normalized + assert len(x) == len(y) and x.ndim == y.ndim == 1 + x_embed = x * self.scale + y_embed = y * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, None] / dim_t + pos_y = y_embed[:, None] / dim_t + pos_x = torch.stack( + (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2 + ).flatten(1) + pos_y = torch.stack( + (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2 + ).flatten(1) + return pos_x, pos_y + + @torch.no_grad() + def encode_boxes(self, x, y, w, h): + pos_x, pos_y = self._encode_xy(x, y) + pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) + return pos + + encode = encode_boxes # Backwards compatibility + + @torch.no_grad() + def encode_points(self, x, y, labels): + (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape + assert bx == by and nx == ny and bx == bl and nx == nl + pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) + pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) + pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) + return pos + + @torch.no_grad() + def forward(self, x: torch.Tensor): + cache_key = (x.shape[-2], x.shape[-1]) + if cache_key in self.cache: + return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) + y_embed = ( + torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) + .view(1, -1, 1) + .repeat(x.shape[0], 1, x.shape[-1]) + ) + x_embed = ( + torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) + .view(1, 1, -1) + .repeat(x.shape[0], x.shape[-2], 1) + ) + + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + self.cache[cache_key] = pos[0] + return pos + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C + + +# Rotary Positional Encoding, adapted from: +# 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py +# 2. https://github.com/naver-ai/rope-vit +# 3. https://github.com/lucidrains/rotary-embedding-torch + + +def init_t_xy(end_x: int, end_y: int): + t = torch.arange(end_x * end_y, dtype=torch.float32) + t_x = (t % end_x).float() + t_y = torch.div(t, end_x, rounding_mode="floor").float() + return t_x, t_y + + +def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): + freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + + t_x, t_y = init_t_xy(end_x, end_y) + freqs_x = torch.outer(t_x, freqs_x) + freqs_y = torch.outer(t_y, freqs_y) + freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) + freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) + return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) + shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_enc( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, + repeat_freqs_k: bool = False, +): + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = ( + torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + if xk.shape[-2] != 0 + else None + ) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + if xk_ is None: + # no keys to rotate, due to dropout + return xq_out.type_as(xq).to(xq.device), xk + # repeat freqs along seq_len dim to match k seq_len + if repeat_freqs_k: + r = xk_.shape[-2] // xq_.shape[-2] + if freqs_cis.is_cuda: + freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) + else: + # torch.repeat on complex numbers may not be supported on non-CUDA devices + # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten + freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) diff --git a/avs.code/v1s.code/model/visual/sam2/modeling/sam/__init__.py b/avs.code/v1s.code/model/visual/sam2/modeling/sam/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/avs.code/v1s.code/model/visual/sam2/modeling/sam/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/avs.code/v1s.code/model/visual/sam2/modeling/sam/mask_decoder.py b/avs.code/v1s.code/model/visual/sam2/modeling/sam/mask_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..078f21cc2ec41805eebec677e6e27771335deaa4 --- /dev/null +++ b/avs.code/v1s.code/model/visual/sam2/modeling/sam/mask_decoder.py @@ -0,0 +1,300 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn + +from model.visual.sam2.modeling.sam2_utils import LayerNorm2d, MLP + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + use_high_res_features: bool = False, + iou_prediction_use_sigmoid=False, + dynamic_multimask_via_stability=False, + dynamic_multimask_stability_delta=0.05, + dynamic_multimask_stability_thresh=0.98, + pred_obj_scores: bool = False, + pred_obj_scores_mlp: bool = False, + use_multimask_token_for_obj_ptr: bool = False, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.pred_obj_scores = pred_obj_scores + if self.pred_obj_scores: + self.obj_score_token = nn.Embedding(1, transformer_dim) + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d( + transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 + ), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d( + transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 + ), + activation(), + ) + self.use_high_res_features = use_high_res_features + if use_high_res_features: + self.conv_s0 = nn.Conv2d( + transformer_dim, transformer_dim // 8, kernel_size=1, stride=1 + ) + self.conv_s1 = nn.Conv2d( + transformer_dim, transformer_dim // 4, kernel_size=1, stride=1 + ) + + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, + iou_head_hidden_dim, + self.num_mask_tokens, + iou_head_depth, + sigmoid_output=iou_prediction_use_sigmoid, + ) + if self.pred_obj_scores: + self.pred_obj_score_head = nn.Linear(transformer_dim, 1) + if pred_obj_scores_mlp: + self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3) + + # When outputting a single mask, optionally we can dynamically fall back to the best + # multimask output token if the single mask output token gives low stability scores. + self.dynamic_multimask_via_stability = dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + audio_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + torch.Tensor: batched SAM token for mask output + """ + masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + repeat_image=repeat_image, + high_res_features=high_res_features, + audio_res_features_=audio_res_features + ) + + # Select the correct mask or masks for output + if multimask_output: + masks = masks[:, 1:, :, :] + iou_pred = iou_pred[:, 1:] + elif self.dynamic_multimask_via_stability and not self.training: + masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) + else: + masks = masks[:, 0:1, :, :] + iou_pred = iou_pred[:, 0:1] + + + if multimask_output and self.use_multimask_token_for_obj_ptr: + sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape + else: + # Take the mask output token. Here we *always* use the token for single mask output. + # At test time, even if we track after 1-click (and using multimask_output=True), + # we still take the single mask token here. The rationale is that we always track + # after multiple clicks during training, so the past tokens seen during training + # are always the single mask token (and we'll let it be the object-memory token). + sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape + + # Prepare output + return masks, iou_pred, sam_tokens_out, object_score_logits + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + audio_res_features_: Optional[List[torch.Tensor]] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + s = 0 + if self.pred_obj_scores: + output_tokens = torch.cat( + [ + self.obj_score_token.weight, + self.iou_token.weight, + self.mask_tokens.weight, + ], + dim=0, + ) + s = 1 + else: + output_tokens = torch.cat( + [self.iou_token.weight, self.mask_tokens.weight], dim=0 + ) + output_tokens = output_tokens.unsqueeze(0).expand( + sparse_prompt_embeddings.size(0), -1, -1 + ) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + if repeat_image: + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + else: + assert image_embeddings.shape[0] == tokens.shape[0] + src = image_embeddings + src = src + dense_prompt_embeddings + assert ( + image_pe.size(0) == 1 + ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens, audio_res_features_) + iou_token_out = hs[:, s, :] + mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + + if not self.use_high_res_features: + upscaled_embedding = self.output_upscaling(src) + else: + dc1, ln1, act1, dc2, act2 = self.output_upscaling + feat_s0, feat_s1 = high_res_features + upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) + upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) + + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append( + self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) + ) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + if self.pred_obj_scores: + assert s == 1 + object_score_logits = self.pred_obj_score_head(hs[:, 0, :]) + else: + # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1 + object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1) + + return masks, iou_pred, mask_tokens_out, object_score_logits + + def _get_stability_scores(self, mask_logits): + """ + Compute stability scores of the mask logits based on the IoU between upper and + lower thresholds. + """ + mask_logits = mask_logits.flatten(-2) + stability_delta = self.dynamic_multimask_stability_delta + area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() + area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() + stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) + return stability_scores + + def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): + """ + When outputting a single mask, if the stability score from the current single-mask + output (based on output token 0) falls below a threshold, we instead select from + multi-mask outputs (based on output token 1~3) the mask with the highest predicted + IoU score. This is intended to ensure a valid mask for both clicking and tracking. + """ + # The best mask from multimask output tokens (1~3) + multimask_logits = all_mask_logits[:, 1:, :, :] + multimask_iou_scores = all_iou_scores[:, 1:] + best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) + batch_inds = torch.arange( + multimask_iou_scores.size(0), device=all_iou_scores.device + ) + best_multimask_logits = multimask_logits[batch_inds, best_scores_inds] + best_multimask_logits = best_multimask_logits.unsqueeze(1) + best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds] + best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1) + + # The mask from singlemask output token 0 and its stability score + singlemask_logits = all_mask_logits[:, 0:1, :, :] + singlemask_iou_scores = all_iou_scores[:, 0:1] + stability_scores = self._get_stability_scores(singlemask_logits) + is_stable = stability_scores >= self.dynamic_multimask_stability_thresh + + # Dynamically fall back to best multimask output upon low stability scores. + mask_logits_out = torch.where( + is_stable[..., None, None].expand_as(singlemask_logits), + singlemask_logits, + best_multimask_logits, + ) + iou_scores_out = torch.where( + is_stable.expand_as(singlemask_iou_scores), + singlemask_iou_scores, + best_multimask_iou_scores, + ) + return mask_logits_out, iou_scores_out diff --git a/avs.code/v1s.code/model/visual/sam2/modeling/sam/prompt_encoder.py b/avs.code/v1s.code/model/visual/sam2/modeling/sam/prompt_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..038cebcc072ae7c0f3f83061061be3edba04d0f8 --- /dev/null +++ b/avs.code/v1s.code/model/visual/sam2/modeling/sam/prompt_encoder.py @@ -0,0 +1,188 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple, Type + +import torch +from torch import nn + +from model.visual.sam2.modeling.position_encoding import PositionEmbeddingRandom + +from model.visual.sam2.modeling.sam2_utils import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [ + nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) + ] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = ( + 4 * image_embedding_size[0], + 4 * image_embedding_size[1], + ) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords( + points, self.input_image_size + ) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + point_embedding[labels == 2] += self.point_embeddings[2].weight + point_embedding[labels == 3] += self.point_embeddings[3].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords( + coords, self.input_image_size + ) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + # we only utilise sounding as prompt. + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty( + (bs, 0, self.embed_dim), device=self._get_device() + ) + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + ''' + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + ''' + return sparse_embeddings, dense_embeddings + diff --git a/avs.code/v1s.code/model/visual/sam2/modeling/sam/transformer.py b/avs.code/v1s.code/model/visual/sam2/modeling/sam/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..31916550afeccb66f4427cee7ec4a7a2d66913a5 --- /dev/null +++ b/avs.code/v1s.code/model/visual/sam2/modeling/sam/transformer.py @@ -0,0 +1,367 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import math +import warnings +from functools import partial +from typing import Tuple, Type + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from model.visual.sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis +from model.visual.sam2.modeling.sam2_utils import MLP +from model.visual.sam2.utils.misc import get_sdpa_settings + +warnings.simplefilter(action="ignore", category=FutureWarning) +# Check whether Flash Attention is available (and use it by default) +OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings() +# A fallback setting to allow all available kernels if Flash Attention fails +ALLOW_ALL_KERNELS = False + + +def sdp_kernel_context(dropout_p): + """ + Get the context for the attention scaled dot-product kernel. We use Flash Attention + by default, but fall back to all available kernels if Flash Attention fails. + """ + if ALLOW_ALL_KERNELS: + return contextlib.nullcontext() + + return torch.backends.cuda.sdp_kernel( + enable_flash=USE_FLASH_ATTN, + # if Flash attention kernel is off, then math kernel needs to be enabled + enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON, + enable_mem_efficient=OLD_GPU, + ) + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + audio_res: [], + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + visual_res, audio_res = audio_res + + # Prepare queries + queries = point_embedding + keys = image_embedding + # Apply transformer blocks and final layernorm + for i, layer in enumerate(self.layers): + keys = keys + visual_res[i] + queries[:, 2:6] = queries[:, 2:6] + audio_res[i] + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + queries[:, 2:6] = queries[:, 2:6] + audio_res[-1] + keys = keys + visual_res[-1] + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLP( + embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation + ) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + dropout: float = 0.0, + kv_in_dim: int = None, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert ( + self.internal_dim % num_heads == 0 + ), "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + self.dropout_p = dropout + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + try: + with sdp_kernel_context(dropout_p): + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + except Exception as e: + # Fall back to all kernels if the Flash attention kernel fails + warnings.warn( + f"Flash Attention kernel failed due to: {e}\nFalling back to all available " + f"kernels for scaled_dot_product_attention (which may have a slower speed).", + category=UserWarning, + stacklevel=2, + ) + global ALLOW_ALL_KERNELS + ALLOW_ALL_KERNELS = True + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out + + +class RoPEAttention(Attention): + """Attention with rotary position encoding.""" + + def __init__( + self, + *args, + rope_theta=10000.0, + # whether to repeat q rope to match k length + # this is needed for cross-attention to memories + rope_k_repeat=False, + feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.compute_cis = partial( + compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta + ) + freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) + self.freqs_cis = freqs_cis + self.rope_k_repeat = rope_k_repeat + + def forward( + self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0 + ) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Apply rotary position encoding + w = h = math.sqrt(q.shape[-2]) + self.freqs_cis = self.freqs_cis.to(q.device) + if self.freqs_cis.shape[0] != q.shape[-2]: + self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) + if q.shape[-2] != k.shape[-2]: + assert self.rope_k_repeat + + num_k_rope = k.size(-2) - num_k_exclude_rope + q, k[:, :, :num_k_rope] = apply_rotary_enc( + q, + k[:, :, :num_k_rope], + freqs_cis=self.freqs_cis, + repeat_freqs_k=self.rope_k_repeat, + ) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + try: + with sdp_kernel_context(dropout_p): + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + except Exception as e: + # Fall back to all kernels if the Flash attention kernel fails + warnings.warn( + f"Flash Attention kernel failed due to: {e}\nFalling back to all available " + f"kernels for scaled_dot_product_attention (which may have a slower speed).", + category=UserWarning, + stacklevel=2, + ) + global ALLOW_ALL_KERNELS + ALLOW_ALL_KERNELS = True + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/avs.code/v1s.code/model/visual/sam2/modeling/sam2_base.py b/avs.code/v1s.code/model/visual/sam2/modeling/sam2_base.py new file mode 100644 index 0000000000000000000000000000000000000000..2ab890394064172b8719e8a06ee0a47d995fd585 --- /dev/null +++ b/avs.code/v1s.code/model/visual/sam2/modeling/sam2_base.py @@ -0,0 +1,940 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed +import torch.nn.functional as F + +from torch.nn.init import trunc_normal_ + +from model.visual.sam2.modeling.sam.mask_decoder import MaskDecoder +from model.visual.sam2.modeling.sam.prompt_encoder import PromptEncoder +from model.visual.sam2.modeling.sam.transformer import TwoWayTransformer +from model.visual.sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames + +# a large negative value as a placeholder score for missing objects +NO_OBJ_SCORE = -1024.0 + + +class SAM2Base(torch.nn.Module): + def __init__( + self, + image_encoder, + memory_attention, + memory_encoder, + num_maskmem=7, # default 1 input frame + 6 previous frames + image_size=512, + backbone_stride=16, # stride of the image backbone output + sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob + sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob + # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks + binarize_mask_from_pts_for_mem_enc=False, + use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder + # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit, + # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model + # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM. + max_cond_frames_in_attn=-1, + # on the first frame, whether to directly add the no-memory embedding to the image feature + # (instead of using the transformer encoder) + directly_add_no_mem_embed=False, + # whether to use high-resolution feature maps in the SAM mask decoder + use_high_res_features_in_sam=False, + # whether to output multiple (3) masks for the first click on initial conditioning frames + multimask_output_in_sam=False, + # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`; + # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points) + multimask_min_pt_num=1, + multimask_max_pt_num=1, + # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`) + multimask_output_for_tracking=False, + # Whether to use multimask tokens for obj ptr; Only relevant when both + # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True + use_multimask_token_for_obj_ptr: bool = False, + # whether to use sigmoid to restrict ious prediction to [0-1] + iou_prediction_use_sigmoid=False, + # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5). + # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of + # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame. + memory_temporal_stride_for_eval=1, + # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks) + non_overlap_masks_for_mem_enc=False, + # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder=False, + # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`) + max_obj_ptrs_in_encoder=16, + # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`) + add_tpos_enc_to_obj_ptrs=True, + # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference + # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) + proj_tpos_enc_in_obj_ptrs=False, + # whether to use signed distance (instead of unsigned absolute distance) in the temporal positional encoding in the object pointers + # (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) + use_signed_tpos_enc_to_obj_ptrs=False, + # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation + # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking) + only_obj_ptrs_in_the_past_for_eval=False, + # Whether to predict if there is an object in the frame + pred_obj_scores: bool = False, + # Whether to use an MLP to predict object scores + pred_obj_scores_mlp: bool = False, + # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True; + # Whether to have a fixed no obj pointer when there is no object present + # or to use it as an additive embedding with obj_ptr produced by decoder + fixed_no_obj_ptr: bool = False, + # Soft no object, i.e. mix in no_obj_ptr softly, + # hope to make recovery easier if there is a mistake and mitigate accumulation of errors + soft_no_obj_ptr: bool = False, + use_mlp_for_obj_ptr_proj: bool = False, + # add no obj embedding to spatial frames + no_obj_embed_spatial: bool = False, + # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class. + sam_mask_decoder_extra_args=None, + compile_image_encoder: bool = False, + ): + super().__init__() + + # Part 1: the image backbone + self.image_encoder = image_encoder + # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting + self.use_high_res_features_in_sam = use_high_res_features_in_sam + self.num_feature_levels = 3 if use_high_res_features_in_sam else 1 + self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder + self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder + if use_obj_ptrs_in_encoder: + # A conv layer to downsample the mask prompt to stride 4 (the same stride as + # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale, + # so that it can be fed into the SAM mask decoder to generate a pointer. + self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4) + self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs + if proj_tpos_enc_in_obj_ptrs: + assert add_tpos_enc_to_obj_ptrs # these options need to be used together + self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs + self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs + self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval + + # Part 2: memory attention to condition current frame's visual features + # with memories (and obj ptrs) from past frames + self.memory_attention = memory_attention + + #### this is for Version 2.0 + # self.hidden_dim = memory_attention.d_model + #### this is for Version 2.1 + # self.hidden_dim = image_encoder.neck.d_model + self.hidden_dim = 256 # well, it is always 256 anyway. + + # Part 3: memory encoder for the previous frame's outputs + self.memory_encoder = memory_encoder + self.mem_dim = self.hidden_dim + if hasattr(self.memory_encoder, "out_proj") and hasattr( + self.memory_encoder.out_proj, "weight" + ): + # if there is compression of memories along channel dim + self.mem_dim = self.memory_encoder.out_proj.weight.shape[0] + self.num_maskmem = num_maskmem # Number of memories accessible + # Temporal encoding of the memories + self.maskmem_tpos_enc = torch.nn.Parameter( + torch.zeros(num_maskmem, 1, 1, self.mem_dim) + ) + trunc_normal_(self.maskmem_tpos_enc, std=0.02) + # a single token to indicate no memory embedding from previous frames + self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + trunc_normal_(self.no_mem_embed, std=0.02) + trunc_normal_(self.no_mem_pos_enc, std=0.02) + self.directly_add_no_mem_embed = directly_add_no_mem_embed + # Apply sigmoid to the output raw mask logits (to turn them from + # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder + self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc + self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc + self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc + self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc + self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval + # On frames with mask input, whether to directly output the input mask without + # using a SAM prompt encoder + mask decoder + self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam + self.multimask_output_in_sam = multimask_output_in_sam + self.multimask_min_pt_num = multimask_min_pt_num + self.multimask_max_pt_num = multimask_max_pt_num + self.multimask_output_for_tracking = multimask_output_for_tracking + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid + + # Part 4: SAM-style prompt encoder (for both mask and point inputs) + # and SAM-style mask decoder for the final mask output + self.image_size = image_size + self.backbone_stride = backbone_stride + self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args + self.pred_obj_scores = pred_obj_scores + self.pred_obj_scores_mlp = pred_obj_scores_mlp + self.fixed_no_obj_ptr = fixed_no_obj_ptr + self.soft_no_obj_ptr = soft_no_obj_ptr + if self.fixed_no_obj_ptr: + assert self.pred_obj_scores + assert self.use_obj_ptrs_in_encoder + if self.pred_obj_scores and self.use_obj_ptrs_in_encoder: + self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) + trunc_normal_(self.no_obj_ptr, std=0.02) + self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj + self.no_obj_embed_spatial = None + if no_obj_embed_spatial: + self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim)) + trunc_normal_(self.no_obj_embed_spatial, std=0.02) + + self._build_sam_heads() + self.max_cond_frames_in_attn = max_cond_frames_in_attn + + # Model compilation + if compile_image_encoder: + # Compile the forward function (not the full module) to allow loading checkpoints. + print( + "Image encoder compilation is enabled. First forward pass will be slow." + ) + self.image_encoder.forward = torch.compile( + self.image_encoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + + ### we fix the use_mask_input_as_output_without_sam to be turned off. + self.use_mask_input_as_output_without_sam = False + + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, *args, **kwargs): + raise NotImplementedError( + "Please use the corresponding methods in SAM2VideoPredictor for inference or SAM2Train for training/fine-tuning" + "See notebooks/video_predictor_example.ipynb for an inference example." + ) + + def _build_sam_heads(self): + """Build SAM-style prompt encoder and mask decoder.""" + self.sam_prompt_embed_dim = self.hidden_dim + self.sam_image_embedding_size = self.image_size // self.backbone_stride + + # build PromptEncoder and MaskDecoder from SAM + # (their hyperparameters like `mask_in_chans=16` are from SAM code) + self.sam_prompt_encoder = PromptEncoder( + embed_dim=self.sam_prompt_embed_dim, + image_embedding_size=( + self.sam_image_embedding_size, + self.sam_image_embedding_size, + ), + input_image_size=(self.image_size, self.image_size), + mask_in_chans=16, + ) + self.sam_mask_decoder = MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=self.sam_prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=self.sam_prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + use_high_res_features=self.use_high_res_features_in_sam, + iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, + pred_obj_scores=self.pred_obj_scores, + pred_obj_scores_mlp=self.pred_obj_scores_mlp, + use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, + **(self.sam_mask_decoder_extra_args or {}), + ) + if self.use_obj_ptrs_in_encoder: + # a linear projection on SAM output tokens to turn them into object pointers + self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim) + if self.use_mlp_for_obj_ptr_proj: + self.obj_ptr_proj = MLP( + self.hidden_dim, self.hidden_dim, self.hidden_dim, 3 + ) + else: + self.obj_ptr_proj = torch.nn.Identity() + if self.proj_tpos_enc_in_obj_ptrs: + # a linear projection on temporal positional encoding in object pointers to + # avoid potential interference with spatial positional encoding + self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim) + else: + self.obj_ptr_tpos_proj = torch.nn.Identity() + + def _forward_sam_heads( + self, + backbone_features, + point_inputs=None, + mask_inputs=None, + high_res_features=None, + multimask_output=False, + audio_res=None + ): + """ + Forward SAM prompt encoders and mask heads. + + Inputs: + - backbone_features: image features of [B, C, H, W] shape + - point_inputs: a dictionary with "point_coords" and "point_labels", where + 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the + absolute pixel-unit coordinate in (x, y) format of the P input points + 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means + positive clicks, 0 means negative clicks, and -1 means padding + - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the + same spatial size as the image. + - high_res_features: either 1) None or 2) or a list of length 2 containing + two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively, + which will be used as high-resolution feature maps for SAM decoder. + - multimask_output: if it's True, we output 3 candidate masks and their 3 + corresponding IoU estimates, and if it's False, we output only 1 mask and + its corresponding IoU estimate. + + Outputs: + - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if + `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM + output mask logits (before sigmoid) for the low-resolution masks, with 4x + the resolution (1/4 stride) of the input backbone_features. + - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3 + if `multimask_output=True` and M = 1 if `multimask_output=False`), + upsampled from the low-resolution masks, with shape size as the image + (stride is 1 pixel). + - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1 + if `multimask_output=False`), the estimated IoU of each output mask. + - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `low_res_multimasks`. + - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `high_res_multimasks`. + - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted + based on the output token from the SAM mask decoder. + """ + B = backbone_features.size(0) + device = backbone_features.device + assert backbone_features.size(1) == self.sam_prompt_embed_dim + assert backbone_features.size(2) == self.sam_image_embedding_size + assert backbone_features.size(3) == self.sam_image_embedding_size + + ''' + # a) Handle point prompts + if point_inputs is not None: + sam_point_coords = point_inputs["point_coords"] + sam_point_labels = point_inputs["point_labels"] + assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B + raise NotImplementedError + else: + # If no points are provide, pad with an empty point (with label -1) + sam_point_coords = torch.zeros(B, 1, 2, device=device) + sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) + + # b) Handle mask prompts + if mask_inputs is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) + if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: + sam_mask_prompt = F.interpolate( + mask_inputs.float(), + size=self.sam_prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + else: + sam_mask_prompt = mask_inputs + raise NotImplementedError + else: + # Otherwise, simply feed None (and SAM's prompt encoder will add + # a learned `no_mask_embed` to indicate no mask input in this case). + sam_mask_prompt = None + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( + points=(sam_point_coords, sam_point_labels), + boxes=None, + masks=sam_mask_prompt, + ) + ''' + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( + points=None, + boxes=None, + masks=None, + ) + + ( + low_res_multimasks, + ious, + sam_output_tokens, + object_score_logits, + ) = self.sam_mask_decoder( + image_embeddings=backbone_features, + image_pe=self.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=False, # the image is already batched + high_res_features=high_res_features, + audio_res_features=audio_res + ) + ''' + if self.pred_obj_scores: + is_obj_appearing = object_score_logits > 0 + + # Mask used for spatial memories is always a *hard* choice between obj and no obj, + # consistent with the actual mask prediction + low_res_multimasks = torch.where( + is_obj_appearing[:, None, None], + low_res_multimasks, + NO_OBJ_SCORE, + ) + ''' + # convert masks from possibly bfloat16 (or float16) to float32 + # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) + low_res_multimasks = low_res_multimasks.float() + high_res_multimasks = F.interpolate( + low_res_multimasks, + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + sam_output_token = sam_output_tokens[:, 0] + if multimask_output: + # comment this line temporarily. + # take the best mask prediction (with the highest IoU estimation) + best_iou_inds = torch.argmax(ious, dim=-1) + batch_inds = torch.arange(B, device=device) + low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + if sam_output_tokens.size(1) > 1: + sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] + else: + low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks + + # Extract object pointer from the SAM output token (with occlusion handling) + obj_ptr = self.obj_ptr_proj(sam_output_token) + + # don't train occlusion at the moment, command temporarily. + if self.pred_obj_scores: + is_obj_appearing = object_score_logits > 0 + # Allow *soft* no obj ptr, unlike for masks + if self.soft_no_obj_ptr: + lambda_is_obj_appearing = object_score_logits.sigmoid() + else: + lambda_is_obj_appearing = is_obj_appearing.float() + + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + return ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): + """ + Directly turn binary `mask_inputs` into a output mask logits without using SAM. + (same input and output shapes as in _forward_sam_heads above). + """ + # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid). + out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05 + mask_inputs_float = mask_inputs.float() + high_res_masks = mask_inputs_float * out_scale + out_bias + low_res_masks = F.interpolate( + high_res_masks, + size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + # a dummy IoU prediction of all 1's under mask input + ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float() + if not self.use_obj_ptrs_in_encoder: + # all zeros as a dummy object pointer (of shape [B, C]) + obj_ptr = torch.zeros( + mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device + ) + else: + # produce an object pointer using the SAM decoder from the mask input + _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads( + backbone_features=backbone_features, + mask_inputs=self.mask_downsample(mask_inputs_float), + high_res_features=high_res_features, + ) + # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem; + # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying + # on the object_scores from the SAM decoder. + is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1) + is_obj_appearing = is_obj_appearing[..., None] + lambda_is_obj_appearing = is_obj_appearing.float() + object_score_logits = out_scale * lambda_is_obj_appearing + out_bias + if self.pred_obj_scores: + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_masks, + high_res_masks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def precompute_high_res_features(self, backbone_out): + if self.use_high_res_features_in_sam: + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0( + backbone_out["backbone_fpn"][0] + ) + backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1( + backbone_out["backbone_fpn"][1] + ) + return backbone_out + + def forward_image(self, img_batch: torch.Tensor, pre_compute=True): + """Get the image feature on the input batch.""" + backbone_out = self.image_encoder(img_batch) + return backbone_out if not pre_compute else self.precompute_high_res_features(backbone_out) + + def _prepare_backbone_features(self, backbone_out): + """Prepare and flatten visual features.""" + backbone_out = backbone_out.copy() + assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) + assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels + + feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :] + vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :] + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] + # flatten NxCxHxW to HWxNxC + vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] + vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] + + return backbone_out, vision_feats, vision_pos_embeds, feat_sizes + + def _prepare_memory_conditioned_features( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + ): + """Fuse the current frame's visual feature map with previous memory.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + device = current_vision_feats[-1].device + # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images. + # In this case, we skip the fusion with any memory. + if self.num_maskmem == 0: # Disable memory and skip fusion + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + return pix_feat + + num_obj_ptr_tokens = 0 + tpos_sign_mul = -1 if track_in_reverse else 1 + # Step 1: condition the visual features of the current frame on previous memories + if not is_init_cond_frame: + # Retrieve the memories encoded with the maskmem backbone + to_cat_memory, to_cat_memory_pos_embed = [], [] + # Add conditioning frames's output first (all cond frames have t_pos=0 for + # when getting temporal positional embedding below) + assert len(output_dict["cond_frame_outputs"]) > 0 + # Select a maximum number of temporally closest cond frames for cross attention + cond_outputs = output_dict["cond_frame_outputs"] + selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames( + frame_idx, cond_outputs, self.max_cond_frames_in_attn + ) + t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()] + # for t_pos in range(1, min(self.num_maskmem, frame_idx)): + # out = output_dict["non_cond_frame_outputs"].get(t_pos, None) + # t_pos_and_prevs.append((t_pos, out)) + # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory + # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 + # We also allow taking the memory frame non-consecutively (with stride>1), in which case + # we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame. + stride = 1 if self.training else self.memory_temporal_stride_for_eval + + for t_pos in range(1, self.num_maskmem): + t_rel = self.num_maskmem - t_pos # how many frames before current frame + if t_rel == 1: + # for t_rel == 1, we take the last frame (regardless of r) + if not track_in_reverse: + # the frame immediately before this frame (i.e. frame_idx - 1) + prev_frame_idx = frame_idx - t_rel + else: + # the frame immediately after this frame (i.e. frame_idx + 1) + prev_frame_idx = frame_idx + t_rel + else: + # for t_rel >= 2, we take the memory frame from every r-th frames + if not track_in_reverse: + # first find the nearest frame among every r-th frames before this frame + # for r=1, this would be (frame_idx - 2) + prev_frame_idx = ((frame_idx - 2) // stride) * stride + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx - (t_rel - 2) * stride + else: + # first find the nearest frame among every r-th frames after this frame + # for r=1, this would be (frame_idx + 2) + prev_frame_idx = -(-(frame_idx + 2) // stride) * stride + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx + (t_rel - 2) * stride + out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) + if out is None: + # If an unselected conditioning frame is among the last (self.num_maskmem - 1) + # frames, we still attend to it as if it's a non-conditioning frame. + out = unselected_cond_outputs.get(prev_frame_idx, None) + t_pos_and_prevs.append((t_pos, out)) + + for t_pos, prev in t_pos_and_prevs: + if prev is None: + continue # skip padding frames + # "maskmem_features" might have been offloaded to CPU in demo use cases, + # so we load it back to GPU (it's a no-op if it's already on GPU). + feats = prev["maskmem_features"].to(device, non_blocking=True) + to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) + # Spatial positional encoding (it might have been offloaded to CPU in eval) + maskmem_enc = prev["maskmem_pos_enc"][-1].to(device) + maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) + # Temporal positional encoding + maskmem_enc = ( + maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1] + ) + to_cat_memory_pos_embed.append(maskmem_enc) + # Construct the list of past object pointers + if self.use_obj_ptrs_in_encoder: + max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder) + # First add those object pointers from selected conditioning frames + # (optionally, only include object pointers in the past during evaluation) + if not self.training and self.only_obj_ptrs_in_the_past_for_eval: + ptr_cond_outputs = { + t: out + for t, out in selected_cond_outputs.items() + if (t >= frame_idx if track_in_reverse else t <= frame_idx) + } + else: + ptr_cond_outputs = selected_cond_outputs + pos_and_ptrs = [ + # Temporal pos encoding contains how far away each pointer is from current frame + ( + ( + (frame_idx - t) * tpos_sign_mul + if self.use_signed_tpos_enc_to_obj_ptrs + else abs(frame_idx - t) + ), + out["obj_ptr"], + ) + for t, out in ptr_cond_outputs.items() + ] + # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame + for t_diff in range(1, max_obj_ptrs_in_encoder): + t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff + if t < 0 or (num_frames is not None and t >= num_frames): + break + out = output_dict["non_cond_frame_outputs"].get( + t, unselected_cond_outputs.get(t, None) + ) + if out is not None: + pos_and_ptrs.append((t_diff, out["obj_ptr"])) + # If we have at least one object pointer, add them to the across attention + if len(pos_and_ptrs) > 0: + pos_list, ptrs_list = zip(*pos_and_ptrs) + # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape + obj_ptrs = torch.stack(ptrs_list, dim=0) + # a temporal positional embedding based on how far each object pointer is from + # the current frame (sine embedding normalized by the max pointer num). + # default false. + if self.add_tpos_enc_to_obj_ptrs: + t_diff_max = max_obj_ptrs_in_encoder - 1 + tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim + obj_pos = torch.tensor(pos_list, device=device) + obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) + obj_pos = self.obj_ptr_tpos_proj(obj_pos) + obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) + else: + obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim) + if self.mem_dim < C: + # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C + obj_ptrs = obj_ptrs.reshape( + -1, B, C // self.mem_dim, self.mem_dim + ) + obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1) + obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0) + to_cat_memory.append(obj_ptrs) + to_cat_memory_pos_embed.append(obj_pos) + num_obj_ptr_tokens = obj_ptrs.shape[0] + else: + num_obj_ptr_tokens = 0 + else: + # for initial conditioning frames, encode them without using any previous memory + if self.directly_add_no_mem_embed: + # directly add no-mem embedding (instead of using the transformer encoder) + pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + # Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder) + # the Following lines will never be triggered. + raise NotImplementedError + to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] + to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] + + # Step 2: Concatenate the memories and forward through the transformer encoder + memory = torch.cat(to_cat_memory, dim=0) + memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0) + + pix_feat_with_mem = self.memory_attention( + curr=current_vision_feats, + curr_pos=current_vision_pos_embeds, + memory=memory, + memory_pos=memory_pos_embed, + num_obj_ptr_tokens=num_obj_ptr_tokens, + ) + # reshape the output (HW)BC => BCHW + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + def _encode_new_memory( + self, + current_vision_feats, + feat_sizes, + pred_masks_high_res, + object_score_logits, + is_mask_from_pts, + ): + """Encode the current image and its prediction into a memory feature.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + # top-level feature, (HW)BC => BCHW + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + if self.non_overlap_masks_for_mem_enc and not self.training: + # optionally, apply non-overlapping constraints to the masks (it's applied + # in the batch dimension and should only be used during eval, where all + # the objects come from the same video under batch size 1). + pred_masks_high_res = self._apply_non_overlapping_constraints( + pred_masks_high_res + ) + raise NotImplementedError + # scale the raw mask logits with a temperature before applying sigmoid + binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts + if binarize and not self.training: + mask_for_mem = (pred_masks_high_res > 0).float() + else: + # apply sigmoid on the raw mask logits to turn them into range (0, 1) + mask_for_mem = torch.sigmoid(pred_masks_high_res) + # apply scale and bias terms to the sigmoid probabilities + if self.sigmoid_scale_for_mem_enc != 1.0: + mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc + if self.sigmoid_bias_for_mem_enc != 0.0: + mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc + maskmem_out = self.memory_encoder( + pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied + ) + maskmem_features = maskmem_out["vision_features"] + maskmem_pos_enc = maskmem_out["vision_pos_enc"] + # add a no-object embedding to the spatial memory to indicate that the frame + # is predicted to be occluded (i.e. no object is appearing in the frame) + if self.no_obj_embed_spatial is not None: + is_obj_appearing = (object_score_logits > 0).float() + maskmem_features += ( + 1 - is_obj_appearing[..., None, None] + ) * self.no_obj_embed_spatial[..., None, None].expand( + *maskmem_features.shape + ) + # it will be used in sam2.1 + # raise NotImplementedError + + return maskmem_features, maskmem_pos_enc + + def _track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ): + current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(current_vision_feats) > 1: + high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) + ] + else: + high_res_features = None + if mask_inputs is not None and self.use_mask_input_as_output_without_sam: + # When use_mask_input_as_output_without_sam=True, we directly output the mask input + # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. + pix_feat = current_vision_feats[-1].permute(1, 2, 0) + pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + sam_outputs = self._use_mask_as_output( + pix_feat, high_res_features, mask_inputs + ) + else: + # fused the visual feature with previous memory features in the memory bank + pix_feat = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats[-1:], + current_vision_pos_embeds=current_vision_pos_embeds[-1:], + feat_sizes=feat_sizes[-1:], + output_dict=output_dict, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + ) + # apply SAM-style segmentation head + # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, + # e.g. in demo where such logits come from earlier interaction instead of correction sampling + # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) + if prev_sam_mask_logits is not None: + assert point_inputs is not None and mask_inputs is None + mask_inputs = prev_sam_mask_logits + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + + return current_out, sam_outputs, high_res_features, pix_feat + + def _encode_memory_in_output( + self, + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ): + if run_mem_encoder and self.num_maskmem > 0: + high_res_masks_for_mem_enc = high_res_masks + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks_for_mem_enc, + object_score_logits=object_score_logits, + is_mask_from_pts=(point_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + + def track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + # Whether to run the memory encoder on the predicted masks. Sometimes we might want + # to skip the memory encoder with `run_mem_encoder=False`. For example, + # in demo we might call `track_step` multiple times for each user click, + # and only encode the memory when the user finalizes their clicks. And in ablation + # settings like SAM training on static images, we don't need the memory encoder. + run_mem_encoder=True, + # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). + prev_sam_mask_logits=None, + ): + current_out, sam_outputs, _, _ = self._track_step( + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ) + + ( + _, + _, + _, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = sam_outputs + + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + if not self.training: + # Only add this in inference (to avoid unused param in activation checkpointing; + # it's mainly used in the demo to encode spatial memories w/ consolidated masks) + current_out["object_score_logits"] = object_score_logits + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + self._encode_memory_in_output( + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ) + + return current_out + + def _use_multimask(self, is_init_cond_frame, point_inputs): + """Whether to use multimask output in the SAM head.""" + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) + multimask_output = ( + self.multimask_output_in_sam + and (is_init_cond_frame or self.multimask_output_for_tracking) + and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) + ) + return multimask_output + + def _apply_non_overlapping_constraints(self, pred_masks): + """ + Apply non-overlapping constraints to the object scores in pred_masks. Here we + keep only the highest scoring object at each spatial location in pred_masks. + """ + batch_size = pred_masks.size(0) + if batch_size == 1: + return pred_masks + + device = pred_masks.device + # "max_obj_inds": object index of the object with the highest score at each location + max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) + # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks` + batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] + keep = max_obj_inds == batch_obj_inds + # suppress overlapping regions' scores below -10.0 so that the foreground regions + # don't overlap (here sigmoid(-10.0)=4.5398e-05) + pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) + return pred_masks diff --git a/avs.code/v1s.code/model/visual/sam2/modeling/sam2_utils.py b/avs.code/v1s.code/model/visual/sam2/modeling/sam2_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..19133558dd657bbcf67f851011d45bd4999cab0a --- /dev/null +++ b/avs.code/v1s.code/model/visual/sam2/modeling/sam2_utils.py @@ -0,0 +1,323 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import copy +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model.visual.sam2.utils.misc import mask_to_box + + +def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): + """ + Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs` + that are temporally closest to the current frame at `frame_idx`. Here, we take + - a) the closest conditioning frame before `frame_idx` (if any); + - b) the closest conditioning frame after `frame_idx` (if any); + - c) any other temporally closest conditioning frames until reaching a total + of `max_cond_frame_num` conditioning frames. + + Outputs: + - selected_outputs: selected items (keys & values) from `cond_frame_outputs`. + - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`. + """ + if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: + selected_outputs = cond_frame_outputs + unselected_outputs = {} + else: + assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" + selected_outputs = {} + + # the closest conditioning frame before `frame_idx` (if any) + idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None) + if idx_before is not None: + selected_outputs[idx_before] = cond_frame_outputs[idx_before] + + # the closest conditioning frame after `frame_idx` (if any) + idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None) + if idx_after is not None: + selected_outputs[idx_after] = cond_frame_outputs[idx_after] + + # add other temporally closest conditioning frames until reaching a total + # of `max_cond_frame_num` conditioning frames. + num_remain = max_cond_frame_num - len(selected_outputs) + inds_remain = sorted( + (t for t in cond_frame_outputs if t not in selected_outputs), + key=lambda x: abs(x - frame_idx), + )[:num_remain] + selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) + unselected_outputs = { + t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs + } + + return selected_outputs, unselected_outputs + + +def get_1d_sine_pe(pos_inds, dim, temperature=10000): + """ + Get 1D sine positional embedding as in the original Transformer paper. + """ + pe_dim = dim // 2 + dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) + dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) + + pos_embed = pos_inds.unsqueeze(-1) / dim_t + pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) + return pos_embed + + +def get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") + + +def get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +class DropPath(nn.Module): + # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py + def __init__(self, drop_prob=0.0, scale_by_keep=True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and self.scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + activation: nn.Module = nn.ReLU, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + self.act = activation() + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +def sample_box_points( + masks: torch.Tensor, + noise: float = 0.1, # SAM default + noise_bound: int = 20, # SAM default + top_left_label: int = 2, + bottom_right_label: int = 3, +) -> Tuple[np.array, np.array]: + """ + Sample a noised version of the top left and bottom right corners of a given `bbox` + + Inputs: + - masks: [B, 1, H,W] boxes, dtype=torch.Tensor + - noise: noise as a fraction of box width and height, dtype=float + - noise_bound: maximum amount of noise (in pure pixesl), dtype=int + + Returns: + - box_coords: [B, num_pt, 2], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.float + - box_labels: [B, num_pt], label 2 is reserverd for top left and 3 for bottom right corners, dtype=torch.int32 + """ + device = masks.device + box_coords = mask_to_box(masks) + B, _, H, W = masks.shape + box_labels = torch.tensor( + [top_left_label, bottom_right_label], dtype=torch.int, device=device + ).repeat(B) + if noise > 0.0: + if not isinstance(noise_bound, torch.Tensor): + noise_bound = torch.tensor(noise_bound, device=device) + bbox_w = box_coords[..., 2] - box_coords[..., 0] + bbox_h = box_coords[..., 3] - box_coords[..., 1] + max_dx = torch.min(bbox_w * noise, noise_bound) + max_dy = torch.min(bbox_h * noise, noise_bound) + box_noise = 2 * torch.rand(B, 1, 4, device=device) - 1 + box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1) + + box_coords = box_coords + box_noise + img_bounds = ( + torch.tensor([W, H, W, H], device=device) - 1 + ) # uncentered pixel coords + box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping + + box_coords = box_coords.reshape(-1, 2, 2) # always 2 points + box_labels = box_labels.reshape(-1, 2) + return box_coords, box_labels + + +def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1): + """ + Sample `num_pt` random points (along with their labels) independently from the error regions. + + Inputs: + - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool + - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None + - num_pt: int, number of points to sample independently for each of the B error maps + + Outputs: + - points: [B, num_pt, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point + - labels: [B, num_pt], dtype=torch.int32, where 1 means positive clicks and 0 means + negative clicks + """ + if pred_masks is None: # if pred_masks is not provided, treat it as empty + pred_masks = torch.zeros_like(gt_masks) + assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1 + assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape + assert num_pt >= 0 + + B, _, H_im, W_im = gt_masks.shape + device = gt_masks.device + + # false positive region, a new point sampled in this region should have + # negative label to correct the FP error + fp_masks = ~gt_masks & pred_masks + # false negative region, a new point sampled in this region should have + # positive label to correct the FN error + fn_masks = gt_masks & ~pred_masks + # whether the prediction completely match the ground-truth on each mask + all_correct = torch.all((gt_masks == pred_masks).flatten(2), dim=2) + all_correct = all_correct[..., None, None] + + # channel 0 is FP map, while channel 1 is FN map + pts_noise = torch.rand(B, num_pt, H_im, W_im, 2, device=device) + # sample a negative new click from FP region or a positive new click + # from FN region, depend on where the maximum falls, + # and in case the predictions are all correct (no FP or FN), we just + # sample a negative click from the background region + pts_noise[..., 0] *= fp_masks | (all_correct & ~gt_masks) + pts_noise[..., 1] *= fn_masks + pts_idx = pts_noise.flatten(2).argmax(dim=2) + labels = (pts_idx % 2).to(torch.int32) + pts_idx = pts_idx // 2 + pts_x = pts_idx % W_im + pts_y = pts_idx // W_im + points = torch.stack([pts_x, pts_y], dim=2).to(torch.float) + return points, labels + + +def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True): + """ + Sample 1 random point (along with its label) from the center of each error region, + that is, the point with the largest distance to the boundary of each error region. + This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py + + Inputs: + - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool + - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None + - padding: if True, pad with boundary of 1 px for distance transform + + Outputs: + - points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point + - labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks + """ + import cv2 + + if pred_masks is None: + pred_masks = torch.zeros_like(gt_masks) + assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1 + assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape + + B, _, _, W_im = gt_masks.shape + device = gt_masks.device + + # false positive region, a new point sampled in this region should have + # negative label to correct the FP error + fp_masks = ~gt_masks & pred_masks + # false negative region, a new point sampled in this region should have + # positive label to correct the FN error + fn_masks = gt_masks & ~pred_masks + + fp_masks = fp_masks.cpu().numpy() + fn_masks = fn_masks.cpu().numpy() + points = torch.zeros(B, 1, 2, dtype=torch.float) + labels = torch.ones(B, 1, dtype=torch.int32) + for b in range(B): + fn_mask = fn_masks[b, 0] + fp_mask = fp_masks[b, 0] + if padding: + fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant") + fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant") + # compute the distance of each point in FN/FP region to its boundary + fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0) + fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0) + if padding: + fn_mask_dt = fn_mask_dt[1:-1, 1:-1] + fp_mask_dt = fp_mask_dt[1:-1, 1:-1] + + # take the point in FN/FP region with the largest distance to its boundary + fn_mask_dt_flat = fn_mask_dt.reshape(-1) + fp_mask_dt_flat = fp_mask_dt.reshape(-1) + fn_argmax = np.argmax(fn_mask_dt_flat) + fp_argmax = np.argmax(fp_mask_dt_flat) + is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax] + pt_idx = fn_argmax if is_positive else fp_argmax + points[b, 0, 0] = pt_idx % W_im # x + points[b, 0, 1] = pt_idx // W_im # y + labels[b, 0] = int(is_positive) + + points = points.to(device) + labels = labels.to(device) + return points, labels + + +def get_next_point(gt_masks, pred_masks, method): + if method == "uniform": + return sample_random_points_from_errors(gt_masks, pred_masks) + elif method == "center": + return sample_one_point_from_error_center(gt_masks, pred_masks) + else: + raise ValueError(f"unknown sampling method {method}") diff --git a/avs.code/v1s.code/model/visual/sam2/organised_sam2_train.py b/avs.code/v1s.code/model/visual/sam2/organised_sam2_train.py new file mode 100644 index 0000000000000000000000000000000000000000..607c3ad22ba7dcb7eb74c30e1283f68c4808450e --- /dev/null +++ b/avs.code/v1s.code/model/visual/sam2/organised_sam2_train.py @@ -0,0 +1,811 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import numpy as np +import torch +import torch.distributed +from model.visual.sam2.modeling.sam2_base import SAM2Base +from model.visual.sam2.modeling.sam2_utils import ( + get_1d_sine_pe, + get_next_point, + sample_box_points, + select_closest_cond_frames, +) + +from utils.misc import concat_points + +from utils.data_utils import BatchedVideoDatapoint + + +class SAM2Train(SAM2Base): + def __init__( + self, + image_encoder, + memory_attention=None, + memory_encoder=None, + prob_to_use_pt_input_for_train=0.0, + prob_to_use_pt_input_for_eval=0.0, + prob_to_use_box_input_for_train=0.0, + prob_to_use_box_input_for_eval=0.0, + # if it is greater than 1, we interactive point sampling in the 1st frame and other randomly selected frames + num_frames_to_correct_for_train=1, # default: only iteratively sample on first frame + num_frames_to_correct_for_eval=1, # default: only iteratively sample on first frame + rand_frames_to_correct_for_train=False, + rand_frames_to_correct_for_eval=False, + # how many frames to use as initial conditioning frames (for both point input and mask input; the first frame is always used as an initial conditioning frame) + # - if `rand_init_cond_frames` below is True, we randomly sample 1~num_init_cond_frames initial conditioning frames + # - otherwise we sample a fixed number of num_init_cond_frames initial conditioning frames + # note: for point input, we sample correction points on all such initial conditioning frames, and we require that `num_frames_to_correct` >= `num_init_cond_frames`; + # these are initial conditioning frames because as we track the video, more conditioning frames might be added + # when a frame receives correction clicks under point input if `add_all_frames_to_correct_as_cond=True` + num_init_cond_frames_for_train=1, # default: only use the first frame as initial conditioning frame + num_init_cond_frames_for_eval=1, # default: only use the first frame as initial conditioning frame + rand_init_cond_frames_for_train=True, # default: random 1~num_init_cond_frames_for_train cond frames (to be constent w/ previous TA data loader) + rand_init_cond_frames_for_eval=False, + # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click + # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames + add_all_frames_to_correct_as_cond=False, + # how many additional correction points to sample (on each frame selected to be corrected) + # note that the first frame receives an initial input click (in addition to any correction clicks) + num_correction_pt_per_frame=7, + # method for point sampling during evaluation + # "uniform" (sample uniformly from error region) or "center" (use the point with the largest distance to error region boundary) + # default to "center" to be consistent with evaluation in the SAM paper + pt_sampling_for_eval="center", + # During training, we optionally allow sampling the correction points from GT regions + # instead of the prediction error regions with a small probability. This might allow the + # model to overfit less to the error regions in training datasets + prob_to_sample_from_gt_for_train=0.0, + use_act_ckpt_iterative_pt_sampling=False, + # whether to forward image features per frame (as it's being tracked) during evaluation, instead of forwarding image features + # of all frames at once. This avoids backbone OOM errors on very long videos in evaluation, but could be slightly slower. + forward_backbone_per_frame_for_eval=False, + freeze_image_encoder=False, + **kwargs, + ): + super().__init__(image_encoder, memory_attention, memory_encoder, **kwargs) + self.use_act_ckpt_iterative_pt_sampling = use_act_ckpt_iterative_pt_sampling + self.forward_backbone_per_frame_for_eval = forward_backbone_per_frame_for_eval + + # Point sampler and conditioning frames + self.prob_to_use_pt_input_for_train = prob_to_use_pt_input_for_train + self.prob_to_use_box_input_for_train = prob_to_use_box_input_for_train + self.prob_to_use_pt_input_for_eval = prob_to_use_pt_input_for_eval + self.prob_to_use_box_input_for_eval = prob_to_use_box_input_for_eval + if prob_to_use_pt_input_for_train > 0 or prob_to_use_pt_input_for_eval > 0: + logging.info( + f"Training with points (sampled from masks) as inputs with p={prob_to_use_pt_input_for_train}" + ) + assert num_frames_to_correct_for_train >= num_init_cond_frames_for_train + assert num_frames_to_correct_for_eval >= num_init_cond_frames_for_eval + + self.num_frames_to_correct_for_train = num_frames_to_correct_for_train + self.num_frames_to_correct_for_eval = num_frames_to_correct_for_eval + self.rand_frames_to_correct_for_train = rand_frames_to_correct_for_train + self.rand_frames_to_correct_for_eval = rand_frames_to_correct_for_eval + # Initial multi-conditioning frames + self.num_init_cond_frames_for_train = num_init_cond_frames_for_train + self.num_init_cond_frames_for_eval = num_init_cond_frames_for_eval + self.rand_init_cond_frames_for_train = rand_init_cond_frames_for_train + self.rand_init_cond_frames_for_eval = rand_init_cond_frames_for_eval + self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond + self.num_correction_pt_per_frame = num_correction_pt_per_frame + self.pt_sampling_for_eval = pt_sampling_for_eval + self.prob_to_sample_from_gt_for_train = prob_to_sample_from_gt_for_train + # A random number generator with a fixed initial seed across GPUs + self.rng = np.random.default_rng(seed=42) + if freeze_image_encoder: + for p in self.image_encoder.parameters(): + p.requires_grad = False + + + def forward(self, input: BatchedVideoDatapoint): + if self.training or not self.forward_backbone_per_frame_for_eval: + # precompute image features on all frames before tracking + backbone_out = self.forward_image(input.flat_img_batch) + else: + # defer image feature computation on a frame until it's being tracked + backbone_out = {"backbone_fpn": None, "vision_pos_enc": None} + backbone_out = self.prepare_prompt_inputs(backbone_out, input) + previous_stages_out = self.forward_tracking(backbone_out, input) + + return previous_stages_out + + def _prepare_backbone_features_per_frame(self, img_batch, img_ids): + """Compute the image backbone features on the fly for the given img_ids.""" + # Only forward backbone on unique image ids to avoid repetitive computation + # (if `img_ids` has only one element, it's already unique so we skip this step). + if img_ids.numel() > 1: + unique_img_ids, inv_ids = torch.unique(img_ids, return_inverse=True) + else: + unique_img_ids, inv_ids = img_ids, None + + # Compute the image features on those unique image ids + image = img_batch[unique_img_ids] + backbone_out = self.forward_image(image) + ( + _, + vision_feats, + vision_pos_embeds, + feat_sizes, + ) = self._prepare_backbone_features(backbone_out) + ''' + vision_feats + torch.Size([65536, 5, 32]) + torch.Size([16384, 5, 64]) + torch.Size([4096, 5, 256]) + ''' + # Inverse-map image features for `unique_img_ids` to the final image features + # for the original input `img_ids`. + if inv_ids is not None: + image = image[inv_ids] + vision_feats = [x[:, inv_ids] for x in vision_feats] + vision_pos_embeds = [x[:, inv_ids] for x in vision_pos_embeds] + + return image, vision_feats, vision_pos_embeds, feat_sizes + + @staticmethod + def dont_prepare_prompt_inputs(backbone_out, num_frames=5, cond_frame=0): + backbone_out["gt_masks_per_frame"] = {} + backbone_out["num_frames"] = num_frames + backbone_out["use_pt_input"] = False + # always start from the first frame. + backbone_out["init_cond_frames"] = [cond_frame] + backbone_out["frames_not_in_init_cond"] = [i for i in range(0, num_frames) if i != cond_frame] + # backbone_out["init_cond_frames"] = [] + # backbone_out["frames_not_in_init_cond"] = [i for i in range(0, num_frames)] + + backbone_out["mask_inputs_per_frame"] = {} + backbone_out["point_inputs_per_frame"] = {} + backbone_out["frames_to_add_correction_pt"] = [] + return backbone_out + + def prepare_prompt_inputs(self, backbone_out, input, start_frame_idx=0): + """ + Prepare input mask, point or box prompts. Optionally, we allow tracking from + a custom `start_frame_idx` to the end of the video (for evaluation purposes). + """ + # Load the ground-truth masks on all frames (so that we can later + # sample correction points from them) + # gt_masks_per_frame = { + # stage_id: targets.segments.unsqueeze(1) # [B, 1, H_im, W_im] + # for stage_id, targets in enumerate(input.find_targets) + # } + gt_masks_per_frame = { + stage_id: masks.unsqueeze(1) # [B, 1, H_im, W_im] + for stage_id, masks in enumerate(input.masks) + } + # gt_masks_per_frame = input.masks.unsqueeze(2) # [T,B,1,H_im,W_im] keep everything in tensor form + backbone_out["gt_masks_per_frame"] = gt_masks_per_frame + num_frames = input.num_frames + backbone_out["num_frames"] = num_frames + + # Randomly decide whether to use point inputs or mask inputs + if self.training: + prob_to_use_pt_input = self.prob_to_use_pt_input_for_train + prob_to_use_box_input = self.prob_to_use_box_input_for_train + num_frames_to_correct = self.num_frames_to_correct_for_train + rand_frames_to_correct = self.rand_frames_to_correct_for_train + num_init_cond_frames = self.num_init_cond_frames_for_train + rand_init_cond_frames = self.rand_init_cond_frames_for_train + else: + prob_to_use_pt_input = self.prob_to_use_pt_input_for_eval + prob_to_use_box_input = self.prob_to_use_box_input_for_eval + num_frames_to_correct = self.num_frames_to_correct_for_eval + rand_frames_to_correct = self.rand_frames_to_correct_for_eval + num_init_cond_frames = self.num_init_cond_frames_for_eval + rand_init_cond_frames = self.rand_init_cond_frames_for_eval + if num_frames == 1: + # here we handle a special case for mixing video + SAM on image training, + # where we force using point input for the SAM task on static images + prob_to_use_pt_input = 1.0 + num_frames_to_correct = 1 + num_init_cond_frames = 1 + assert num_init_cond_frames >= 1 + # (here `self.rng.random()` returns value in range 0.0 <= X < 1.0) + use_pt_input = self.rng.random() < prob_to_use_pt_input + if rand_init_cond_frames and num_init_cond_frames > 1: + # randomly select 1 to `num_init_cond_frames` frames as initial conditioning frames + num_init_cond_frames = self.rng.integers( + 1, num_init_cond_frames, endpoint=True + ) + if ( + use_pt_input + and rand_frames_to_correct + and num_frames_to_correct > num_init_cond_frames + ): + # randomly select `num_init_cond_frames` to `num_frames_to_correct` frames to sample + # correction clicks (only for the case of point input) + num_frames_to_correct = self.rng.integers( + num_init_cond_frames, num_frames_to_correct, endpoint=True + ) + backbone_out["use_pt_input"] = use_pt_input + + # Sample initial conditioning frames + if num_init_cond_frames == 1: + init_cond_frames = [start_frame_idx] # starting frame + else: + # starting frame + randomly selected remaining frames (without replacement) + init_cond_frames = [start_frame_idx] + self.rng.choice( + range(start_frame_idx + 1, num_frames), + num_init_cond_frames - 1, + replace=False, + ).tolist() + backbone_out["init_cond_frames"] = init_cond_frames + backbone_out["frames_not_in_init_cond"] = [ + t for t in range(start_frame_idx, num_frames) if t not in init_cond_frames + ] + # Prepare mask or point inputs on initial conditioning frames + backbone_out["mask_inputs_per_frame"] = {} # {frame_idx: } + backbone_out["point_inputs_per_frame"] = {} # {frame_idx: } + for t in init_cond_frames: + if not use_pt_input: + backbone_out["mask_inputs_per_frame"][t] = gt_masks_per_frame[t] + else: + # During training # P(box) = prob_to_use_pt_input * prob_to_use_box_input + use_box_input = self.rng.random() < prob_to_use_box_input + if use_box_input: + points, labels = sample_box_points( + gt_masks_per_frame[t], + ) + else: + # (here we only sample **one initial point** on initial conditioning frames from the + # ground-truth mask; we may sample more correction points on the fly) + points, labels = get_next_point( + gt_masks=gt_masks_per_frame[t], + pred_masks=None, + method=( + "uniform" if self.training else self.pt_sampling_for_eval + ), + ) + + point_inputs = {"point_coords": points, "point_labels": labels} + backbone_out["point_inputs_per_frame"][t] = point_inputs + + # Sample frames where we will add correction clicks on the fly + # based on the error between prediction and ground-truth masks + if not use_pt_input: + # no correction points will be sampled when using mask inputs + frames_to_add_correction_pt = [] + elif num_frames_to_correct == num_init_cond_frames: + frames_to_add_correction_pt = init_cond_frames + else: + assert num_frames_to_correct > num_init_cond_frames + # initial cond frame + randomly selected remaining frames (without replacement) + extra_num = num_frames_to_correct - num_init_cond_frames + frames_to_add_correction_pt = ( + init_cond_frames + + self.rng.choice( + backbone_out["frames_not_in_init_cond"], extra_num, replace=False + ).tolist() + ) + backbone_out["frames_to_add_correction_pt"] = frames_to_add_correction_pt + + return backbone_out + + def forward_tracking_wo_prompt(self, backbone_out, audio_res=None, return_dict=False): + # img_feats_already_computed = True. + """Forward video tracking on each frame (and sample correction clicks).""" + # Prepare the backbone features + # - vision_feats and vision_pos_embeds are in (HW)BC format + ( + _, + vision_feats, + vision_pos_embeds, + feat_sizes, + ) = self._prepare_backbone_features(backbone_out) + + # Starting the stage loop + num_frames = backbone_out["num_frames"] + init_cond_frames = backbone_out["init_cond_frames"] + frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"] + # first process all the initial conditioning frames to encode them as memory, + # and then conditioning on them to track the remaining frames + processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"] + output_dict = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + + av_v_feats, av_a_feats = audio_res + for stage_id in processing_order: + # Get the image features for the current frames + img_ids = stage_id + # Retrieve image features according to img_ids (if they are already computed). + current_vision_feats = [x[:, img_ids].unsqueeze(1) for x in vision_feats] # add unsqueeze to maintain single sample. + current_vision_pos_embeds = [x[:, img_ids].unsqueeze(1) for x in vision_pos_embeds] # add unsqueeze to maintain single sample. + current_av_v_feats = [x[img_ids] for x in av_v_feats] + current_av_a_feats = [x[img_ids] for x in av_a_feats] + + # Get output masks based on this frame's prompts and previous memory + current_out = self.track_step_wo_prompt( + frame_idx=stage_id, + is_init_cond_frame=stage_id in init_cond_frames, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=None, # backbone_out["point_inputs_per_frame"].get(stage_id, None), + mask_inputs=None, # backbone_out["mask_inputs_per_frame"].get(stage_id, None), + gt_masks=None, # backbone_out["gt_masks_per_frame"].get(stage_id, None), + frames_to_add_correction_pt=None, # frames_to_add_correction_pt, + output_dict=output_dict, + num_frames=num_frames, + audio_res=(current_av_v_feats, current_av_a_feats), + ) + # Append the output, depending on whether it's a conditioning frame + add_output_as_cond_frame = stage_id in init_cond_frames or ( + self.add_all_frames_to_correct_as_cond + and stage_id in frames_to_add_correction_pt + ) + if add_output_as_cond_frame: + output_dict["cond_frame_outputs"][stage_id] = current_out + else: + output_dict["non_cond_frame_outputs"][stage_id] = current_out + + if return_dict: + return output_dict + # turn `output_dict` into a list for loss function + all_frame_outputs = {} + all_frame_outputs.update(output_dict["cond_frame_outputs"]) + all_frame_outputs.update(output_dict["non_cond_frame_outputs"]) + all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)] + # Make DDP happy with activation checkpointing by removing unused keys + all_frame_outputs = [ + {k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs + ] + + + return all_frame_outputs + + def track_step_wo_prompt( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + run_mem_encoder=True, # Whether to run the memory encoder on the predicted masks. + prev_sam_mask_logits=None, # The previously predicted SAM mask logits. + frames_to_add_correction_pt=None, + gt_masks=None, + audio_res=None, + ): + if frames_to_add_correction_pt is None: + frames_to_add_correction_pt = [] + + current_out, sam_outputs, high_res_features, pix_feat = self._track_step_wo_prompt( + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + audio_res + ) + + ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = sam_outputs + current_out["multistep_pred_masks"] = low_res_masks + current_out["multistep_pred_masks_high_res"] = high_res_masks + current_out["multistep_pred_multimasks"] = [low_res_multimasks] + current_out["multistep_pred_multimasks_high_res"] = [high_res_multimasks] + current_out["multistep_pred_ious"] = [ious] + current_out["multistep_point_inputs"] = [point_inputs] + current_out["multistep_object_score_logits"] = [object_score_logits] + + ''' + # Optionally, sample correction points iteratively to correct the mask + if frame_idx in frames_to_add_correction_pt: + point_inputs, final_sam_outputs = self._iter_correct_pt_sampling( + is_init_cond_frame, + point_inputs, + gt_masks, + high_res_features, + pix_feat, + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + object_score_logits, + current_out, + ) + ( + _, + _, + _, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = final_sam_outputs + ''' + # Use the final prediction (after all correction steps for output and eval) + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + + self._encode_memory_in_output( + current_vision_feats, + feat_sizes, + 666., # point_inputs, + run_mem_encoder, + # we follow SAM2 predictor, if we have multiple masks output, we only utilise the first one to perform + # the memory rope attention. + high_res_masks, + object_score_logits, + current_out, + ) + return current_out + + def _track_step_wo_prompt( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + audio_res=None + ): + current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(current_vision_feats) > 1: + high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) + ] + else: + high_res_features = None + if mask_inputs is not None and self.use_mask_input_as_output_without_sam: # False + # When use_mask_input_as_output_without_sam=True, we directly output the mask input + # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. + pix_feat = current_vision_feats[-1].permute(1, 2, 0) + pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + sam_outputs = self._use_mask_as_output( + pix_feat, high_res_features, mask_inputs + ) + else: + # fused the visual feature with previous memory features in the memory bank + pix_feat = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats[-1:], + current_vision_pos_embeds=current_vision_pos_embeds[-1:], + feat_sizes=feat_sizes[-1:], + output_dict=output_dict, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + ) + # current_vision_feats[-1] = current_vision_feats[-1] + self.no_mem_embed + # pix_feat = current_vision_feats[-1].permute(1, 2, 0) + # pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + + # we do not apply any prompts except audio. + ''' + # apply SAM-style segmentation head + # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, + # e.g. in demo where such logits come from earlier interaction instead of correction sampling + # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) + # if prev_sam_mask_logits is not None: + # assert point_inputs is not None and mask_inputs is None + # mask_inputs = prev_sam_mask_logits + + ## comment this line, as we don't use points as prompts. + # multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + ''' + + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=True, + audio_res=audio_res + ) + + return current_out, sam_outputs, high_res_features, pix_feat + + def forward_tracking( + self, backbone_out, input: BatchedVideoDatapoint, return_dict=False + ): + """Forward video tracking on each frame (and sample correction clicks).""" + img_feats_already_computed = backbone_out["backbone_fpn"] is not None + if img_feats_already_computed: + # Prepare the backbone features + # - vision_feats and vision_pos_embeds are in (HW)BC format + ( + _, + vision_feats, + vision_pos_embeds, + feat_sizes, + ) = self._prepare_backbone_features(backbone_out) + + # Starting the stage loop + num_frames = backbone_out["num_frames"] + init_cond_frames = backbone_out["init_cond_frames"] + frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"] + # first process all the initial conditioning frames to encode them as memory, + # and then conditioning on them to track the remaining frames + processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"] + output_dict = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + for stage_id in processing_order: + # Get the image features for the current frames + # img_ids = input.find_inputs[stage_id].img_ids + img_ids = input.flat_obj_to_img_idx[stage_id] + if img_feats_already_computed: + # Retrieve image features according to img_ids (if they are already computed). + current_vision_feats = [x[:, img_ids] for x in vision_feats] + current_vision_pos_embeds = [x[:, img_ids] for x in vision_pos_embeds] + else: + # Otherwise, compute the image features on the fly for the given img_ids + # (this might be used for evaluation on long videos to avoid backbone OOM). + ( + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._prepare_backbone_features_per_frame( + input.flat_img_batch, img_ids + ) + + # Get output masks based on this frame's prompts and previous memory + current_out = self.track_step( + frame_idx=stage_id, + is_init_cond_frame=stage_id in init_cond_frames, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=backbone_out["point_inputs_per_frame"].get(stage_id, None), + mask_inputs=backbone_out["mask_inputs_per_frame"].get(stage_id, None), + gt_masks=backbone_out["gt_masks_per_frame"].get(stage_id, None), + frames_to_add_correction_pt=frames_to_add_correction_pt, + output_dict=output_dict, + num_frames=num_frames, + ) + # Append the output, depending on whether it's a conditioning frame + add_output_as_cond_frame = stage_id in init_cond_frames or ( + self.add_all_frames_to_correct_as_cond + and stage_id in frames_to_add_correction_pt + ) + if add_output_as_cond_frame: + output_dict["cond_frame_outputs"][stage_id] = current_out + else: + output_dict["non_cond_frame_outputs"][stage_id] = current_out + + if return_dict: + return output_dict + # turn `output_dict` into a list for loss function + all_frame_outputs = {} + all_frame_outputs.update(output_dict["cond_frame_outputs"]) + all_frame_outputs.update(output_dict["non_cond_frame_outputs"]) + all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)] + # Make DDP happy with activation checkpointing by removing unused keys + all_frame_outputs = [ + {k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs + ] + + return all_frame_outputs + + def track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + run_mem_encoder=True, # Whether to run the memory encoder on the predicted masks. + prev_sam_mask_logits=None, # The previously predicted SAM mask logits. + frames_to_add_correction_pt=None, + gt_masks=None, + ): + if frames_to_add_correction_pt is None: + frames_to_add_correction_pt = [] + current_out, sam_outputs, high_res_features, pix_feat = self._track_step( + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ) + + ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = sam_outputs + + current_out["multistep_pred_masks"] = low_res_masks + current_out["multistep_pred_masks_high_res"] = high_res_masks + current_out["multistep_pred_multimasks"] = [low_res_multimasks] + current_out["multistep_pred_multimasks_high_res"] = [high_res_multimasks] + current_out["multistep_pred_ious"] = [ious] + current_out["multistep_point_inputs"] = [point_inputs] + current_out["multistep_object_score_logits"] = [object_score_logits] + + # Optionally, sample correction points iteratively to correct the mask + if frame_idx in frames_to_add_correction_pt: + point_inputs, final_sam_outputs = self._iter_correct_pt_sampling( + is_init_cond_frame, + point_inputs, + gt_masks, + high_res_features, + pix_feat, + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + object_score_logits, + current_out, + ) + ( + _, + _, + _, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = final_sam_outputs + + # Use the final prediction (after all correction steps for output and eval) + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + self._encode_memory_in_output( + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ) + return current_out + + def _iter_correct_pt_sampling( + self, + is_init_cond_frame, + point_inputs, + gt_masks, + high_res_features, + pix_feat_with_mem, + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + object_score_logits, + current_out, + ): + + assert gt_masks is not None + all_pred_masks = [low_res_masks] + all_pred_high_res_masks = [high_res_masks] + all_pred_multimasks = [low_res_multimasks] + all_pred_high_res_multimasks = [high_res_multimasks] + all_pred_ious = [ious] + all_point_inputs = [point_inputs] + all_object_score_logits = [object_score_logits] + for _ in range(self.num_correction_pt_per_frame): + # sample a new point from the error between prediction and ground-truth + # (with a small probability, directly sample from GT masks instead of errors) + if self.training and self.prob_to_sample_from_gt_for_train > 0: + sample_from_gt = ( + self.rng.random() < self.prob_to_sample_from_gt_for_train + ) + else: + sample_from_gt = False + # if `pred_for_new_pt` is None, only GT masks will be used for point sampling + pred_for_new_pt = None if sample_from_gt else (high_res_masks > 0) + new_points, new_labels = get_next_point( + gt_masks=gt_masks, + pred_masks=pred_for_new_pt, + method="uniform" if self.training else self.pt_sampling_for_eval, + ) + point_inputs = concat_points(point_inputs, new_points, new_labels) + # Feed the mask logits of the previous SAM outputs in the next SAM decoder step. + # For tracking, this means that when the user adds a correction click, we also feed + # the tracking output mask logits along with the click as input to the SAM decoder. + mask_inputs = low_res_masks + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + if self.use_act_ckpt_iterative_pt_sampling and not multimask_output: + sam_outputs = torch.utils.checkpoint.checkpoint( + self._forward_sam_heads, + backbone_features=pix_feat_with_mem, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + use_reentrant=False, + ) + else: + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat_with_mem, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + _, + object_score_logits, + ) = sam_outputs + all_pred_masks.append(low_res_masks) + all_pred_high_res_masks.append(high_res_masks) + all_pred_multimasks.append(low_res_multimasks) + all_pred_high_res_multimasks.append(high_res_multimasks) + all_pred_ious.append(ious) + all_point_inputs.append(point_inputs) + all_object_score_logits.append(object_score_logits) + + # Concatenate the masks along channel (to compute losses on all of them, + # using `MultiStepIteractiveMasks`) + current_out["multistep_pred_masks"] = torch.cat(all_pred_masks, dim=1) + current_out["multistep_pred_masks_high_res"] = torch.cat( + all_pred_high_res_masks, dim=1 + ) + current_out["multistep_pred_multimasks"] = all_pred_multimasks + current_out["multistep_pred_multimasks_high_res"] = all_pred_high_res_multimasks + current_out["multistep_pred_ious"] = all_pred_ious + current_out["multistep_point_inputs"] = all_point_inputs + current_out["multistep_object_score_logits"] = all_object_score_logits + + return point_inputs, sam_outputs diff --git a/avs.code/v1s.code/model/visual/sam2/utils/__init__.py b/avs.code/v1s.code/model/visual/sam2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/avs.code/v1s.code/model/visual/sam2/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/avs.code/v1s.code/model/visual/sam2/utils/misc.py b/avs.code/v1s.code/model/visual/sam2/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..b65ee825732ff85137805be650edd4cbe8e6f6d4 --- /dev/null +++ b/avs.code/v1s.code/model/visual/sam2/utils/misc.py @@ -0,0 +1,349 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import warnings +from threading import Thread + +import numpy as np +import torch +from PIL import Image +from tqdm import tqdm + + +def get_sdpa_settings(): + if torch.cuda.is_available(): + old_gpu = torch.cuda.get_device_properties(0).major < 7 + # only use Flash Attention on Ampere (8.0) or newer GPUs + use_flash_attn = torch.cuda.get_device_properties(0).major >= 8 + if not use_flash_attn: + warnings.warn( + "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.", + category=UserWarning, + stacklevel=2, + ) + # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only + # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases) + pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2]) + if pytorch_version < (2, 2): + warnings.warn( + f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. " + "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).", + category=UserWarning, + stacklevel=2, + ) + math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn + else: + old_gpu = True + use_flash_attn = False + math_kernel_on = True + + return old_gpu, use_flash_attn, math_kernel_on + + +def get_connected_components(mask): + """ + Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W). + + Inputs: + - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is + background. + + Outputs: + - labels: A tensor of shape (N, 1, H, W) containing the connected component labels + for foreground pixels and 0 for background pixels. + - counts: A tensor of shape (N, 1, H, W) containing the area of the connected + components for foreground pixels and 0 for background pixels. + """ + from sam2 import _C + + return _C.get_connected_componnets(mask.to(torch.uint8).contiguous()) + + +def mask_to_box(masks: torch.Tensor): + """ + compute bounding box given an input mask + + Inputs: + - masks: [B, 1, H, W] masks, dtype=torch.Tensor + + Returns: + - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor + """ + B, _, h, w = masks.shape + device = masks.device + xs = torch.arange(w, device=device, dtype=torch.int32) + ys = torch.arange(h, device=device, dtype=torch.int32) + grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy") + grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w) + grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w) + min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1) + max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1) + min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1) + max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1) + bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1) + + return bbox_coords + + +def _load_img_as_tensor(img_path, image_size): + img_pil = Image.open(img_path) + img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) + if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images + img_np = img_np / 255.0 + else: + raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}") + img = torch.from_numpy(img_np).permute(2, 0, 1) + video_width, video_height = img_pil.size # the original video size + return img, video_height, video_width + + +class AsyncVideoFrameLoader: + """ + A list of video frames to be load asynchronously without blocking session start. + """ + + def __init__( + self, + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ): + self.img_paths = img_paths + self.image_size = image_size + self.offload_video_to_cpu = offload_video_to_cpu + self.img_mean = img_mean + self.img_std = img_std + # items in `self.images` will be loaded asynchronously + self.images = [None] * len(img_paths) + # catch and raise any exceptions in the async loading thread + self.exception = None + # video_height and video_width be filled when loading the first image + self.video_height = None + self.video_width = None + self.compute_device = compute_device + + # load the first frame to fill video_height and video_width and also + # to cache it (since it's most likely where the user will click) + self.__getitem__(0) + + # load the rest of frames asynchronously without blocking the session start + def _load_frames(): + try: + for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"): + self.__getitem__(n) + except Exception as e: + self.exception = e + + self.thread = Thread(target=_load_frames, daemon=True) + self.thread.start() + + def __getitem__(self, index): + if self.exception is not None: + raise RuntimeError("Failure in frame loading thread") from self.exception + + img = self.images[index] + if img is not None: + return img + + img, video_height, video_width = _load_img_as_tensor( + self.img_paths[index], self.image_size + ) + self.video_height = video_height + self.video_width = video_width + # normalize by mean and std + img -= self.img_mean + img /= self.img_std + if not self.offload_video_to_cpu: + img = img.to(self.compute_device, non_blocking=True) + self.images[index] = img + return img + + def __len__(self): + return len(self.images) + + +def load_video_frames( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + compute_device=torch.device("cuda"), +): + """ + Load the video frames from video_path. The frames are resized to image_size as in + the model and are loaded to GPU if offload_video_to_cpu=False. This is used by the demo. + """ + is_bytes = isinstance(video_path, bytes) + is_str = isinstance(video_path, str) + is_mp4_path = is_str and os.path.splitext(video_path)[-1] in [".mp4", ".MP4"] + if is_bytes or is_mp4_path: + return load_video_frames_from_video_file( + video_path=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + compute_device=compute_device, + ) + elif is_str and os.path.isdir(video_path): + return load_video_frames_from_jpg_images( + video_path=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + async_loading_frames=async_loading_frames, + compute_device=compute_device, + ) + else: + raise NotImplementedError( + "Only MP4 video and JPEG folder are supported at this moment" + ) + + +def load_video_frames_from_jpg_images( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + compute_device=torch.device("cuda"), +): + """ + Load the video frames from a directory of JPEG files (".jpg" format). + + The frames are resized to image_size x image_size and are loaded to GPU if + `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. + + You can load a frame asynchronously by setting `async_loading_frames` to `True`. + """ + if isinstance(video_path, str) and os.path.isdir(video_path): + jpg_folder = video_path + else: + raise NotImplementedError( + "Only JPEG frames are supported at this moment. For video files, you may use " + "ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n" + "```\n" + "ffmpeg -i .mp4 -q:v 2 -start_number 0 /'%05d.jpg'\n" + "```\n" + "where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks " + "ffmpeg to start the JPEG file from 00000.jpg." + ) + + frame_names = [ + p + for p in os.listdir(jpg_folder) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + ] + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + num_frames = len(frame_names) + if num_frames == 0: + raise RuntimeError(f"no images found in {jpg_folder}") + img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + + if async_loading_frames: + lazy_images = AsyncVideoFrameLoader( + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ) + return lazy_images, lazy_images.video_height, lazy_images.video_width + + images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) + for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): + images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) + if not offload_video_to_cpu: + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def load_video_frames_from_video_file( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + compute_device=torch.device("cuda"), +): + """Load the video frames from a video file.""" + import decord + + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + # Get the original video height and width + decord.bridge.set_bridge("torch") + video_height, video_width, _ = decord.VideoReader(video_path).next().shape + # Iterate over all frames in the video + images = [] + for frame in decord.VideoReader(video_path, width=image_size, height=image_size): + images.append(frame.permute(2, 0, 1)) + + images = torch.stack(images, dim=0).float() / 255.0 + if not offload_video_to_cpu: + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def fill_holes_in_mask_scores(mask, max_area): + """ + A post processor to fill small holes in mask scores with area under `max_area`. + """ + # Holes are those connected components in background with area <= self.max_area + # (background regions are those with mask scores <= 0) + assert max_area > 0, "max_area must be positive" + + input_mask = mask + try: + labels, areas = get_connected_components(mask <= 0) + is_hole = (labels > 0) & (areas <= max_area) + # We fill holes with a small positive mask score (0.1) to change them to foreground. + mask = torch.where(is_hole, 0.1, mask) + except Exception as e: + # Skip the post-processing step on removing small holes if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. You can " + "still use SAM 2 and it's OK to ignore the error above, although some post-processing " + "functionality may be limited (which doesn't affect the results in most cases; see " + "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + mask = input_mask + + return mask + + +def concat_points(old_point_inputs, new_points, new_labels): + """Add new points and labels to previous point inputs (add at the end).""" + if old_point_inputs is None: + points, labels = new_points, new_labels + else: + points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1) + labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1) + + return {"point_coords": points, "point_labels": labels} diff --git a/avs.code/v1s.code/model/visual/sam2/utils/transforms.py b/avs.code/v1s.code/model/visual/sam2/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6d4fa6a3e4d2e2a0dde7f87e4991daff338467c4 --- /dev/null +++ b/avs.code/v1s.code/model/visual/sam2/utils/transforms.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import Normalize, Resize, ToTensor + + +class SAM2Transforms(nn.Module): + def __init__( + self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0 + ): + """ + Transforms for SAM2. + """ + super().__init__() + self.resolution = resolution + self.mask_threshold = mask_threshold + self.max_hole_area = max_hole_area + self.max_sprinkle_area = max_sprinkle_area + self.mean = [0.485, 0.456, 0.406] + self.std = [0.229, 0.224, 0.225] + self.to_tensor = ToTensor() + self.transforms = torch.jit.script( + nn.Sequential( + Resize((self.resolution, self.resolution)), + Normalize(self.mean, self.std), + ) + ) + + def __call__(self, x): + x = self.to_tensor(x) + return self.transforms(x) + + def forward_batch(self, img_list): + img_batch = [self.transforms(self.to_tensor(img)) for img in img_list] + img_batch = torch.stack(img_batch, dim=0) + return img_batch + + def transform_coords( + self, coords: torch.Tensor, normalize=False, orig_hw=None + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, + If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + + Returns + Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model. + """ + if normalize: + assert orig_hw is not None + h, w = orig_hw + coords = coords.clone() + coords[..., 0] = coords[..., 0] / w + coords[..., 1] = coords[..., 1] / h + + coords = coords * self.resolution # unnormalize coords + return coords + + def transform_boxes( + self, boxes: torch.Tensor, normalize=False, orig_hw=None + ) -> torch.Tensor: + """ + Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates, + if the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + """ + boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) + return boxes + + def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: + """ + Perform PostProcessing on output masks. + """ + from model.visual.sam2.utils.misc import get_connected_components + + masks = masks.float() + input_masks = masks + mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image + try: + if self.max_hole_area > 0: + # Holes are those connected components in background with area <= self.fill_hole_area + # (background regions are those with mask scores <= self.mask_threshold) + labels, areas = get_connected_components( + mask_flat <= self.mask_threshold + ) + is_hole = (labels > 0) & (areas <= self.max_hole_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with a small positive mask score (10.0) to change them to foreground. + masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) + + if self.max_sprinkle_area > 0: + labels, areas = get_connected_components( + mask_flat > self.mask_threshold + ) + is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with negative mask score (-10.0) to change them to background. + masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) + except Exception as e: + # Skip the post-processing step if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. You can " + "still use SAM 2 and it's OK to ignore the error above, although some post-processing " + "functionality may be limited (which doesn't affect the results in most cases; see " + "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + masks = input_masks + + masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) + return masks diff --git a/avs.code/v1s.code/tools/remap_aural_ckpt_keys.py b/avs.code/v1s.code/tools/remap_aural_ckpt_keys.py new file mode 100644 index 0000000000000000000000000000000000000000..cbb8d6086a854b1b0ab011542eafd99f4bf8a3bf --- /dev/null +++ b/avs.code/v1s.code/tools/remap_aural_ckpt_keys.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 +""" +Remap legacy checkpoint keys: rename audio_prompter.* to the current AuralFuser layout (aural_fuser.*), +and drop duplicate weights under training_layers / finetuning_layers. + +Usage: + python tools/remap_aural_ckpt_keys.py /path/to/model.pth [--in-place] [--no-backup] + +By default writes _remapped.pth; --in-place overwrites the input (after a .bak backup unless --no-backup). +""" +from __future__ import annotations + +import argparse +import shutil +from pathlib import Path + +import torch + +# Matches AuralFuser ModuleList names (old train_* indices start at 1; new indices are 0-based). +_REPLACEMENTS: list[tuple[str, str]] = [ + ("train_f_patch_embed1", "patch_embeds.0"), + ("train_f_patch_embed2", "patch_embeds.1"), + ("train_f_patch_embed3", "patch_embeds.2"), + ("train_f_a_block1", "fusion_modules.0"), + ("train_f_a_block2", "fusion_modules.1"), + ("train_f_a_block3", "fusion_modules.2"), + ("train_f_block1", "f_blocks.0"), + ("train_f_block2", "f_blocks.1"), + ("train_f_block3", "f_blocks.2"), + ("train_a_block1", "a_blocks.0"), + ("train_a_block2", "a_blocks.1"), + ("train_a_block3", "a_blocks.2"), + ("train_smooth1", "smooth_convs.0"), + ("train_smooth2", "smooth_convs.1"), +] + + +def remap_state_dict(sd: dict) -> dict: + out: dict = {} + dropped = 0 + for k, v in sd.items(): + if k.startswith("audio_prompter."): + if ".training_layers." in k or ".finetuning_layers." in k: + dropped += 1 + continue + nk = k.replace("audio_prompter.", "aural_fuser.", 1) + for old, new in _REPLACEMENTS: + nk = nk.replace(old, new) + out[nk] = v + else: + out[k] = v + if dropped: + print(f"Dropped duplicate keys: {dropped} (training_layers / finetuning_layers)") + return out + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("ckpt", type=Path, help="Input .pth (full-model state_dict)") + ap.add_argument( + "-o", "--output", type=Path, default=None, + help="Output path; default _remapped.pth", + ) + ap.add_argument("--in-place", action="store_true", help="Overwrite input file") + ap.add_argument("--no-backup", action="store_true", help="Skip .bak when using --in-place") + args = ap.parse_args() + + ckpt_path: Path = args.ckpt.resolve() + if not ckpt_path.is_file(): + raise SystemExit(f"File not found: {ckpt_path}") + + print(f"Loading: {ckpt_path}") + sd = torch.load(ckpt_path, map_location="cpu") + if not isinstance(sd, dict): + raise SystemExit("Expected top-level checkpoint to be a state_dict dict") + + n_old_ap = sum(1 for k in sd if k.startswith("audio_prompter.")) + if n_old_ap == 0: + print("Warning: no audio_prompter.* keys found; checkpoint may already be remapped.") + + new_sd = remap_state_dict(sd) + n_af = sum(1 for k in new_sd if k.startswith("aural_fuser.")) + print(f"aural_fuser key count: {n_af}") + + if args.in_place: + out = ckpt_path + if not args.no_backup: + bak = ckpt_path.with_suffix(ckpt_path.suffix + ".bak") + print(f"Backup -> {bak}") + shutil.copy2(ckpt_path, bak) + else: + out = args.output or ckpt_path.with_name(ckpt_path.stem + "_remapped.pth") + + torch.save(new_sd, out) + print(f"Saved: {out} ({len(new_sd)} tensor keys)") + + +if __name__ == "__main__": + main() diff --git a/avs.code/v1s.code/trainer/train.py b/avs.code/v1s.code/trainer/train.py new file mode 100644 index 0000000000000000000000000000000000000000..6622eb2e1dd1a57e725121cd259b8c138329df05 --- /dev/null +++ b/avs.code/v1s.code/trainer/train.py @@ -0,0 +1,179 @@ +"""Training and validation loop for the AV segmentation model.""" +import numpy +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + + +class Trainer: + """Wraps train/valid steps with optional loss, metrics, and logging.""" + + def __init__(self, hyp_param, loss, tensorboard, metrics): + self.param = hyp_param + self.loss = loss + self.tensorboard = tensorboard + self.metrics = metrics + from loss.training.contrastive_learning import ContrastLoss + self.cl = ContrastLoss(self.param) + + @torch.no_grad() + def valid(self, epoch, dataloader, model, process=''): + """Evaluate foreground IoU / F-score. `process` selects SAM multimask decoding (see branch below).""" + if not isinstance(dataloader, DataLoader): + raise TypeError( + "valid() expects a torch.utils.data.DataLoader (do not pass iter(dataloader) first)." + ) + self.metrics['foreground_iou'].reset() + self.metrics['foreground_f-score'].reset() + dataloader_length = len(dataloader) + tbar = range(dataloader_length) + tbar = tqdm(tbar, ncols=135) if self.param.local_rank <= 0 else tbar + iou_pool = [None] * self.param.gpus + fscore_pool = [None] * self.param.gpus + + data_iter = iter(dataloader) + for batch_index in tbar: + items = next(data_iter) + frame, spect, label, prompt_dicts = items['frame'], items['spectrogram'], items['label'], items['prompts'] + + frame = torch.flatten(frame, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) + spect = torch.flatten(spect, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) + label = torch.flatten(label, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) + + with torch.autocast("cuda", dtype=torch.bfloat16): + outputs, _ = model.module(frame, spect, prompt_dicts, sam_process=True) + logits = torch.cat([torch.cat(i['multistep_pred_multimasks_high_res']) for i in outputs]) + ious_scores = torch.cat([torch.cat(i['multistep_pred_ious']) for i in outputs]) + occ_scores = torch.cat([torch.cat(i['multistep_object_score_logits']) for i in outputs]) + # process: '' = first multimask; iou_select = argmax IoU head; iou_occ_select = + objectness gate + if process == 'iou_select': + ious_scores = torch.argmax(ious_scores, dim=1) + logits = logits[torch.arange(0, frame.shape[0]), ious_scores, ...] + elif process == 'iou_occ_select': + ious_scores = torch.argmax(ious_scores, dim=1) + logits = logits[torch.arange(0, frame.shape[0]), ious_scores, ...] + logits[occ_scores.squeeze() < 0, ...] = 0. + else: + logits = logits[:, 0, ...] + + masks = logits > 0. + foreground_iou_rank = self.metrics['foreground_iou'].calculate_iou(masks.squeeze().long(), + label.squeeze().long(), + get_entire_list=True) + + foreground_f_score_rank = self.metrics['foreground_f-score'].calculate_f_score(logits.squeeze(), + label.squeeze(), + get_entire_list=True) + torch.distributed.all_gather_object(iou_pool, foreground_iou_rank) + torch.distributed.all_gather_object(fscore_pool, foreground_f_score_rank) + foreground_iou = sum([i['foreground_iou'][0].cpu() for i in iou_pool]) / sum( + [i['foreground_iou'][1] for i in iou_pool]) + foreground_f_score = sum([i['foreground_f-score'][0] for i in fscore_pool]) / sum( + [i['foreground_f-score'][1] for i in fscore_pool]) + + if self.param.local_rank <= 0: + tbar.set_description('epoch {} | valid.f_iou {}, valid.f_f-score {}'.format(epoch, + numpy.round( + foreground_iou.cpu().numpy(), + 5), + numpy.round( + foreground_f_score, + 5))) + torch.cuda.empty_cache() + + final_iou = foreground_iou + final_fscore = foreground_f_score + if self.param.local_rank <= 0 and self.tensorboard is not None: + self.tensorboard.upload_wandb_info({"valid.f_iou/{}".format(process): final_iou, + "valid.f_f-score/{}".format(process): final_fscore}) + + def _to_float(x): + if isinstance(x, torch.Tensor): + return float(x.detach().cpu().item()) + return float(x) + + return numpy.round(_to_float(final_iou), 5), numpy.round(_to_float(final_fscore), 5) + + def train(self, epoch, dataloader, model, optimiser): + """One epoch: SAM frozen, AuralFuser + heads trained with composite loss + contrastive term.""" + if not isinstance(dataloader, DataLoader): + raise TypeError( + "train() expects a torch.utils.data.DataLoader (do not pass iter(dataloader) first)." + ) + self.metrics['foreground_iou'].reset() + self.metrics['foreground_f-score'].reset() + + dataloader_length = len(dataloader) + tbar = range(dataloader_length) + tbar = tqdm(tbar, ncols=135) if self.param.local_rank <= 0 else tbar + + data_iter = iter(dataloader) + for batch_index in tbar: + current_index = dataloader_length * epoch + batch_index + items = next(data_iter) + + frame, spect, label, prompt_dicts = items['frame'], items['spectrogram'], items['label'], items['prompts'] + frame = torch.flatten(frame, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) + spect = torch.flatten(spect, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) + label = torch.flatten(label, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) + with torch.autocast("cuda", dtype=torch.bfloat16): + outputs, proj_feats = model(frame, spect, prompt_dicts, sam_process=False) + + # Use label_index to pick one supervised frame (legacy v1s behavior). + label_index = prompt_dicts.get('label_index', None) + if label_index is not None: + if not isinstance(label_index, torch.Tensor): + label_index = torch.as_tensor(label_index) + label_index = label_index.flatten().to(device=label.device, dtype=torch.bool) + if label_index.any(): + frame_idx = int(torch.where(label_index)[0][0].item()) + else: + frame_idx = 0 + else: + frame_idx = 0 + + outputs_sel = outputs[frame_idx:frame_idx + 1] + label_sel = label[frame_idx:frame_idx + 1] + vision_feats, audio_feats = proj_feats + # Keep the same nested-list structure as legacy v1s code. + proj_feats_sel = ( + [[vision_feats[i][frame_idx]] for i in range(3)], + [[audio_feats[i][frame_idx]] for i in range(3)], + ) + + loss_dict = self.loss(outputs_sel, label_sel.unsqueeze(1)) + cl_loss = self.cl(proj_feats_sel, outputs_sel, label_sel) + + optimiser.zero_grad() + (loss_dict['core_loss'] + cl_loss).backward() + optimiser.step() + + current_lr = self.param.lr * (1 - current_index / (dataloader_length * self.param.epochs)) ** 0.9 + for params_lr in optimiser.param_groups: + names = params_lr.get("name", []) + if names and any("vgg" in n for n in names): + params_lr['lr'] = current_lr * 0.1 + else: + params_lr['lr'] = current_lr + + if self.param.local_rank <= 0: + logits = torch.cat([i['multistep_pred_multimasks_high_res'][0] for i in outputs_sel]) + foreground_iou = self.metrics['foreground_iou'].calculate_iou((logits > 0)[:, 0, ...].long(), + label_sel.long()) + + self.tensorboard.upload_wandb_info({"loss": loss_dict['core_loss'].item(), "f_iou": foreground_iou.item(), + "lr": optimiser.param_groups[0]['lr'], + "loss_dice": loss_dict['loss_dice'], + "loss_focal": loss_dict['loss_mask'], + "loss_contras": cl_loss.item()}) + tbar.set_description('epoch {} | loss {}, f_iou {}'.format(epoch, loss_dict['core_loss'].item(), + foreground_iou.item())) + ''' + if batch_index % 200 == 0: + pred_mask = (logits > 0)[:, 0, ...].long() + n_vis = min(4, frame.shape[0], pred_mask.shape[0], label.shape[0]) + self.tensorboard.upload_wandb_image( + frame[:n_vis], pred_mask[:n_vis], label[:n_vis].long() + ) + ''' + return diff --git a/avs.code/v1s.code/utils/data_utils.py b/avs.code/v1s.code/utils/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7e7a98f8ec73e6e5dafd1e395b48a98575e5afb1 --- /dev/null +++ b/avs.code/v1s.code/utils/data_utils.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch + +from PIL import Image as PILImage + + +class BatchedVideoMetaData: + """ + This class represents metadata about a batch of videos. + Attributes: + unique_objects_identifier: A tensor of shape Bx3 containing unique identifiers for each object in the batch. Index consists of (video_id, obj_id, frame_id) + frame_orig_size: A tensor of shape Bx2 containing the original size of each frame in the batch. + """ + + unique_objects_identifier: torch.LongTensor + frame_orig_size: torch.LongTensor + + +class BatchedVideoDatapoint: + """ + This class represents a batch of videos with associated annotations and metadata. + Attributes: + img_batch: A [TxBxCxHxW] tensor containing the image data for each frame in the batch, where T is the number of frames per video, and B is the number of videos in the batch. + obj_to_frame_idx: A [TxOx2] tensor containing the image_batch index which the object belongs to. O is the number of objects in the batch. + masks: A [TxOxHxW] tensor containing binary masks for each object in the batch. + metadata: An instance of BatchedVideoMetaData containing metadata about the batch. + dict_key: A string key used to identify the batch. + """ + + img_batch: torch.FloatTensor + obj_to_frame_idx: torch.IntTensor + masks: torch.BoolTensor + metadata: BatchedVideoMetaData + + dict_key: str + + def pin_memory(self, device=None): + return self.apply(torch.Tensor.pin_memory, device=device) + + @property + def num_frames(self) -> int: + """ + Returns the number of frames per video. + """ + return self.batch_size[0] + + @property + def num_videos(self) -> int: + """ + Returns the number of videos in the batch. + """ + return self.img_batch.shape[1] + + @property + def flat_obj_to_img_idx(self) -> torch.IntTensor: + """ + Returns a flattened tensor containing the object to img index. + The flat index can be used to access a flattened img_batch of shape [(T*B)xCxHxW] + """ + frame_idx, video_idx = self.obj_to_frame_idx.unbind(dim=-1) + flat_idx = video_idx * self.num_frames + frame_idx + return flat_idx + + @property + def flat_img_batch(self) -> torch.FloatTensor: + """ + Returns a flattened img_batch_tensor of shape [(B*T)xCxHxW] + """ + + return self.img_batch.transpose(0, 1).flatten(0, 1) + + +@dataclass +class Object: + # Id of the object in the media + object_id: int + # Index of the frame in the media (0 if single image) + frame_index: int + segment: Union[torch.Tensor, dict] # RLE dict or binary mask + + +@dataclass +class Frame: + data: Union[torch.Tensor, PILImage.Image] + objects: List[Object] + + +@dataclass +class VideoDatapoint: + """Refers to an image/video and all its annotations""" + + frames: List[Frame] + video_id: int + size: Tuple[int, int] + + +def collate_fn( + batch: List[VideoDatapoint], + dict_key, +) -> BatchedVideoDatapoint: + """ + Args: + batch: A list of VideoDatapoint instances. + dict_key (str): A string key used to identify the batch. + """ + img_batch = [] + for video in batch: + img_batch += [torch.stack([frame.data for frame in video.frames], dim=0)] + + img_batch = torch.stack(img_batch, dim=0).permute((1, 0, 2, 3, 4)) + T = img_batch.shape[0] + # Prepare data structures for sequential processing. Per-frame processing but batched across videos. + step_t_objects_identifier = [[] for _ in range(T)] + step_t_frame_orig_size = [[] for _ in range(T)] + + step_t_masks = [[] for _ in range(T)] + step_t_obj_to_frame_idx = [ + [] for _ in range(T) + ] # List to store frame indices for each time step + + for video_idx, video in enumerate(batch): + orig_video_id = video.video_id + orig_frame_size = video.size + for t, frame in enumerate(video.frames): + objects = frame.objects + for obj in objects: + orig_obj_id = obj.object_id + orig_frame_idx = obj.frame_index + step_t_obj_to_frame_idx[t].append( + torch.tensor([t, video_idx], dtype=torch.int) + ) + step_t_masks[t].append(obj.segment.to(torch.bool)) + step_t_objects_identifier[t].append( + torch.tensor([orig_video_id, orig_obj_id, orig_frame_idx]) + ) + step_t_frame_orig_size[t].append(torch.tensor(orig_frame_size)) + + obj_to_frame_idx = torch.stack( + [ + torch.stack(obj_to_frame_idx, dim=0) + for obj_to_frame_idx in step_t_obj_to_frame_idx + ], + dim=0, + ) + masks = torch.stack([torch.stack(masks, dim=0) for masks in step_t_masks], dim=0) + objects_identifier = torch.stack( + [torch.stack(id, dim=0) for id in step_t_objects_identifier], dim=0 + ) + frame_orig_size = torch.stack( + [torch.stack(id, dim=0) for id in step_t_frame_orig_size], dim=0 + ) + return BatchedVideoDatapoint( + img_batch=img_batch, + obj_to_frame_idx=obj_to_frame_idx, + masks=masks, + metadata=BatchedVideoMetaData( + unique_objects_identifier=objects_identifier, + frame_orig_size=frame_orig_size, + ), + dict_key=dict_key, + batch_size=[T], + ) diff --git a/avs.code/v1s.code/utils/foreground_fscore.py b/avs.code/v1s.code/utils/foreground_fscore.py new file mode 100644 index 0000000000000000000000000000000000000000..ea20b84d2304ca0bd9981fd1a3c254111e3d0ac4 --- /dev/null +++ b/avs.code/v1s.code/utils/foreground_fscore.py @@ -0,0 +1,90 @@ +import numpy +import torch + + +class AverageMeter: + def __init__(self, *keys): + self.__data = dict() + for k in keys: + self.__data[k] = [0.0, 0] + + def add(self, dict): + for k, v in dict.items(): + self.__data[k][0] += v + self.__data[k][1] += 1 + + def get(self, *keys): + if len(keys) == 1: + return self.__data[keys[0]][0] / self.__data[keys[0]][1] + else: + v_list = [self.__data[k][0] / self.__data[k][1] for k in keys] + return tuple(v_list) + + def get_entire_dict_for_ddp_calculation(self): + return self.__data + + def pop(self, key=None): + if key is None: + for k in self.__data.keys(): + self.__data[k] = [0.0, 0] + else: + v = self.get(key) + self.__data[key] = [0.0, 0] + return v + + +class ForegroundFScore(AverageMeter): + def __init__(self, rank): + self.local_rank = rank + super(ForegroundFScore, self).__init__('foreground_f-score') + + def _eval_pr(self, y_pred, y, num, cuda_flag=True): + if cuda_flag: + prec, recall = torch.zeros(num).cuda(self.local_rank), torch.zeros(num).cuda(self.local_rank) + thlist = torch.linspace(0, 1 - 1e-10, num).cuda(self.local_rank) + else: + prec, recall = torch.zeros(num), torch.zeros(num) + thlist = torch.linspace(0, 1 - 1e-10, num) + for i in range(num): + y_temp = (y_pred >= thlist[i]).float() + tp = (y_temp * y).sum() + prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() + 1e-20) + return prec, recall + + def calculate_f_score(self, pred, gt, pr_num=255, get_entire_list=False): + + r""" + param: + pred: size [N x H x W] + gt: size [N x H x W] + output: + iou: size [1] (size_average=True) or [N] (size_average=False) + """ + # print('=> eval [FMeasure]..') + pred = torch.sigmoid(pred) # =======================================[important] + N = pred.size(0) + beta2 = 0.3 + avg_f, img_num = 0.0, 0 + score = torch.zeros(pr_num) + # fLog = open(os.path.join(measure_path, 'FMeasure.txt'), 'w') + # print("{} videos in this batch".format(N)) + + for img_id in range(N): + # examples with totally black GTs are out of consideration + if torch.mean(gt[img_id].float()) == 0.0: + continue + prec, recall = self._eval_pr(pred[img_id], gt[img_id], pr_num) + f_score = (1 + beta2) * prec * recall / (beta2 * prec + recall) + f_score[f_score != f_score] = 0 # for Nan + avg_f += f_score + img_num += 1 + score = avg_f / img_num + # print('score: ', score) + # fLog.close() + self.add({'foreground_f-score': score.max().item()}) + return self.get('foreground_f-score') if not get_entire_list else self.get_entire_dict_for_ddp_calculation() + + def reset(self,): + super(ForegroundFScore, self).__init__('foreground_f-score') + + diff --git a/avs.code/v1s.code/utils/foreground_iou.py b/avs.code/v1s.code/utils/foreground_iou.py new file mode 100644 index 0000000000000000000000000000000000000000..e01eeb081eee8ebfa1fcb6618d05b9d57c02f817 --- /dev/null +++ b/avs.code/v1s.code/utils/foreground_iou.py @@ -0,0 +1,69 @@ +import numpy +import torch + + +class AverageMeter: + def __init__(self, *keys): + self.__data = dict() + for k in keys: + self.__data[k] = [0.0, 0] + + def add(self, dict): + for k, v in dict.items(): + self.__data[k][0] += v + self.__data[k][1] += 1 + + def get(self, *keys): + if len(keys) == 1: + return self.__data[keys[0]][0] / self.__data[keys[0]][1] + else: + v_list = [self.__data[k][0] / self.__data[k][1] for k in keys] + return tuple(v_list) + + def get_entire_dict_for_ddp_calculation(self): + return self.__data + + def pop(self, key=None): + if key is None: + for k in self.__data.keys(): + self.__data[k] = [0.0, 0] + else: + v = self.get(key) + self.__data[key] = [0.0, 0] + return v + + +class ForegroundIoU(AverageMeter): + def __init__(self): + super(ForegroundIoU, self).__init__('foreground_iou') + + def calculate_iou(self, pred, target, eps=1e-7, get_entire_list=False): + r""" + param (both hard mask): + pred: size [N x H x W], type: int + target: size [N x H x W], type: int + output: + iou: size [1] (size_average=True) or [N] (size_average=False) + """ + assert len(pred.shape) == 3 and pred.shape == target.shape, 'shape mismatch.' + assert pred.dtype is torch.long and target.dtype is torch.long, 'type mismatch.' + + N = pred.size(0) + num_pixels = pred.size(-1) * pred.size(-2) + no_obj_flag = (target.sum(2).sum(1) == 0) + + inter = (pred * target).sum(2).sum(1) + union = torch.max(pred, target).sum(2).sum(1) + + inter_no_obj = ((1 - target) * (1 - pred)).sum(2).sum(1) + inter[no_obj_flag] = inter_no_obj[no_obj_flag] + union[no_obj_flag] = num_pixels + + iou = torch.sum(inter / (union+eps)) / N + + self.add({'foreground_iou': iou}) + return self.get('foreground_iou') if not get_entire_list else self.get_entire_dict_for_ddp_calculation() + + def reset(self,): + super(ForegroundIoU, self).__init__('foreground_iou') + diff --git a/avs.code/v1s.code/utils/iou.py b/avs.code/v1s.code/utils/iou.py new file mode 100644 index 0000000000000000000000000000000000000000..211488b780887a8efd84361bafc6b09bfad4c345 --- /dev/null +++ b/avs.code/v1s.code/utils/iou.py @@ -0,0 +1,76 @@ +import torch +import numpy + + +class BinaryMIoU(object): + def __init__(self, ignore_index): + self.num_classes = 2 + self.ignore_index = ignore_index + self.inter, self.union = 0, 0 + self.correct, self.label = 0, 0 + self.iou = numpy.array([0 for _ in range(self.num_classes)]) + self.acc = 0.0 + + def get_metric_results(self, curr_correct_, curr_label_, curr_inter_, curr_union_): + # calculates the overall miou and acc + self.correct = self.correct + curr_correct_ + self.label = self.label + curr_label_ + self.inter = self.inter + curr_inter_ + self.union = self.union + curr_union_ + self.acc = 1.0 * self.correct / (numpy.spacing(1) + self.label) + self.iou = 1.0 * self.inter / (numpy.spacing(1) + self.union) + return numpy.round(self.iou, 4), numpy.round(self.acc, 4) + # if class_list is None: + # return numpy.round(self.iou.mean().item(), 4), \ + # numpy.round(self.acc, 4) + # else: + # return numpy.round(self.iou[class_list].mean().item(), 4), \ + # numpy.round(self.acc, 4) + + @staticmethod + def get_current_image_results(curr_correct_, curr_label_, curr_inter_, curr_union_): + curr_acc = 1.0 * curr_correct_ / (numpy.spacing(1) + curr_label_) + curr_iou = 1.0 * curr_inter_ / (numpy.spacing(1) + curr_union_) + return curr_iou, curr_acc + + def __call__(self, x, y): + curr_correct, curr_label, curr_inter, curr_union = self.calculate_current_sample(x, y) + return (self.get_metric_results(curr_correct, curr_label, curr_inter, curr_union), + self.get_current_image_results(curr_correct, curr_label, curr_inter, curr_union)) + + def calculate_current_sample(self, output, target): + # output => BxCxHxW (logits) + # target => Bx1xHxW + target[target == self.ignore_index] = -1 + correct, labeled = self.batch_pix_accuracy(output.data, target) + inter, union = self.batch_intersection_union(output.data, target, self.num_classes) + return [numpy.round(correct, 5), numpy.round(labeled, 5), numpy.round(inter, 5), numpy.round(union, 5)] + + @ staticmethod + def batch_pix_accuracy(predict, target): + # _, predict = torch.max(output, 1) + + predict = predict.int() + 1 + target = target.int() + 1 + + pixel_labeled = (target > 0).sum() + pixel_correct = ((predict == target) * (target > 0)).sum() + assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled" + return pixel_correct.cpu().numpy(), pixel_labeled.cpu().numpy() + + @ staticmethod + def batch_intersection_union(predict, target, num_class): + # _, predict = torch.max(output, 1) + predict = predict + 1 + target = target + 1 + + predict = predict * (target > 0).long() + intersection = predict * (predict == target).long() + + area_inter = torch.histc(intersection.float(), bins=num_class, max=num_class, min=1) + area_pred = torch.histc(predict.float(), bins=num_class, max=num_class, min=1) + area_lab = torch.histc(target.float(), bins=num_class, max=num_class, min=1) + area_union = area_pred + area_lab - area_inter + assert (area_inter <= area_union).all(), "Intersection area should be smaller than Union area" + return area_inter.cpu().numpy(), area_union.cpu().numpy() + diff --git a/avs.code/v1s.code/utils/misc.py b/avs.code/v1s.code/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..5eb9d66c31a4b9209b81a5b615386d29f246135c --- /dev/null +++ b/avs.code/v1s.code/utils/misc.py @@ -0,0 +1,350 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import warnings +from threading import Thread + +import numpy as np +import torch +from PIL import Image +from tqdm import tqdm + + +def get_sdpa_settings(): + if torch.cuda.is_available(): + old_gpu = torch.cuda.get_device_properties(0).major < 7 + # only use Flash Attention on Ampere (8.0) or newer GPUs + use_flash_attn = torch.cuda.get_device_properties(0).major >= 8 + if not use_flash_attn: + warnings.warn( + "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.", + category=UserWarning, + stacklevel=2, + ) + # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only + # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases) + pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2]) + if pytorch_version < (2, 2): + warnings.warn( + f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. " + "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).", + category=UserWarning, + stacklevel=2, + ) + math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn + else: + old_gpu = True + use_flash_attn = False + math_kernel_on = True + + return old_gpu, use_flash_attn, math_kernel_on + + +def get_connected_components(mask): + """ + Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W). + + Inputs: + - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is + background. + + Outputs: + - labels: A tensor of shape (N, 1, H, W) containing the connected component labels + for foreground pixels and 0 for background pixels. + - counts: A tensor of shape (N, 1, H, W) containing the area of the connected + components for foreground pixels and 0 for background pixels. + """ + from sam2 import _C + + return _C.get_connected_componnets(mask.to(torch.uint8).contiguous()) + + +def mask_to_box(masks: torch.Tensor): + """ + compute bounding box given an input mask + + Inputs: + - masks: [B, 1, H, W] masks, dtype=torch.Tensor + + Returns: + - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor + """ + B, _, h, w = masks.shape + device = masks.device + xs = torch.arange(w, device=device, dtype=torch.int32) + ys = torch.arange(h, device=device, dtype=torch.int32) + grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy") + grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w) + grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w) + min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1) + max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1) + min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1) + max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1) + bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1) + + return bbox_coords + + +def _load_img_as_tensor(img_path, image_size): + img_pil = Image.open(img_path) + img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) + if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images + img_np = img_np / 255.0 + else: + raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}") + img = torch.from_numpy(img_np).permute(2, 0, 1) + video_width, video_height = img_pil.size # the original video size + return img, video_height, video_width + + +class AsyncVideoFrameLoader: + """ + A list of video frames to be load asynchronously without blocking session start. + """ + + def __init__( + self, + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ): + self.img_paths = img_paths + self.image_size = image_size + self.offload_video_to_cpu = offload_video_to_cpu + self.img_mean = img_mean + self.img_std = img_std + # items in `self.images` will be loaded asynchronously + self.images = [None] * len(img_paths) + # catch and raise any exceptions in the async loading thread + self.exception = None + # video_height and video_width be filled when loading the first image + self.video_height = None + self.video_width = None + self.compute_device = compute_device + + # load the first frame to fill video_height and video_width and also + # to cache it (since it's most likely where the user will click) + self.__getitem__(0) + + # load the rest of frames asynchronously without blocking the session start + def _load_frames(): + try: + for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"): + self.__getitem__(n) + except Exception as e: + self.exception = e + + self.thread = Thread(target=_load_frames, daemon=True) + self.thread.start() + + def __getitem__(self, index): + if self.exception is not None: + raise RuntimeError("Failure in frame loading thread") from self.exception + + img = self.images[index] + if img is not None: + return img + + img, video_height, video_width = _load_img_as_tensor( + self.img_paths[index], self.image_size + ) + self.video_height = video_height + self.video_width = video_width + # normalize by mean and std + img -= self.img_mean + img /= self.img_std + if not self.offload_video_to_cpu: + img = img.to(self.compute_device, non_blocking=True) + self.images[index] = img + return img + + def __len__(self): + return len(self.images) + + +def load_video_frames( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + compute_device=torch.device("cuda"), +): + """ + Load the video frames from video_path. The frames are resized to image_size as in + the model and are loaded to GPU if offload_video_to_cpu=False. This is used by the demo. + """ + is_bytes = isinstance(video_path, bytes) + is_str = isinstance(video_path, str) + is_mp4_path = is_str and os.path.splitext(video_path)[-1] in [".mp4", ".MP4"] + if is_bytes or is_mp4_path: + return load_video_frames_from_video_file( + video_path=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + compute_device=compute_device, + ) + elif is_str and os.path.isdir(video_path): + return load_video_frames_from_jpg_images( + video_path=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + async_loading_frames=async_loading_frames, + compute_device=compute_device, + ) + else: + raise NotImplementedError( + "Only MP4 video and JPEG folder are supported at this moment" + ) + + +def load_video_frames_from_jpg_images( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + compute_device=torch.device("cuda"), +): + """ + Load the video frames from a directory of JPEG files (".jpg" format). + + The frames are resized to image_size x image_size and are loaded to GPU if + `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. + + You can load a frame asynchronously by setting `async_loading_frames` to `True`. + """ + if isinstance(video_path, str) and os.path.isdir(video_path): + jpg_folder = video_path + else: + raise NotImplementedError( + "Only JPEG frames are supported at this moment. For video files, you may use " + "ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n" + "```\n" + "ffmpeg -i .mp4 -q:v 2 -start_number 0 /'%05d.jpg'\n" + "```\n" + "where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks " + "ffmpeg to start the JPEG file from 00000.jpg." + ) + + frame_names = [ + p + for p in os.listdir(jpg_folder) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + ] + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + num_frames = len(frame_names) + if num_frames == 0: + raise RuntimeError(f"no images found in {jpg_folder}") + img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + + if async_loading_frames: + lazy_images = AsyncVideoFrameLoader( + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ) + return lazy_images, lazy_images.video_height, lazy_images.video_width + + images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) + for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): + images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) + if not offload_video_to_cpu: + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def load_video_frames_from_video_file( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + compute_device=torch.device("cuda"), +): + """Load the video frames from a video file.""" + import decord + + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + # Get the original video height and width + decord.bridge.set_bridge("torch") + video_height, video_width, _ = decord.VideoReader(video_path).next().shape + # Iterate over all frames in the video + images = [] + for frame in decord.VideoReader(video_path, width=image_size, height=image_size): + images.append(frame.permute(2, 0, 1)) + + images = torch.stack(images, dim=0).float() / 255.0 + if not offload_video_to_cpu: + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def fill_holes_in_mask_scores(mask, max_area): + """ + A post processor to fill small holes in mask scores with area under `max_area`. + """ + # Holes are those connected components in background with area <= self.max_area + # (background regions are those with mask scores <= 0) + assert max_area > 0, "max_area must be positive" + + input_mask = mask + try: + labels, areas = get_connected_components(mask <= 0) + is_hole = (labels > 0) & (areas <= max_area) + # We fill holes with a small positive mask score (0.1) to change them to foreground. + mask = torch.where(is_hole, 0.1, mask) + except Exception as e: + # Skip the post-processing step on removing small holes if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. You can " + "still use SAM 2 and it's OK to ignore the error above, although some post-processing " + "functionality may be limited (which doesn't affect the results in most cases; see " + "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + mask = input_mask + + return mask + + +def concat_points(old_point_inputs, new_points, new_labels): + """Add new points and labels to previous point inputs (add at the end).""" + if old_point_inputs is None: + points, labels = new_points, new_labels + else: + points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1) + labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1) + + return {"point_coords": points, "point_labels": labels} + diff --git a/avs.code/v1s.code/utils/tensorboard.py b/avs.code/v1s.code/utils/tensorboard.py new file mode 100644 index 0000000000000000000000000000000000000000..3519131463bb3f279eb97b2b44974a402482af42 --- /dev/null +++ b/avs.code/v1s.code/utils/tensorboard.py @@ -0,0 +1,135 @@ +import os + +import PIL +import matplotlib.pyplot as plt +import numpy +import torch +import torchvision +import wandb + +# from utils.visualize import show_img + + +color_map = {"background": (0, 0, 0), "longitudinal": (128, 0, 0), "pothole": (0, 128, 0), + "alligator": (128, 128, 0), "transverse": (128, 0, 128), "ignore": (255, 255, 255)} + + +class Tensorboard: + def __init__(self, config): + if config.get('wandb_online', False): + key = config.get('wandb_key') or os.environ.get('WANDB_API_KEY', '') + if key: + os.environ['WANDB_API_KEY'] = key + wandb.login(key=key, relogin=False) + self.tensor_board = wandb.init(project=config['proj_name'], name=config['experiment_name'], + config=config, settings=wandb.Settings(code_dir="")) + else: + os.environ.setdefault("WANDB_MODE", "disabled") + self.tensor_board = wandb.init(project=config['proj_name'], name=config['experiment_name'], + config=config, mode="disabled", + settings=wandb.Settings(code_dir="")) + + self._log_images = bool(config.get('wandb_online', False)) + + self.restore_transform = torchvision.transforms.Compose([ + DeNormalize(config['image_mean'], config['image_std']), + torchvision.transforms.ToPILImage()]) + + def upload_wandb_info(self, info_dict): + for i, info in enumerate(info_dict): + self.tensor_board.log({info: info_dict[info]}) + return + + + def upload_wandb_image(self, frames, pseudo_label_from_pred, pseudo_label_from_sam, img_number=4): + if not self._log_images: + return + + def _batched_rgb(t): + """[N,C,H,W] or [C,H,W] float tensor on CPU.""" + if not isinstance(t, torch.Tensor): + t = torch.as_tensor(t) + t = t.detach().cpu().float() + if t.dim() == 3: + return t.unsqueeze(0) + if t.dim() == 4: + return t + raise ValueError("frames must be [C,H,W] or [N,C,H,W], got shape {}".format(tuple(t.shape))) + + def _batched_mask(t): + """[N,H,W] or [N,1,H,W] or [H,W].""" + if not isinstance(t, torch.Tensor): + t = torch.as_tensor(t) + t = t.detach().cpu().float() + while t.dim() > 3: + t = t.squeeze(1) + if t.dim() == 2: + t = t.unsqueeze(0) + if t.dim() != 3: + raise ValueError("masks must be [H,W], [N,H,W] or [N,1,H,W], got shape {}".format(tuple(t.shape))) + return t + + frames = _batched_rgb(frames) + pseudo_label_from_pred = _batched_mask(pseudo_label_from_pred) + pseudo_label_from_sam = _batched_mask(pseudo_label_from_sam) + + n = min(frames.shape[0], pseudo_label_from_pred.shape[0], pseudo_label_from_sam.shape[0], img_number) + frames = frames[:n] + pseudo_label_from_pred = pseudo_label_from_pred[:n] + pseudo_label_from_sam = pseudo_label_from_sam[:n] + + pseudo_label_from_sam = pseudo_label_from_sam.clone() + pseudo_label_from_pred = pseudo_label_from_pred.clone() + pseudo_label_from_sam[pseudo_label_from_sam == 255.] = 0.5 + pseudo_label_from_pred[pseudo_label_from_pred == 255.] = 0.5 + + denorm = self.restore_transform.transforms[0] + image_list = [] + label_list = [] + logits_list = [] + for i in range(n): + fi = frames[i].clone() + if fi.shape[0] == 3: + denorm(fi) + fi.clamp_(0.0, 1.0) + image_list.append(wandb.Image(fi, caption="id {}".format(str(i)))) + # wandb.Image expects torch tensors as [C, H, W] (it permutes CHW→HWC) + ms = pseudo_label_from_sam[i].squeeze() + mp = pseudo_label_from_pred[i].squeeze() + if ms.dim() == 2: + ms = ms.unsqueeze(0) + if mp.dim() == 2: + mp = mp.unsqueeze(0) + label_list.append(wandb.Image(ms, caption="id {}".format(str(i)))) + logits_list.append(wandb.Image(mp, caption="id {}".format(str(i)))) + + self.tensor_board.log({"image": image_list, "label": label_list, "logits": logits_list}) + + def de_normalize(self, image): + return [self.restore_transform(i.detach().cpu()) if (isinstance(i, torch.Tensor) and len(i.shape) == 3) + else colorize_mask(i.detach().cpu().numpy(), self.palette) + for i in image] + + def finish(self): + self.tensor_board.finish() + + +class DeNormalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, tensor): + for t, m, s in zip(tensor, self.mean, self.std): + t.mul_(s).add_(m) + return tensor + + +def colorize_mask(mask, palette): + zero_pad = 256 * 3 - len(palette) + for i in range(zero_pad): + palette.append(0) + # palette[-6:-3] = [183, 65, 14] + new_mask = PIL.Image.fromarray(mask.astype(numpy.uint8)).convert('P') + new_mask.putpalette(palette) + return new_mask diff --git a/avs.code/v1s.code/utils/utils.py b/avs.code/v1s.code/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e72f27a7e2be77cea271001230195ef79f685351 --- /dev/null +++ b/avs.code/v1s.code/utils/utils.py @@ -0,0 +1,119 @@ +"""Optimizer helpers: split learning rates for AuralFuser train_* vs VGG backbone.""" +import torch +import copy +from typing import List, Dict, Set, Any + + +def manipulate_params(cfg, model): + weight_decay_norm = 0 + weight_decay_embed = 0 + defaults = {} + defaults["lr"] = cfg.lr + defaults["weight_decay"] = cfg.weight_decay + + norm_module_types = ( + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + torch.nn.SyncBatchNorm, + torch.nn.GroupNorm, + torch.nn.InstanceNorm1d, + torch.nn.InstanceNorm2d, + torch.nn.InstanceNorm3d, + torch.nn.LayerNorm, + torch.nn.LocalResponseNorm, + ) + + params_training: List[Dict[str, Any]] = [] + params_finetuning: List[Dict[str, Any]] = [] + memo: Set[torch.nn.parameter.Parameter] = set() + + train_prefixes = ( + "patch_embeds", + "f_blocks", + "a_blocks", + "fusion_modules", + "smooth_convs", + "train_proj_v1", + "train_proj_a1", + ) + + for module_name, module in model.named_modules(): + for module_param_name, value in module.named_parameters(recurse=False): + if not value.requires_grad: + continue + # Avoid duplicating parameters + if value in memo: + continue + memo.add(value) + hyperparams = copy.copy(defaults) + if 'vgg' in module_name or 'vgg' in module_param_name: + hyperparams['lr'] *= 0.1 + params_finetuning.append({"params": [value], "name": [module_name], **hyperparams}) + elif ( + 'train' in module_name + or 'train' in module_param_name + or module_name.startswith(train_prefixes) + ): + if ( + "relative_position_bias_table" in module_param_name + or "pos_embed" in module_param_name + ): + hyperparams["weight_decay"] = 0.0 + if isinstance(module, norm_module_types): + hyperparams["weight_decay"] = 0.0 + if isinstance(module, torch.nn.Embedding): + hyperparams["weight_decay"] = 0.0 + params_training.append({"params": [value], "name": [module_name], **hyperparams}) + else: + print('undefined layer type.') + raise NotImplementedError + final_list = params_training + params_finetuning + assert len([p for p in model.parameters() if p.requires_grad]) == len(final_list), 'checksum confirmed not pass.' + return final_list + + +def group_weight(weight_group, module, weight_decay_value, lr): + group_decay = [] + group_no_decay = [] + norm_module_types = ( + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + torch.nn.SyncBatchNorm, + torch.nn.GroupNorm, + torch.nn.InstanceNorm1d, + torch.nn.InstanceNorm2d, + torch.nn.InstanceNorm3d, + torch.nn.LayerNorm, + torch.nn.LocalResponseNorm, + ) + + for m in module.modules(): + if isinstance(m, torch.nn.Linear): + group_decay.append(m.weight) + if m.bias is not None: + group_no_decay.append(m.bias) + elif isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d)): + group_decay.append(m.weight) + if m.bias is not None: + group_no_decay.append(m.bias) + elif isinstance(m, norm_module_types): + if m.weight is not None: + group_no_decay.append(m.weight) + if m.bias is not None: + group_no_decay.append(m.bias) + elif isinstance(m, torch.nn.Parameter): + group_no_decay.append(m) + elif isinstance(m, torch.nn.Embedding): + group_no_decay.append(m) + else: + print('undefined layer type find.') + raise NotImplementedError + + assert len(list(module.parameters())) == len(group_decay) + len( + group_no_decay) + weight_group.append(dict(params=group_decay, weight_deacy=weight_decay_value, lr=lr)) + weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr)) + return weight_group \ No newline at end of file diff --git a/avs.code/v2.code/configs/__init__.py b/avs.code/v2.code/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/avs.code/v2.code/configs/auralfuser/architecture.yaml b/avs.code/v2.code/configs/auralfuser/architecture.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ab4c3d06ca42335ce6bfc8064bbd5cfd44c8080a --- /dev/null +++ b/avs.code/v2.code/configs/auralfuser/architecture.yaml @@ -0,0 +1,30 @@ +# @package _global_ + +aural_fuser: + patch_cfgs: + - [4, 4] + - [2, 2] + - [1, 1] + f_depths: [3, 6, 12] + block_kw: + dim: 256 + num_heads: 4 + mlp_ratio: 4 + qkv_bias: true + qk_scale: null + drop: 0.1 + attn_drop: 0.1 + drop_path: 0.0 + sr_ratio: 4 + linear: false + one_d_kw: + dim: 256 + num_heads: 4 + mlp_ratio: 4 + qkv_bias: true + qk_scale: null + drop: 0.1 + attn_drop: 0.1 + drop_path: 0.0 + sr_ratio: 4 + linear: false diff --git a/avs.code/v2.code/configs/config.py b/avs.code/v2.code/configs/config.py new file mode 100644 index 0000000000000000000000000000000000000000..21e69377536f2a491f46e3e35803c78d448c54ba --- /dev/null +++ b/avs.code/v2.code/configs/config.py @@ -0,0 +1,84 @@ +import os +import numpy +from easydict import EasyDict + +# v1m.code package root (parent of this `configs/` directory) +_CODE_ROOT = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +# workspace root (parent of avs.code) +_WORKSPACE_ROOT = os.path.dirname(os.path.dirname(_CODE_ROOT)) + +C = EasyDict() +config = C +cfg = C + +C.seed = 666 + +C.audio = EasyDict() +C.audio.FREEZE_AUDIO_EXTRACTOR = True +C.audio.PRETRAINED_VGGISH_MODEL_PATH = os.path.join(_WORKSPACE_ROOT, 'ckpts', 'vggish-10086976.pth') +C.audio.PREPROCESS_AUDIO_TO_LOG_MEL = False +C.audio.POSTPROCESS_LOG_MEL_WITH_PCA = False +C.train_vggish = False + +"""Root Directory Config""" +C.repo_name = 'AV' +C.root_dir = _CODE_ROOT + +"""Data Dir and Weight Dir""" +C.data_root_path = os.path.join(_WORKSPACE_ROOT, 'AVSBench') +C.backbone_weight = os.path.join(_WORKSPACE_ROOT, 'ckpts', 'sam_ckpts', 'sam2_hiera_large.pt') +C.sam_config_path = os.path.join('sam2', 'sam2_hiera_l.yaml') + +"""Network Config""" +C.fix_bias = True +C.bn_eps = 1e-5 +C.bn_momentum = 0.1 + +"""Image Config""" +C.num_classes = 2 + +C.image_mean = numpy.array([0.485, 0.456, 0.406]) +C.image_std = numpy.array([0.229, 0.224, 0.225]) + + +C.image_size = 1024 +C.image_embedding_size = int(C.image_size / 16) +C.avsbench_size = (224, 224) + +C.scale_list = [.5, .75, 1., 1.25, 1.5] +C.ignore_index = 255 + +"""Train Config""" +C.lr = 7.5e-5 +C.batch_size = 8 +C.energy_weight = .05 + +C.lr_power = 0.9 +C.momentum = 0.9 +C.weight_decay = 0.05 + +C.num_workers = 4 + +"""Display Config""" +C.record_info_iter = 20 +C.display_iter = 50 + +"""Wandb Config""" +# Paste your W&B API key here, or set the WANDB_API_KEY environment variable instead. +C.wandb_key = "" + +# Your project [work_space] name +C.proj_name = "AVS-final-report" + +C.experiment_name = "v2-hiera-l" + + +# False = no wandb logging (see utils/tensorboard.py) +C.wandb_online = False + +"""Save Config""" +C.saved_dir = os.path.join(_WORKSPACE_ROOT, 'ckpts', C.experiment_name) + +import pathlib + +pathlib.Path(C.saved_dir).mkdir(parents=True, exist_ok=True) diff --git a/avs.code/v2.code/configs/sam2/sam2_hiera_b+.yaml b/avs.code/v2.code/configs/sam2/sam2_hiera_b+.yaml new file mode 100644 index 0000000000000000000000000000000000000000..52e0f10732134149f6a994be063d11fd7591c430 --- /dev/null +++ b/avs.code/v2.code/configs/sam2/sam2_hiera_b+.yaml @@ -0,0 +1,114 @@ +# @package _global_ + +# Model +model: + _target_: model.visual.sam2.organised_sam2_train.SAM2Train + image_encoder: + _target_: model.visual.sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: model.visual.sam2.modeling.backbones.hieradet.Hiera + embed_dim: 112 + num_heads: 2 + neck: + _target_: model.visual.sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [896, 448, 224, 112] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: model.visual.sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: model.visual.sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: model.visual.sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: model.visual.sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: model.visual.sam2.modeling.memory_encoder.Fuser + layer: + _target_: model.visual.sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False + diff --git a/avs.code/v2.code/configs/sam2/sam2_hiera_l.yaml b/avs.code/v2.code/configs/sam2/sam2_hiera_l.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8478b3d4b8b16d8b22f6555cf7b1f00231d7fd59 --- /dev/null +++ b/avs.code/v2.code/configs/sam2/sam2_hiera_l.yaml @@ -0,0 +1,117 @@ +# @package _global_ + +# Model +model: + _target_: model.visual.sam2.organised_sam2_train.SAM2Train + image_encoder: + _target_: model.visual.sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: model.visual.sam2.modeling.backbones.hieradet.Hiera + embed_dim: 144 + num_heads: 2 + stages: [2, 6, 36, 4] + global_att_blocks: [23, 33, 43] + window_pos_embed_bkg_spatial_size: [7, 7] + window_spec: [8, 4, 16, 8] + neck: + _target_: model.visual.sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [1152, 576, 288, 144] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: model.visual.sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: model.visual.sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: model.visual.sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: model.visual.sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: model.visual.sam2.modeling.memory_encoder.Fuser + layer: + _target_: model.visual.sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/avs.code/v2.code/configs/sam2/sam2_hiera_s.yaml b/avs.code/v2.code/configs/sam2/sam2_hiera_s.yaml new file mode 100644 index 0000000000000000000000000000000000000000..26e5d4d39f7b2892396106005c37c7ffe6c83bc2 --- /dev/null +++ b/avs.code/v2.code/configs/sam2/sam2_hiera_s.yaml @@ -0,0 +1,116 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 11, 2] + global_att_blocks: [7, 10, 13] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/avs.code/v2.code/configs/sam2/sam2_hiera_t.yaml b/avs.code/v2.code/configs/sam2/sam2_hiera_t.yaml new file mode 100644 index 0000000000000000000000000000000000000000..59e605b73c9777b70942538252d27a55ae8a7e1a --- /dev/null +++ b/avs.code/v2.code/configs/sam2/sam2_hiera_t.yaml @@ -0,0 +1,118 @@ +# @package _global_ + +# Model +model: + _target_: model.visual.sam2.organised_sam2_train.SAM2Train + image_encoder: + _target_: model.visual.sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: model.visual.sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 7, 2] + global_att_blocks: [5, 7, 9] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: model.visual.sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: model.visual.sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: model.visual.sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: model.visual.sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: model.visual.sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: model.visual.sam2.modeling.memory_encoder.Fuser + layer: + _target_: model.visual.sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 224 # 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + # SAM decoder + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: false + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + # HieraT does not currently support compilation, should always be set to False + compile_image_encoder: False diff --git a/avs.code/v2.code/configs/training/sam2_training_config.yaml b/avs.code/v2.code/configs/training/sam2_training_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..55771e7232fe88c4ea445958956eca8174c2e872 --- /dev/null +++ b/avs.code/v2.code/configs/training/sam2_training_config.yaml @@ -0,0 +1,62 @@ +# @package _global_ + +# Video transforms + +train_transforms: + - _target_: dataloader.sam2_dataset.transforms.ComposeAPI + transforms: + - _target_: dataloader.sam2_dataset.transforms.RandomHorizontalFlip + consistent_transform: True + - _target_: dataloader.sam2_dataset.transforms.RandomAffine + degrees: 25 + shear: 20 + image_interpolation: bilinear + consistent_transform: True + - _target_: dataloader.sam2_dataset.transforms.RandomResizeAPI + sizes: 1024 # ${scratch.resolution} + square: true + consistent_transform: True + - _target_: dataloader.sam2_dataset.transforms.ColorJitter + consistent_transform: True + brightness: 0.1 + contrast: 0.03 + saturation: 0.03 + hue: null + - _target_: dataloader.sam2_dataset.transforms.RandomGrayscale + p: 0.05 + consistent_transform: True + - _target_: dataloader.sam2_dataset.transforms.ColorJitter + consistent_transform: False + brightness: 0.1 + contrast: 0.05 + saturation: 0.05 + hue: null + - _target_: dataloader.sam2_dataset.transforms.ToTensorAPI + - _target_: dataloader.sam2_dataset.transforms.NormalizeAPI + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +loss: + all: + _target_: loss.training.sam2_training_loss.MultiStepMultiMasksAndIous + weight_dict: + loss_mask: 20 # 20 + loss_dice: 1 + loss_iou: 1 + loss_class: 1 + supervise_all_iou: true + iou_use_l1_loss: true + pred_obj_scores: true + focal_gamma_obj_score: 0.0 + focal_alpha_obj_score: -1.0 + gpu_num: 4. + +# Contrastive loss (ContrastLoss); loaded in main.py / inference.py → hyp_param.contrastive_learning +contrastive_learning: + temperature: 0.10 + ignore_idx: 255 + ood_idx: 254 + max_views: 512 + proj_dim: 512 + sample_limits: 128 + total_limits: 15240 diff --git a/avs.code/v2.code/dataloader/audio/audio_augmentation.py b/avs.code/v2.code/dataloader/audio/audio_augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..850d1577ea2bca4f8ec209edc201fb54968be928 --- /dev/null +++ b/avs.code/v2.code/dataloader/audio/audio_augmentation.py @@ -0,0 +1,23 @@ +import numpy + + +class Augmentation(object): + """Audio pre-step used by training/inference: int16 waveform -> float in [-1, 1]. + + The previous audiomentations-based transforms were commented out and never applied; + behavior is unchanged: only scaling by 1/32768. + """ + + def __init__(self, mono=True): + self.mono = mono + + def train_aug(self, x_, sr_): + x_ = x_ / 32768.0 + return x_ + + def test_process(self, x_): + x_ = x_ / 32768.0 + return x_ + + def __call__(self, x, sr, split): + return self.train_aug(x, sr) if split == "train" else self.test_process(x) diff --git a/avs.code/v2.code/dataloader/audio/audio_dataset.py b/avs.code/v2.code/dataloader/audio/audio_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5c8e8b276e8545aa55ef56295719a0ad2b167106 --- /dev/null +++ b/avs.code/v2.code/dataloader/audio/audio_dataset.py @@ -0,0 +1,38 @@ +import torch +import numpy +import os +from dataloader.audio.preprocess_vgg.vggish_input import waveform_to_examples +import soundfile + + +class Audio(torch.utils.data.Dataset): + def __init__(self, augmentation, directory_path, split): + # temporarily set no augmentation. + self.augmentation = augmentation + self.directory_path = directory_path + self.split = split + + def load_audio_wave(self, file_index, file_index_mix): + audio_path = os.path.join(file_index, 'audio.wav') + wav_data, sample_rate = soundfile.read(audio_path, dtype='int16') + assert wav_data.dtype == numpy.int16, 'Bad sample type: %r' % wav_data.dtype + + if file_index_mix is not None: + audio_path2 = os.path.join(file_index_mix, 'audio.wav') + wav_data2, _ = soundfile.read(audio_path2, dtype='int16') + mix_lambda = numpy.random.beta(10, 10) + min_length = min(wav_data.shape[0], wav_data2.shape[0]) + wav_data = wav_data[:min_length] * mix_lambda + wav_data2[:min_length] * (1-mix_lambda) + + wav_data = self.augmentation(wav_data, sample_rate, self.split) + audio_log_mel = torch.cat([waveform_to_examples(wav_data[:, 0], sample_rate, True).detach(), + waveform_to_examples(wav_data[:, 1], sample_rate, True).detach()], dim=1) + + # for the vgg preprocess, we will need 5 seconds audio log. + if audio_log_mel.shape[0] < 5: + audio_log_mel = torch.cat([audio_log_mel, + audio_log_mel[-1].unsqueeze(0).repeat(5-audio_log_mel.shape[0], 1, 1, 1)]) + return audio_log_mel + + def __len__(self): + return len(self.audio_list) diff --git a/avs.code/v2.code/dataloader/audio/preprocess_vgg/mel_features.py b/avs.code/v2.code/dataloader/audio/preprocess_vgg/mel_features.py new file mode 100644 index 0000000000000000000000000000000000000000..ac58fb5427f772fcced9cbd3cec3373ffbe5908c --- /dev/null +++ b/avs.code/v2.code/dataloader/audio/preprocess_vgg/mel_features.py @@ -0,0 +1,223 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Defines routines to compute mel spectrogram features from audio waveform.""" + +import numpy as np + + +def frame(data, window_length, hop_length): + """Convert array into a sequence of successive possibly overlapping frames. + + An n-dimensional array of shape (num_samples, ...) is converted into an + (n+1)-D array of shape (num_frames, window_length, ...), where each frame + starts hop_length points after the preceding one. + + This is accomplished using stride_tricks, so the original data is not + copied. However, there is no zero-padding, so any incomplete frames at the + end are not included. + + Args: + data: np.array of dimension N >= 1. + window_length: Number of samples in each frame. + hop_length: Advance (in samples) between each window. + + Returns: + (N+1)-D np.array with as many rows as there are complete frames that can be + extracted. + """ + num_samples = data.shape[0] + num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length)) + shape = (num_frames, window_length) + data.shape[1:] + strides = (data.strides[0] * hop_length,) + data.strides + return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides) + + +def periodic_hann(window_length): + """Calculate a "periodic" Hann window. + + The classic Hann window is defined as a raised cosine that starts and + ends on zero, and where every value appears twice, except the middle + point for an odd-length window. Matlab calls this a "symmetric" window + and np.hanning() returns it. However, for Fourier analysis, this + actually represents just over one cycle of a period N-1 cosine, and + thus is not compactly expressed on a length-N Fourier basis. Instead, + it's better to use a raised cosine that ends just before the final + zero value - i.e. a complete cycle of a period-N cosine. Matlab + calls this a "periodic" window. This routine calculates it. + + Args: + window_length: The number of points in the returned window. + + Returns: + A 1D np.array containing the periodic hann window. + """ + return 0.5 - (0.5 * np.cos(2 * np.pi / window_length * + np.arange(window_length))) + + +def stft_magnitude(signal, fft_length, + hop_length=None, + window_length=None): + """Calculate the short-time Fourier transform magnitude. + + Args: + signal: 1D np.array of the input time-domain signal. + fft_length: Size of the FFT to apply. + hop_length: Advance (in samples) between each frame passed to FFT. + window_length: Length of each block of samples to pass to FFT. + + Returns: + 2D np.array where each row contains the magnitudes of the fft_length/2+1 + unique values of the FFT for the corresponding frame of input samples. + """ + frames = frame(signal, window_length, hop_length) + # Apply frame window to each frame. We use a periodic Hann (cosine of period + # window_length) instead of the symmetric Hann of np.hanning (period + # window_length-1). + window = periodic_hann(window_length) + windowed_frames = frames * window + return np.abs(np.fft.rfft(windowed_frames, int(fft_length))) + + +# Mel spectrum constants and functions. +_MEL_BREAK_FREQUENCY_HERTZ = 700.0 +_MEL_HIGH_FREQUENCY_Q = 1127.0 + + +def hertz_to_mel(frequencies_hertz): + """Convert frequencies to mel scale using HTK formula. + + Args: + frequencies_hertz: Scalar or np.array of frequencies in hertz. + + Returns: + Object of same size as frequencies_hertz containing corresponding values + on the mel scale. + """ + return _MEL_HIGH_FREQUENCY_Q * np.log( + 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ)) + + +def spectrogram_to_mel_matrix(num_mel_bins=20, + num_spectrogram_bins=129, + audio_sample_rate=8000, + lower_edge_hertz=125.0, + upper_edge_hertz=3800.0): + """Return a matrix that can post-multiply spectrogram rows to make mel. + + Returns a np.array matrix A that can be used to post-multiply a matrix S of + spectrogram values (STFT magnitudes) arranged as frames x bins to generate a + "mel spectrogram" M of frames x num_mel_bins. M = S A. + + The classic HTK algorithm exploits the complementarity of adjacent mel bands + to multiply each FFT bin by only one mel weight, then add it, with positive + and negative signs, to the two adjacent mel bands to which that bin + contributes. Here, by expressing this operation as a matrix multiply, we go + from num_fft multiplies per frame (plus around 2*num_fft adds) to around + num_fft^2 multiplies and adds. However, because these are all presumably + accomplished in a single call to np.dot(), it's not clear which approach is + faster in Python. The matrix multiplication has the attraction of being more + general and flexible, and much easier to read. + + Args: + num_mel_bins: How many bands in the resulting mel spectrum. This is + the number of columns in the output matrix. + num_spectrogram_bins: How many bins there are in the source spectrogram + data, which is understood to be fft_size/2 + 1, i.e. the spectrogram + only contains the nonredundant FFT bins. + audio_sample_rate: Samples per second of the audio at the input to the + spectrogram. We need this to figure out the actual frequencies for + each spectrogram bin, which dictates how they are mapped into mel. + lower_edge_hertz: Lower bound on the frequencies to be included in the mel + spectrum. This corresponds to the lower edge of the lowest triangular + band. + upper_edge_hertz: The desired top edge of the highest frequency band. + + Returns: + An np.array with shape (num_spectrogram_bins, num_mel_bins). + + Raises: + ValueError: if frequency edges are incorrectly ordered or out of range. + """ + nyquist_hertz = audio_sample_rate / 2. + if lower_edge_hertz < 0.0: + raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz) + if lower_edge_hertz >= upper_edge_hertz: + raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" % + (lower_edge_hertz, upper_edge_hertz)) + if upper_edge_hertz > nyquist_hertz: + raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" % + (upper_edge_hertz, nyquist_hertz)) + spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins) + spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz) + # The i'th mel band (starting from i=1) has center frequency + # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge + # band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in + # the band_edges_mel arrays. + band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz), + hertz_to_mel(upper_edge_hertz), num_mel_bins + 2) + # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins + # of spectrogram values. + mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins)) + for i in range(num_mel_bins): + lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3] + # Calculate lower and upper slopes for every spectrogram bin. + # Line segments are linear in the *mel* domain, not hertz. + lower_slope = ((spectrogram_bins_mel - lower_edge_mel) / + (center_mel - lower_edge_mel)) + upper_slope = ((upper_edge_mel - spectrogram_bins_mel) / + (upper_edge_mel - center_mel)) + # .. then intersect them with each other and zero. + mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope, + upper_slope)) + # HTK excludes the spectrogram DC bin; make sure it always gets a zero + # coefficient. + mel_weights_matrix[0, :] = 0.0 + return mel_weights_matrix + + +def log_mel_spectrogram(data, + audio_sample_rate=8000, + log_offset=0.0, + window_length_secs=0.025, + hop_length_secs=0.010, + **kwargs): + """Convert waveform to a log magnitude mel-frequency spectrogram. + + Args: + data: 1D np.array of waveform data. + audio_sample_rate: The sampling rate of data. + log_offset: Add this to values when taking log to avoid -Infs. + window_length_secs: Duration of each window to analyze. + hop_length_secs: Advance between successive analysis windows. + **kwargs: Additional arguments to pass to spectrogram_to_mel_matrix. + + Returns: + 2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank + magnitudes for successive frames. + """ + window_length_samples = int(round(audio_sample_rate * window_length_secs)) + hop_length_samples = int(round(audio_sample_rate * hop_length_secs)) + fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0))) + spectrogram = stft_magnitude( + data, + fft_length=fft_length, + hop_length=hop_length_samples, + window_length=window_length_samples) + mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix( + num_spectrogram_bins=spectrogram.shape[1], + audio_sample_rate=audio_sample_rate, **kwargs)) + return np.log(mel_spectrogram + log_offset) diff --git a/avs.code/v2.code/dataloader/audio/preprocess_vgg/vggish_input.py b/avs.code/v2.code/dataloader/audio/preprocess_vgg/vggish_input.py new file mode 100644 index 0000000000000000000000000000000000000000..9d58e81bc70a85138980128e033f271998794605 --- /dev/null +++ b/avs.code/v2.code/dataloader/audio/preprocess_vgg/vggish_input.py @@ -0,0 +1,98 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Compute input examples for VGGish from audio waveform.""" + +# Modification: Return torch tensors rather than numpy arrays +import torch + +import numpy as np +import resampy + +from dataloader.audio.preprocess_vgg import mel_features +from dataloader.audio.preprocess_vgg import vggish_params + +import soundfile as sf + + +def waveform_to_examples(data, sample_rate, return_tensor=True): + """Converts audio waveform into an array of examples for VGGish. + + Args: + data: np.array of either one dimension (mono) or two dimensions + (multi-channel, with the outer dimension representing channels). + Each sample is generally expected to lie in the range [-1.0, +1.0], + although this is not required. + sample_rate: Sample rate of data. + return_tensor: Return data as a Pytorch tensor ready for VGGish + + Returns: + 3-D np.array of shape [num_examples, num_frames, num_bands] which represents + a sequence of examples, each of which contains a patch of log mel + spectrogram, covering num_frames frames of audio and num_bands mel frequency + bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS. + + """ + # Convert to mono. + if len(data.shape) > 1: + data = np.mean(data, axis=1) + # Resample to the rate assumed by VGGish. + if sample_rate != vggish_params.SAMPLE_RATE: + data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE) + + # Compute log mel spectrogram features. + log_mel = mel_features.log_mel_spectrogram( + data, + audio_sample_rate=vggish_params.SAMPLE_RATE, + log_offset=vggish_params.LOG_OFFSET, + window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS, + hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS, + num_mel_bins=vggish_params.NUM_MEL_BINS, + lower_edge_hertz=vggish_params.MEL_MIN_HZ, + upper_edge_hertz=vggish_params.MEL_MAX_HZ) + + # Frame features into examples. + features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS + example_window_length = int(round( + vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate)) + example_hop_length = int(round( + vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate)) + log_mel_examples = mel_features.frame( + log_mel, + window_length=example_window_length, + hop_length=example_hop_length) + + if return_tensor: + log_mel_examples = torch.tensor( + log_mel_examples, requires_grad=True)[:, None, :, :].float() + + return log_mel_examples + + +def wavfile_to_examples(wav_file, return_tensor=True): + """Convenience wrapper around waveform_to_examples() for a common WAV format. + + Args: + wav_file: String path to a file, or a file-like object. The file + is assumed to contain WAV audio data with signed 16-bit PCM samples. + torch: Return data as a Pytorch tensor ready for VGGish + + Returns: + See waveform_to_examples. + """ + wav_data, sr = sf.read(wav_file, dtype='int16') + assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype + samples = wav_data / 32768.0 # Convert to [-1.0, +1.0] + return waveform_to_examples(samples, sr, return_tensor) diff --git a/avs.code/v2.code/dataloader/audio/preprocess_vgg/vggish_params.py b/avs.code/v2.code/dataloader/audio/preprocess_vgg/vggish_params.py new file mode 100644 index 0000000000000000000000000000000000000000..526784bceaa4c9c8b8dc2b8f82e0f3d395d4bec2 --- /dev/null +++ b/avs.code/v2.code/dataloader/audio/preprocess_vgg/vggish_params.py @@ -0,0 +1,53 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Global parameters for the VGGish model. + +See vggish_slim.py for more information. +""" + +# Architectural constants. +NUM_FRAMES = 96 # Frames in input mel-spectrogram patch. +NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch. +EMBEDDING_SIZE = 128 # Size of embedding layer. + +# Hyperparameters used in feature and example generation. +SAMPLE_RATE = 16000 +STFT_WINDOW_LENGTH_SECONDS = 0.025 +STFT_HOP_LENGTH_SECONDS = 0.010 +NUM_MEL_BINS = NUM_BANDS +MEL_MIN_HZ = 125 +MEL_MAX_HZ = 7500 +LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram. +EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames +EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap. + +# Parameters used for embedding postprocessing. +PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors' +PCA_MEANS_NAME = 'pca_means' +QUANTIZE_MIN_VAL = -2.0 +QUANTIZE_MAX_VAL = +2.0 + +# Hyperparameters used in training. +INIT_STDDEV = 0.01 # Standard deviation used to initialize weights. +LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer. +ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer. + +# Names of ops, tensors, and features. +INPUT_OP_NAME = 'vggish/input_features' +INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0' +OUTPUT_OP_NAME = 'vggish/embedding' +OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0' +AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding' diff --git a/avs.code/v2.code/dataloader/dataset.py b/avs.code/v2.code/dataloader/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cb266ac17dea1fc3415b2f83da04be5aff982c05 --- /dev/null +++ b/avs.code/v2.code/dataloader/dataset.py @@ -0,0 +1,79 @@ +"""Fused audio-visual dataset for AVSBench-style indexing.""" +import os +import random +import PIL.Image +import numpy +import torch +from dataloader.visual.visual_dataset import Visual +from dataloader.audio.audio_dataset import Audio +import pandas + + +class AV(torch.utils.data.Dataset): + """Pairs video frames + labels from `Visual` with log-mel spectrograms from `Audio` via `metadata.csv`.""" + + def __init__(self, split, augmentation, param, root_path=''): + # v2.code entry: always merge v1s + v1m + v2 from `avss_index/metadata.csv` (artifacts v2 pool). + # Visual/Audio get `root_path/v2` as base path; per-sample `load_data` uses full `file_path` (v1s|v1m|v2/uid). + v2_root = os.path.join(root_path, 'v2') + self.visual_dataset = Visual( + augmentation['visual'], + v2_root, + split, + param.image_size, + param.image_embedding_size, + ) + self.audio_dataset = Audio(augmentation['audio'], v2_root, split) + self.augment = augmentation + self.split = split + self.file_path = self.organise_files(self.split, root_path, csv_name_='avss_index/metadata.csv') + + def __getitem__(self, index): + mixing_prob = 0. # we omit this option. + other_index = random.randint(1, self.__len__()) - 1 if random.random() < mixing_prob and self.split == 'train' else None + frame, label, prompts = self.visual_dataset.load_data(self.file_path[index]) + if other_index is not None: + other_frame, other_label, other_prompts = self.visual_dataset.load_data(self.file_path[other_index]) + frame, label, prompts = self.visual_mix(frame, other_frame, label, other_label, prompts, other_prompts) + audio_mel = self.audio_dataset.load_audio_wave(self.file_path[index], self.file_path[other_index]) + else: + audio_mel = self.audio_dataset.load_audio_wave(self.file_path[index], None) + + assert other_index is None if self.split == 'test' else 1, print('no mix in validation.') + + return {'frame': frame, 'label': label, 'spectrogram': audio_mel, 'id': self.file_path[index], + 'prompts': prompts} + + def __len__(self): + return len(self.file_path) + + @staticmethod + def organise_files(split_, root_path_, csv_name_): + total_files = pandas.read_csv(os.path.join(root_path_, csv_name_)) + files_info_v2 = total_files[(total_files["split"] == split_) & (total_files["label"] == 'v2')]['uid'] + files_path_v2 = [os.path.join(root_path_, 'v2', files_name) for files_name in files_info_v2] + files_info_v1s = total_files[(total_files["split"] == split_) & (total_files["label"] == 'v1s')]['uid'] + files_path_v1s = [os.path.join(root_path_, 'v1s', files_name) for files_name in files_info_v1s] + files_info_v1m = total_files[(total_files["split"] == split_) & (total_files["label"] == 'v1m')]['uid'] + files_path_v1m = [os.path.join(root_path_, 'v1m', files_name) for files_name in files_info_v1m] + files_path = files_path_v1s + files_path_v1m + files_path_v2 + del total_files + return files_path + + @staticmethod + def visual_mix(frame1, frame2, label1, label2, prompts1, prompts2): + mix_frame = frame1.clone() + mix_label = label1.clone() + bbx1, bby1, bbx2, bby2 = 0, 0, mix_label.shape[1] - 1, mix_label.shape[2] - 1 + + for i in range(0, mix_frame.shape[0]): + label_canvas_foreground = label2[i, bbx1:bbx2, bby1:bby2] > 0. + mix_frame[i, :, bbx1:bbx2, bby1:bby2][:, label_canvas_foreground] = ( + frame2[i, :, bbx1:bbx2, bby1:bby2][:, label_canvas_foreground]) + mix_label[i, bbx1:bbx2, bby1:bby2][label_canvas_foreground] = ( + label2[i, bbx1:bbx2, bby1:bby2][label_canvas_foreground]) + + return mix_frame, mix_label, prompts1 + + + diff --git a/avs.code/v2.code/dataloader/sam2_dataset/__init__.py b/avs.code/v2.code/dataloader/sam2_dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/avs.code/v2.code/dataloader/sam2_dataset/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/avs.code/v2.code/dataloader/sam2_dataset/transforms.py b/avs.code/v2.code/dataloader/sam2_dataset/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..7731e59ba98a5465493e3a9c4b785eb4d4420ca2 --- /dev/null +++ b/avs.code/v2.code/dataloader/sam2_dataset/transforms.py @@ -0,0 +1,528 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Transforms and data augmentation for both image + bbox. +""" + +import logging + +import random +from typing import Iterable + +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as F +import torchvision.transforms.v2.functional as Fv2 +from PIL import Image as PILImage +# from docutils.nodes import label +import numpy +from torchvision.transforms import InterpolationMode + +# from utils.data_utils import VideoDatapoint + + +def hflip(frames, labels, index): + # print(index) + # print(len(frames), frames[index].size, type(frames[index])) + # print(len(labels), labels[index].size, type(labels[index])) + frames[index] = F.hflip(frames[index]) + labels[index] = F.hflip(labels[index]) + # for obj in frames[index].objects: + # if obj.segment is not None: + # obj.segment = F.hflip(obj.segment) + + return frames, labels + + +def get_size_with_aspect_ratio(image_size, size, max_size=None): + w, h = image_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = max_size * min_original_size / max_original_size + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = int(round(size)) + oh = int(round(size * h / w)) + else: + oh = int(round(size)) + ow = int(round(size * w / h)) + + return (oh, ow) + + +def resize(frames, labels, index, size, max_size=None, square=False, v2=False): + # size can be min_size (scalar) or (w, h) tuple + def get_size(image_size, size, max_size=None): + if isinstance(size, (list, tuple)): + return size[::-1] + else: + return get_size_with_aspect_ratio(image_size, size, max_size) + + if square: + size = size, size + else: + raise NotImplementedError + # cur_size = ( + # frames[index].data.size()[-2:][::-1] + # if v2 + # else frames[index].data.size + # ) + # size = get_size(cur_size, size, max_size) + + # old_size = ( + # frames[index].data.size()[-2:][::-1] + # if v2 + # else frames[index].data.size + # ) + if v2: + frames[index].data = Fv2.resize( + frames[index].data, size, antialias=True + ) + else: + frames[index] = F.resize(frames[index], size) + labels[index] = F.resize(labels[index], size) + # new_size = ( + # frames[index].data.size()[-2:][::-1] + # if v2 + # else frames[index].data.size + # ) + + # for obj in frames[index].objects: + # if obj.segment is not None: + # obj.segment = F.resize(obj.segment[None, None], size).squeeze() + + # h, w = size + # frames[index].size = (h, w) + return frames, labels + + +def pad(frames, index, padding, v2=False): + old_h, old_w = frames[index].size + h, w = old_h, old_w + if len(padding) == 2: + # assumes that we only pad on the bottom right corners + frames[index].data = F.pad( + frames[index].data, (0, 0, padding[0], padding[1]) + ) + h += padding[1] + w += padding[0] + else: + # left, top, right, bottom + frames[index].data = F.pad( + frames[index].data, + (padding[0], padding[1], padding[2], padding[3]), + ) + h += padding[1] + padding[3] + w += padding[0] + padding[2] + + frames[index].size = (h, w) + + for obj in frames[index].objects: + if obj.segment is not None: + if v2: + if len(padding) == 2: + obj.segment = Fv2.pad(obj.segment, (0, 0, padding[0], padding[1])) + else: + obj.segment = Fv2.pad(obj.segment, tuple(padding)) + else: + if len(padding) == 2: + obj.segment = F.pad(obj.segment, (0, 0, padding[0], padding[1])) + else: + obj.segment = F.pad(obj.segment, tuple(padding)) + return frames + + +class RandomHorizontalFlip: + def __init__(self, consistent_transform, p=0.5): + self.p = p + self.consistent_transform = consistent_transform + + def __call__(self, frames, labels, **kwargs): + if self.consistent_transform: + if random.random() < self.p: + for i in range(len(frames)): + frames, labels = hflip(frames, labels, i) + return frames, labels + for i in range(len(frames)): + if random.random() < self.p: + frames, labels = hflip(frames, labels, i) + return frames, labels + + +class RandomResizeAPI: + def __init__( + self, sizes, consistent_transform, max_size=None, square=False, v2=False + ): + if isinstance(sizes, int): + sizes = (sizes,) + assert isinstance(sizes, Iterable) + self.sizes = list(sizes) + self.max_size = max_size + self.square = square + self.consistent_transform = consistent_transform + self.v2 = v2 + + def __call__(self, frames, labels): + if self.consistent_transform: + size = random.choice(self.sizes) + for i in range(len(frames)): + frames, labels = resize( + frames, labels, i, size, self.max_size, square=self.square, v2=self.v2 + ) + return frames, labels + for i in range(len(frames)): + size = random.choice(self.sizes) + frames, labels = resize( + frames, labels, i, size, self.max_size, square=self.square, v2=self.v2 + ) + return frames, labels + + +class ToTensorAPI: + def __init__(self, v2=False): + self.v2 = v2 + + def __call__(self, frames, labels, **kwargs): + for img_idx in range(len(frames)): + if self.v2: + raise NotImplementedError + # frames[img_idx] = Fv2.to_tensor(frames[img_idx]) + else: + frames[img_idx] = F.to_tensor(frames[img_idx]) + labels[img_idx] = torch.tensor(numpy.array(labels[img_idx]), dtype=torch.float) + return frames, labels + + +class NormalizeAPI: + def __init__(self, mean, std, v2=False): + self.mean = mean + self.std = std + self.v2 = v2 + + def __call__(self, frames, labels, **kwargs): + for img_idx in range(len(frames)): + # if self.v2: + # img.data = Fv2.convert_image_dtype(img.data, torch.float32) + # img.data = Fv2.normalize(img.data, mean=self.mean, std=self.std) + # else: + frames[img_idx] = F.normalize(frames[img_idx], mean=self.mean, std=self.std) + + return frames, labels + +''' + + + + + + + + +''' +class ComposeAPI: + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, frames, labels, **kwargs): + for t in self.transforms: + frames, labels = t(frames, labels, **kwargs) + return frames, labels + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string + + +class RandomGrayscale: + def __init__(self, consistent_transform, p=0.5): + self.p = p + self.consistent_transform = consistent_transform + self.Grayscale = T.Grayscale(num_output_channels=3) + + def __call__(self, frames, labels, **kwargs): + if self.consistent_transform: + if random.random() < self.p: + for img_idx in range(len(frames)): + frames[img_idx] = self.Grayscale(frames[img_idx]) + return frames, labels + for img_idx in range(len(frames)): + if random.random() < self.p: + frames[img_idx] = self.Grayscale(frames[img_idx]) + return frames, labels + + +class ColorJitter: + def __init__(self, consistent_transform, brightness, contrast, saturation, hue): + self.consistent_transform = consistent_transform + self.brightness = ( + brightness + if isinstance(brightness, list) + else [max(0, 1 - brightness), 1 + brightness] + ) + self.contrast = ( + contrast + if isinstance(contrast, list) + else [max(0, 1 - contrast), 1 + contrast] + ) + self.saturation = ( + saturation + if isinstance(saturation, list) + else [max(0, 1 - saturation), 1 + saturation] + ) + self.hue = hue if isinstance(hue, list) or hue is None else ([-hue, hue]) + + def __call__(self, frames, labels, **kwargs): + if self.consistent_transform: + # Create a color jitter transformation params + ( + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ) = T.ColorJitter.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) + for img in frames: + if not self.consistent_transform: + ( + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ) = T.ColorJitter.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) + for fn_id in fn_idx: + if fn_id == 0 and brightness_factor is not None: + img = F.adjust_brightness(img, brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + img = F.adjust_contrast(img, contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + img = F.adjust_saturation(img, saturation_factor) + elif fn_id == 3 and hue_factor is not None: + img = F.adjust_hue(img, hue_factor) + return frames, labels + + +class RandomAffine: + def __init__( + self, + degrees, + consistent_transform, + scale=None, + translate=None, + shear=None, + image_mean=(123, 116, 103), + label_fill_value=0., + log_warning=True, + num_tentatives=1, + image_interpolation="bicubic", + ): + """ + The mask is required for this transform. + if consistent_transform if True, then the same random affine is applied to all frames and masks. + """ + self.degrees = degrees if isinstance(degrees, list) else ([-degrees, degrees]) + self.scale = scale + self.shear = ( + shear if isinstance(shear, list) else ([-shear, shear] if shear else None) + ) + self.translate = translate + self.fill_img = image_mean + self.fill_label = label_fill_value + self.consistent_transform = consistent_transform + self.log_warning = log_warning + self.num_tentatives = num_tentatives + assert self.num_tentatives >= 1., 'must have at least one if we utilise the augmentation.' + + if image_interpolation == "bicubic": + self.image_interpolation = InterpolationMode.BICUBIC + elif image_interpolation == "bilinear": + self.image_interpolation = InterpolationMode.BILINEAR + else: + raise NotImplementedError + + def __call__(self, frames, labels, **kwargs): + for _tentative in range(self.num_tentatives): + res_img, res_labels = self.transform_frames(frames, labels) + # if res is not None: + return res_img, res_labels + + # raise NotImplementedError + # if self.log_warning: + # logging.warning( + # f"Skip RandomAffine for zero-area mask in first frame after {self.num_tentatives} tentatives" + # ) + # return frames + + def transform_frames(self, frames, labels): + _, height, width = F.get_dimensions(frames[0]) + img_size = [width, height] + + if self.consistent_transform: + # Create a random affine transformation + affine_params = T.RandomAffine.get_params( + degrees=self.degrees, + translate=self.translate, + scale_ranges=self.scale, + shears=self.shear, + img_size=img_size, + ) + + for img_idx, img in enumerate(frames): + if not self.consistent_transform: + # if not consistent we create a new affine params for every frame&mask pair Create a random affine transformation + affine_params = T.RandomAffine.get_params( + degrees=self.degrees, + translate=self.translate, + scale_ranges=self.scale, + shears=self.shear, + img_size=img_size, + ) + frames[img_idx] = F.affine( + img, + *affine_params, + interpolation=self.image_interpolation, + fill=self.fill_img, + ) + labels[img_idx] = F.affine( + labels[img_idx], + *affine_params, + # default: interpolation='nearest', + fill=self.fill_label, + ) + return frames, labels + + +''' +def random_mosaic_frame( + datapoint, + index, + grid_h, + grid_w, + target_grid_y, + target_grid_x, + should_hflip, +): + # Step 1: downsize the images and paste them into a mosaic + image_data = datapoint.frames[index].data + is_pil = isinstance(image_data, PILImage.Image) + if is_pil: + H_im = image_data.height + W_im = image_data.width + image_data_output = PILImage.new("RGB", (W_im, H_im)) + else: + H_im = image_data.size(-2) + W_im = image_data.size(-1) + image_data_output = torch.zeros_like(image_data) + + downsize_cache = {} + for grid_y in range(grid_h): + for grid_x in range(grid_w): + y_offset_b = grid_y * H_im // grid_h + x_offset_b = grid_x * W_im // grid_w + y_offset_e = (grid_y + 1) * H_im // grid_h + x_offset_e = (grid_x + 1) * W_im // grid_w + H_im_downsize = y_offset_e - y_offset_b + W_im_downsize = x_offset_e - x_offset_b + + if (H_im_downsize, W_im_downsize) in downsize_cache: + image_data_downsize = downsize_cache[(H_im_downsize, W_im_downsize)] + else: + image_data_downsize = F.resize( + image_data, + size=(H_im_downsize, W_im_downsize), + interpolation=InterpolationMode.BILINEAR, + antialias=True, # antialiasing for downsizing + ) + downsize_cache[(H_im_downsize, W_im_downsize)] = image_data_downsize + if should_hflip[grid_y, grid_x].item(): + image_data_downsize = F.hflip(image_data_downsize) + + if is_pil: + image_data_output.paste(image_data_downsize, (x_offset_b, y_offset_b)) + else: + image_data_output[:, y_offset_b:y_offset_e, x_offset_b:x_offset_e] = ( + image_data_downsize + ) + + datapoint.frames[index].data = image_data_output + + # Step 2: downsize the masks and paste them into the target grid of the mosaic + for obj in datapoint.frames[index].objects: + if obj.segment is None: + continue + assert obj.segment.shape == (H_im, W_im) and obj.segment.dtype == torch.uint8 + segment_output = torch.zeros_like(obj.segment) + + target_y_offset_b = target_grid_y * H_im // grid_h + target_x_offset_b = target_grid_x * W_im // grid_w + target_y_offset_e = (target_grid_y + 1) * H_im // grid_h + target_x_offset_e = (target_grid_x + 1) * W_im // grid_w + target_H_im_downsize = target_y_offset_e - target_y_offset_b + target_W_im_downsize = target_x_offset_e - target_x_offset_b + + segment_downsize = F.resize( + obj.segment[None, None], + size=(target_H_im_downsize, target_W_im_downsize), + interpolation=InterpolationMode.BILINEAR, + antialias=True, # antialiasing for downsizing + )[0, 0] + if should_hflip[target_grid_y, target_grid_x].item(): + segment_downsize = F.hflip(segment_downsize[None, None])[0, 0] + + segment_output[ + target_y_offset_b:target_y_offset_e, target_x_offset_b:target_x_offset_e + ] = segment_downsize + obj.segment = segment_output + + return datapoint + + +class RandomMosaicVideoAPI: + def __init__(self, prob=0.15, grid_h=2, grid_w=2, use_random_hflip=False): + self.prob = prob + self.grid_h = grid_h + self.grid_w = grid_w + self.use_random_hflip = use_random_hflip + + def __call__(self, frames, **kwargs): + if random.random() > self.prob: + return datapoint + + # select a random location to place the target mask in the mosaic + target_grid_y = random.randint(0, self.grid_h - 1) + target_grid_x = random.randint(0, self.grid_w - 1) + # whether to flip each grid in the mosaic horizontally + if self.use_random_hflip: + should_hflip = torch.rand(self.grid_h, self.grid_w) < 0.5 + else: + should_hflip = torch.zeros(self.grid_h, self.grid_w, dtype=torch.bool) + for i in range(len(datapoint.frames)): + datapoint = random_mosaic_frame( + datapoint, + i, + grid_h=self.grid_h, + grid_w=self.grid_w, + target_grid_y=target_grid_y, + target_grid_x=target_grid_x, + should_hflip=should_hflip, + ) + + return datapoint +''' \ No newline at end of file diff --git a/avs.code/v2.code/dataloader/visual/visual_augmentation.py b/avs.code/v2.code/dataloader/visual/visual_augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..5d40aed7c8b8c08d50a46db122e1213bd4878afd --- /dev/null +++ b/avs.code/v2.code/dataloader/visual/visual_augmentation.py @@ -0,0 +1,140 @@ +import random + +import matplotlib.pyplot as plt +import numpy +import torch +import torchvision.transforms.functional as F +import torchvision.transforms as transforms + + +class Augmentation(object): + def __init__(self, image_mean, image_std, image_width, image_height, scale_list, ignore_index=255): + self.image_size = (image_height, image_width) + # self.image_norm = (image_mean, image_std) + # self.get_crop_pos = transforms.RandomCrop(self.image_size) + self.color_jitter = transforms.ColorJitter(brightness=.5, contrast=.5, saturation=.5, hue=.25) + self.gaussian_blurring = transforms.GaussianBlur((3, 3)) + self.scale_list = scale_list + + self.normalise = transforms.Normalize(mean=image_mean, std=image_std) + self.to_tensor = transforms.ToTensor() + + self.ignore_index = ignore_index + + # self.normalise = transforms.Normalize(mean=image_mean, std=image_std) + + # if setup == "avs" or setup == "avss" or setup == "avss_binary": + # # AVS + # self.scale_list = [.5, .75, 1.] + # self.color_jitter = None + # else: + # # COCO + # # self.scale_list = [.75, 1., 1.25, 1.5, 1.75, 2.] + # self.scale_list = [0.5,0.75,1.0,1.25,1.5,1.75,2.0] + + # def normalise(self, image): + # image = image / 255.0 + # image = image - self.image_norm[0] + # image = image / self.image_norm[1] + # return image + + def resize(self, image_, label_, size=None): + h_, w_ = self.image_size if size is None else size + image_ = F.resize(image_, (h_, w_), transforms.InterpolationMode.BICUBIC) + label_ = F.resize(label_, (h_, w_), transforms.InterpolationMode.NEAREST) + return image_, label_ + + def random_crop_with_padding(self, image_, label_): + w_, h_ = image_.size + if min(h_, w_) < min(self.image_size): + res_w_ = max(self.image_size[0] - w_, 0) + res_h_ = max(self.image_size[1] - h_, 0) + image_ = F.pad(image_, [0, 0, res_w_, res_h_], fill=(numpy.array(self.image_norm[0]) * 255.).tolist()) + # image_ = F.pad(image_, [0, 0, res_w_, res_h_], fill=self.ignore_index) # if error, define the padding value. + label_ = F.pad(label_, [0, 0, res_w_, res_h_], fill=self.ignore_index) + + pos_ = self.get_crop_pos.get_params(image_, self.image_size) + image_ = F.crop(image_, *pos_) + label_ = F.crop(label_, *pos_) + + return image_, label_ + + # @staticmethod + def random_scales(self, image_, label_): + w_, h_ = image_.size + chosen_scale = random.choice(self.scale_list) + w_, h_ = int(w_ * chosen_scale), int(h_ * chosen_scale) + image_ = F.resize(image_, (h_, w_), transforms.InterpolationMode.BICUBIC) + label_ = F.resize(label_, (h_, w_), transforms.InterpolationMode.NEAREST) + return image_, label_ + + @staticmethod + def random_flip_h(image_, label_): + chosen_flip = random.random() > 0.5 + image_ = F.hflip(image_) if chosen_flip else image_ + label_ = F.hflip(label_) if chosen_flip else label_ + return image_, label_ + + def augment_entire_clip(self, x_list, y_list): + degree_ = float(torch.empty(1).uniform_(float(-25.), float(25.)).item()) + shear_ = [float(torch.empty(1).uniform_(float(-20.), float(20.)).item()), + torch.empty(1).uniform_(float(-20.), float(20.)).item()] + dice = random.random() + for index, single_x in enumerate(x_list): + if dice <= 0.1: + single_x = F.rgb_to_grayscale(single_x, num_output_channels=3) + + single_x = F.affine(single_x, angle=degree_, shear=shear_, translate=[0,0], scale=1., + interpolation=transforms.InterpolationMode.BILINEAR, fill=[0., 0., 0.]) + single_y = F.affine(y_list[index], angle=degree_, shear=shear_, translate=[0,0], scale=1., + interpolation=transforms.InterpolationMode.NEAREST, fill=[0.]) + x_list[index] = single_x + y_list[index] = single_y + + return x_list, y_list + + + + + def train_aug(self, x_, y_): + x_, y_ = self.random_flip_h(x_, y_) + # # x, y = self.random_scales(x, y) + x_, y_ = self.resize(x_, y_) + + if self.color_jitter is not None and random.random() < 0.5: + x_ = self.color_jitter(x_) + if self.gaussian_blurring is not None and random.random() < 0.5: + x_ = self.gaussian_blurring(x_) + + # x, y = self.random_crop_with_padding(x, y) + + x_ = self.normalise(self.to_tensor(x_)).type(torch.float32) + # receive pseudo labels. + y_ = torch.tensor(numpy.array(y_)[numpy.newaxis, ...], dtype=torch.float) + return x_, y_ + + def test_process(self, x_, y_): + # x = self.to_tensor(x) + # y = torch.tensor(numpy.asarray(y)).long() + + # following AVSbench setup, we fix image size (224, 224) + x_, y_ = self.resize(x_, y_) + + x_ = self.normalise(self.to_tensor(x_)).type(torch.float32) + y_ = torch.tensor(numpy.array(y_)[numpy.newaxis, ...], dtype=torch.float) + return x_, y_ + + def __call__(self, x, y, split): + return self.train_aug(x, y) if split == "train" \ + else self.test_process(x, y) + + + + + + + + + + + diff --git a/avs.code/v2.code/dataloader/visual/visual_dataset.py b/avs.code/v2.code/dataloader/visual/visual_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..922aa12196dc74b0c3a6d957d10d19f842c4f97b --- /dev/null +++ b/avs.code/v2.code/dataloader/visual/visual_dataset.py @@ -0,0 +1,124 @@ +import os +import re +import PIL.Image +import matplotlib.pyplot as plt +import numpy +import torch +import pandas +import torchvision + + +class Visual(torch.utils.data.Dataset): + def __init__(self, augmentation, directory_path, split, image_size, image_embedding_size): + self.augment = augmentation + self.directory_path = directory_path + self.split = split + self.image_size = image_size + self.embedding_size = image_embedding_size + + def load_data(self, file_prefix): + frame_path = os.path.join(file_prefix, 'frames') + frame_path = [os.path.join(frame_path, i) for i in os.listdir(frame_path)] + label_path = os.path.join(file_prefix, 'labels_rgb') + label_path = [os.path.join(label_path, i) for i in os.listdir(label_path)] + + # if self.split == 'train': + # label_path += [os.path.join(file_prefix.replace('v1s', 'v1s_sam2_pseudo_labels'), i) for i in + # os.listdir(file_prefix.replace('v1s', 'v1s_sam2_pseudo_labels'))] + + frame_path.sort(key=lambda x: tuple(map(int, x.split('/')[-1].split("_")[-1].split('.jpg')[0]))) + label_path.sort(key=lambda x: tuple(map(int, x.split('/')[-1].split("_")[-1].split('.png')[0]))) + + frame = [PIL.Image.open(i) for i in frame_path] + label = [PIL.Image.open(i).convert('L') for i in label_path] + if self.split == 'train' and '/v1s/' in file_prefix: + # v1s: only first frame has mask files; pad with empty masks to match frame count (artifacts). + label += [PIL.Image.new('L', frame[0].size)] * (len(frame) - len(label)) + + # Length must match clip length (e.g. v2 @ 10 frames); first slot True matches artifacts' "frame 0" hint. + label_idx = torch.tensor([True] + [False] * (len(frame) - 1), dtype=torch.bool) + + # receive the prompts from the ground truth. + # prompts = {"point_coords": torch.nan, "point_labels": torch.nan, + # "masks": [None]*len(frame), "box_coords": [None]*len(frame)} + + prompts = {} + image_batch = [None]*len(frame) + label_batch = [None]*len(frame) + + if self.split == 'train': + # frame, label = self.augment.augment_entire_clip(frame, label) + frame, label = self.augment(frame, label) + + + for i in range(len(frame)): + if self.split == 'test': + curr_frame, curr_label = self.augment(frame[i], label[i], split=self.split) + else: + curr_frame, curr_label = frame[i], label[i] + # if self.split == 'train' and i > 0: + # curr_label = curr_label / 255. + # curr_label[curr_label > 0.5] = 1 + # curr_label[curr_label < 0.5] = 0 + # # curr_label[(0.05 < curr_label) & (curr_label < 0.95)] = 255 + # # we temporarily make it to be hard mask; + # # curr_label = ((curr_label / 255.) - 0.5) * 2 + # # curr_label[curr_label >= 0.] = 1. + # # curr_label[curr_label < 0.] = 0. + # else: + curr_label[curr_label > 0.] = 1. + image_batch[i], label_batch[i] = curr_frame, curr_label + + # image_batch[i], label_batch[i] = self.augment(frame[i], label[i], split=self.split) + # note: we simply convert the code to binary mask in v1s, v1m; + # to some reason, we failed to load the label in `L' format and had to hardcoding here. + # label_batch[i][label_batch[i] > 0.] = 1. + + # prompts['box_coords'][i], prompts['masks'][i] = self.receive_other_prompts(label_batch[i]) + + # organise the prompts + # prompts.update({'masks': torch.stack(prompts['masks'], dim=0)}) + # prompts.update({'box_coords': torch.stack(prompts['box_coords'], dim=0)}) + # prompts.update({'point_labels': torch.stack(prompts['point_labels'], dim=0)}) + prompts.update({'label_index': label_idx}) + return torch.stack(image_batch, dim=0), torch.stack(label_batch, dim=0), prompts + + def receive_other_prompts(self, y_): + # y_ = torch.zeros_like(y_) + if len(torch.unique(y_)) > 1: + # foreground point + points_foreground = torch.stack(torch.where(y_ > 0)[::-1], dim=0).transpose(1, 0) + + # bbox prompt (left-top corner & right-bottom corner) + bbox_one = torch.min(points_foreground[:, 0]), torch.min(points_foreground[:, 1]) + bbox_fou = torch.max(points_foreground[:, 0]), torch.max(points_foreground[:, 1]) + bbox_coord = torch.tensor(bbox_one + bbox_fou, dtype=torch.float) + bbox_coord = self.transform_coords(bbox_coord, orig_hw=y_.squeeze().shape) + # mask prompt + low_mask = torchvision.transforms.functional.resize(y_.clone(), [self.embedding_size*4, self.embedding_size*4], + torchvision.transforms.InterpolationMode.NEAREST) + else: + # for the pure background situation. + bbox_coord = torch.zeros([4], dtype=torch.float).fill_(float('nan')) + low_mask = torch.zeros([1, self.embedding_size*4, self.embedding_size*4], dtype=torch.float).fill_(float('nan')) + + return bbox_coord, low_mask + + # we transfer the coords to SAM's input resolution (1024, 1024). + def transform_coords(self, coords: torch.Tensor, orig_hw=None) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, + If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + + Returns + Un-normalized coordinates in the range of [0, 1] which is expected by the sam2 model. + """ + h, w = orig_hw + coords = coords.clone().reshape(-1, 2, 2) + coords[..., 0] = coords[..., 0] / w + coords[..., 1] = coords[..., 1] / h + coords = coords * self.image_size # unnormalize coords + return coords.reshape(4) + + + diff --git a/avs.code/v2.code/inference.py b/avs.code/v2.code/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..b8040ed9e7d011381d28bd8fc9701f554a2fb1fb --- /dev/null +++ b/avs.code/v2.code/inference.py @@ -0,0 +1,193 @@ +"""Distributed inference on the test set; runs the same three `process` modes as training validation.""" +import os +import pathlib +import torch +import numpy +import random +import argparse +from easydict import EasyDict + +# Avoid import failure when configs.config creates saved_dir without write permission. +_real_mkdir = pathlib.Path.mkdir + + +def _safe_mkdir(self, mode=0o777, parents=False, exist_ok=False): + try: + return _real_mkdir(self, mode, parents=parents, exist_ok=exist_ok) + except PermissionError: + pass + + +pathlib.Path.mkdir = _safe_mkdir + + +def seed_it(seed): + random.seed(seed) + os.environ["PYTHONSEED"] = str(seed) + numpy.random.seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.enabled = True + torch.manual_seed(seed) + + +class _DummyTensorboard: + """Minimal Tensorboard stub so Trainer.valid runs without wandb logging.""" + + def upload_wandb_info(self, info_dict): + pass + + def upload_wandb_image(self, *args, **kwargs): + pass + + +def main(local_rank, ngpus_per_node, hyp_param): + hyp_param.local_rank = local_rank + torch.distributed.init_process_group( + backend='nccl', + init_method='env://', + rank=hyp_param.local_rank, + world_size=hyp_param.gpus * 1 + ) + seed_it(local_rank + hyp_param.seed) + + import model.visual.sam2 # noqa: F401 — registers Hydra `configs` + from hydra import compose + from omegaconf import OmegaConf + + arch_h = compose(config_name='auralfuser/architecture.yaml') + OmegaConf.resolve(arch_h) + hyp_param.aural_fuser = OmegaConf.to_container(arch_h.aural_fuser, resolve=True) + + train_cfg = compose(config_name='training/sam2_training_config.yaml') + OmegaConf.resolve(train_cfg) + hyp_param.contrastive_learning = OmegaConf.to_container(train_cfg.contrastive_learning, resolve=True) + + from model.mymodel import AVmodel + av_model = AVmodel(hyp_param).cuda() + torch.cuda.set_device(hyp_param.local_rank) + ckpt_sd = torch.load(hyp_param.inference_ckpt, map_location="cpu") + if not isinstance(ckpt_sd, dict): + raise TypeError("Checkpoint must be a state_dict dictionary.") + # Support both formats: + # 1) full-model checkpoint (keys like `v_model.*`, `aural_fuser.*`) + # 2) train-only checkpoint for aural_fuser (keys without `aural_fuser.` prefix) + if any(k.startswith("v_model.") or k.startswith("aural_fuser.") for k in ckpt_sd.keys()): + av_model.load_state_dict(ckpt_sd, strict=True) + else: + av_model.aural_fuser.load_state_dict(ckpt_sd, strict=True) + + av_model = torch.nn.parallel.distributed.DistributedDataParallel(av_model, device_ids=[hyp_param.local_rank], + find_unused_parameters=False) + av_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(av_model) + av_model.eval() + + from dataloader.dataset import AV + from dataloader.visual.visual_augmentation import Augmentation as VisualAugmentation + from dataloader.audio.audio_augmentation import Augmentation as AudioAugmentation + from torch.utils.data import DataLoader, Subset + from torch.utils.data.distributed import DistributedSampler + + visual_augmentation = VisualAugmentation(hyp_param.image_mean, hyp_param.image_std, + hyp_param.image_size, hyp_param.image_size, + hyp_param.scale_list, ignore_index=hyp_param.ignore_index) + audio_augmentation = AudioAugmentation(mono=True) + + dataset = AV(split='test', augmentation={"visual": visual_augmentation, "audio": audio_augmentation}, + param=hyp_param, root_path=hyp_param.data_root_path) + + max_batches = getattr(hyp_param, "inference_max_batches", 0) or 0 + if max_batches > 0: + n_samples = min(max_batches * hyp_param.batch_size, len(dataset)) + dataset = Subset(dataset, range(n_samples)) + + sampler = DistributedSampler(dataset, shuffle=False) + test_dataloader = DataLoader(dataset, batch_size=hyp_param.batch_size, sampler=sampler, + num_workers=hyp_param.num_workers) + + from trainer.train import Trainer + from utils.foreground_iou import ForegroundIoU + from utils.foreground_fscore import ForegroundFScore + + metrics = { + "foreground_iou": ForegroundIoU(), + "foreground_f-score": ForegroundFScore(hyp_param.local_rank), + } + trainer = Trainer(hyp_param, loss=None, tensorboard=_DummyTensorboard(), metrics=metrics) + + # Same three modes as main.py validation: default first mask / iou_select / iou_occ_select + runs = [ + ("", "default (logits[:,0])"), + ("iou_select", "iou_select"), + ("iou_occ_select", "iou_occ_select"), + ] + results = [] + for process, label in runs: + fiou, ffscore = trainer.valid(epoch=0, dataloader=test_dataloader, model=av_model, process=process) + results.append((label, fiou, ffscore)) + torch.cuda.empty_cache() + + if hyp_param.local_rank <= 0: + print("\n========== inference (same three process flags as training valid) ==========") + for label, fiou, ffscore in results: + print(" {:32s} f_iou={} f_f-score={}".format(label, fiou, ffscore)) + print("=======================================================\n") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Inference: full test set + three process modes') + + parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N') + + parser.add_argument("--local_rank", type=int, default=-1, + help='multi-process training for DDP') + + parser.add_argument('-g', '--gpus', default=1, type=int, + help='number of gpus per node') + + parser.add_argument('--batch_size', default=1, type=int, + help='Batch size (match training if needed)') + + parser.add_argument('--epochs', default=80, type=int, + help="unused") + + parser.add_argument('--lr', default=1e-5, type=float, + help="unused") + + parser.add_argument('--online', action="store_true", + help='unused') + + parser.add_argument( + '--inference_ckpt', type=str, default=None, + help='Trained AuralSAM2 checkpoint (.pth state_dict: full model or aural_fuser-only). ' + 'SAM2 backbone is loaded from backbone_weight in configs (same path as training: repo_root/ckpts/sam_ckpts/). ' + 'Default if unset: avs.code/training_details/.../hiera_l.pth', + ) + parser.add_argument('--inference_max_batches', type=int, default=0, + help='0 = full test; >0 = first N batches only (debug)') + + args = parser.parse_args() + + from configs.config import C + + args = EasyDict({**C, **vars(args)}) + + _repo = pathlib.Path(__file__).resolve().parent + # Repo root: .../AuralSAM2 (parent of avs.code) + _workspace = _repo.parent.parent + args.data_root_path = str(_workspace / 'AVSBench') + args.backbone_weight = str(_workspace / 'ckpts' / 'sam_ckpts' / 'sam2_hiera_large.pt') + args.audio.PRETRAINED_VGGISH_MODEL_PATH = str(_workspace / 'ckpts' / 'vggish-10086976.pth') + args.saved_dir = '/tmp/v2_infer_ckpt' + pathlib.Path(args.saved_dir).mkdir(parents=True, exist_ok=True) + if args.inference_ckpt is None: + args.inference_ckpt = str( + _repo.parent / 'training_details' / 'v2' / 'hiera_l' / 'hiera_l.pth' + ) + + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '9901' + + torch.multiprocessing.spawn(main, nprocs=args.gpus, args=(args.gpus, args)) diff --git a/avs.code/v2.code/loss/training/__init__.py b/avs.code/v2.code/loss/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..da8932ea6e0916b9f16cb514f1e64704440d554e --- /dev/null +++ b/avs.code/v2.code/loss/training/__init__.py @@ -0,0 +1,2 @@ +"""Training loss modules.""" + diff --git a/avs.code/v2.code/loss/training/contrastive_learning.py b/avs.code/v2.code/loss/training/contrastive_learning.py new file mode 100644 index 0000000000000000000000000000000000000000..b4097c017ab93e2cf1d5b4e27d8b74a8ce26f134 --- /dev/null +++ b/avs.code/v2.code/loss/training/contrastive_learning.py @@ -0,0 +1,209 @@ +from abc import ABC + +import torch +import torch.nn as nn + + +class ContrastLoss(nn.Module, ABC): + def __init__(self, hyp_param): + super().__init__() + self.param = hyp_param + _defaults = { + "temperature": 0.10, + "ignore_idx": 255, + "ood_idx": 254, + "max_views": 512, + "proj_dim": 512, + "sample_limits": 128, + "total_limits": 64, + } + _raw = getattr(hyp_param, "contrastive_learning", None) or {} + _cfg = {**_defaults, **_raw} + self.temperature = _cfg["temperature"] + self.ignore_idx = _cfg["ignore_idx"] + self.ood_idx = _cfg["ood_idx"] + self.max_views = _cfg["max_views"] + self.proj_dim = _cfg["proj_dim"] + self.sample_limits = _cfg["sample_limits"] + self.total_limits = _cfg["total_limits"] + + def select_class_wise_samples(self, embeddings, audio_embeddings, predictions, masks, batch_idx): + embedding_sample_list = [] + label_list = [] + embedding_sample_list_a = [] + label_list_a = [] + class_index_list = torch.unique(masks) + + if len(class_index_list) > 1: + for class_index in class_index_list[1:]: + embedding_sample_list_a.append(audio_embeddings.unsqueeze(0)) + label_list_a.append(class_index.unsqueeze(0) + batch_idx * 1e3) + else: + embedding_sample_list_a.append(audio_embeddings.unsqueeze(0)) + label_list_a.append(torch.zeros([1], device=embeddings.device) + batch_idx * 1e3) + + sample_limits = self.sample_limits + embeddings = embeddings.permute(1, 0) + for class_index in class_index_list: + hard_indices = embeddings[((masks != predictions) & (masks == class_index)).nonzero()] + easy_indices = embeddings[((masks == predictions) & (masks == class_index)).nonzero()] + + hard_indices_num, easy_indices_num = hard_indices.shape[0], easy_indices.shape[0] + selective_num_hard = min(sample_limits, hard_indices_num) + selective_num_easy = min(sample_limits, easy_indices_num) + + if (selective_num_hard + selective_num_easy) < sample_limits * 2: + if selective_num_hard > selective_num_easy: + selective_num_hard += sample_limits * 2 - selective_num_easy + else: + selective_num_easy += sample_limits * 2 - selective_num_hard + + hard_chosen_indices = torch.randperm(hard_indices_num)[:selective_num_hard] + embedding_sample_list.append(hard_indices[hard_chosen_indices]) + label_list.append(masks[hard_chosen_indices] + batch_idx * 1e3) + + easy_chosen_indices = torch.randperm(easy_indices_num)[:selective_num_easy] + embedding_sample_list.append(easy_indices[easy_chosen_indices]) + label_list.append(masks[easy_chosen_indices] + batch_idx * 1e3) + return embedding_sample_list, label_list, embedding_sample_list_a, label_list_a + + def forward_audio_visual(self, visual_embeddings, audio_embeddings, masks, predictions): + masks = masks.flatten(start_dim=1) + predictions = predictions.flatten(start_dim=1) + visual_embeddings = visual_embeddings.flatten(start_dim=-2) + + visual_embedding_sample_list = [] + visual_label_list = [] + audio_embedding_sample_list = [] + audio_label_list = [] + + for frame_idx in range(masks.shape[0]): + current_vision_feats = visual_embeddings[frame_idx] + current_masks = masks[frame_idx] + current_predictions = predictions[frame_idx] + current_audio_feats = audio_embeddings[frame_idx] + for layer_idx in range(3): + ( + selected_vision_embeddings, + selected_vision_labels, + selected_audio_embeddings, + selected_audio_labels, + ) = self.select_class_wise_samples( + current_vision_feats[layer_idx], + current_audio_feats[layer_idx], + current_predictions, + current_masks, + 0, + ) + visual_embedding_sample_list += selected_vision_embeddings + visual_label_list += selected_vision_labels + audio_embedding_sample_list += selected_audio_embeddings + audio_label_list += selected_audio_labels + + if len(visual_embedding_sample_list) == 0: + return 0.0 + + # Same as artifacts `loss/cl.py`: cat then squeeze. If only one row, squeeze drops batch dim and + # `info_nce` hits "2 vs 1" — keep at least 2D without adding a helper. + visual_embedding_sample_list = torch.cat(visual_embedding_sample_list, dim=0).squeeze() + if visual_embedding_sample_list.dim() == 1: + visual_embedding_sample_list = visual_embedding_sample_list.unsqueeze(0) + visual_label_list = torch.cat(visual_label_list, dim=0).unsqueeze(-1) + audio_embedding_sample_list = torch.cat(audio_embedding_sample_list, dim=0).squeeze() + if audio_embedding_sample_list.dim() == 1: + audio_embedding_sample_list = audio_embedding_sample_list.unsqueeze(0) + audio_label_list = torch.cat(audio_label_list).unsqueeze(1) + + total_limits = self.total_limits + if visual_embedding_sample_list.shape[0] > total_limits: + rand_index = torch.randperm(visual_embedding_sample_list.shape[0])[total_limits] + visual_embedding_sample_list = visual_embedding_sample_list[:rand_index] + visual_label_list = visual_label_list[:rand_index] + loss = self.info_nce( + visual_embedding_sample_list, + visual_label_list, + audio_embedding_sample_list, + audio_label_list, + ) + return loss + + def forward(self, embeddings, output_dicts, masks): + # Align with artifacts `loss/cl.py` forward: squeeze(1) on interp, loop over masks.shape[0], squeeze(-1) on audio. + predictions = torch.cat([i["multistep_pred_masks"] for i in output_dicts]) + predictions = torch.nn.functional.interpolate( + predictions, + size=(int(self.param.image_size / 16), int(self.param.image_size / 16)), + mode="bilinear", + align_corners=False, + ).squeeze(1) + masks = torch.nn.functional.interpolate( + masks.unsqueeze(1), + size=(int(self.param.image_size / 16), int(self.param.image_size / 16)), + mode="nearest", + ).squeeze(1) + visual_embeddings, audio_embeddings = embeddings + visual_embeddings = torch.cat( + [ + torch.cat( + [ + visual_embeddings[0][i].unsqueeze(0), + visual_embeddings[1][i].unsqueeze(0), + visual_embeddings[2][i].unsqueeze(0), + ] + ).unsqueeze(0) + for i in range(masks.shape[0]) + ] + ) + audio_embeddings = torch.cat( + [ + torch.cat( + [ + audio_embeddings[0][i].unsqueeze(0), + audio_embeddings[1][i].unsqueeze(0), + audio_embeddings[2][i].unsqueeze(0), + ] + ).unsqueeze(0) + for i in range(masks.shape[0]) + ] + ) + return self.forward_audio_visual( + visual_embeddings, audio_embeddings.squeeze(-1), masks, predictions + ) + + @staticmethod + def manipulate_cover_mask(a_label, current_mask): + a_label = a_label + 1 + visual_mask = torch.matmul(a_label, torch.transpose(a_label, 0, 1)) + current_mask[: visual_mask.shape[1], : visual_mask.shape[0]][visual_mask == 1.0] = 0 + current_mask[: visual_mask.shape[1], : visual_mask.shape[0]][visual_mask == 4.0] = 0 + return current_mask + + def info_nce(self, anchors_, a_labels_, contras_, c_labels_): + c_labels_ = torch.cat([a_labels_, c_labels_]) + contras_ = torch.cat([anchors_, contras_]) + mask = torch.eq(a_labels_, torch.transpose(c_labels_, 0, 1)).float() + + anchor_dot_contrast = torch.div( + torch.matmul(anchors_, torch.transpose(contras_, 0, 1)), + self.temperature, + ) + + logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) + logits = anchor_dot_contrast - logits_max.detach() + neg_mask = 1 - mask + + mask = self.manipulate_cover_mask(a_label=a_labels_, current_mask=mask) + mask = mask.fill_diagonal_(0.0) + + neg_logits = torch.exp(logits) * neg_mask + neg_logits = neg_logits.sum(1, keepdim=True) + exp_logits = torch.exp(logits) + log_prob = logits - torch.log(exp_logits + neg_logits) + + mask_pos_pairs = mask.sum(1) + mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs) + mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs + assert not torch.isnan(mean_log_prob_pos).any(), print(torch.isnan(log_prob).any()) + return -mean_log_prob_pos.mean() + + diff --git a/avs.code/v2.code/loss/training/sam2_training_loss.py b/avs.code/v2.code/loss/training/sam2_training_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..f6ce1b02c0dbbf5d7e771b314a4a537145e28978 --- /dev/null +++ b/avs.code/v2.code/loss/training/sam2_training_loss.py @@ -0,0 +1,220 @@ +from collections import defaultdict +from typing import Dict, List + +import torch +import torch.distributed +import torch.nn as nn +import torch.nn.functional as F + +CORE_LOSS_KEY = "core_loss" + + +def dice_loss(inputs, targets, num_objects, loss_on_multimask=False): + inputs = inputs.sigmoid() + if loss_on_multimask: + assert inputs.dim() == 4 and targets.dim() == 4 + inputs = inputs.flatten(2) + targets = targets.flatten(2) + numerator = 2 * (inputs * targets).sum(-1) + else: + inputs = inputs.flatten(1) + numerator = 2 * (inputs * targets).sum(1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + if loss_on_multimask: + return loss / num_objects + return loss.sum() / num_objects + + +def sigmoid_focal_loss( + inputs, + targets, + num_objects, + alpha: float = 0.25, + gamma: float = 2, + loss_on_multimask=False, +): + prob = inputs.sigmoid() + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + if loss_on_multimask: + assert loss.dim() == 4 + return loss.flatten(2).mean(-1) / num_objects + return loss.mean(1).sum() / num_objects + + +def iou_loss( + inputs, targets, pred_ious, num_objects, loss_on_multimask=False, use_l1_loss=False +): + assert inputs.dim() == 4 and targets.dim() == 4 + pred_mask = inputs.flatten(2) > 0 + gt_mask = targets.flatten(2) > 0 + area_i = torch.sum(pred_mask & gt_mask, dim=-1).float() + area_u = torch.sum(pred_mask | gt_mask, dim=-1).float() + actual_ious = area_i / torch.clamp(area_u, min=1.0) + + if use_l1_loss: + loss = F.l1_loss(pred_ious, actual_ious, reduction="none") + else: + loss = F.mse_loss(pred_ious, actual_ious, reduction="none") + if loss_on_multimask: + return loss / num_objects + return loss.sum() / num_objects + + +class MultiStepMultiMasksAndIous(nn.Module): + def __init__( + self, + weight_dict, + focal_alpha=0.25, + focal_gamma=2, + supervise_all_iou=False, + iou_use_l1_loss=False, + pred_obj_scores=False, + focal_gamma_obj_score=0.0, + focal_alpha_obj_score=-1, + gpu_num=1, + ): + super().__init__() + self.weight_dict = weight_dict + self.focal_alpha = focal_alpha + self.focal_gamma = focal_gamma + self.world_size = gpu_num + assert "loss_mask" in self.weight_dict + assert "loss_dice" in self.weight_dict + assert "loss_iou" in self.weight_dict + if "loss_class" not in self.weight_dict: + self.weight_dict["loss_class"] = 0.0 + + self.focal_alpha_obj_score = focal_alpha_obj_score + self.focal_gamma_obj_score = focal_gamma_obj_score + self.supervise_all_iou = supervise_all_iou + self.iou_use_l1_loss = iou_use_l1_loss + self.pred_obj_scores = pred_obj_scores + + def forward(self, outs_batch: List[Dict], targets_batch: torch.Tensor): + assert len(outs_batch) == len(targets_batch) + num_objects = torch.tensor( + targets_batch.shape[1], device=targets_batch.device, dtype=torch.float + ) + torch.distributed.all_reduce(num_objects) + num_objects = torch.clamp(num_objects / self.world_size, min=1).item() + + losses = defaultdict(int) + for outs, targets in zip(outs_batch, targets_batch): + cur_losses = self._forward(outs, targets, num_objects) + for k, v in cur_losses.items(): + losses[k] += v + return losses + + def _forward(self, outputs: Dict, targets: torch.Tensor, num_objects): + target_masks = targets.unsqueeze(1).float() + assert target_masks.dim() == 4 + + src_masks_list = outputs["multistep_pred_multimasks_high_res"] + ious_list = outputs["multistep_pred_ious"] + object_score_logits_list = outputs["multistep_object_score_logits"] + assert len(src_masks_list) == len(ious_list) + assert len(object_score_logits_list) == len(ious_list) + + losses = {"loss_mask": 0, "loss_dice": 0, "loss_iou": 0, "loss_class": 0} + for src_masks, ious, object_score_logits in zip( + src_masks_list, ious_list, object_score_logits_list + ): + self._update_losses( + losses, src_masks, target_masks, ious, num_objects, object_score_logits + ) + losses[CORE_LOSS_KEY] = self.reduce_loss(losses) + return losses + + def _update_losses( + self, losses, src_masks, target_masks, ious, num_objects, object_score_logits + ): + target_masks = target_masks.expand_as(src_masks) + loss_multimask = sigmoid_focal_loss( + src_masks, + target_masks, + num_objects, + alpha=self.focal_alpha, + gamma=self.focal_gamma, + loss_on_multimask=True, + ) + loss_multidice = dice_loss( + src_masks, target_masks, num_objects, loss_on_multimask=True + ) + if not self.pred_obj_scores: + loss_class = torch.tensor( + 0.0, dtype=loss_multimask.dtype, device=loss_multimask.device + ) + target_obj = torch.ones( + loss_multimask.shape[0], + 1, + dtype=loss_multimask.dtype, + device=loss_multimask.device, + ) + else: + target_obj = torch.any((target_masks[:, 0] > 0).flatten(1), dim=-1)[ + ..., None + ].float() + loss_class = sigmoid_focal_loss( + object_score_logits, + target_obj, + num_objects, + alpha=self.focal_alpha_obj_score, + gamma=self.focal_gamma_obj_score, + ) + + loss_multiiou = iou_loss( + src_masks, + target_masks, + ious, + num_objects, + loss_on_multimask=True, + use_l1_loss=self.iou_use_l1_loss, + ) + assert loss_multimask.dim() == 2 + assert loss_multidice.dim() == 2 + assert loss_multiiou.dim() == 2 + if loss_multimask.size(1) > 1: + loss_combo = ( + loss_multimask * self.weight_dict["loss_mask"] + + loss_multidice * self.weight_dict["loss_dice"] + ) + best_loss_inds = torch.argmin(loss_combo, dim=-1) + batch_inds = torch.arange(loss_combo.size(0), device=loss_combo.device) + + loss_mask = loss_multimask[batch_inds, best_loss_inds].unsqueeze(1) + loss_dice = loss_multidice[batch_inds, best_loss_inds].unsqueeze(1) + if self.supervise_all_iou: + loss_iou = loss_multiiou.mean(dim=-1).unsqueeze(1) + else: + loss_iou = loss_multiiou[batch_inds, best_loss_inds].unsqueeze(1) + else: + loss_mask = loss_multimask + loss_dice = loss_multidice + loss_iou = loss_multiiou + + loss_mask = loss_mask * target_obj + loss_dice = loss_dice * target_obj + loss_iou = loss_iou * target_obj + + losses["loss_mask"] += loss_mask.sum() + losses["loss_dice"] += loss_dice.sum() + losses["loss_iou"] += loss_iou.sum() + losses["loss_class"] += loss_class + + def reduce_loss(self, losses): + reduced_loss = 0.0 + for loss_key, weight in self.weight_dict.items(): + if loss_key not in losses: + raise ValueError(f"{type(self)} doesn't compute {loss_key}") + if weight != 0: + reduced_loss += losses[loss_key] * weight + return reduced_loss + diff --git a/avs.code/v2.code/main.py b/avs.code/v2.code/main.py new file mode 100644 index 0000000000000000000000000000000000000000..90151d9412514ad79281d6a0ff12b5b9c98ecb5c --- /dev/null +++ b/avs.code/v2.code/main.py @@ -0,0 +1,166 @@ +"""DDP training entry: AV model with SAM2 frozen, AuralFuser trainable, Hydra transforms and loss.""" +import os +import torch +import numpy +import random +import argparse +from easydict import EasyDict + + +def seed_it(seed): + """Fix RNGs and cuDNN for reproducible runs (rank offsets seed in DDP).""" + os.environ["PYTHONSEED"] = str(seed) + random.seed(seed) + numpy.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.enabled = True + torch.backends.cudnn.deterministic = True + + torch.backends.cudnn.benchmark = False + + +def main(local_rank, ngpus_per_node, hyp_param): + hyp_param.local_rank = local_rank + # NCCL process group; world size = GPUs on this node + torch.distributed.init_process_group( + backend='nccl', + init_method='env://', + rank=hyp_param.local_rank, + world_size=hyp_param.gpus * 1 + ) + seed_it(local_rank + hyp_param.seed) + + torch.cuda.set_device(hyp_param.local_rank) + + import model.visual.sam2 # noqa: F401 — registers Hydra `configs` (initialize_config_module) + + from hydra import compose + from hydra.utils import instantiate + from omegaconf import OmegaConf + + # Hydra configs under v1m.code/configs (same pattern as training/sam2_training_config.yaml) + transform_config_path = 'training/sam2_training_config.yaml' + + if 'hiera_t' in hyp_param.sam_config_path: + hyp_param.image_size = 224 + hyp_param.image_embedding_size = int(hyp_param.image_size / 16) + print('\n upload image size to be {}x{} \n'.format(224, 224), flush=True) + + cfg = compose(config_name=transform_config_path) + OmegaConf.resolve(cfg) + hyp_param.contrastive_learning = OmegaConf.to_container(cfg.contrastive_learning, resolve=True) + + arch_h = compose(config_name='auralfuser/architecture.yaml') + OmegaConf.resolve(arch_h) + hyp_param.aural_fuser = OmegaConf.to_container(arch_h.aural_fuser, resolve=True) + + from model.mymodel import AVmodel + av_model = AVmodel(hyp_param).cuda(hyp_param.local_rank) + + av_model = torch.nn.parallel.distributed.DistributedDataParallel(av_model, device_ids=[hyp_param.local_rank], + find_unused_parameters=True) + + # Optimizer: parameter groups from AuralFuser only (train_* vs VGG backbone) + from utils.utils import manipulate_params + parameter_list = manipulate_params(hyp_param, av_model.module.aural_fuser) + optimiser = torch.optim.AdamW(parameter_list, betas=(0.9, 0.999)) + + from dataloader.dataset import AV + from dataloader.visual.visual_augmentation import Augmentation as VisualAugmentation + from dataloader.audio.audio_augmentation import Augmentation as AudioAugmentation + from torch.utils.data.distributed import DistributedSampler + + compose_api = instantiate(cfg.train_transforms, _recursive_=True)[0] + + audio_augmentation = AudioAugmentation(mono=True) + train_dataset = AV(split='train', augmentation={"visual": compose_api, "audio": audio_augmentation}, + param=hyp_param, root_path=hyp_param.data_root_path) + + + visual_augmentation = VisualAugmentation(hyp_param.image_mean, hyp_param.image_std, + hyp_param.image_size, hyp_param.image_size, + hyp_param.scale_list, ignore_index=hyp_param.ignore_index) + + audio_augmentation = AudioAugmentation(mono=True) + + random_sampler = DistributedSampler(train_dataset, shuffle=True) + train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=hyp_param.batch_size, + sampler=random_sampler, + num_workers=hyp_param.num_workers, drop_last=True) + + test_dataset = AV(split='test', augmentation={"visual": visual_augmentation, "audio": audio_augmentation}, + param=hyp_param, root_path=hyp_param.data_root_path) + + order_sampler = DistributedSampler(test_dataset, shuffle=False) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, sampler=order_sampler, + num_workers=hyp_param.num_workers) + + + criterion = instantiate(cfg.loss, _recursive_=True)['all'] + from utils.tensorboard import Tensorboard + tensorboard = Tensorboard(config=hyp_param) if hyp_param.local_rank <= 0 else None + + from trainer.train import Trainer + from utils.foreground_iou import ForegroundIoU + from utils.foreground_fscore import ForegroundFScore + metrics = {"foreground_iou": ForegroundIoU(), "foreground_f-score": ForegroundFScore(0 if hyp_param.local_rank <= 0 else hyp_param.local_rank)} + + trainer = Trainer(hyp_param, loss=criterion, tensorboard=tensorboard, metrics=metrics) + + + curr_best = 0. # checkpoint when IoU (iou_select mode) improves + + for epoch in range(hyp_param.epochs): + av_model.train() + av_model.module.freeze_sam_parameters() + random_sampler.set_epoch(epoch) + trainer.train(epoch=epoch, dataloader=train_dataloader, model=av_model, optimiser=optimiser) + + torch.distributed.barrier() + torch.cuda.empty_cache() + + av_model.eval() + # Three validation modes: default first mask / IoU-selected mask / IoU + objectness gate + curr_results1, _ = trainer.valid(epoch=epoch, dataloader=test_dataloader, model=av_model, process='first_index') + curr_results, _ = trainer.valid(epoch=epoch, dataloader=test_dataloader, model=av_model, process='iou_select') + curr_results3, _ = trainer.valid(epoch=epoch, dataloader=test_dataloader, model=av_model, process='iou_occ_select') + if hyp_param.local_rank <= 0 and curr_results > curr_best: + curr_best = curr_results + torch.save(av_model.module.aural_fuser.state_dict(), os.path.join(hyp_param.saved_dir, str(curr_results) + ".pth")) + torch.distributed.barrier() + torch.cuda.empty_cache() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='PyTorch Training') + parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N') + + parser.add_argument("--local_rank", type=int, default=-1, + help='multi-process training for DDP') + + parser.add_argument('-g', '--gpus', default=1, type=int, + help='number of gpus per node') + + parser.add_argument('--batch_size', default=1, type=int) + + parser.add_argument('--epochs', default=80, type=int, + help="total epochs that used for the training") + + parser.add_argument('--lr', default=1e-4, type=float, + help='Default HEAD Learning rate is same as others, ' + '*Note: in ddp training, lr will automatically times by n_gpu') + + parser.add_argument('--online', action="store_true", + help='switch on for visualization; switch off for debug') + + args = parser.parse_args() + + from configs.config import C + + args = EasyDict({**C, **vars(args)}) + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '9902' + + torch.multiprocessing.spawn(main, nprocs=args.gpus, args=(args.gpus, args)) diff --git a/avs.code/v2.code/model/audio/torchvggish/mel_features.py b/avs.code/v2.code/model/audio/torchvggish/mel_features.py new file mode 100644 index 0000000000000000000000000000000000000000..ac58fb5427f772fcced9cbd3cec3373ffbe5908c --- /dev/null +++ b/avs.code/v2.code/model/audio/torchvggish/mel_features.py @@ -0,0 +1,223 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Defines routines to compute mel spectrogram features from audio waveform.""" + +import numpy as np + + +def frame(data, window_length, hop_length): + """Convert array into a sequence of successive possibly overlapping frames. + + An n-dimensional array of shape (num_samples, ...) is converted into an + (n+1)-D array of shape (num_frames, window_length, ...), where each frame + starts hop_length points after the preceding one. + + This is accomplished using stride_tricks, so the original data is not + copied. However, there is no zero-padding, so any incomplete frames at the + end are not included. + + Args: + data: np.array of dimension N >= 1. + window_length: Number of samples in each frame. + hop_length: Advance (in samples) between each window. + + Returns: + (N+1)-D np.array with as many rows as there are complete frames that can be + extracted. + """ + num_samples = data.shape[0] + num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length)) + shape = (num_frames, window_length) + data.shape[1:] + strides = (data.strides[0] * hop_length,) + data.strides + return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides) + + +def periodic_hann(window_length): + """Calculate a "periodic" Hann window. + + The classic Hann window is defined as a raised cosine that starts and + ends on zero, and where every value appears twice, except the middle + point for an odd-length window. Matlab calls this a "symmetric" window + and np.hanning() returns it. However, for Fourier analysis, this + actually represents just over one cycle of a period N-1 cosine, and + thus is not compactly expressed on a length-N Fourier basis. Instead, + it's better to use a raised cosine that ends just before the final + zero value - i.e. a complete cycle of a period-N cosine. Matlab + calls this a "periodic" window. This routine calculates it. + + Args: + window_length: The number of points in the returned window. + + Returns: + A 1D np.array containing the periodic hann window. + """ + return 0.5 - (0.5 * np.cos(2 * np.pi / window_length * + np.arange(window_length))) + + +def stft_magnitude(signal, fft_length, + hop_length=None, + window_length=None): + """Calculate the short-time Fourier transform magnitude. + + Args: + signal: 1D np.array of the input time-domain signal. + fft_length: Size of the FFT to apply. + hop_length: Advance (in samples) between each frame passed to FFT. + window_length: Length of each block of samples to pass to FFT. + + Returns: + 2D np.array where each row contains the magnitudes of the fft_length/2+1 + unique values of the FFT for the corresponding frame of input samples. + """ + frames = frame(signal, window_length, hop_length) + # Apply frame window to each frame. We use a periodic Hann (cosine of period + # window_length) instead of the symmetric Hann of np.hanning (period + # window_length-1). + window = periodic_hann(window_length) + windowed_frames = frames * window + return np.abs(np.fft.rfft(windowed_frames, int(fft_length))) + + +# Mel spectrum constants and functions. +_MEL_BREAK_FREQUENCY_HERTZ = 700.0 +_MEL_HIGH_FREQUENCY_Q = 1127.0 + + +def hertz_to_mel(frequencies_hertz): + """Convert frequencies to mel scale using HTK formula. + + Args: + frequencies_hertz: Scalar or np.array of frequencies in hertz. + + Returns: + Object of same size as frequencies_hertz containing corresponding values + on the mel scale. + """ + return _MEL_HIGH_FREQUENCY_Q * np.log( + 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ)) + + +def spectrogram_to_mel_matrix(num_mel_bins=20, + num_spectrogram_bins=129, + audio_sample_rate=8000, + lower_edge_hertz=125.0, + upper_edge_hertz=3800.0): + """Return a matrix that can post-multiply spectrogram rows to make mel. + + Returns a np.array matrix A that can be used to post-multiply a matrix S of + spectrogram values (STFT magnitudes) arranged as frames x bins to generate a + "mel spectrogram" M of frames x num_mel_bins. M = S A. + + The classic HTK algorithm exploits the complementarity of adjacent mel bands + to multiply each FFT bin by only one mel weight, then add it, with positive + and negative signs, to the two adjacent mel bands to which that bin + contributes. Here, by expressing this operation as a matrix multiply, we go + from num_fft multiplies per frame (plus around 2*num_fft adds) to around + num_fft^2 multiplies and adds. However, because these are all presumably + accomplished in a single call to np.dot(), it's not clear which approach is + faster in Python. The matrix multiplication has the attraction of being more + general and flexible, and much easier to read. + + Args: + num_mel_bins: How many bands in the resulting mel spectrum. This is + the number of columns in the output matrix. + num_spectrogram_bins: How many bins there are in the source spectrogram + data, which is understood to be fft_size/2 + 1, i.e. the spectrogram + only contains the nonredundant FFT bins. + audio_sample_rate: Samples per second of the audio at the input to the + spectrogram. We need this to figure out the actual frequencies for + each spectrogram bin, which dictates how they are mapped into mel. + lower_edge_hertz: Lower bound on the frequencies to be included in the mel + spectrum. This corresponds to the lower edge of the lowest triangular + band. + upper_edge_hertz: The desired top edge of the highest frequency band. + + Returns: + An np.array with shape (num_spectrogram_bins, num_mel_bins). + + Raises: + ValueError: if frequency edges are incorrectly ordered or out of range. + """ + nyquist_hertz = audio_sample_rate / 2. + if lower_edge_hertz < 0.0: + raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz) + if lower_edge_hertz >= upper_edge_hertz: + raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" % + (lower_edge_hertz, upper_edge_hertz)) + if upper_edge_hertz > nyquist_hertz: + raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" % + (upper_edge_hertz, nyquist_hertz)) + spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins) + spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz) + # The i'th mel band (starting from i=1) has center frequency + # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge + # band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in + # the band_edges_mel arrays. + band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz), + hertz_to_mel(upper_edge_hertz), num_mel_bins + 2) + # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins + # of spectrogram values. + mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins)) + for i in range(num_mel_bins): + lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3] + # Calculate lower and upper slopes for every spectrogram bin. + # Line segments are linear in the *mel* domain, not hertz. + lower_slope = ((spectrogram_bins_mel - lower_edge_mel) / + (center_mel - lower_edge_mel)) + upper_slope = ((upper_edge_mel - spectrogram_bins_mel) / + (upper_edge_mel - center_mel)) + # .. then intersect them with each other and zero. + mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope, + upper_slope)) + # HTK excludes the spectrogram DC bin; make sure it always gets a zero + # coefficient. + mel_weights_matrix[0, :] = 0.0 + return mel_weights_matrix + + +def log_mel_spectrogram(data, + audio_sample_rate=8000, + log_offset=0.0, + window_length_secs=0.025, + hop_length_secs=0.010, + **kwargs): + """Convert waveform to a log magnitude mel-frequency spectrogram. + + Args: + data: 1D np.array of waveform data. + audio_sample_rate: The sampling rate of data. + log_offset: Add this to values when taking log to avoid -Infs. + window_length_secs: Duration of each window to analyze. + hop_length_secs: Advance between successive analysis windows. + **kwargs: Additional arguments to pass to spectrogram_to_mel_matrix. + + Returns: + 2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank + magnitudes for successive frames. + """ + window_length_samples = int(round(audio_sample_rate * window_length_secs)) + hop_length_samples = int(round(audio_sample_rate * hop_length_secs)) + fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0))) + spectrogram = stft_magnitude( + data, + fft_length=fft_length, + hop_length=hop_length_samples, + window_length=window_length_samples) + mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix( + num_spectrogram_bins=spectrogram.shape[1], + audio_sample_rate=audio_sample_rate, **kwargs)) + return np.log(mel_spectrogram + log_offset) diff --git a/avs.code/v2.code/model/audio/torchvggish/vggish.py b/avs.code/v2.code/model/audio/torchvggish/vggish.py new file mode 100644 index 0000000000000000000000000000000000000000..f01c22867c713bfd8713eee5665120b92602761d --- /dev/null +++ b/avs.code/v2.code/model/audio/torchvggish/vggish.py @@ -0,0 +1,193 @@ +import numpy as np +import torch +import torch.nn as nn +from torch import hub + +from . import vggish_input, vggish_params + + +class VGG(nn.Module): + def __init__(self, features): + super(VGG, self).__init__() + self.features = features + self.embeddings = nn.Sequential( + nn.Linear(512 * 4 * 6, 4096), + nn.ReLU(True), + nn.Linear(4096, 4096), + nn.ReLU(True), + nn.Linear(4096, 128), + nn.ReLU(True)) + + def forward(self, x): + x = self.features(x) + + # Transpose the output from features to + # remain compatible with vggish embeddings + x = torch.transpose(x, 1, 3) + x = torch.transpose(x, 1, 2) + x = x.contiguous() + x = x.view(x.size(0), -1) + + return self.embeddings(x) + + +class Postprocessor(nn.Module): + """Post-processes VGGish embeddings. Returns a torch.Tensor instead of a + numpy array in order to preserve the gradient. + + "The initial release of AudioSet included 128-D VGGish embeddings for each + segment of AudioSet. These released embeddings were produced by applying + a PCA transformation (technically, a whitening transform is included as well) + and 8-bit quantization to the raw embedding output from VGGish, in order to + stay compatible with the YouTube-8M project which provides visual embeddings + in the same format for a large set of YouTube videos. This class implements + the same PCA (with whitening) and quantization transformations." + """ + + def __init__(self): + """Constructs a postprocessor.""" + super(Postprocessor, self).__init__() + # Create empty matrix, for user's state_dict to load + self.pca_eigen_vectors = torch.empty( + (vggish_params.EMBEDDING_SIZE, vggish_params.EMBEDDING_SIZE,), + dtype=torch.float, + ) + self.pca_means = torch.empty( + (vggish_params.EMBEDDING_SIZE, 1), dtype=torch.float + ) + + self.pca_eigen_vectors = nn.Parameter(self.pca_eigen_vectors, requires_grad=False) + self.pca_means = nn.Parameter(self.pca_means, requires_grad=False) + + def postprocess(self, embeddings_batch): + """Applies tensor postprocessing to a batch of embeddings. + + Args: + embeddings_batch: An tensor of shape [batch_size, embedding_size] + containing output from the embedding layer of VGGish. + + Returns: + A tensor of the same shape as the input, containing the PCA-transformed, + quantized, and clipped version of the input. + """ + assert len(embeddings_batch.shape) == 2, "Expected 2-d batch, got %r" % ( + embeddings_batch.shape, + ) + assert ( + embeddings_batch.shape[1] == vggish_params.EMBEDDING_SIZE + ), "Bad batch shape: %r" % (embeddings_batch.shape,) + + # Apply PCA. + # - Embeddings come in as [batch_size, embedding_size]. + # - Transpose to [embedding_size, batch_size]. + # - Subtract pca_means column vector from each column. + # - Premultiply by PCA matrix of shape [output_dims, input_dims] + # where both are are equal to embedding_size in our case. + # - Transpose result back to [batch_size, embedding_size]. + pca_applied = torch.mm(self.pca_eigen_vectors, (embeddings_batch.t() - self.pca_means)).t() + + # Quantize by: + # - clipping to [min, max] range + clipped_embeddings = torch.clamp( + pca_applied, vggish_params.QUANTIZE_MIN_VAL, vggish_params.QUANTIZE_MAX_VAL + ) + # - convert to 8-bit in range [0.0, 255.0] + quantized_embeddings = torch.round( + (clipped_embeddings - vggish_params.QUANTIZE_MIN_VAL) + * ( + 255.0 + / (vggish_params.QUANTIZE_MAX_VAL - vggish_params.QUANTIZE_MIN_VAL) + ) + ) + return torch.squeeze(quantized_embeddings) + + def forward(self, x): + return self.postprocess(x) + + +def make_layers(): + layers = [] + in_channels = 1 + for v in [64, "M", 128, "M", 256, 256, "M", 512, 512, "M"]: + if v == "M": + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = v + return nn.Sequential(*layers) + + +def _vgg(): + return VGG(make_layers()) + + +# def _spectrogram(): +# config = dict( +# sr=16000, +# n_fft=400, +# n_mels=64, +# hop_length=160, +# window="hann", +# center=False, +# pad_mode="reflect", +# htk=True, +# fmin=125, +# fmax=7500, +# output_format='Magnitude', +# # device=device, +# ) +# return Spectrogram.MelSpectrogram(**config) + + +class VGGish(VGG): + def __init__(self, cfg, device=None): + super().__init__(make_layers()) + if cfg.FREEZE_AUDIO_EXTRACTOR: + state_dict = torch.load(cfg.PRETRAINED_VGGISH_MODEL_PATH) + super().load_state_dict(state_dict) + print(f'==> Load pretrained VGGish parameters from {cfg.PRETRAINED_VGGISH_MODEL_PATH}') + + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print("device: ", device) + self.device = device + + self.preprocess = cfg.PREPROCESS_AUDIO_TO_LOG_MEL + self.postprocess = cfg.POSTPROCESS_LOG_MEL_WITH_PCA + if self.postprocess: + self.pproc = Postprocessor() + if cfg.FREEZE_AUDIO_EXTRACTOR: + state_dict = torch.load(cfg.PRETRAINED_PCA_PARAMS_PATH) + # TODO: Convert the state_dict to torch + state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME] = torch.as_tensor( + state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME], dtype=torch.float + ) + state_dict[vggish_params.PCA_MEANS_NAME] = torch.as_tensor( + state_dict[vggish_params.PCA_MEANS_NAME].reshape(-1, 1), dtype=torch.float + ) + self.pproc.load_state_dict(state_dict) + self.to(self.device) + + def forward(self, x): + if self.preprocess: + print(">>> pre processing...") + x = self._preprocess(x) + x = x.to(self.device) + x = VGG.forward(self, x) + if self.postprocess: + print(">>> post processing...") + x = self._postprocess(x) + return x + + def _preprocess(self, x): + # if isinstance(x, np.ndarray): + # x = vggish_input.waveform_to_examples(x, fs) + if isinstance(x, str): + x = vggish_input.wavfile_to_examples(x) + else: + raise AttributeError + return x + + def _postprocess(self, x): + return self.pproc(x) diff --git a/avs.code/v2.code/model/audio/torchvggish/vggish_input.py b/avs.code/v2.code/model/audio/torchvggish/vggish_input.py new file mode 100644 index 0000000000000000000000000000000000000000..ede228b1fb630180f1f49244355d373fb3300f03 --- /dev/null +++ b/avs.code/v2.code/model/audio/torchvggish/vggish_input.py @@ -0,0 +1,98 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Compute input examples for VGGish from audio waveform.""" + +# Modification: Return torch tensors rather than numpy arrays +import torch + +import numpy as np +import resampy + +from . import mel_features +from . import vggish_params + +import soundfile as sf + + +def waveform_to_examples(data, sample_rate, return_tensor=True): + """Converts audio waveform into an array of examples for VGGish. + + Args: + data: np.array of either one dimension (mono) or two dimensions + (multi-channel, with the outer dimension representing channels). + Each sample is generally expected to lie in the range [-1.0, +1.0], + although this is not required. + sample_rate: Sample rate of data. + return_tensor: Return data as a Pytorch tensor ready for VGGish + + Returns: + 3-D np.array of shape [num_examples, num_frames, num_bands] which represents + a sequence of examples, each of which contains a patch of log mel + spectrogram, covering num_frames frames of audio and num_bands mel frequency + bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS. + + """ + # Convert to mono. + if len(data.shape) > 1: + data = np.mean(data, axis=1) + # Resample to the rate assumed by VGGish. + if sample_rate != vggish_params.SAMPLE_RATE: + data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE) + + # Compute log mel spectrogram features. + log_mel = mel_features.log_mel_spectrogram( + data, + audio_sample_rate=vggish_params.SAMPLE_RATE, + log_offset=vggish_params.LOG_OFFSET, + window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS, + hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS, + num_mel_bins=vggish_params.NUM_MEL_BINS, + lower_edge_hertz=vggish_params.MEL_MIN_HZ, + upper_edge_hertz=vggish_params.MEL_MAX_HZ) + + # Frame features into examples. + features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS + example_window_length = int(round( + vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate)) + example_hop_length = int(round( + vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate)) + log_mel_examples = mel_features.frame( + log_mel, + window_length=example_window_length, + hop_length=example_hop_length) + + if return_tensor: + log_mel_examples = torch.tensor( + log_mel_examples, requires_grad=True)[:, None, :, :].float() + + return log_mel_examples + + +def wavfile_to_examples(wav_file, return_tensor=True): + """Convenience wrapper around waveform_to_examples() for a common WAV format. + + Args: + wav_file: String path to a file, or a file-like object. The file + is assumed to contain WAV audio data with signed 16-bit PCM samples. + torch: Return data as a Pytorch tensor ready for VGGish + + Returns: + See waveform_to_examples. + """ + wav_data, sr = sf.read(wav_file, dtype='int16') + assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype + samples = wav_data / 32768.0 # Convert to [-1.0, +1.0] + return waveform_to_examples(samples, sr, return_tensor) diff --git a/avs.code/v2.code/model/audio/torchvggish/vggish_params.py b/avs.code/v2.code/model/audio/torchvggish/vggish_params.py new file mode 100644 index 0000000000000000000000000000000000000000..526784bceaa4c9c8b8dc2b8f82e0f3d395d4bec2 --- /dev/null +++ b/avs.code/v2.code/model/audio/torchvggish/vggish_params.py @@ -0,0 +1,53 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Global parameters for the VGGish model. + +See vggish_slim.py for more information. +""" + +# Architectural constants. +NUM_FRAMES = 96 # Frames in input mel-spectrogram patch. +NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch. +EMBEDDING_SIZE = 128 # Size of embedding layer. + +# Hyperparameters used in feature and example generation. +SAMPLE_RATE = 16000 +STFT_WINDOW_LENGTH_SECONDS = 0.025 +STFT_HOP_LENGTH_SECONDS = 0.010 +NUM_MEL_BINS = NUM_BANDS +MEL_MIN_HZ = 125 +MEL_MAX_HZ = 7500 +LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram. +EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames +EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap. + +# Parameters used for embedding postprocessing. +PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors' +PCA_MEANS_NAME = 'pca_means' +QUANTIZE_MIN_VAL = -2.0 +QUANTIZE_MAX_VAL = +2.0 + +# Hyperparameters used in training. +INIT_STDDEV = 0.01 # Standard deviation used to initialize weights. +LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer. +ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer. + +# Names of ops, tensors, and features. +INPUT_OP_NAME = 'vggish/input_features' +INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0' +OUTPUT_OP_NAME = 'vggish/embedding' +OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0' +AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding' diff --git a/avs.code/v2.code/model/aural_fuser.py b/avs.code/v2.code/model/aural_fuser.py new file mode 100644 index 0000000000000000000000000000000000000000..924810bfcf8bee5e285cab7d54e477daf254b85a --- /dev/null +++ b/avs.code/v2.code/model/aural_fuser.py @@ -0,0 +1,567 @@ +import math + +import torch +import torch.nn as nn +from model.audio.torchvggish import vggish +from timm.models.layers import DropPath, trunc_normal_ + +from model.visual.sam2.modeling.position_encoding import PositionEmbeddingSine + + +class ProjectionHead(nn.Module): + def __init__(self, dim_in, proj_dim=256, norm_act=nn.BatchNorm2d, conv_layer=nn.Conv2d): + super().__init__() + self.proj = nn.Sequential( + conv_layer(dim_in, proj_dim, kernel_size=1), + norm_act(proj_dim), + conv_layer(proj_dim, proj_dim, kernel_size=1), + ) + + def forward(self, x): + return torch.nn.functional.normalize(self.proj(x), p=2, dim=1) + +class AuralFuser(torch.nn.Module): + """Fuses VGGish audio with SAM2 FPN maps via patch embeds, fusion blocks, and projection heads.""" + + def __init__(self, hyp_param): + self.hyp_param = hyp_param + super().__init__() + self.vgg = vggish.VGGish(self.hyp_param.audio) + if not getattr(self.hyp_param, "train_vggish", False): + for p in self.vgg.parameters(): + p.requires_grad = False + + self.position_encoding_func = PositionEmbeddingSine(num_pos_feats=256, normalize=True, scale=None, + temperature=10000) + + # Populated in main.py / inference.py via Hydra compose('auralfuser/architecture.yaml') → hyp_param.aural_fuser + if not hasattr(self.hyp_param, "aural_fuser") or self.hyp_param.aural_fuser is None: + raise ValueError( + "hyp_param.aural_fuser is missing; load it with Hydra compose before constructing AuralFuser." + ) + arch_cfg = self.hyp_param.aural_fuser + + _patch_cfgs = [tuple(i) for i in arch_cfg["patch_cfgs"]] + _f_depths = arch_cfg["f_depths"] + _block_kw = dict(arch_cfg["block_kw"]) + _block_kw["norm_layer"] = nn.LayerNorm + _one_d_kw = dict(arch_cfg["one_d_kw"]) + _one_d_kw["norm_layer"] = nn.LayerNorm + self.patch_embeds = nn.ModuleList( + nn.Conv2d(256, 256, kernel_size=k, stride=s) for k, s in _patch_cfgs + ) + + self.f_blocks = nn.ModuleList( + nn.ModuleList([Block(**_block_kw) for _ in range(n)]) for n in _f_depths + ) + + self.a_blocks = nn.ModuleList( + nn.ModuleList([OneDBlock(**_one_d_kw) for _ in range(3)]) for _ in range(3) + ) + + self.fusion_modules = nn.ModuleList( + AudioVisualFusionModule(in_channels=256, mode='dot') for _ in range(3) + ) + self.smooth_convs = nn.ModuleList( + nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0) for _ in range(2) + ) + + self.train_proj_v1 = ProjectionHead(dim_in=256, proj_dim=128) + + self.train_proj_a1 = ProjectionHead(dim_in=256, norm_act=nn.BatchNorm1d, conv_layer=nn.Conv1d, proj_dim=128) + + @staticmethod + def positionalencoding1d(d_model, length): + if d_model % 2 != 0: + raise ValueError("Cannot use sin/cos positional encoding with " + "odd dim (got dim={:d})".format(d_model)) + pe = torch.zeros(length, d_model) + position = torch.arange(0, length).unsqueeze(1) + div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) * + -(math.log(10000.0) / d_model))) + pe[:, 0::2] = torch.sin(position.float() * div_term) + pe[:, 1::2] = torch.cos(position.float() * div_term) + + return pe + + def forward(self, feature_dicts, spect=None): + image_embed_shape = [self.hyp_param.image_embedding_size] * 2 + H, W = image_embed_shape[0], image_embed_shape[1] + d = torch.cat( + [ + self.vgg(spect[:, 0, ...].unsqueeze(1)), + self.vgg(spect[:, 1, ...].unsqueeze(1)), + ], + dim=-1, + ) + length = d.shape[-1] + fix_audio_pos = self.positionalencoding1d(length, 1).squeeze().to(spect.device) + fpn = list(feature_dicts["backbone_fpn"]) + patch_embeds = list(self.patch_embeds) + f_blocks = list(self.f_blocks) + a_blocks = list(self.a_blocks) + tpavi = list(self.fusion_modules) + smooths = [None, self.smooth_convs[0], self.smooth_convs[1]] + + feats = [None, None, None] + d_outputs = [] + + for i in range(3): + x = fpn[i] + x = patch_embeds[i](x) + x_pos = self.position_encoding_func(x) + x = x.flatten(2).permute(0, 2, 1) + x_pos = x_pos.flatten(2).permute(0, 2, 1) + + if i == 0: + x = x + x_pos + d = d + fix_audio_pos + else: + x = x + feats[i - 1] + x = smooths[i]( + x.permute(0, 2, 1).reshape(x.shape[0], 256, H, W) + ).flatten(2).permute(0, 2, 1) + x = x + x_pos + d = d + fix_audio_pos + + for blks in f_blocks[i]: + x = blks(x, H, W, x_pos) + for blks in a_blocks[i]: + d = blks(d, fix_audio_pos) + + x = x + x_pos + d = d + fix_audio_pos + x, d_out, _, _ = tpavi[i](x, H, W, x_pos, d, length) + d = d_out + feats[i] = x + d_outputs.append(d_out) + + a, b, c = feats + d1, d2, d3 = d_outputs + + feature_residual = [a, b, c] + audio_out = [d1, d2, d3] + + proj_feature_out = [ + [ + self.train_proj_v1(a.permute(0, 2, 1).reshape(-1, 256, *image_embed_shape)), + self.train_proj_v1(b.permute(0, 2, 1).reshape(-1, 256, *image_embed_shape)), + self.train_proj_v1(c.permute(0, 2, 1).reshape(-1, 256, *image_embed_shape)), + ], + [ + self.train_proj_a1(d1.unsqueeze(-1)), + self.train_proj_a1(d2.unsqueeze(-1)), + self.train_proj_a1(d3.unsqueeze(-1)), + ], + ] + + return feature_residual, audio_out, proj_feature_out + + +class AudioVisualFusionModule(nn.Module): + def __init__(self, in_channels, inter_channels=None, mode='dot', + dimension=3): + super().__init__() + assert mode == 'dot' + self.mode = mode + self.dimension = dimension + + self.in_channels = in_channels + self.inter_channels = in_channels // 2 + + self.align_channel = nn.Conv1d(256, in_channels, kernel_size=1) + self.align_channel_back = nn.Conv1d(in_channels, 128, kernel_size=1) + + self.norm_layer = nn.LayerNorm(in_channels) + + if dimension == 3: + conv_nd = nn.Conv3d + bn = nn.BatchNorm3d + elif dimension == 2: + conv_nd = nn.Conv2d + bn = nn.BatchNorm2d + else: + conv_nd = nn.Conv1d + bn = nn.BatchNorm1d + + self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) + + self.W_z = nn.Sequential( + conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1), + bn(self.in_channels) + ) + nn.init.constant_(self.W_z[1].weight, 0) + nn.init.constant_(self.W_z[1].bias, 0) + + self.W_z2 = nn.Sequential( + nn.Conv1d(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1), + nn.BatchNorm1d(self.in_channels) + ) + nn.init.constant_(self.W_z2[1].weight, 0) + nn.init.constant_(self.W_z2[1].bias, 0) + self.norm_layer2 = nn.LayerNorm(self.in_channels) + + self.q_frame = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) + self.k_frame = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) + self.v_frame = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) + + self.q_audio = nn.Conv1d(self.in_channels, self.inter_channels, kernel_size=1) + self.k_audio = nn.Conv1d(self.in_channels, self.inter_channels, kernel_size=1) + self.v_audio = nn.Conv1d(self.in_channels, self.inter_channels, kernel_size=1) + + def forward(self, frame, H_x, W_x, tmp1, audio, tmp2): + frame = frame.permute(0, 2, 1) + frame = frame.reshape(frame.shape[0], frame.shape[1], H_x, W_x) + frame = frame.unsqueeze(2) + audio = self.align_channel(audio.unsqueeze(-1)) + + batch_size, _ = frame.size(0), frame.size(1) + q_frame = self.q_frame(frame).reshape(1, -1, self.inter_channels) + k_frame = self.k_frame(frame).reshape(1, -1, self.inter_channels) + v_frame = self.v_frame(frame).reshape(1, -1, self.inter_channels) + q_audio = self.q_audio(audio).reshape(1, -1, self.inter_channels) + k_audio = self.k_audio(audio).reshape(1, -1, self.inter_channels) + v_audio = self.v_audio(audio).reshape(1, -1, self.inter_channels) + f = torch.matmul(q_frame, k_audio.mT) + f_normalise = f / f.size(1) + + frame_attn = torch.matmul(f_normalise, v_audio) + + frame_attn = frame_attn.permute(0, 2, 1).contiguous() + frame_attn = frame_attn.view(batch_size, self.inter_channels, *frame.size()[2:]) + frame_attn = self.W_z(frame_attn) + frame = frame_attn + frame + + frame = frame.permute(0, 2, 3, 4, 1) + frame = self.norm_layer(frame) + frame = frame.permute(0, 4, 1, 2, 3) + frame = frame.squeeze().flatten(start_dim=2).permute(0, 2, 1) + + a = torch.matmul(q_audio, k_frame.mT) + a_normalise = a / a.size(-1) + + audio_attn = torch.matmul(a_normalise, v_frame) + audio_attn = audio_attn.permute(0, 2, 1).contiguous() + + audio_attn = audio_attn.view(batch_size, self.inter_channels).unsqueeze(-1) + audio_attn = self.W_z2(audio_attn) + + audio = audio_attn + audio + + audio = self.norm_layer2(audio.squeeze()).squeeze() + + return frame, audio, frame_attn, audio_attn + + +class OneDBlock(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = OneDAttention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = OneDMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, + linear=linear) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, _pos): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class OneDAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, + linear=False): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.linear = linear + self.sr_ratio = sr_ratio + if not linear: + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.pool = nn.AdaptiveAvgPool2d(7) + self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1) + self.norm = nn.LayerNorm(dim) + self.act = nn.GELU() + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = x.unsqueeze(0) + + B, N, C = x.shape + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + k, v = kv[0], kv[1] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + x = x.squeeze() + return x + + +class OneDMlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + self.linear = linear + + if self.linear: + self.relu = nn.ReLU(inplace=True) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = self.fc1(x) + if self.linear: + x = self.relu(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, linear=linear) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W, _pos): + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) + + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, + linear=False): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.linear = linear + self.sr_ratio = sr_ratio + if not linear: + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.pool = nn.AdaptiveAvgPool2d(7) + self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1) + self.norm = nn.LayerNorm(dim) + self.act = nn.GELU() + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + B, N, C = x.shape + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + if not self.linear: + if self.sr_ratio > 1: + x_ = x.permute(0, 2, 1).reshape(B, C, H, W) + x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) + x_ = self.norm(x_) + kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + else: + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + else: + x_ = x.permute(0, 2, 1).reshape(B, C, H, W) + x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1) + x_ = self.norm(x_) + x_ = self.act(x_) + kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + self.linear = linear + + if self.linear: + self.relu = nn.ReLU(inplace=True) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + if self.linear: + x = self.relu(x) + x = self.dwconv(x, H, W) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2) + return x diff --git a/avs.code/v2.code/model/mymodel.py b/avs.code/v2.code/model/mymodel.py new file mode 100644 index 0000000000000000000000000000000000000000..35194cd584a4786f713447829592b15c7a366095 --- /dev/null +++ b/avs.code/v2.code/model/mymodel.py @@ -0,0 +1,102 @@ +import logging + +from typing import List, Optional, Tuple, Union + +import numpy +import numpy as np +import torch +from PIL.Image import Image + +from model.visual.sam2.modeling.sam2_base import SAM2Base + +from model.visual.sam2.modeling.backbones.hieradet import Hiera +from model.visual.sam2.modeling.backbones.image_encoder import FpnNeck +from model.visual.sam2.modeling.backbones.image_encoder import ImageEncoder +from model.visual.sam2.modeling.position_encoding import PositionEmbeddingSine + +from model.visual.sam2.modeling.memory_attention import MemoryAttention +from model.visual.sam2.modeling.memory_attention import MemoryAttentionLayer +from model.visual.sam2.modeling.sam.transformer import RoPEAttention +from model.visual.sam2.modeling.memory_encoder import MemoryEncoder +from model.visual.sam2.modeling.memory_encoder import MaskDownSampler +from model.visual.sam2.modeling.memory_encoder import Fuser +from model.visual.sam2.modeling.memory_encoder import CXBlock + +from model.visual.sam2.utils.transforms import SAM2Transforms +from model.visual.sam2.modeling.backbones.hieradet import do_pool +from model.visual.sam2.modeling.backbones.utils import ( + PatchEmbed, + window_partition, + window_unpartition, +) + + +class AVmodel(torch.nn.Module): + """End-to-end AV segmentation: SAM2 visual backbone + AuralFuser audio-visual fusion + tracking head.""" + + def __init__(self, param, mask_threshold=0.0, max_hole_area=0.0, max_sprinkle_area=0.0, ): + super().__init__() + self.param = param + self.mask_threshold = mask_threshold + self._bb_feat_sizes = [(int(self.param.image_size / 4), int(self.param.image_size / 4)), + (int(self.param.image_size / 8), int(self.param.image_size / 8)), + (int(self.param.image_size / 16), int(self.param.image_size / 16))] + + from model.visual.sam2.build_sam import build_sam2_visual_predictor + self.v_model = build_sam2_visual_predictor(self.param.sam_config_path, self.param.backbone_weight, + apply_postprocessing=True, mode='train') + self._transforms = SAM2Transforms( + resolution=self.v_model.image_size, + mask_threshold=mask_threshold, + max_hole_area=max_hole_area, + max_sprinkle_area=max_sprinkle_area, + ) + from model.aural_fuser import AuralFuser + self.aural_fuser = AuralFuser(hyp_param=self.param) + + + + def _prepare_backbone_features(self, backbone_out): + """Prepare and flatten visual features.""" + backbone_out = backbone_out.copy() + assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) + assert len(backbone_out["backbone_fpn"]) >= self.v_model.num_feature_levels + + feature_maps = backbone_out["backbone_fpn"][-self.v_model.num_feature_levels:] + vision_pos_embeds = backbone_out["vision_pos_enc"][-self.v_model.num_feature_levels:] + + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] + vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] + vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] + + return backbone_out, vision_feats, vision_pos_embeds, feat_sizes + + def forward_frame(self, frame_): + frame = torch.nn.functional.interpolate(frame_, (self.param.image_size, self.param.image_size), + antialias=True, align_corners=False, mode='bilinear') + return self.v_model.image_encoder(frame) + + def forward(self, frames, spect, prompts, sam_process=False): + """Fuse audio into FPN features, then run SAM2 tracking. `sam_process` is reserved for prompt path.""" + backbone_feats = self.v_model.forward_image(frames, pre_compute=False) + audio_residual_feats = self.aural_fuser(backbone_feats, spect) + visual_resfeats, audio_resfeats, proj_feats = audio_residual_feats + + map_res = visual_resfeats[::-1] + vec_res = audio_resfeats[::-1] + + av_feats = (map_res, vec_res) + backbone_feats = self.v_model.precompute_high_res_features(backbone_feats) + backbone_feats = self.v_model.dont_prepare_prompt_inputs(backbone_feats, num_frames=frames.shape[0], + cond_frame=int(frames.shape[0]/2) if self.training else 0) + outputs = self.v_model.forward_tracking_wo_prompt(backbone_feats, audio_res=av_feats) + return outputs, proj_feats + + @property + def device(self) -> torch.device: + return self.v_model.device + + def freeze_sam_parameters(self): + self.v_model.eval() + for name, parameter in self.v_model.named_parameters(): + parameter.requires_grad = False diff --git a/avs.code/v2.code/model/visual/sam2/__init__.py b/avs.code/v2.code/model/visual/sam2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..46a1cecc55b6fd02a5ce6c66d9cc8a77343156db --- /dev/null +++ b/avs.code/v2.code/model/visual/sam2/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from hydra import initialize_config_module +from hydra.core.global_hydra import GlobalHydra + +if not GlobalHydra.instance().is_initialized(): + initialize_config_module("configs", version_base="1.2") diff --git a/avs.code/v2.code/model/visual/sam2/build_sam.py b/avs.code/v2.code/model/visual/sam2/build_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..69f68c2e672d35d925aeb496cac918c1ee913dde --- /dev/null +++ b/avs.code/v2.code/model/visual/sam2/build_sam.py @@ -0,0 +1,171 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os + +import torch +from hydra import compose +from hydra.utils import instantiate +from omegaconf import OmegaConf +''' +import sam2 + +# Check if the user is running Python from the parent directory of the sam2 repo +# (i.e. the directory where this repo is cloned into) -- this is not supported since +# it could shadow the sam2 package and cause issues. +if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")): + # If the user has "sam2/sam2" in their path, they are likey importing the repo itself + # as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory). + # This typically happens because the user is running Python from the parent directory + # that contains the sam2 repo they cloned. + raise RuntimeError( + "You're likely running Python from the parent directory of the sam2 repository " + "(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). " + "This is not supported since the `sam2` Python package could be shadowed by the " + "repository name (the repository is also named `sam2` and contains the Python package " + "in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir " + "rather than its parent dir, or from your home directory) after installing SAM 2." + ) +''' + +HF_MODEL_ID_TO_FILENAMES = { + "facebook/sam2-hiera-tiny": ( + "sam2/sam2_hiera_t.yaml", + "sam2_hiera_tiny.pt", + ), + "facebook/sam2-hiera-small": ( + "sam2/sam2_hiera_s.yaml", + "sam2_hiera_small.pt", + ), + "facebook/sam2-hiera-base-plus": ( + "sam2/sam2_hiera_b+.yaml", + "sam2_hiera_base_plus.pt", + ), + "facebook/sam2-hiera-large": ( + "sam2/sam2_hiera_l.yaml", + "sam2_hiera_large.pt", + ), + "facebook/sam2.1-hiera-tiny": ( + "sam2.1/sam2.1_hiera_t.yaml", + "sam2.1_hiera_tiny.pt", + ), + "facebook/sam2.1-hiera-small": ( + "sam2.1/sam2.1_hiera_s.yaml", + "sam2.1_hiera_small.pt", + ), + "facebook/sam2.1-hiera-base-plus": ( + "sam2.1/sam2.1_hiera_b+.yaml", + "sam2.1_hiera_base_plus.pt", + ), + "facebook/sam2.1-hiera-large": ( + "sam2.1/sam2.1_hiera_l.yaml", + "sam2.1_hiera_large.pt", + ), +} + + +def build_sam2( + config_file, + ckpt_path=None, + device="cuda", + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, + **kwargs, +): + + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + # dynamically fall back to multi-mask if the single mask is not stable + "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + ] + # Read config and init model + cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) + OmegaConf.resolve(cfg) + model = instantiate(cfg.model, _recursive_=True) + _load_checkpoint(model, ckpt_path) + model = model.to(device) + if mode == "eval": + model.eval() + return model + + +def build_sam2_visual_predictor( + config_file, + ckpt_path=None, + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, + **kwargs, +): + # visual + hydra_overrides = [] + # "++model._target_=model.visual.sam2.organised_sam2_train.SAM2Train", + # ] + # hydra_overrides = [ + # "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", + # ] + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + + # dynamically fall back to multi-mask if the single mask is not stable + # "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + # "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + # "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + + # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking + "++model.binarize_mask_from_pts_for_mem_enc=true", + # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) + # "++model.fill_hole_area=8", + ] + hydra_overrides.extend(hydra_overrides_extra) + + # Read config and init model + cfg = compose(config_name=config_file, overrides=hydra_overrides) + OmegaConf.resolve(cfg) + model = instantiate(cfg.model, _recursive_=True) + _load_checkpoint(model, ckpt_path) + if mode == "eval": + model.eval() + return model + + +def _hf_download(model_id): + from huggingface_hub import hf_hub_download + + config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id] + ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) + return config_name, ckpt_path + + +def build_sam2_hf(model_id, **kwargs): + config_name, ckpt_path = _hf_download(model_id) + return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs) + + +# def build_sam2_video_predictor_hf(model_id, **kwargs): +# config_name, ckpt_path = _hf_download(model_id) +# return build_sam2_video_predictor( +# config_file=config_name, ckpt_path=ckpt_path, **kwargs +# ) + + +def _load_checkpoint(model, ckpt_path): + if ckpt_path is not None: + sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"] + missing_keys, unexpected_keys = model.load_state_dict(sd) + if missing_keys: + logging.error(missing_keys) + raise RuntimeError() + if unexpected_keys: + logging.error(unexpected_keys) + raise RuntimeError() + logging.info("Loaded checkpoint sucessfully") diff --git a/avs.code/v2.code/model/visual/sam2/modeling/__init__.py b/avs.code/v2.code/model/visual/sam2/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/avs.code/v2.code/model/visual/sam2/modeling/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/avs.code/v2.code/model/visual/sam2/modeling/backbones/__init__.py b/avs.code/v2.code/model/visual/sam2/modeling/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/avs.code/v2.code/model/visual/sam2/modeling/backbones/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/avs.code/v2.code/model/visual/sam2/modeling/backbones/hieradet.py b/avs.code/v2.code/model/visual/sam2/modeling/backbones/hieradet.py new file mode 100644 index 0000000000000000000000000000000000000000..3fb6633c9c752cbefe2fc6043c81fb79bc659465 --- /dev/null +++ b/avs.code/v2.code/model/visual/sam2/modeling/backbones/hieradet.py @@ -0,0 +1,317 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from functools import partial +from typing import List, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from iopath.common.file_io import g_pathmgr + +from model.visual.sam2.modeling.backbones.utils import ( + PatchEmbed, + window_partition, + window_unpartition, +) + +from model.visual.sam2.modeling.sam2_utils import DropPath, MLP + + +def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: + if pool is None: + return x + # (B, H, W, C) -> (B, C, H, W) + x = x.permute(0, 3, 1, 2) + x = pool(x) + # (B, C, H', W') -> (B, H', W', C) + x = x.permute(0, 2, 3, 1) + if norm: + x = norm(x) + + return x + + +class MultiScaleAttention(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + q_pool: nn.Module = None, + ): + super().__init__() + + self.dim = dim + self.dim_out = dim_out + self.num_heads = num_heads + self.q_pool = q_pool + self.qkv = nn.Linear(dim, dim_out * 3) + self.proj = nn.Linear(dim_out, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (B, H * W, 3, nHead, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) + # q, k, v with shape (B, H * W, nheads, C) + q, k, v = torch.unbind(qkv, 2) + + # Q pooling (for downsample at stage changes) + if self.q_pool: + q = do_pool(q.reshape(B, H, W, -1), self.q_pool) + H, W = q.shape[1:3] # downsampled shape + q = q.reshape(B, H * W, self.num_heads, -1) + + # Torch's SDPA expects [B, nheads, H*W, C] so we transpose + x = F.scaled_dot_product_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + ) + # Transpose back + x = x.transpose(1, 2) + x = x.reshape(B, H, W, -1) + + x = self.proj(x) + + return x + + +class MultiScaleBlock(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + mlp_ratio: float = 4.0, + drop_path: float = 0.0, + norm_layer: Union[nn.Module, str] = "LayerNorm", + q_stride: Tuple[int, int] = None, + act_layer: nn.Module = nn.GELU, + window_size: int = 0, + ): + super().__init__() + + if isinstance(norm_layer, str): + norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) + + self.dim = dim + self.dim_out = dim_out + self.norm1 = norm_layer(dim) + + self.window_size = window_size + + self.pool, self.q_stride = None, q_stride + if self.q_stride: + self.pool = nn.MaxPool2d( + kernel_size=q_stride, stride=q_stride, ceil_mode=False + ) + + self.attn = MultiScaleAttention( + dim, + dim_out, + num_heads=num_heads, + q_pool=self.pool, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim_out) + self.mlp = MLP( + dim_out, + int(dim_out * mlp_ratio), + dim_out, + num_layers=2, + activation=act_layer, + ) + + if dim != dim_out: + self.proj = nn.Linear(dim, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x # B, H, W, C + x = self.norm1(x) + + # Skip connection + if self.dim != self.dim_out: + shortcut = do_pool(self.proj(x), self.pool) + + # Window partition + window_size = self.window_size + if window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, window_size) + + # Window Attention + Q Pooling (if stage change) + x = self.attn(x) + if self.q_stride: + # Shapes have changed due to Q pooling + window_size = self.window_size // self.q_stride[0] + H, W = shortcut.shape[1:3] + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + pad_hw = (H + pad_h, W + pad_w) + + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, window_size, pad_hw, (H, W)) + + x = shortcut + self.drop_path(x) + # MLP + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Hiera(nn.Module): + """ + Reference: https://arxiv.org/abs/2306.00989 + """ + + def __init__( + self, + embed_dim: int = 96, # initial embed dim + num_heads: int = 1, # initial number of heads + drop_path_rate: float = 0.0, # stochastic depth + q_pool: int = 3, # number of q_pool stages + q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages + stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage + dim_mul: float = 2.0, # dim_mul factor at stage shift + head_mul: float = 2.0, # head_mul factor at stage shift + window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), + # window size per stage, when not using global att. + window_spec: Tuple[int, ...] = ( + 8, + 4, + 14, + 7, + ), + # global attn in these blocks + global_att_blocks: Tuple[int, ...] = ( + 12, + 16, + 20, + ), + weights_path=None, + return_interm_layers=True, # return feats from every stage + ): + super().__init__() + + assert len(stages) == len(window_spec) + self.window_spec = window_spec + + depth = sum(stages) + self.q_stride = q_stride + self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] + assert 0 <= q_pool <= len(self.stage_ends[:-1]) + self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] + self.return_interm_layers = return_interm_layers + + self.patch_embed = PatchEmbed( + embed_dim=embed_dim, + ) + # Which blocks have global att? + self.global_att_blocks = global_att_blocks + + # Windowed positional embedding (https://arxiv.org/abs/2311.05613) + self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size + self.pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size) + ) + self.pos_embed_window = nn.Parameter( + torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]) + ) + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + cur_stage = 1 + self.blocks = nn.ModuleList() + + for i in range(depth): + dim_out = embed_dim + # lags by a block, so first block of + # next stage uses an initial window size + # of previous stage and final window size of current stage + window_size = self.window_spec[cur_stage - 1] + + if self.global_att_blocks is not None: + window_size = 0 if i in self.global_att_blocks else window_size + + if i - 1 in self.stage_ends: + dim_out = int(embed_dim * dim_mul) + num_heads = int(num_heads * head_mul) + cur_stage += 1 + + block = MultiScaleBlock( + dim=embed_dim, + dim_out=dim_out, + num_heads=num_heads, + drop_path=dpr[i], + q_stride=self.q_stride if i in self.q_pool_blocks else None, + window_size=window_size, + ) + + embed_dim = dim_out + self.blocks.append(block) + + self.channel_list = ( + [self.blocks[i].dim_out for i in self.stage_ends[::-1]] + if return_interm_layers + else [self.blocks[-1].dim_out] + ) + + if weights_path is not None: + with g_pathmgr.open(weights_path, "rb") as f: + chkpt = torch.load(f, map_location="cpu") + logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False)) + + def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: + h, w = hw + window_embed = self.pos_embed_window + pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") + pos_embed = pos_embed + window_embed.tile( + [x // y for x, y in zip(pos_embed.shape, window_embed.shape)] + ) + pos_embed = pos_embed.permute(0, 2, 3, 1) + return pos_embed + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + x = self.patch_embed(x) + # x: (B, H, W, C) + + # Add pos embed + x = x + self._get_pos_embed(x.shape[1:3]) + + outputs = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if (i == self.stage_ends[-1]) or ( + i in self.stage_ends and self.return_interm_layers + ): + feats = x.permute(0, 3, 1, 2) + outputs.append(feats) + + return outputs + + def get_layer_id(self, layer_name): + # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 + num_layers = self.get_num_layers() + + if layer_name.find("rel_pos") != -1: + return num_layers + 1 + elif layer_name.find("pos_embed") != -1: + return 0 + elif layer_name.find("patch_embed") != -1: + return 0 + elif layer_name.find("blocks") != -1: + return int(layer_name.split("blocks")[1].split(".")[1]) + 1 + else: + return num_layers + 1 + + def get_num_layers(self) -> int: + return len(self.blocks) diff --git a/avs.code/v2.code/model/visual/sam2/modeling/backbones/image_encoder.py b/avs.code/v2.code/model/visual/sam2/modeling/backbones/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..37e9266bc98596e97ca303118c910ed24f6cee2c --- /dev/null +++ b/avs.code/v2.code/model/visual/sam2/modeling/backbones/image_encoder.py @@ -0,0 +1,134 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ImageEncoder(nn.Module): + def __init__( + self, + trunk: nn.Module, + neck: nn.Module, + scalp: int = 0, + ): + super().__init__() + self.trunk = trunk + self.neck = neck + self.scalp = scalp + assert ( + self.trunk.channel_list == self.neck.backbone_channel_list + ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" + + def forward(self, sample: torch.Tensor): + # Forward through backbone + features, pos = self.neck(self.trunk(sample)) + if self.scalp > 0: + # Discard the lowest resolution features + features, pos = features[: -self.scalp], pos[: -self.scalp] + + src = features[-1] + output = { + "vision_features": src, + "vision_pos_enc": pos, + "backbone_fpn": features, + } + return output + + +class FpnNeck(nn.Module): + """ + A modified variant of Feature Pyramid Network (FPN) neck + (we remove output conv and also do bicubic interpolation similar to ViT + pos embed interpolation) + """ + + def __init__( + self, + position_encoding: nn.Module, + d_model: int, + backbone_channel_list: List[int], + kernel_size: int = 1, + stride: int = 1, + padding: int = 0, + fpn_interp_model: str = "bilinear", + fuse_type: str = "sum", + fpn_top_down_levels: Optional[List[int]] = None, + ): + """Initialize the neck + :param trunk: the backbone + :param position_encoding: the positional encoding to use + :param d_model: the dimension of the model + :param neck_norm: the normalization to use + """ + super().__init__() + self.position_encoding = position_encoding + self.convs = nn.ModuleList() + self.backbone_channel_list = backbone_channel_list + self.d_model = d_model + for dim in backbone_channel_list: + current = nn.Sequential() + current.add_module( + "conv", + nn.Conv2d( + in_channels=dim, + out_channels=d_model, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ), + ) + + self.convs.append(current) + self.fpn_interp_model = fpn_interp_model + assert fuse_type in ["sum", "avg"] + self.fuse_type = fuse_type + + # levels to have top-down features in its outputs + # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 + # have top-down propagation, while outputs of level 0 and level 1 have only + # lateral features from the same backbone level. + if fpn_top_down_levels is None: + # default is to have top-down features on all levels + fpn_top_down_levels = range(len(self.convs)) + self.fpn_top_down_levels = list(fpn_top_down_levels) + + def forward(self, xs: List[torch.Tensor]): + + out = [None] * len(self.convs) + pos = [None] * len(self.convs) + assert len(xs) == len(self.convs) + # fpn forward pass + # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py + prev_features = None + # forward in top-down order (from low to high resolution) + n = len(self.convs) - 1 + for i in range(n, -1, -1): + x = xs[i] + lateral_features = self.convs[n - i](x) + if i in self.fpn_top_down_levels and prev_features is not None: + top_down_features = F.interpolate( + prev_features.to(dtype=torch.float32), + scale_factor=2.0, + mode=self.fpn_interp_model, + align_corners=( + None if self.fpn_interp_model == "nearest" else False + ), + antialias=False, + ) + prev_features = lateral_features + top_down_features + if self.fuse_type == "avg": + prev_features /= 2 + else: + prev_features = lateral_features + x_out = prev_features + out[i] = x_out + pos[i] = self.position_encoding(x_out).to(x_out.dtype) + + return out, pos diff --git a/avs.code/v2.code/model/visual/sam2/modeling/backbones/utils.py b/avs.code/v2.code/model/visual/sam2/modeling/backbones/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..32d55c7545f064de133a5ff0200ba1ece9b504b7 --- /dev/null +++ b/avs.code/v2.code/model/visual/sam2/modeling/backbones/utils.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Some utilities for backbones, in particular for windowing""" + +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def window_partition(x, window_size): + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows, (Hp, Wp) + + +def window_unpartition(windows, window_size, pad_hw, hw): + """ + Window unpartition into original sequences and removing padding. + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view( + B, Hp // window_size, Wp // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, ...] = (7, 7), + stride: Tuple[int, ...] = (4, 4), + padding: Tuple[int, ...] = (3, 3), + in_chans: int = 3, + embed_dim: int = 768, + ): + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/avs.code/v2.code/model/visual/sam2/modeling/memory_attention.py b/avs.code/v2.code/model/visual/sam2/modeling/memory_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..11f4ccb1904f022c18f8a02b9590a66bd57bb8f1 --- /dev/null +++ b/avs.code/v2.code/model/visual/sam2/modeling/memory_attention.py @@ -0,0 +1,169 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from torch import nn, Tensor + +from model.visual.sam2.modeling.sam.transformer import RoPEAttention + +from model.visual.sam2.modeling.sam2_utils import get_activation_fn, get_clones + + +class MemoryAttentionLayer(nn.Module): + + def __init__( + self, + activation: str, + cross_attention: nn.Module, + d_model: int, + dim_feedforward: int, + dropout: float, + pos_enc_at_attn: bool, + pos_enc_at_cross_attn_keys: bool, + pos_enc_at_cross_attn_queries: bool, + self_attention: nn.Module, + ): + super().__init__() + self.d_model = d_model + self.dim_feedforward = dim_feedforward + self.dropout_value = dropout + self.self_attn = self_attention + self.cross_attn_image = cross_attention + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation_str = activation + self.activation = get_activation_fn(activation) + + # Where to add pos enc + self.pos_enc_at_attn = pos_enc_at_attn + self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries + self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys + + def _forward_sa(self, tgt, query_pos): + # Self-Attention + tgt2 = self.norm1(tgt) + q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 + tgt2 = self.self_attn(q, k, v=tgt2) + tgt = tgt + self.dropout1(tgt2) + return tgt + + def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): + kwds = {} + if num_k_exclude_rope > 0: + assert isinstance(self.cross_attn_image, RoPEAttention) + kwds = {"num_k_exclude_rope": num_k_exclude_rope} + + # Cross-Attention + tgt2 = self.norm2(tgt) + tgt2 = self.cross_attn_image( + q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, + k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, + v=memory, + **kwds, + ) + tgt = tgt + self.dropout2(tgt2) + return tgt + + def forward( + self, + tgt, + memory, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + num_k_exclude_rope: int = 0, + ) -> torch.Tensor: + + # Self-Attn, Cross-Attn + tgt = self._forward_sa(tgt, query_pos) + tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) + # MLP + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + +class MemoryAttention(nn.Module): + def __init__( + self, + d_model: int, + pos_enc_at_input: bool, + layer: nn.Module, + num_layers: int, + batch_first: bool = True, # Do layers expect batch first input? + ): + super().__init__() + self.d_model = d_model + self.layers = get_clones(layer, num_layers) + self.num_layers = num_layers + self.norm = nn.LayerNorm(d_model) + self.pos_enc_at_input = pos_enc_at_input + self.batch_first = batch_first + + def forward( + self, + curr: torch.Tensor, # self-attention inputs + memory: torch.Tensor, # cross-attention inputs + curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs + memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs + num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* + ): + if isinstance(curr, list): + assert isinstance(curr_pos, list) + assert len(curr) == len(curr_pos) == 1 + curr, curr_pos = ( + curr[0], + curr_pos[0], + ) + + assert ( + curr.shape[1] == memory.shape[1] + ), "Batch size must be the same for curr and memory" + + output = curr + if self.pos_enc_at_input and curr_pos is not None: + output = output + 0.1 * curr_pos + + if self.batch_first: + # Convert to batch first + output = output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + memory = memory.transpose(0, 1) + memory_pos = memory_pos.transpose(0, 1) + + for layer in self.layers: + kwds = {} + if isinstance(layer.cross_attn_image, RoPEAttention): + kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} + + output = layer( + tgt=output, + memory=memory, + pos=memory_pos, + query_pos=curr_pos, + **kwds, + ) + normed_output = self.norm(output) + + if self.batch_first: + # Convert back to seq first + normed_output = normed_output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + + return normed_output diff --git a/avs.code/v2.code/model/visual/sam2/modeling/memory_encoder.py b/avs.code/v2.code/model/visual/sam2/modeling/memory_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..7e1143cc0d5774ff96108203e404f678f14b0a23 --- /dev/null +++ b/avs.code/v2.code/model/visual/sam2/modeling/memory_encoder.py @@ -0,0 +1,181 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model.visual.sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d + + +class MaskDownSampler(nn.Module): + """ + Progressively downsample a mask by total_stride, each time by stride. + Note that LayerNorm is applied per *token*, like in ViT. + + With each downsample (by a factor stride**2), channel capacity increases by the same factor. + In the end, we linearly project to embed_dim channels. + """ + + def __init__( + self, + embed_dim=256, + kernel_size=4, + stride=4, + padding=0, + total_stride=16, + activation=nn.GELU, + ): + super().__init__() + num_layers = int(math.log2(total_stride) // math.log2(stride)) + assert stride**num_layers == total_stride + self.encoder = nn.Sequential() + mask_in_chans, mask_out_chans = 1, 1 + for _ in range(num_layers): + mask_out_chans = mask_in_chans * (stride**2) + self.encoder.append( + nn.Conv2d( + mask_in_chans, + mask_out_chans, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + ) + self.encoder.append(LayerNorm2d(mask_out_chans)) + self.encoder.append(activation()) + mask_in_chans = mask_out_chans + + self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) + + def forward(self, x): + return self.encoder(x) + + +# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) +class CXBlock(nn.Module): + r"""ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__( + self, + dim, + kernel_size=7, + padding=3, + drop_path=0.0, + layer_scale_init_value=1e-6, + use_dwconv=True, + ): + super().__init__() + self.dwconv = nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=padding, + groups=dim if use_dwconv else 1, + ) # depthwise conv + self.norm = LayerNorm2d(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, 4 * dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = self.norm(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class Fuser(nn.Module): + def __init__(self, layer, num_layers, dim=None, input_projection=False): + super().__init__() + self.proj = nn.Identity() + self.layers = get_clones(layer, num_layers) + + if input_projection: + assert dim is not None + self.proj = nn.Conv2d(dim, dim, kernel_size=1) + + def forward(self, x): + # normally x: (N, C, H, W) + x = self.proj(x) + for layer in self.layers: + x = layer(x) + return x + + +class MemoryEncoder(nn.Module): + def __init__( + self, + out_dim, + mask_downsampler, + fuser, + position_encoding, + in_dim=256, # in_dim of pix_feats + ): + super().__init__() + + self.mask_downsampler = mask_downsampler + + self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) + self.fuser = fuser + self.position_encoding = position_encoding + self.out_proj = nn.Identity() + if out_dim != in_dim: + self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) + + def forward( + self, + pix_feat: torch.Tensor, + masks: torch.Tensor, + skip_mask_sigmoid: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + ## Process masks + # sigmoid, so that less domain shift from gt masks which are bool + if not skip_mask_sigmoid: + masks = F.sigmoid(masks) + masks = self.mask_downsampler(masks) + + ## Fuse pix_feats and downsampled masks + # in case the visual features are on CPU, cast them to CUDA + pix_feat = pix_feat.to(masks.device) + + x = self.pix_feat_proj(pix_feat) + x = x + masks + x = self.fuser(x) + x = self.out_proj(x) + + pos = self.position_encoding(x).to(x.dtype) + + return {"vision_features": x, "vision_pos_enc": [pos]} diff --git a/avs.code/v2.code/model/visual/sam2/modeling/position_encoding.py b/avs.code/v2.code/model/visual/sam2/modeling/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..52ac22674d5d4fdd9e83b6bdf034bff56d04bc0d --- /dev/null +++ b/avs.code/v2.code/model/visual/sam2/modeling/position_encoding.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Any, Optional, Tuple + +import numpy as np + +import torch +from torch import nn + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention Is All You Need paper, generalized to work on images. + """ + + def __init__( + self, + num_pos_feats, + temperature: int = 10000, + normalize: bool = True, + scale: Optional[float] = None, + ): + super().__init__() + assert num_pos_feats % 2 == 0, "Expecting even model width" + self.num_pos_feats = num_pos_feats // 2 + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + self.cache = {} + + def _encode_xy(self, x, y): + # The positions are expected to be normalized + assert len(x) == len(y) and x.ndim == y.ndim == 1 + x_embed = x * self.scale + y_embed = y * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, None] / dim_t + pos_y = y_embed[:, None] / dim_t + pos_x = torch.stack( + (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2 + ).flatten(1) + pos_y = torch.stack( + (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2 + ).flatten(1) + return pos_x, pos_y + + @torch.no_grad() + def encode_boxes(self, x, y, w, h): + pos_x, pos_y = self._encode_xy(x, y) + pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) + return pos + + encode = encode_boxes # Backwards compatibility + + @torch.no_grad() + def encode_points(self, x, y, labels): + (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape + assert bx == by and nx == ny and bx == bl and nx == nl + pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) + pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) + pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) + return pos + + @torch.no_grad() + def forward(self, x: torch.Tensor): + cache_key = (x.shape[-2], x.shape[-1]) + if cache_key in self.cache: + return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) + y_embed = ( + torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) + .view(1, -1, 1) + .repeat(x.shape[0], 1, x.shape[-1]) + ) + x_embed = ( + torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) + .view(1, 1, -1) + .repeat(x.shape[0], x.shape[-2], 1) + ) + + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + self.cache[cache_key] = pos[0] + return pos + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C + + +# Rotary Positional Encoding, adapted from: +# 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py +# 2. https://github.com/naver-ai/rope-vit +# 3. https://github.com/lucidrains/rotary-embedding-torch + + +def init_t_xy(end_x: int, end_y: int): + t = torch.arange(end_x * end_y, dtype=torch.float32) + t_x = (t % end_x).float() + t_y = torch.div(t, end_x, rounding_mode="floor").float() + return t_x, t_y + + +def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): + freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + + t_x, t_y = init_t_xy(end_x, end_y) + freqs_x = torch.outer(t_x, freqs_x) + freqs_y = torch.outer(t_y, freqs_y) + freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) + freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) + return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) + shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_enc( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, + repeat_freqs_k: bool = False, +): + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = ( + torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + if xk.shape[-2] != 0 + else None + ) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + if xk_ is None: + # no keys to rotate, due to dropout + return xq_out.type_as(xq).to(xq.device), xk + # repeat freqs along seq_len dim to match k seq_len + if repeat_freqs_k: + r = xk_.shape[-2] // xq_.shape[-2] + if freqs_cis.is_cuda: + freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) + else: + # torch.repeat on complex numbers may not be supported on non-CUDA devices + # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten + freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) diff --git a/avs.code/v2.code/model/visual/sam2/modeling/sam/__init__.py b/avs.code/v2.code/model/visual/sam2/modeling/sam/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/avs.code/v2.code/model/visual/sam2/modeling/sam/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/avs.code/v2.code/model/visual/sam2/modeling/sam/mask_decoder.py b/avs.code/v2.code/model/visual/sam2/modeling/sam/mask_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..078f21cc2ec41805eebec677e6e27771335deaa4 --- /dev/null +++ b/avs.code/v2.code/model/visual/sam2/modeling/sam/mask_decoder.py @@ -0,0 +1,300 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn + +from model.visual.sam2.modeling.sam2_utils import LayerNorm2d, MLP + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + use_high_res_features: bool = False, + iou_prediction_use_sigmoid=False, + dynamic_multimask_via_stability=False, + dynamic_multimask_stability_delta=0.05, + dynamic_multimask_stability_thresh=0.98, + pred_obj_scores: bool = False, + pred_obj_scores_mlp: bool = False, + use_multimask_token_for_obj_ptr: bool = False, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.pred_obj_scores = pred_obj_scores + if self.pred_obj_scores: + self.obj_score_token = nn.Embedding(1, transformer_dim) + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d( + transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 + ), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d( + transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 + ), + activation(), + ) + self.use_high_res_features = use_high_res_features + if use_high_res_features: + self.conv_s0 = nn.Conv2d( + transformer_dim, transformer_dim // 8, kernel_size=1, stride=1 + ) + self.conv_s1 = nn.Conv2d( + transformer_dim, transformer_dim // 4, kernel_size=1, stride=1 + ) + + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, + iou_head_hidden_dim, + self.num_mask_tokens, + iou_head_depth, + sigmoid_output=iou_prediction_use_sigmoid, + ) + if self.pred_obj_scores: + self.pred_obj_score_head = nn.Linear(transformer_dim, 1) + if pred_obj_scores_mlp: + self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3) + + # When outputting a single mask, optionally we can dynamically fall back to the best + # multimask output token if the single mask output token gives low stability scores. + self.dynamic_multimask_via_stability = dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + audio_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + torch.Tensor: batched SAM token for mask output + """ + masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + repeat_image=repeat_image, + high_res_features=high_res_features, + audio_res_features_=audio_res_features + ) + + # Select the correct mask or masks for output + if multimask_output: + masks = masks[:, 1:, :, :] + iou_pred = iou_pred[:, 1:] + elif self.dynamic_multimask_via_stability and not self.training: + masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) + else: + masks = masks[:, 0:1, :, :] + iou_pred = iou_pred[:, 0:1] + + + if multimask_output and self.use_multimask_token_for_obj_ptr: + sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape + else: + # Take the mask output token. Here we *always* use the token for single mask output. + # At test time, even if we track after 1-click (and using multimask_output=True), + # we still take the single mask token here. The rationale is that we always track + # after multiple clicks during training, so the past tokens seen during training + # are always the single mask token (and we'll let it be the object-memory token). + sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape + + # Prepare output + return masks, iou_pred, sam_tokens_out, object_score_logits + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + audio_res_features_: Optional[List[torch.Tensor]] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + s = 0 + if self.pred_obj_scores: + output_tokens = torch.cat( + [ + self.obj_score_token.weight, + self.iou_token.weight, + self.mask_tokens.weight, + ], + dim=0, + ) + s = 1 + else: + output_tokens = torch.cat( + [self.iou_token.weight, self.mask_tokens.weight], dim=0 + ) + output_tokens = output_tokens.unsqueeze(0).expand( + sparse_prompt_embeddings.size(0), -1, -1 + ) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + if repeat_image: + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + else: + assert image_embeddings.shape[0] == tokens.shape[0] + src = image_embeddings + src = src + dense_prompt_embeddings + assert ( + image_pe.size(0) == 1 + ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens, audio_res_features_) + iou_token_out = hs[:, s, :] + mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + + if not self.use_high_res_features: + upscaled_embedding = self.output_upscaling(src) + else: + dc1, ln1, act1, dc2, act2 = self.output_upscaling + feat_s0, feat_s1 = high_res_features + upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) + upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) + + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append( + self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) + ) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + if self.pred_obj_scores: + assert s == 1 + object_score_logits = self.pred_obj_score_head(hs[:, 0, :]) + else: + # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1 + object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1) + + return masks, iou_pred, mask_tokens_out, object_score_logits + + def _get_stability_scores(self, mask_logits): + """ + Compute stability scores of the mask logits based on the IoU between upper and + lower thresholds. + """ + mask_logits = mask_logits.flatten(-2) + stability_delta = self.dynamic_multimask_stability_delta + area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() + area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() + stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) + return stability_scores + + def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): + """ + When outputting a single mask, if the stability score from the current single-mask + output (based on output token 0) falls below a threshold, we instead select from + multi-mask outputs (based on output token 1~3) the mask with the highest predicted + IoU score. This is intended to ensure a valid mask for both clicking and tracking. + """ + # The best mask from multimask output tokens (1~3) + multimask_logits = all_mask_logits[:, 1:, :, :] + multimask_iou_scores = all_iou_scores[:, 1:] + best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) + batch_inds = torch.arange( + multimask_iou_scores.size(0), device=all_iou_scores.device + ) + best_multimask_logits = multimask_logits[batch_inds, best_scores_inds] + best_multimask_logits = best_multimask_logits.unsqueeze(1) + best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds] + best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1) + + # The mask from singlemask output token 0 and its stability score + singlemask_logits = all_mask_logits[:, 0:1, :, :] + singlemask_iou_scores = all_iou_scores[:, 0:1] + stability_scores = self._get_stability_scores(singlemask_logits) + is_stable = stability_scores >= self.dynamic_multimask_stability_thresh + + # Dynamically fall back to best multimask output upon low stability scores. + mask_logits_out = torch.where( + is_stable[..., None, None].expand_as(singlemask_logits), + singlemask_logits, + best_multimask_logits, + ) + iou_scores_out = torch.where( + is_stable.expand_as(singlemask_iou_scores), + singlemask_iou_scores, + best_multimask_iou_scores, + ) + return mask_logits_out, iou_scores_out diff --git a/avs.code/v2.code/model/visual/sam2/modeling/sam/prompt_encoder.py b/avs.code/v2.code/model/visual/sam2/modeling/sam/prompt_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..038cebcc072ae7c0f3f83061061be3edba04d0f8 --- /dev/null +++ b/avs.code/v2.code/model/visual/sam2/modeling/sam/prompt_encoder.py @@ -0,0 +1,188 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple, Type + +import torch +from torch import nn + +from model.visual.sam2.modeling.position_encoding import PositionEmbeddingRandom + +from model.visual.sam2.modeling.sam2_utils import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [ + nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) + ] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = ( + 4 * image_embedding_size[0], + 4 * image_embedding_size[1], + ) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords( + points, self.input_image_size + ) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + point_embedding[labels == 2] += self.point_embeddings[2].weight + point_embedding[labels == 3] += self.point_embeddings[3].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords( + coords, self.input_image_size + ) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + # we only utilise sounding as prompt. + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty( + (bs, 0, self.embed_dim), device=self._get_device() + ) + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + ''' + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + ''' + return sparse_embeddings, dense_embeddings + diff --git a/avs.code/v2.code/model/visual/sam2/modeling/sam/transformer.py b/avs.code/v2.code/model/visual/sam2/modeling/sam/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..31916550afeccb66f4427cee7ec4a7a2d66913a5 --- /dev/null +++ b/avs.code/v2.code/model/visual/sam2/modeling/sam/transformer.py @@ -0,0 +1,367 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import math +import warnings +from functools import partial +from typing import Tuple, Type + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from model.visual.sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis +from model.visual.sam2.modeling.sam2_utils import MLP +from model.visual.sam2.utils.misc import get_sdpa_settings + +warnings.simplefilter(action="ignore", category=FutureWarning) +# Check whether Flash Attention is available (and use it by default) +OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings() +# A fallback setting to allow all available kernels if Flash Attention fails +ALLOW_ALL_KERNELS = False + + +def sdp_kernel_context(dropout_p): + """ + Get the context for the attention scaled dot-product kernel. We use Flash Attention + by default, but fall back to all available kernels if Flash Attention fails. + """ + if ALLOW_ALL_KERNELS: + return contextlib.nullcontext() + + return torch.backends.cuda.sdp_kernel( + enable_flash=USE_FLASH_ATTN, + # if Flash attention kernel is off, then math kernel needs to be enabled + enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON, + enable_mem_efficient=OLD_GPU, + ) + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + audio_res: [], + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + visual_res, audio_res = audio_res + + # Prepare queries + queries = point_embedding + keys = image_embedding + # Apply transformer blocks and final layernorm + for i, layer in enumerate(self.layers): + keys = keys + visual_res[i] + queries[:, 2:6] = queries[:, 2:6] + audio_res[i] + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + queries[:, 2:6] = queries[:, 2:6] + audio_res[-1] + keys = keys + visual_res[-1] + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLP( + embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation + ) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + dropout: float = 0.0, + kv_in_dim: int = None, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert ( + self.internal_dim % num_heads == 0 + ), "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + self.dropout_p = dropout + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + try: + with sdp_kernel_context(dropout_p): + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + except Exception as e: + # Fall back to all kernels if the Flash attention kernel fails + warnings.warn( + f"Flash Attention kernel failed due to: {e}\nFalling back to all available " + f"kernels for scaled_dot_product_attention (which may have a slower speed).", + category=UserWarning, + stacklevel=2, + ) + global ALLOW_ALL_KERNELS + ALLOW_ALL_KERNELS = True + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out + + +class RoPEAttention(Attention): + """Attention with rotary position encoding.""" + + def __init__( + self, + *args, + rope_theta=10000.0, + # whether to repeat q rope to match k length + # this is needed for cross-attention to memories + rope_k_repeat=False, + feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.compute_cis = partial( + compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta + ) + freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) + self.freqs_cis = freqs_cis + self.rope_k_repeat = rope_k_repeat + + def forward( + self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0 + ) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Apply rotary position encoding + w = h = math.sqrt(q.shape[-2]) + self.freqs_cis = self.freqs_cis.to(q.device) + if self.freqs_cis.shape[0] != q.shape[-2]: + self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) + if q.shape[-2] != k.shape[-2]: + assert self.rope_k_repeat + + num_k_rope = k.size(-2) - num_k_exclude_rope + q, k[:, :, :num_k_rope] = apply_rotary_enc( + q, + k[:, :, :num_k_rope], + freqs_cis=self.freqs_cis, + repeat_freqs_k=self.rope_k_repeat, + ) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + try: + with sdp_kernel_context(dropout_p): + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + except Exception as e: + # Fall back to all kernels if the Flash attention kernel fails + warnings.warn( + f"Flash Attention kernel failed due to: {e}\nFalling back to all available " + f"kernels for scaled_dot_product_attention (which may have a slower speed).", + category=UserWarning, + stacklevel=2, + ) + global ALLOW_ALL_KERNELS + ALLOW_ALL_KERNELS = True + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/avs.code/v2.code/model/visual/sam2/modeling/sam2_base.py b/avs.code/v2.code/model/visual/sam2/modeling/sam2_base.py new file mode 100644 index 0000000000000000000000000000000000000000..2ab890394064172b8719e8a06ee0a47d995fd585 --- /dev/null +++ b/avs.code/v2.code/model/visual/sam2/modeling/sam2_base.py @@ -0,0 +1,940 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed +import torch.nn.functional as F + +from torch.nn.init import trunc_normal_ + +from model.visual.sam2.modeling.sam.mask_decoder import MaskDecoder +from model.visual.sam2.modeling.sam.prompt_encoder import PromptEncoder +from model.visual.sam2.modeling.sam.transformer import TwoWayTransformer +from model.visual.sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames + +# a large negative value as a placeholder score for missing objects +NO_OBJ_SCORE = -1024.0 + + +class SAM2Base(torch.nn.Module): + def __init__( + self, + image_encoder, + memory_attention, + memory_encoder, + num_maskmem=7, # default 1 input frame + 6 previous frames + image_size=512, + backbone_stride=16, # stride of the image backbone output + sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob + sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob + # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks + binarize_mask_from_pts_for_mem_enc=False, + use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder + # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit, + # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model + # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM. + max_cond_frames_in_attn=-1, + # on the first frame, whether to directly add the no-memory embedding to the image feature + # (instead of using the transformer encoder) + directly_add_no_mem_embed=False, + # whether to use high-resolution feature maps in the SAM mask decoder + use_high_res_features_in_sam=False, + # whether to output multiple (3) masks for the first click on initial conditioning frames + multimask_output_in_sam=False, + # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`; + # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points) + multimask_min_pt_num=1, + multimask_max_pt_num=1, + # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`) + multimask_output_for_tracking=False, + # Whether to use multimask tokens for obj ptr; Only relevant when both + # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True + use_multimask_token_for_obj_ptr: bool = False, + # whether to use sigmoid to restrict ious prediction to [0-1] + iou_prediction_use_sigmoid=False, + # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5). + # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of + # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame. + memory_temporal_stride_for_eval=1, + # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks) + non_overlap_masks_for_mem_enc=False, + # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder=False, + # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`) + max_obj_ptrs_in_encoder=16, + # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`) + add_tpos_enc_to_obj_ptrs=True, + # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference + # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) + proj_tpos_enc_in_obj_ptrs=False, + # whether to use signed distance (instead of unsigned absolute distance) in the temporal positional encoding in the object pointers + # (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) + use_signed_tpos_enc_to_obj_ptrs=False, + # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation + # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking) + only_obj_ptrs_in_the_past_for_eval=False, + # Whether to predict if there is an object in the frame + pred_obj_scores: bool = False, + # Whether to use an MLP to predict object scores + pred_obj_scores_mlp: bool = False, + # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True; + # Whether to have a fixed no obj pointer when there is no object present + # or to use it as an additive embedding with obj_ptr produced by decoder + fixed_no_obj_ptr: bool = False, + # Soft no object, i.e. mix in no_obj_ptr softly, + # hope to make recovery easier if there is a mistake and mitigate accumulation of errors + soft_no_obj_ptr: bool = False, + use_mlp_for_obj_ptr_proj: bool = False, + # add no obj embedding to spatial frames + no_obj_embed_spatial: bool = False, + # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class. + sam_mask_decoder_extra_args=None, + compile_image_encoder: bool = False, + ): + super().__init__() + + # Part 1: the image backbone + self.image_encoder = image_encoder + # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting + self.use_high_res_features_in_sam = use_high_res_features_in_sam + self.num_feature_levels = 3 if use_high_res_features_in_sam else 1 + self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder + self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder + if use_obj_ptrs_in_encoder: + # A conv layer to downsample the mask prompt to stride 4 (the same stride as + # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale, + # so that it can be fed into the SAM mask decoder to generate a pointer. + self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4) + self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs + if proj_tpos_enc_in_obj_ptrs: + assert add_tpos_enc_to_obj_ptrs # these options need to be used together + self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs + self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs + self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval + + # Part 2: memory attention to condition current frame's visual features + # with memories (and obj ptrs) from past frames + self.memory_attention = memory_attention + + #### this is for Version 2.0 + # self.hidden_dim = memory_attention.d_model + #### this is for Version 2.1 + # self.hidden_dim = image_encoder.neck.d_model + self.hidden_dim = 256 # well, it is always 256 anyway. + + # Part 3: memory encoder for the previous frame's outputs + self.memory_encoder = memory_encoder + self.mem_dim = self.hidden_dim + if hasattr(self.memory_encoder, "out_proj") and hasattr( + self.memory_encoder.out_proj, "weight" + ): + # if there is compression of memories along channel dim + self.mem_dim = self.memory_encoder.out_proj.weight.shape[0] + self.num_maskmem = num_maskmem # Number of memories accessible + # Temporal encoding of the memories + self.maskmem_tpos_enc = torch.nn.Parameter( + torch.zeros(num_maskmem, 1, 1, self.mem_dim) + ) + trunc_normal_(self.maskmem_tpos_enc, std=0.02) + # a single token to indicate no memory embedding from previous frames + self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + trunc_normal_(self.no_mem_embed, std=0.02) + trunc_normal_(self.no_mem_pos_enc, std=0.02) + self.directly_add_no_mem_embed = directly_add_no_mem_embed + # Apply sigmoid to the output raw mask logits (to turn them from + # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder + self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc + self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc + self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc + self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc + self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval + # On frames with mask input, whether to directly output the input mask without + # using a SAM prompt encoder + mask decoder + self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam + self.multimask_output_in_sam = multimask_output_in_sam + self.multimask_min_pt_num = multimask_min_pt_num + self.multimask_max_pt_num = multimask_max_pt_num + self.multimask_output_for_tracking = multimask_output_for_tracking + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid + + # Part 4: SAM-style prompt encoder (for both mask and point inputs) + # and SAM-style mask decoder for the final mask output + self.image_size = image_size + self.backbone_stride = backbone_stride + self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args + self.pred_obj_scores = pred_obj_scores + self.pred_obj_scores_mlp = pred_obj_scores_mlp + self.fixed_no_obj_ptr = fixed_no_obj_ptr + self.soft_no_obj_ptr = soft_no_obj_ptr + if self.fixed_no_obj_ptr: + assert self.pred_obj_scores + assert self.use_obj_ptrs_in_encoder + if self.pred_obj_scores and self.use_obj_ptrs_in_encoder: + self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) + trunc_normal_(self.no_obj_ptr, std=0.02) + self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj + self.no_obj_embed_spatial = None + if no_obj_embed_spatial: + self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim)) + trunc_normal_(self.no_obj_embed_spatial, std=0.02) + + self._build_sam_heads() + self.max_cond_frames_in_attn = max_cond_frames_in_attn + + # Model compilation + if compile_image_encoder: + # Compile the forward function (not the full module) to allow loading checkpoints. + print( + "Image encoder compilation is enabled. First forward pass will be slow." + ) + self.image_encoder.forward = torch.compile( + self.image_encoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + + ### we fix the use_mask_input_as_output_without_sam to be turned off. + self.use_mask_input_as_output_without_sam = False + + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, *args, **kwargs): + raise NotImplementedError( + "Please use the corresponding methods in SAM2VideoPredictor for inference or SAM2Train for training/fine-tuning" + "See notebooks/video_predictor_example.ipynb for an inference example." + ) + + def _build_sam_heads(self): + """Build SAM-style prompt encoder and mask decoder.""" + self.sam_prompt_embed_dim = self.hidden_dim + self.sam_image_embedding_size = self.image_size // self.backbone_stride + + # build PromptEncoder and MaskDecoder from SAM + # (their hyperparameters like `mask_in_chans=16` are from SAM code) + self.sam_prompt_encoder = PromptEncoder( + embed_dim=self.sam_prompt_embed_dim, + image_embedding_size=( + self.sam_image_embedding_size, + self.sam_image_embedding_size, + ), + input_image_size=(self.image_size, self.image_size), + mask_in_chans=16, + ) + self.sam_mask_decoder = MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=self.sam_prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=self.sam_prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + use_high_res_features=self.use_high_res_features_in_sam, + iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, + pred_obj_scores=self.pred_obj_scores, + pred_obj_scores_mlp=self.pred_obj_scores_mlp, + use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, + **(self.sam_mask_decoder_extra_args or {}), + ) + if self.use_obj_ptrs_in_encoder: + # a linear projection on SAM output tokens to turn them into object pointers + self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim) + if self.use_mlp_for_obj_ptr_proj: + self.obj_ptr_proj = MLP( + self.hidden_dim, self.hidden_dim, self.hidden_dim, 3 + ) + else: + self.obj_ptr_proj = torch.nn.Identity() + if self.proj_tpos_enc_in_obj_ptrs: + # a linear projection on temporal positional encoding in object pointers to + # avoid potential interference with spatial positional encoding + self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim) + else: + self.obj_ptr_tpos_proj = torch.nn.Identity() + + def _forward_sam_heads( + self, + backbone_features, + point_inputs=None, + mask_inputs=None, + high_res_features=None, + multimask_output=False, + audio_res=None + ): + """ + Forward SAM prompt encoders and mask heads. + + Inputs: + - backbone_features: image features of [B, C, H, W] shape + - point_inputs: a dictionary with "point_coords" and "point_labels", where + 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the + absolute pixel-unit coordinate in (x, y) format of the P input points + 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means + positive clicks, 0 means negative clicks, and -1 means padding + - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the + same spatial size as the image. + - high_res_features: either 1) None or 2) or a list of length 2 containing + two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively, + which will be used as high-resolution feature maps for SAM decoder. + - multimask_output: if it's True, we output 3 candidate masks and their 3 + corresponding IoU estimates, and if it's False, we output only 1 mask and + its corresponding IoU estimate. + + Outputs: + - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if + `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM + output mask logits (before sigmoid) for the low-resolution masks, with 4x + the resolution (1/4 stride) of the input backbone_features. + - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3 + if `multimask_output=True` and M = 1 if `multimask_output=False`), + upsampled from the low-resolution masks, with shape size as the image + (stride is 1 pixel). + - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1 + if `multimask_output=False`), the estimated IoU of each output mask. + - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `low_res_multimasks`. + - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `high_res_multimasks`. + - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted + based on the output token from the SAM mask decoder. + """ + B = backbone_features.size(0) + device = backbone_features.device + assert backbone_features.size(1) == self.sam_prompt_embed_dim + assert backbone_features.size(2) == self.sam_image_embedding_size + assert backbone_features.size(3) == self.sam_image_embedding_size + + ''' + # a) Handle point prompts + if point_inputs is not None: + sam_point_coords = point_inputs["point_coords"] + sam_point_labels = point_inputs["point_labels"] + assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B + raise NotImplementedError + else: + # If no points are provide, pad with an empty point (with label -1) + sam_point_coords = torch.zeros(B, 1, 2, device=device) + sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) + + # b) Handle mask prompts + if mask_inputs is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) + if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: + sam_mask_prompt = F.interpolate( + mask_inputs.float(), + size=self.sam_prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + else: + sam_mask_prompt = mask_inputs + raise NotImplementedError + else: + # Otherwise, simply feed None (and SAM's prompt encoder will add + # a learned `no_mask_embed` to indicate no mask input in this case). + sam_mask_prompt = None + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( + points=(sam_point_coords, sam_point_labels), + boxes=None, + masks=sam_mask_prompt, + ) + ''' + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( + points=None, + boxes=None, + masks=None, + ) + + ( + low_res_multimasks, + ious, + sam_output_tokens, + object_score_logits, + ) = self.sam_mask_decoder( + image_embeddings=backbone_features, + image_pe=self.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=False, # the image is already batched + high_res_features=high_res_features, + audio_res_features=audio_res + ) + ''' + if self.pred_obj_scores: + is_obj_appearing = object_score_logits > 0 + + # Mask used for spatial memories is always a *hard* choice between obj and no obj, + # consistent with the actual mask prediction + low_res_multimasks = torch.where( + is_obj_appearing[:, None, None], + low_res_multimasks, + NO_OBJ_SCORE, + ) + ''' + # convert masks from possibly bfloat16 (or float16) to float32 + # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) + low_res_multimasks = low_res_multimasks.float() + high_res_multimasks = F.interpolate( + low_res_multimasks, + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + sam_output_token = sam_output_tokens[:, 0] + if multimask_output: + # comment this line temporarily. + # take the best mask prediction (with the highest IoU estimation) + best_iou_inds = torch.argmax(ious, dim=-1) + batch_inds = torch.arange(B, device=device) + low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + if sam_output_tokens.size(1) > 1: + sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] + else: + low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks + + # Extract object pointer from the SAM output token (with occlusion handling) + obj_ptr = self.obj_ptr_proj(sam_output_token) + + # don't train occlusion at the moment, command temporarily. + if self.pred_obj_scores: + is_obj_appearing = object_score_logits > 0 + # Allow *soft* no obj ptr, unlike for masks + if self.soft_no_obj_ptr: + lambda_is_obj_appearing = object_score_logits.sigmoid() + else: + lambda_is_obj_appearing = is_obj_appearing.float() + + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + return ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): + """ + Directly turn binary `mask_inputs` into a output mask logits without using SAM. + (same input and output shapes as in _forward_sam_heads above). + """ + # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid). + out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05 + mask_inputs_float = mask_inputs.float() + high_res_masks = mask_inputs_float * out_scale + out_bias + low_res_masks = F.interpolate( + high_res_masks, + size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + # a dummy IoU prediction of all 1's under mask input + ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float() + if not self.use_obj_ptrs_in_encoder: + # all zeros as a dummy object pointer (of shape [B, C]) + obj_ptr = torch.zeros( + mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device + ) + else: + # produce an object pointer using the SAM decoder from the mask input + _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads( + backbone_features=backbone_features, + mask_inputs=self.mask_downsample(mask_inputs_float), + high_res_features=high_res_features, + ) + # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem; + # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying + # on the object_scores from the SAM decoder. + is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1) + is_obj_appearing = is_obj_appearing[..., None] + lambda_is_obj_appearing = is_obj_appearing.float() + object_score_logits = out_scale * lambda_is_obj_appearing + out_bias + if self.pred_obj_scores: + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_masks, + high_res_masks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def precompute_high_res_features(self, backbone_out): + if self.use_high_res_features_in_sam: + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0( + backbone_out["backbone_fpn"][0] + ) + backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1( + backbone_out["backbone_fpn"][1] + ) + return backbone_out + + def forward_image(self, img_batch: torch.Tensor, pre_compute=True): + """Get the image feature on the input batch.""" + backbone_out = self.image_encoder(img_batch) + return backbone_out if not pre_compute else self.precompute_high_res_features(backbone_out) + + def _prepare_backbone_features(self, backbone_out): + """Prepare and flatten visual features.""" + backbone_out = backbone_out.copy() + assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) + assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels + + feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :] + vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :] + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] + # flatten NxCxHxW to HWxNxC + vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] + vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] + + return backbone_out, vision_feats, vision_pos_embeds, feat_sizes + + def _prepare_memory_conditioned_features( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + ): + """Fuse the current frame's visual feature map with previous memory.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + device = current_vision_feats[-1].device + # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images. + # In this case, we skip the fusion with any memory. + if self.num_maskmem == 0: # Disable memory and skip fusion + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + return pix_feat + + num_obj_ptr_tokens = 0 + tpos_sign_mul = -1 if track_in_reverse else 1 + # Step 1: condition the visual features of the current frame on previous memories + if not is_init_cond_frame: + # Retrieve the memories encoded with the maskmem backbone + to_cat_memory, to_cat_memory_pos_embed = [], [] + # Add conditioning frames's output first (all cond frames have t_pos=0 for + # when getting temporal positional embedding below) + assert len(output_dict["cond_frame_outputs"]) > 0 + # Select a maximum number of temporally closest cond frames for cross attention + cond_outputs = output_dict["cond_frame_outputs"] + selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames( + frame_idx, cond_outputs, self.max_cond_frames_in_attn + ) + t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()] + # for t_pos in range(1, min(self.num_maskmem, frame_idx)): + # out = output_dict["non_cond_frame_outputs"].get(t_pos, None) + # t_pos_and_prevs.append((t_pos, out)) + # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory + # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 + # We also allow taking the memory frame non-consecutively (with stride>1), in which case + # we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame. + stride = 1 if self.training else self.memory_temporal_stride_for_eval + + for t_pos in range(1, self.num_maskmem): + t_rel = self.num_maskmem - t_pos # how many frames before current frame + if t_rel == 1: + # for t_rel == 1, we take the last frame (regardless of r) + if not track_in_reverse: + # the frame immediately before this frame (i.e. frame_idx - 1) + prev_frame_idx = frame_idx - t_rel + else: + # the frame immediately after this frame (i.e. frame_idx + 1) + prev_frame_idx = frame_idx + t_rel + else: + # for t_rel >= 2, we take the memory frame from every r-th frames + if not track_in_reverse: + # first find the nearest frame among every r-th frames before this frame + # for r=1, this would be (frame_idx - 2) + prev_frame_idx = ((frame_idx - 2) // stride) * stride + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx - (t_rel - 2) * stride + else: + # first find the nearest frame among every r-th frames after this frame + # for r=1, this would be (frame_idx + 2) + prev_frame_idx = -(-(frame_idx + 2) // stride) * stride + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx + (t_rel - 2) * stride + out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) + if out is None: + # If an unselected conditioning frame is among the last (self.num_maskmem - 1) + # frames, we still attend to it as if it's a non-conditioning frame. + out = unselected_cond_outputs.get(prev_frame_idx, None) + t_pos_and_prevs.append((t_pos, out)) + + for t_pos, prev in t_pos_and_prevs: + if prev is None: + continue # skip padding frames + # "maskmem_features" might have been offloaded to CPU in demo use cases, + # so we load it back to GPU (it's a no-op if it's already on GPU). + feats = prev["maskmem_features"].to(device, non_blocking=True) + to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) + # Spatial positional encoding (it might have been offloaded to CPU in eval) + maskmem_enc = prev["maskmem_pos_enc"][-1].to(device) + maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) + # Temporal positional encoding + maskmem_enc = ( + maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1] + ) + to_cat_memory_pos_embed.append(maskmem_enc) + # Construct the list of past object pointers + if self.use_obj_ptrs_in_encoder: + max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder) + # First add those object pointers from selected conditioning frames + # (optionally, only include object pointers in the past during evaluation) + if not self.training and self.only_obj_ptrs_in_the_past_for_eval: + ptr_cond_outputs = { + t: out + for t, out in selected_cond_outputs.items() + if (t >= frame_idx if track_in_reverse else t <= frame_idx) + } + else: + ptr_cond_outputs = selected_cond_outputs + pos_and_ptrs = [ + # Temporal pos encoding contains how far away each pointer is from current frame + ( + ( + (frame_idx - t) * tpos_sign_mul + if self.use_signed_tpos_enc_to_obj_ptrs + else abs(frame_idx - t) + ), + out["obj_ptr"], + ) + for t, out in ptr_cond_outputs.items() + ] + # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame + for t_diff in range(1, max_obj_ptrs_in_encoder): + t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff + if t < 0 or (num_frames is not None and t >= num_frames): + break + out = output_dict["non_cond_frame_outputs"].get( + t, unselected_cond_outputs.get(t, None) + ) + if out is not None: + pos_and_ptrs.append((t_diff, out["obj_ptr"])) + # If we have at least one object pointer, add them to the across attention + if len(pos_and_ptrs) > 0: + pos_list, ptrs_list = zip(*pos_and_ptrs) + # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape + obj_ptrs = torch.stack(ptrs_list, dim=0) + # a temporal positional embedding based on how far each object pointer is from + # the current frame (sine embedding normalized by the max pointer num). + # default false. + if self.add_tpos_enc_to_obj_ptrs: + t_diff_max = max_obj_ptrs_in_encoder - 1 + tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim + obj_pos = torch.tensor(pos_list, device=device) + obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) + obj_pos = self.obj_ptr_tpos_proj(obj_pos) + obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) + else: + obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim) + if self.mem_dim < C: + # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C + obj_ptrs = obj_ptrs.reshape( + -1, B, C // self.mem_dim, self.mem_dim + ) + obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1) + obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0) + to_cat_memory.append(obj_ptrs) + to_cat_memory_pos_embed.append(obj_pos) + num_obj_ptr_tokens = obj_ptrs.shape[0] + else: + num_obj_ptr_tokens = 0 + else: + # for initial conditioning frames, encode them without using any previous memory + if self.directly_add_no_mem_embed: + # directly add no-mem embedding (instead of using the transformer encoder) + pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + # Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder) + # the Following lines will never be triggered. + raise NotImplementedError + to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] + to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] + + # Step 2: Concatenate the memories and forward through the transformer encoder + memory = torch.cat(to_cat_memory, dim=0) + memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0) + + pix_feat_with_mem = self.memory_attention( + curr=current_vision_feats, + curr_pos=current_vision_pos_embeds, + memory=memory, + memory_pos=memory_pos_embed, + num_obj_ptr_tokens=num_obj_ptr_tokens, + ) + # reshape the output (HW)BC => BCHW + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + def _encode_new_memory( + self, + current_vision_feats, + feat_sizes, + pred_masks_high_res, + object_score_logits, + is_mask_from_pts, + ): + """Encode the current image and its prediction into a memory feature.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + # top-level feature, (HW)BC => BCHW + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + if self.non_overlap_masks_for_mem_enc and not self.training: + # optionally, apply non-overlapping constraints to the masks (it's applied + # in the batch dimension and should only be used during eval, where all + # the objects come from the same video under batch size 1). + pred_masks_high_res = self._apply_non_overlapping_constraints( + pred_masks_high_res + ) + raise NotImplementedError + # scale the raw mask logits with a temperature before applying sigmoid + binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts + if binarize and not self.training: + mask_for_mem = (pred_masks_high_res > 0).float() + else: + # apply sigmoid on the raw mask logits to turn them into range (0, 1) + mask_for_mem = torch.sigmoid(pred_masks_high_res) + # apply scale and bias terms to the sigmoid probabilities + if self.sigmoid_scale_for_mem_enc != 1.0: + mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc + if self.sigmoid_bias_for_mem_enc != 0.0: + mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc + maskmem_out = self.memory_encoder( + pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied + ) + maskmem_features = maskmem_out["vision_features"] + maskmem_pos_enc = maskmem_out["vision_pos_enc"] + # add a no-object embedding to the spatial memory to indicate that the frame + # is predicted to be occluded (i.e. no object is appearing in the frame) + if self.no_obj_embed_spatial is not None: + is_obj_appearing = (object_score_logits > 0).float() + maskmem_features += ( + 1 - is_obj_appearing[..., None, None] + ) * self.no_obj_embed_spatial[..., None, None].expand( + *maskmem_features.shape + ) + # it will be used in sam2.1 + # raise NotImplementedError + + return maskmem_features, maskmem_pos_enc + + def _track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ): + current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(current_vision_feats) > 1: + high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) + ] + else: + high_res_features = None + if mask_inputs is not None and self.use_mask_input_as_output_without_sam: + # When use_mask_input_as_output_without_sam=True, we directly output the mask input + # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. + pix_feat = current_vision_feats[-1].permute(1, 2, 0) + pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + sam_outputs = self._use_mask_as_output( + pix_feat, high_res_features, mask_inputs + ) + else: + # fused the visual feature with previous memory features in the memory bank + pix_feat = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats[-1:], + current_vision_pos_embeds=current_vision_pos_embeds[-1:], + feat_sizes=feat_sizes[-1:], + output_dict=output_dict, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + ) + # apply SAM-style segmentation head + # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, + # e.g. in demo where such logits come from earlier interaction instead of correction sampling + # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) + if prev_sam_mask_logits is not None: + assert point_inputs is not None and mask_inputs is None + mask_inputs = prev_sam_mask_logits + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + + return current_out, sam_outputs, high_res_features, pix_feat + + def _encode_memory_in_output( + self, + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ): + if run_mem_encoder and self.num_maskmem > 0: + high_res_masks_for_mem_enc = high_res_masks + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks_for_mem_enc, + object_score_logits=object_score_logits, + is_mask_from_pts=(point_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + + def track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + # Whether to run the memory encoder on the predicted masks. Sometimes we might want + # to skip the memory encoder with `run_mem_encoder=False`. For example, + # in demo we might call `track_step` multiple times for each user click, + # and only encode the memory when the user finalizes their clicks. And in ablation + # settings like SAM training on static images, we don't need the memory encoder. + run_mem_encoder=True, + # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). + prev_sam_mask_logits=None, + ): + current_out, sam_outputs, _, _ = self._track_step( + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ) + + ( + _, + _, + _, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = sam_outputs + + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + if not self.training: + # Only add this in inference (to avoid unused param in activation checkpointing; + # it's mainly used in the demo to encode spatial memories w/ consolidated masks) + current_out["object_score_logits"] = object_score_logits + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + self._encode_memory_in_output( + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ) + + return current_out + + def _use_multimask(self, is_init_cond_frame, point_inputs): + """Whether to use multimask output in the SAM head.""" + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) + multimask_output = ( + self.multimask_output_in_sam + and (is_init_cond_frame or self.multimask_output_for_tracking) + and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) + ) + return multimask_output + + def _apply_non_overlapping_constraints(self, pred_masks): + """ + Apply non-overlapping constraints to the object scores in pred_masks. Here we + keep only the highest scoring object at each spatial location in pred_masks. + """ + batch_size = pred_masks.size(0) + if batch_size == 1: + return pred_masks + + device = pred_masks.device + # "max_obj_inds": object index of the object with the highest score at each location + max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) + # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks` + batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] + keep = max_obj_inds == batch_obj_inds + # suppress overlapping regions' scores below -10.0 so that the foreground regions + # don't overlap (here sigmoid(-10.0)=4.5398e-05) + pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) + return pred_masks diff --git a/avs.code/v2.code/model/visual/sam2/modeling/sam2_utils.py b/avs.code/v2.code/model/visual/sam2/modeling/sam2_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..19133558dd657bbcf67f851011d45bd4999cab0a --- /dev/null +++ b/avs.code/v2.code/model/visual/sam2/modeling/sam2_utils.py @@ -0,0 +1,323 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import copy +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model.visual.sam2.utils.misc import mask_to_box + + +def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): + """ + Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs` + that are temporally closest to the current frame at `frame_idx`. Here, we take + - a) the closest conditioning frame before `frame_idx` (if any); + - b) the closest conditioning frame after `frame_idx` (if any); + - c) any other temporally closest conditioning frames until reaching a total + of `max_cond_frame_num` conditioning frames. + + Outputs: + - selected_outputs: selected items (keys & values) from `cond_frame_outputs`. + - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`. + """ + if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: + selected_outputs = cond_frame_outputs + unselected_outputs = {} + else: + assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" + selected_outputs = {} + + # the closest conditioning frame before `frame_idx` (if any) + idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None) + if idx_before is not None: + selected_outputs[idx_before] = cond_frame_outputs[idx_before] + + # the closest conditioning frame after `frame_idx` (if any) + idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None) + if idx_after is not None: + selected_outputs[idx_after] = cond_frame_outputs[idx_after] + + # add other temporally closest conditioning frames until reaching a total + # of `max_cond_frame_num` conditioning frames. + num_remain = max_cond_frame_num - len(selected_outputs) + inds_remain = sorted( + (t for t in cond_frame_outputs if t not in selected_outputs), + key=lambda x: abs(x - frame_idx), + )[:num_remain] + selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) + unselected_outputs = { + t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs + } + + return selected_outputs, unselected_outputs + + +def get_1d_sine_pe(pos_inds, dim, temperature=10000): + """ + Get 1D sine positional embedding as in the original Transformer paper. + """ + pe_dim = dim // 2 + dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) + dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) + + pos_embed = pos_inds.unsqueeze(-1) / dim_t + pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) + return pos_embed + + +def get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") + + +def get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +class DropPath(nn.Module): + # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py + def __init__(self, drop_prob=0.0, scale_by_keep=True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and self.scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + activation: nn.Module = nn.ReLU, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + self.act = activation() + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +def sample_box_points( + masks: torch.Tensor, + noise: float = 0.1, # SAM default + noise_bound: int = 20, # SAM default + top_left_label: int = 2, + bottom_right_label: int = 3, +) -> Tuple[np.array, np.array]: + """ + Sample a noised version of the top left and bottom right corners of a given `bbox` + + Inputs: + - masks: [B, 1, H,W] boxes, dtype=torch.Tensor + - noise: noise as a fraction of box width and height, dtype=float + - noise_bound: maximum amount of noise (in pure pixesl), dtype=int + + Returns: + - box_coords: [B, num_pt, 2], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.float + - box_labels: [B, num_pt], label 2 is reserverd for top left and 3 for bottom right corners, dtype=torch.int32 + """ + device = masks.device + box_coords = mask_to_box(masks) + B, _, H, W = masks.shape + box_labels = torch.tensor( + [top_left_label, bottom_right_label], dtype=torch.int, device=device + ).repeat(B) + if noise > 0.0: + if not isinstance(noise_bound, torch.Tensor): + noise_bound = torch.tensor(noise_bound, device=device) + bbox_w = box_coords[..., 2] - box_coords[..., 0] + bbox_h = box_coords[..., 3] - box_coords[..., 1] + max_dx = torch.min(bbox_w * noise, noise_bound) + max_dy = torch.min(bbox_h * noise, noise_bound) + box_noise = 2 * torch.rand(B, 1, 4, device=device) - 1 + box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1) + + box_coords = box_coords + box_noise + img_bounds = ( + torch.tensor([W, H, W, H], device=device) - 1 + ) # uncentered pixel coords + box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping + + box_coords = box_coords.reshape(-1, 2, 2) # always 2 points + box_labels = box_labels.reshape(-1, 2) + return box_coords, box_labels + + +def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1): + """ + Sample `num_pt` random points (along with their labels) independently from the error regions. + + Inputs: + - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool + - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None + - num_pt: int, number of points to sample independently for each of the B error maps + + Outputs: + - points: [B, num_pt, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point + - labels: [B, num_pt], dtype=torch.int32, where 1 means positive clicks and 0 means + negative clicks + """ + if pred_masks is None: # if pred_masks is not provided, treat it as empty + pred_masks = torch.zeros_like(gt_masks) + assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1 + assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape + assert num_pt >= 0 + + B, _, H_im, W_im = gt_masks.shape + device = gt_masks.device + + # false positive region, a new point sampled in this region should have + # negative label to correct the FP error + fp_masks = ~gt_masks & pred_masks + # false negative region, a new point sampled in this region should have + # positive label to correct the FN error + fn_masks = gt_masks & ~pred_masks + # whether the prediction completely match the ground-truth on each mask + all_correct = torch.all((gt_masks == pred_masks).flatten(2), dim=2) + all_correct = all_correct[..., None, None] + + # channel 0 is FP map, while channel 1 is FN map + pts_noise = torch.rand(B, num_pt, H_im, W_im, 2, device=device) + # sample a negative new click from FP region or a positive new click + # from FN region, depend on where the maximum falls, + # and in case the predictions are all correct (no FP or FN), we just + # sample a negative click from the background region + pts_noise[..., 0] *= fp_masks | (all_correct & ~gt_masks) + pts_noise[..., 1] *= fn_masks + pts_idx = pts_noise.flatten(2).argmax(dim=2) + labels = (pts_idx % 2).to(torch.int32) + pts_idx = pts_idx // 2 + pts_x = pts_idx % W_im + pts_y = pts_idx // W_im + points = torch.stack([pts_x, pts_y], dim=2).to(torch.float) + return points, labels + + +def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True): + """ + Sample 1 random point (along with its label) from the center of each error region, + that is, the point with the largest distance to the boundary of each error region. + This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py + + Inputs: + - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool + - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None + - padding: if True, pad with boundary of 1 px for distance transform + + Outputs: + - points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point + - labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks + """ + import cv2 + + if pred_masks is None: + pred_masks = torch.zeros_like(gt_masks) + assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1 + assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape + + B, _, _, W_im = gt_masks.shape + device = gt_masks.device + + # false positive region, a new point sampled in this region should have + # negative label to correct the FP error + fp_masks = ~gt_masks & pred_masks + # false negative region, a new point sampled in this region should have + # positive label to correct the FN error + fn_masks = gt_masks & ~pred_masks + + fp_masks = fp_masks.cpu().numpy() + fn_masks = fn_masks.cpu().numpy() + points = torch.zeros(B, 1, 2, dtype=torch.float) + labels = torch.ones(B, 1, dtype=torch.int32) + for b in range(B): + fn_mask = fn_masks[b, 0] + fp_mask = fp_masks[b, 0] + if padding: + fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant") + fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant") + # compute the distance of each point in FN/FP region to its boundary + fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0) + fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0) + if padding: + fn_mask_dt = fn_mask_dt[1:-1, 1:-1] + fp_mask_dt = fp_mask_dt[1:-1, 1:-1] + + # take the point in FN/FP region with the largest distance to its boundary + fn_mask_dt_flat = fn_mask_dt.reshape(-1) + fp_mask_dt_flat = fp_mask_dt.reshape(-1) + fn_argmax = np.argmax(fn_mask_dt_flat) + fp_argmax = np.argmax(fp_mask_dt_flat) + is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax] + pt_idx = fn_argmax if is_positive else fp_argmax + points[b, 0, 0] = pt_idx % W_im # x + points[b, 0, 1] = pt_idx // W_im # y + labels[b, 0] = int(is_positive) + + points = points.to(device) + labels = labels.to(device) + return points, labels + + +def get_next_point(gt_masks, pred_masks, method): + if method == "uniform": + return sample_random_points_from_errors(gt_masks, pred_masks) + elif method == "center": + return sample_one_point_from_error_center(gt_masks, pred_masks) + else: + raise ValueError(f"unknown sampling method {method}") diff --git a/avs.code/v2.code/model/visual/sam2/organised_sam2_train.py b/avs.code/v2.code/model/visual/sam2/organised_sam2_train.py new file mode 100644 index 0000000000000000000000000000000000000000..607c3ad22ba7dcb7eb74c30e1283f68c4808450e --- /dev/null +++ b/avs.code/v2.code/model/visual/sam2/organised_sam2_train.py @@ -0,0 +1,811 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import numpy as np +import torch +import torch.distributed +from model.visual.sam2.modeling.sam2_base import SAM2Base +from model.visual.sam2.modeling.sam2_utils import ( + get_1d_sine_pe, + get_next_point, + sample_box_points, + select_closest_cond_frames, +) + +from utils.misc import concat_points + +from utils.data_utils import BatchedVideoDatapoint + + +class SAM2Train(SAM2Base): + def __init__( + self, + image_encoder, + memory_attention=None, + memory_encoder=None, + prob_to_use_pt_input_for_train=0.0, + prob_to_use_pt_input_for_eval=0.0, + prob_to_use_box_input_for_train=0.0, + prob_to_use_box_input_for_eval=0.0, + # if it is greater than 1, we interactive point sampling in the 1st frame and other randomly selected frames + num_frames_to_correct_for_train=1, # default: only iteratively sample on first frame + num_frames_to_correct_for_eval=1, # default: only iteratively sample on first frame + rand_frames_to_correct_for_train=False, + rand_frames_to_correct_for_eval=False, + # how many frames to use as initial conditioning frames (for both point input and mask input; the first frame is always used as an initial conditioning frame) + # - if `rand_init_cond_frames` below is True, we randomly sample 1~num_init_cond_frames initial conditioning frames + # - otherwise we sample a fixed number of num_init_cond_frames initial conditioning frames + # note: for point input, we sample correction points on all such initial conditioning frames, and we require that `num_frames_to_correct` >= `num_init_cond_frames`; + # these are initial conditioning frames because as we track the video, more conditioning frames might be added + # when a frame receives correction clicks under point input if `add_all_frames_to_correct_as_cond=True` + num_init_cond_frames_for_train=1, # default: only use the first frame as initial conditioning frame + num_init_cond_frames_for_eval=1, # default: only use the first frame as initial conditioning frame + rand_init_cond_frames_for_train=True, # default: random 1~num_init_cond_frames_for_train cond frames (to be constent w/ previous TA data loader) + rand_init_cond_frames_for_eval=False, + # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click + # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames + add_all_frames_to_correct_as_cond=False, + # how many additional correction points to sample (on each frame selected to be corrected) + # note that the first frame receives an initial input click (in addition to any correction clicks) + num_correction_pt_per_frame=7, + # method for point sampling during evaluation + # "uniform" (sample uniformly from error region) or "center" (use the point with the largest distance to error region boundary) + # default to "center" to be consistent with evaluation in the SAM paper + pt_sampling_for_eval="center", + # During training, we optionally allow sampling the correction points from GT regions + # instead of the prediction error regions with a small probability. This might allow the + # model to overfit less to the error regions in training datasets + prob_to_sample_from_gt_for_train=0.0, + use_act_ckpt_iterative_pt_sampling=False, + # whether to forward image features per frame (as it's being tracked) during evaluation, instead of forwarding image features + # of all frames at once. This avoids backbone OOM errors on very long videos in evaluation, but could be slightly slower. + forward_backbone_per_frame_for_eval=False, + freeze_image_encoder=False, + **kwargs, + ): + super().__init__(image_encoder, memory_attention, memory_encoder, **kwargs) + self.use_act_ckpt_iterative_pt_sampling = use_act_ckpt_iterative_pt_sampling + self.forward_backbone_per_frame_for_eval = forward_backbone_per_frame_for_eval + + # Point sampler and conditioning frames + self.prob_to_use_pt_input_for_train = prob_to_use_pt_input_for_train + self.prob_to_use_box_input_for_train = prob_to_use_box_input_for_train + self.prob_to_use_pt_input_for_eval = prob_to_use_pt_input_for_eval + self.prob_to_use_box_input_for_eval = prob_to_use_box_input_for_eval + if prob_to_use_pt_input_for_train > 0 or prob_to_use_pt_input_for_eval > 0: + logging.info( + f"Training with points (sampled from masks) as inputs with p={prob_to_use_pt_input_for_train}" + ) + assert num_frames_to_correct_for_train >= num_init_cond_frames_for_train + assert num_frames_to_correct_for_eval >= num_init_cond_frames_for_eval + + self.num_frames_to_correct_for_train = num_frames_to_correct_for_train + self.num_frames_to_correct_for_eval = num_frames_to_correct_for_eval + self.rand_frames_to_correct_for_train = rand_frames_to_correct_for_train + self.rand_frames_to_correct_for_eval = rand_frames_to_correct_for_eval + # Initial multi-conditioning frames + self.num_init_cond_frames_for_train = num_init_cond_frames_for_train + self.num_init_cond_frames_for_eval = num_init_cond_frames_for_eval + self.rand_init_cond_frames_for_train = rand_init_cond_frames_for_train + self.rand_init_cond_frames_for_eval = rand_init_cond_frames_for_eval + self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond + self.num_correction_pt_per_frame = num_correction_pt_per_frame + self.pt_sampling_for_eval = pt_sampling_for_eval + self.prob_to_sample_from_gt_for_train = prob_to_sample_from_gt_for_train + # A random number generator with a fixed initial seed across GPUs + self.rng = np.random.default_rng(seed=42) + if freeze_image_encoder: + for p in self.image_encoder.parameters(): + p.requires_grad = False + + + def forward(self, input: BatchedVideoDatapoint): + if self.training or not self.forward_backbone_per_frame_for_eval: + # precompute image features on all frames before tracking + backbone_out = self.forward_image(input.flat_img_batch) + else: + # defer image feature computation on a frame until it's being tracked + backbone_out = {"backbone_fpn": None, "vision_pos_enc": None} + backbone_out = self.prepare_prompt_inputs(backbone_out, input) + previous_stages_out = self.forward_tracking(backbone_out, input) + + return previous_stages_out + + def _prepare_backbone_features_per_frame(self, img_batch, img_ids): + """Compute the image backbone features on the fly for the given img_ids.""" + # Only forward backbone on unique image ids to avoid repetitive computation + # (if `img_ids` has only one element, it's already unique so we skip this step). + if img_ids.numel() > 1: + unique_img_ids, inv_ids = torch.unique(img_ids, return_inverse=True) + else: + unique_img_ids, inv_ids = img_ids, None + + # Compute the image features on those unique image ids + image = img_batch[unique_img_ids] + backbone_out = self.forward_image(image) + ( + _, + vision_feats, + vision_pos_embeds, + feat_sizes, + ) = self._prepare_backbone_features(backbone_out) + ''' + vision_feats + torch.Size([65536, 5, 32]) + torch.Size([16384, 5, 64]) + torch.Size([4096, 5, 256]) + ''' + # Inverse-map image features for `unique_img_ids` to the final image features + # for the original input `img_ids`. + if inv_ids is not None: + image = image[inv_ids] + vision_feats = [x[:, inv_ids] for x in vision_feats] + vision_pos_embeds = [x[:, inv_ids] for x in vision_pos_embeds] + + return image, vision_feats, vision_pos_embeds, feat_sizes + + @staticmethod + def dont_prepare_prompt_inputs(backbone_out, num_frames=5, cond_frame=0): + backbone_out["gt_masks_per_frame"] = {} + backbone_out["num_frames"] = num_frames + backbone_out["use_pt_input"] = False + # always start from the first frame. + backbone_out["init_cond_frames"] = [cond_frame] + backbone_out["frames_not_in_init_cond"] = [i for i in range(0, num_frames) if i != cond_frame] + # backbone_out["init_cond_frames"] = [] + # backbone_out["frames_not_in_init_cond"] = [i for i in range(0, num_frames)] + + backbone_out["mask_inputs_per_frame"] = {} + backbone_out["point_inputs_per_frame"] = {} + backbone_out["frames_to_add_correction_pt"] = [] + return backbone_out + + def prepare_prompt_inputs(self, backbone_out, input, start_frame_idx=0): + """ + Prepare input mask, point or box prompts. Optionally, we allow tracking from + a custom `start_frame_idx` to the end of the video (for evaluation purposes). + """ + # Load the ground-truth masks on all frames (so that we can later + # sample correction points from them) + # gt_masks_per_frame = { + # stage_id: targets.segments.unsqueeze(1) # [B, 1, H_im, W_im] + # for stage_id, targets in enumerate(input.find_targets) + # } + gt_masks_per_frame = { + stage_id: masks.unsqueeze(1) # [B, 1, H_im, W_im] + for stage_id, masks in enumerate(input.masks) + } + # gt_masks_per_frame = input.masks.unsqueeze(2) # [T,B,1,H_im,W_im] keep everything in tensor form + backbone_out["gt_masks_per_frame"] = gt_masks_per_frame + num_frames = input.num_frames + backbone_out["num_frames"] = num_frames + + # Randomly decide whether to use point inputs or mask inputs + if self.training: + prob_to_use_pt_input = self.prob_to_use_pt_input_for_train + prob_to_use_box_input = self.prob_to_use_box_input_for_train + num_frames_to_correct = self.num_frames_to_correct_for_train + rand_frames_to_correct = self.rand_frames_to_correct_for_train + num_init_cond_frames = self.num_init_cond_frames_for_train + rand_init_cond_frames = self.rand_init_cond_frames_for_train + else: + prob_to_use_pt_input = self.prob_to_use_pt_input_for_eval + prob_to_use_box_input = self.prob_to_use_box_input_for_eval + num_frames_to_correct = self.num_frames_to_correct_for_eval + rand_frames_to_correct = self.rand_frames_to_correct_for_eval + num_init_cond_frames = self.num_init_cond_frames_for_eval + rand_init_cond_frames = self.rand_init_cond_frames_for_eval + if num_frames == 1: + # here we handle a special case for mixing video + SAM on image training, + # where we force using point input for the SAM task on static images + prob_to_use_pt_input = 1.0 + num_frames_to_correct = 1 + num_init_cond_frames = 1 + assert num_init_cond_frames >= 1 + # (here `self.rng.random()` returns value in range 0.0 <= X < 1.0) + use_pt_input = self.rng.random() < prob_to_use_pt_input + if rand_init_cond_frames and num_init_cond_frames > 1: + # randomly select 1 to `num_init_cond_frames` frames as initial conditioning frames + num_init_cond_frames = self.rng.integers( + 1, num_init_cond_frames, endpoint=True + ) + if ( + use_pt_input + and rand_frames_to_correct + and num_frames_to_correct > num_init_cond_frames + ): + # randomly select `num_init_cond_frames` to `num_frames_to_correct` frames to sample + # correction clicks (only for the case of point input) + num_frames_to_correct = self.rng.integers( + num_init_cond_frames, num_frames_to_correct, endpoint=True + ) + backbone_out["use_pt_input"] = use_pt_input + + # Sample initial conditioning frames + if num_init_cond_frames == 1: + init_cond_frames = [start_frame_idx] # starting frame + else: + # starting frame + randomly selected remaining frames (without replacement) + init_cond_frames = [start_frame_idx] + self.rng.choice( + range(start_frame_idx + 1, num_frames), + num_init_cond_frames - 1, + replace=False, + ).tolist() + backbone_out["init_cond_frames"] = init_cond_frames + backbone_out["frames_not_in_init_cond"] = [ + t for t in range(start_frame_idx, num_frames) if t not in init_cond_frames + ] + # Prepare mask or point inputs on initial conditioning frames + backbone_out["mask_inputs_per_frame"] = {} # {frame_idx: } + backbone_out["point_inputs_per_frame"] = {} # {frame_idx: } + for t in init_cond_frames: + if not use_pt_input: + backbone_out["mask_inputs_per_frame"][t] = gt_masks_per_frame[t] + else: + # During training # P(box) = prob_to_use_pt_input * prob_to_use_box_input + use_box_input = self.rng.random() < prob_to_use_box_input + if use_box_input: + points, labels = sample_box_points( + gt_masks_per_frame[t], + ) + else: + # (here we only sample **one initial point** on initial conditioning frames from the + # ground-truth mask; we may sample more correction points on the fly) + points, labels = get_next_point( + gt_masks=gt_masks_per_frame[t], + pred_masks=None, + method=( + "uniform" if self.training else self.pt_sampling_for_eval + ), + ) + + point_inputs = {"point_coords": points, "point_labels": labels} + backbone_out["point_inputs_per_frame"][t] = point_inputs + + # Sample frames where we will add correction clicks on the fly + # based on the error between prediction and ground-truth masks + if not use_pt_input: + # no correction points will be sampled when using mask inputs + frames_to_add_correction_pt = [] + elif num_frames_to_correct == num_init_cond_frames: + frames_to_add_correction_pt = init_cond_frames + else: + assert num_frames_to_correct > num_init_cond_frames + # initial cond frame + randomly selected remaining frames (without replacement) + extra_num = num_frames_to_correct - num_init_cond_frames + frames_to_add_correction_pt = ( + init_cond_frames + + self.rng.choice( + backbone_out["frames_not_in_init_cond"], extra_num, replace=False + ).tolist() + ) + backbone_out["frames_to_add_correction_pt"] = frames_to_add_correction_pt + + return backbone_out + + def forward_tracking_wo_prompt(self, backbone_out, audio_res=None, return_dict=False): + # img_feats_already_computed = True. + """Forward video tracking on each frame (and sample correction clicks).""" + # Prepare the backbone features + # - vision_feats and vision_pos_embeds are in (HW)BC format + ( + _, + vision_feats, + vision_pos_embeds, + feat_sizes, + ) = self._prepare_backbone_features(backbone_out) + + # Starting the stage loop + num_frames = backbone_out["num_frames"] + init_cond_frames = backbone_out["init_cond_frames"] + frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"] + # first process all the initial conditioning frames to encode them as memory, + # and then conditioning on them to track the remaining frames + processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"] + output_dict = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + + av_v_feats, av_a_feats = audio_res + for stage_id in processing_order: + # Get the image features for the current frames + img_ids = stage_id + # Retrieve image features according to img_ids (if they are already computed). + current_vision_feats = [x[:, img_ids].unsqueeze(1) for x in vision_feats] # add unsqueeze to maintain single sample. + current_vision_pos_embeds = [x[:, img_ids].unsqueeze(1) for x in vision_pos_embeds] # add unsqueeze to maintain single sample. + current_av_v_feats = [x[img_ids] for x in av_v_feats] + current_av_a_feats = [x[img_ids] for x in av_a_feats] + + # Get output masks based on this frame's prompts and previous memory + current_out = self.track_step_wo_prompt( + frame_idx=stage_id, + is_init_cond_frame=stage_id in init_cond_frames, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=None, # backbone_out["point_inputs_per_frame"].get(stage_id, None), + mask_inputs=None, # backbone_out["mask_inputs_per_frame"].get(stage_id, None), + gt_masks=None, # backbone_out["gt_masks_per_frame"].get(stage_id, None), + frames_to_add_correction_pt=None, # frames_to_add_correction_pt, + output_dict=output_dict, + num_frames=num_frames, + audio_res=(current_av_v_feats, current_av_a_feats), + ) + # Append the output, depending on whether it's a conditioning frame + add_output_as_cond_frame = stage_id in init_cond_frames or ( + self.add_all_frames_to_correct_as_cond + and stage_id in frames_to_add_correction_pt + ) + if add_output_as_cond_frame: + output_dict["cond_frame_outputs"][stage_id] = current_out + else: + output_dict["non_cond_frame_outputs"][stage_id] = current_out + + if return_dict: + return output_dict + # turn `output_dict` into a list for loss function + all_frame_outputs = {} + all_frame_outputs.update(output_dict["cond_frame_outputs"]) + all_frame_outputs.update(output_dict["non_cond_frame_outputs"]) + all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)] + # Make DDP happy with activation checkpointing by removing unused keys + all_frame_outputs = [ + {k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs + ] + + + return all_frame_outputs + + def track_step_wo_prompt( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + run_mem_encoder=True, # Whether to run the memory encoder on the predicted masks. + prev_sam_mask_logits=None, # The previously predicted SAM mask logits. + frames_to_add_correction_pt=None, + gt_masks=None, + audio_res=None, + ): + if frames_to_add_correction_pt is None: + frames_to_add_correction_pt = [] + + current_out, sam_outputs, high_res_features, pix_feat = self._track_step_wo_prompt( + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + audio_res + ) + + ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = sam_outputs + current_out["multistep_pred_masks"] = low_res_masks + current_out["multistep_pred_masks_high_res"] = high_res_masks + current_out["multistep_pred_multimasks"] = [low_res_multimasks] + current_out["multistep_pred_multimasks_high_res"] = [high_res_multimasks] + current_out["multistep_pred_ious"] = [ious] + current_out["multistep_point_inputs"] = [point_inputs] + current_out["multistep_object_score_logits"] = [object_score_logits] + + ''' + # Optionally, sample correction points iteratively to correct the mask + if frame_idx in frames_to_add_correction_pt: + point_inputs, final_sam_outputs = self._iter_correct_pt_sampling( + is_init_cond_frame, + point_inputs, + gt_masks, + high_res_features, + pix_feat, + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + object_score_logits, + current_out, + ) + ( + _, + _, + _, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = final_sam_outputs + ''' + # Use the final prediction (after all correction steps for output and eval) + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + + self._encode_memory_in_output( + current_vision_feats, + feat_sizes, + 666., # point_inputs, + run_mem_encoder, + # we follow SAM2 predictor, if we have multiple masks output, we only utilise the first one to perform + # the memory rope attention. + high_res_masks, + object_score_logits, + current_out, + ) + return current_out + + def _track_step_wo_prompt( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + audio_res=None + ): + current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(current_vision_feats) > 1: + high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) + ] + else: + high_res_features = None + if mask_inputs is not None and self.use_mask_input_as_output_without_sam: # False + # When use_mask_input_as_output_without_sam=True, we directly output the mask input + # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. + pix_feat = current_vision_feats[-1].permute(1, 2, 0) + pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + sam_outputs = self._use_mask_as_output( + pix_feat, high_res_features, mask_inputs + ) + else: + # fused the visual feature with previous memory features in the memory bank + pix_feat = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats[-1:], + current_vision_pos_embeds=current_vision_pos_embeds[-1:], + feat_sizes=feat_sizes[-1:], + output_dict=output_dict, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + ) + # current_vision_feats[-1] = current_vision_feats[-1] + self.no_mem_embed + # pix_feat = current_vision_feats[-1].permute(1, 2, 0) + # pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + + # we do not apply any prompts except audio. + ''' + # apply SAM-style segmentation head + # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, + # e.g. in demo where such logits come from earlier interaction instead of correction sampling + # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) + # if prev_sam_mask_logits is not None: + # assert point_inputs is not None and mask_inputs is None + # mask_inputs = prev_sam_mask_logits + + ## comment this line, as we don't use points as prompts. + # multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + ''' + + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=True, + audio_res=audio_res + ) + + return current_out, sam_outputs, high_res_features, pix_feat + + def forward_tracking( + self, backbone_out, input: BatchedVideoDatapoint, return_dict=False + ): + """Forward video tracking on each frame (and sample correction clicks).""" + img_feats_already_computed = backbone_out["backbone_fpn"] is not None + if img_feats_already_computed: + # Prepare the backbone features + # - vision_feats and vision_pos_embeds are in (HW)BC format + ( + _, + vision_feats, + vision_pos_embeds, + feat_sizes, + ) = self._prepare_backbone_features(backbone_out) + + # Starting the stage loop + num_frames = backbone_out["num_frames"] + init_cond_frames = backbone_out["init_cond_frames"] + frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"] + # first process all the initial conditioning frames to encode them as memory, + # and then conditioning on them to track the remaining frames + processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"] + output_dict = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + for stage_id in processing_order: + # Get the image features for the current frames + # img_ids = input.find_inputs[stage_id].img_ids + img_ids = input.flat_obj_to_img_idx[stage_id] + if img_feats_already_computed: + # Retrieve image features according to img_ids (if they are already computed). + current_vision_feats = [x[:, img_ids] for x in vision_feats] + current_vision_pos_embeds = [x[:, img_ids] for x in vision_pos_embeds] + else: + # Otherwise, compute the image features on the fly for the given img_ids + # (this might be used for evaluation on long videos to avoid backbone OOM). + ( + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._prepare_backbone_features_per_frame( + input.flat_img_batch, img_ids + ) + + # Get output masks based on this frame's prompts and previous memory + current_out = self.track_step( + frame_idx=stage_id, + is_init_cond_frame=stage_id in init_cond_frames, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=backbone_out["point_inputs_per_frame"].get(stage_id, None), + mask_inputs=backbone_out["mask_inputs_per_frame"].get(stage_id, None), + gt_masks=backbone_out["gt_masks_per_frame"].get(stage_id, None), + frames_to_add_correction_pt=frames_to_add_correction_pt, + output_dict=output_dict, + num_frames=num_frames, + ) + # Append the output, depending on whether it's a conditioning frame + add_output_as_cond_frame = stage_id in init_cond_frames or ( + self.add_all_frames_to_correct_as_cond + and stage_id in frames_to_add_correction_pt + ) + if add_output_as_cond_frame: + output_dict["cond_frame_outputs"][stage_id] = current_out + else: + output_dict["non_cond_frame_outputs"][stage_id] = current_out + + if return_dict: + return output_dict + # turn `output_dict` into a list for loss function + all_frame_outputs = {} + all_frame_outputs.update(output_dict["cond_frame_outputs"]) + all_frame_outputs.update(output_dict["non_cond_frame_outputs"]) + all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)] + # Make DDP happy with activation checkpointing by removing unused keys + all_frame_outputs = [ + {k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs + ] + + return all_frame_outputs + + def track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + run_mem_encoder=True, # Whether to run the memory encoder on the predicted masks. + prev_sam_mask_logits=None, # The previously predicted SAM mask logits. + frames_to_add_correction_pt=None, + gt_masks=None, + ): + if frames_to_add_correction_pt is None: + frames_to_add_correction_pt = [] + current_out, sam_outputs, high_res_features, pix_feat = self._track_step( + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ) + + ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = sam_outputs + + current_out["multistep_pred_masks"] = low_res_masks + current_out["multistep_pred_masks_high_res"] = high_res_masks + current_out["multistep_pred_multimasks"] = [low_res_multimasks] + current_out["multistep_pred_multimasks_high_res"] = [high_res_multimasks] + current_out["multistep_pred_ious"] = [ious] + current_out["multistep_point_inputs"] = [point_inputs] + current_out["multistep_object_score_logits"] = [object_score_logits] + + # Optionally, sample correction points iteratively to correct the mask + if frame_idx in frames_to_add_correction_pt: + point_inputs, final_sam_outputs = self._iter_correct_pt_sampling( + is_init_cond_frame, + point_inputs, + gt_masks, + high_res_features, + pix_feat, + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + object_score_logits, + current_out, + ) + ( + _, + _, + _, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = final_sam_outputs + + # Use the final prediction (after all correction steps for output and eval) + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + self._encode_memory_in_output( + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ) + return current_out + + def _iter_correct_pt_sampling( + self, + is_init_cond_frame, + point_inputs, + gt_masks, + high_res_features, + pix_feat_with_mem, + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + object_score_logits, + current_out, + ): + + assert gt_masks is not None + all_pred_masks = [low_res_masks] + all_pred_high_res_masks = [high_res_masks] + all_pred_multimasks = [low_res_multimasks] + all_pred_high_res_multimasks = [high_res_multimasks] + all_pred_ious = [ious] + all_point_inputs = [point_inputs] + all_object_score_logits = [object_score_logits] + for _ in range(self.num_correction_pt_per_frame): + # sample a new point from the error between prediction and ground-truth + # (with a small probability, directly sample from GT masks instead of errors) + if self.training and self.prob_to_sample_from_gt_for_train > 0: + sample_from_gt = ( + self.rng.random() < self.prob_to_sample_from_gt_for_train + ) + else: + sample_from_gt = False + # if `pred_for_new_pt` is None, only GT masks will be used for point sampling + pred_for_new_pt = None if sample_from_gt else (high_res_masks > 0) + new_points, new_labels = get_next_point( + gt_masks=gt_masks, + pred_masks=pred_for_new_pt, + method="uniform" if self.training else self.pt_sampling_for_eval, + ) + point_inputs = concat_points(point_inputs, new_points, new_labels) + # Feed the mask logits of the previous SAM outputs in the next SAM decoder step. + # For tracking, this means that when the user adds a correction click, we also feed + # the tracking output mask logits along with the click as input to the SAM decoder. + mask_inputs = low_res_masks + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + if self.use_act_ckpt_iterative_pt_sampling and not multimask_output: + sam_outputs = torch.utils.checkpoint.checkpoint( + self._forward_sam_heads, + backbone_features=pix_feat_with_mem, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + use_reentrant=False, + ) + else: + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat_with_mem, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + _, + object_score_logits, + ) = sam_outputs + all_pred_masks.append(low_res_masks) + all_pred_high_res_masks.append(high_res_masks) + all_pred_multimasks.append(low_res_multimasks) + all_pred_high_res_multimasks.append(high_res_multimasks) + all_pred_ious.append(ious) + all_point_inputs.append(point_inputs) + all_object_score_logits.append(object_score_logits) + + # Concatenate the masks along channel (to compute losses on all of them, + # using `MultiStepIteractiveMasks`) + current_out["multistep_pred_masks"] = torch.cat(all_pred_masks, dim=1) + current_out["multistep_pred_masks_high_res"] = torch.cat( + all_pred_high_res_masks, dim=1 + ) + current_out["multistep_pred_multimasks"] = all_pred_multimasks + current_out["multistep_pred_multimasks_high_res"] = all_pred_high_res_multimasks + current_out["multistep_pred_ious"] = all_pred_ious + current_out["multistep_point_inputs"] = all_point_inputs + current_out["multistep_object_score_logits"] = all_object_score_logits + + return point_inputs, sam_outputs diff --git a/avs.code/v2.code/model/visual/sam2/utils/__init__.py b/avs.code/v2.code/model/visual/sam2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/avs.code/v2.code/model/visual/sam2/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/avs.code/v2.code/model/visual/sam2/utils/misc.py b/avs.code/v2.code/model/visual/sam2/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..b65ee825732ff85137805be650edd4cbe8e6f6d4 --- /dev/null +++ b/avs.code/v2.code/model/visual/sam2/utils/misc.py @@ -0,0 +1,349 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import warnings +from threading import Thread + +import numpy as np +import torch +from PIL import Image +from tqdm import tqdm + + +def get_sdpa_settings(): + if torch.cuda.is_available(): + old_gpu = torch.cuda.get_device_properties(0).major < 7 + # only use Flash Attention on Ampere (8.0) or newer GPUs + use_flash_attn = torch.cuda.get_device_properties(0).major >= 8 + if not use_flash_attn: + warnings.warn( + "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.", + category=UserWarning, + stacklevel=2, + ) + # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only + # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases) + pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2]) + if pytorch_version < (2, 2): + warnings.warn( + f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. " + "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).", + category=UserWarning, + stacklevel=2, + ) + math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn + else: + old_gpu = True + use_flash_attn = False + math_kernel_on = True + + return old_gpu, use_flash_attn, math_kernel_on + + +def get_connected_components(mask): + """ + Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W). + + Inputs: + - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is + background. + + Outputs: + - labels: A tensor of shape (N, 1, H, W) containing the connected component labels + for foreground pixels and 0 for background pixels. + - counts: A tensor of shape (N, 1, H, W) containing the area of the connected + components for foreground pixels and 0 for background pixels. + """ + from sam2 import _C + + return _C.get_connected_componnets(mask.to(torch.uint8).contiguous()) + + +def mask_to_box(masks: torch.Tensor): + """ + compute bounding box given an input mask + + Inputs: + - masks: [B, 1, H, W] masks, dtype=torch.Tensor + + Returns: + - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor + """ + B, _, h, w = masks.shape + device = masks.device + xs = torch.arange(w, device=device, dtype=torch.int32) + ys = torch.arange(h, device=device, dtype=torch.int32) + grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy") + grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w) + grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w) + min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1) + max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1) + min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1) + max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1) + bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1) + + return bbox_coords + + +def _load_img_as_tensor(img_path, image_size): + img_pil = Image.open(img_path) + img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) + if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images + img_np = img_np / 255.0 + else: + raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}") + img = torch.from_numpy(img_np).permute(2, 0, 1) + video_width, video_height = img_pil.size # the original video size + return img, video_height, video_width + + +class AsyncVideoFrameLoader: + """ + A list of video frames to be load asynchronously without blocking session start. + """ + + def __init__( + self, + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ): + self.img_paths = img_paths + self.image_size = image_size + self.offload_video_to_cpu = offload_video_to_cpu + self.img_mean = img_mean + self.img_std = img_std + # items in `self.images` will be loaded asynchronously + self.images = [None] * len(img_paths) + # catch and raise any exceptions in the async loading thread + self.exception = None + # video_height and video_width be filled when loading the first image + self.video_height = None + self.video_width = None + self.compute_device = compute_device + + # load the first frame to fill video_height and video_width and also + # to cache it (since it's most likely where the user will click) + self.__getitem__(0) + + # load the rest of frames asynchronously without blocking the session start + def _load_frames(): + try: + for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"): + self.__getitem__(n) + except Exception as e: + self.exception = e + + self.thread = Thread(target=_load_frames, daemon=True) + self.thread.start() + + def __getitem__(self, index): + if self.exception is not None: + raise RuntimeError("Failure in frame loading thread") from self.exception + + img = self.images[index] + if img is not None: + return img + + img, video_height, video_width = _load_img_as_tensor( + self.img_paths[index], self.image_size + ) + self.video_height = video_height + self.video_width = video_width + # normalize by mean and std + img -= self.img_mean + img /= self.img_std + if not self.offload_video_to_cpu: + img = img.to(self.compute_device, non_blocking=True) + self.images[index] = img + return img + + def __len__(self): + return len(self.images) + + +def load_video_frames( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + compute_device=torch.device("cuda"), +): + """ + Load the video frames from video_path. The frames are resized to image_size as in + the model and are loaded to GPU if offload_video_to_cpu=False. This is used by the demo. + """ + is_bytes = isinstance(video_path, bytes) + is_str = isinstance(video_path, str) + is_mp4_path = is_str and os.path.splitext(video_path)[-1] in [".mp4", ".MP4"] + if is_bytes or is_mp4_path: + return load_video_frames_from_video_file( + video_path=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + compute_device=compute_device, + ) + elif is_str and os.path.isdir(video_path): + return load_video_frames_from_jpg_images( + video_path=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + async_loading_frames=async_loading_frames, + compute_device=compute_device, + ) + else: + raise NotImplementedError( + "Only MP4 video and JPEG folder are supported at this moment" + ) + + +def load_video_frames_from_jpg_images( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + compute_device=torch.device("cuda"), +): + """ + Load the video frames from a directory of JPEG files (".jpg" format). + + The frames are resized to image_size x image_size and are loaded to GPU if + `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. + + You can load a frame asynchronously by setting `async_loading_frames` to `True`. + """ + if isinstance(video_path, str) and os.path.isdir(video_path): + jpg_folder = video_path + else: + raise NotImplementedError( + "Only JPEG frames are supported at this moment. For video files, you may use " + "ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n" + "```\n" + "ffmpeg -i .mp4 -q:v 2 -start_number 0 /'%05d.jpg'\n" + "```\n" + "where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks " + "ffmpeg to start the JPEG file from 00000.jpg." + ) + + frame_names = [ + p + for p in os.listdir(jpg_folder) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + ] + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + num_frames = len(frame_names) + if num_frames == 0: + raise RuntimeError(f"no images found in {jpg_folder}") + img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + + if async_loading_frames: + lazy_images = AsyncVideoFrameLoader( + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ) + return lazy_images, lazy_images.video_height, lazy_images.video_width + + images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) + for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): + images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) + if not offload_video_to_cpu: + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def load_video_frames_from_video_file( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + compute_device=torch.device("cuda"), +): + """Load the video frames from a video file.""" + import decord + + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + # Get the original video height and width + decord.bridge.set_bridge("torch") + video_height, video_width, _ = decord.VideoReader(video_path).next().shape + # Iterate over all frames in the video + images = [] + for frame in decord.VideoReader(video_path, width=image_size, height=image_size): + images.append(frame.permute(2, 0, 1)) + + images = torch.stack(images, dim=0).float() / 255.0 + if not offload_video_to_cpu: + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def fill_holes_in_mask_scores(mask, max_area): + """ + A post processor to fill small holes in mask scores with area under `max_area`. + """ + # Holes are those connected components in background with area <= self.max_area + # (background regions are those with mask scores <= 0) + assert max_area > 0, "max_area must be positive" + + input_mask = mask + try: + labels, areas = get_connected_components(mask <= 0) + is_hole = (labels > 0) & (areas <= max_area) + # We fill holes with a small positive mask score (0.1) to change them to foreground. + mask = torch.where(is_hole, 0.1, mask) + except Exception as e: + # Skip the post-processing step on removing small holes if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. You can " + "still use SAM 2 and it's OK to ignore the error above, although some post-processing " + "functionality may be limited (which doesn't affect the results in most cases; see " + "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + mask = input_mask + + return mask + + +def concat_points(old_point_inputs, new_points, new_labels): + """Add new points and labels to previous point inputs (add at the end).""" + if old_point_inputs is None: + points, labels = new_points, new_labels + else: + points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1) + labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1) + + return {"point_coords": points, "point_labels": labels} diff --git a/avs.code/v2.code/model/visual/sam2/utils/transforms.py b/avs.code/v2.code/model/visual/sam2/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6d4fa6a3e4d2e2a0dde7f87e4991daff338467c4 --- /dev/null +++ b/avs.code/v2.code/model/visual/sam2/utils/transforms.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import Normalize, Resize, ToTensor + + +class SAM2Transforms(nn.Module): + def __init__( + self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0 + ): + """ + Transforms for SAM2. + """ + super().__init__() + self.resolution = resolution + self.mask_threshold = mask_threshold + self.max_hole_area = max_hole_area + self.max_sprinkle_area = max_sprinkle_area + self.mean = [0.485, 0.456, 0.406] + self.std = [0.229, 0.224, 0.225] + self.to_tensor = ToTensor() + self.transforms = torch.jit.script( + nn.Sequential( + Resize((self.resolution, self.resolution)), + Normalize(self.mean, self.std), + ) + ) + + def __call__(self, x): + x = self.to_tensor(x) + return self.transforms(x) + + def forward_batch(self, img_list): + img_batch = [self.transforms(self.to_tensor(img)) for img in img_list] + img_batch = torch.stack(img_batch, dim=0) + return img_batch + + def transform_coords( + self, coords: torch.Tensor, normalize=False, orig_hw=None + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, + If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + + Returns + Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model. + """ + if normalize: + assert orig_hw is not None + h, w = orig_hw + coords = coords.clone() + coords[..., 0] = coords[..., 0] / w + coords[..., 1] = coords[..., 1] / h + + coords = coords * self.resolution # unnormalize coords + return coords + + def transform_boxes( + self, boxes: torch.Tensor, normalize=False, orig_hw=None + ) -> torch.Tensor: + """ + Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates, + if the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + """ + boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) + return boxes + + def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: + """ + Perform PostProcessing on output masks. + """ + from model.visual.sam2.utils.misc import get_connected_components + + masks = masks.float() + input_masks = masks + mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image + try: + if self.max_hole_area > 0: + # Holes are those connected components in background with area <= self.fill_hole_area + # (background regions are those with mask scores <= self.mask_threshold) + labels, areas = get_connected_components( + mask_flat <= self.mask_threshold + ) + is_hole = (labels > 0) & (areas <= self.max_hole_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with a small positive mask score (10.0) to change them to foreground. + masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) + + if self.max_sprinkle_area > 0: + labels, areas = get_connected_components( + mask_flat > self.mask_threshold + ) + is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with negative mask score (-10.0) to change them to background. + masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) + except Exception as e: + # Skip the post-processing step if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. You can " + "still use SAM 2 and it's OK to ignore the error above, although some post-processing " + "functionality may be limited (which doesn't affect the results in most cases; see " + "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + masks = input_masks + + masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) + return masks diff --git a/avs.code/v2.code/tools/build_avsbench_v2_merge_subset.py b/avs.code/v2.code/tools/build_avsbench_v2_merge_subset.py new file mode 100644 index 0000000000000000000000000000000000000000..6a172492ead739c736ea5acd4d8df7df68bfa928 --- /dev/null +++ b/avs.code/v2.code/tools/build_avsbench_v2_merge_subset.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +""" +Build merge-debug tree under AVSBench/v2/ (name = v2): + + AVSBench/v2/avss_index/metadata.csv + AVSBench/v2/v1s/<20 uids>/ + AVSBench/v2/v1m/<20 uids>/ + AVSBench/v2/v2/<20 uids>/ # v2-protocol clips from ~/Downloads/v2.zip + +Each modality: 16 train + 4 test rows (20 clips). Full AVSBench is used as source for v1s/v1m copies. +""" +from __future__ import annotations + +import csv +import os +import shutil +import zipfile + +import pandas as pd + +_WORKSPACE = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +AVSBENCH = os.path.join(_WORKSPACE, "AVSBench") +SUBSET_ROOT = os.path.join(AVSBENCH, "v2") # data_root for merge run +FULL_META = os.path.join(AVSBENCH, "avss_index", "metadata.csv") +ZIP_PATH = os.path.expanduser("~/Downloads/v2.zip") +PER_LABEL = 20 +TRAIN_N, TEST_N = 16, 4 + + +def _pick_rows_from_full_metadata(label: str) -> list[dict]: + df = pd.read_csv(FULL_META) + picked: list[dict] = [] + for split, need in ("train", TRAIN_N), ("test", TEST_N): + got = 0 + sub = df[(df["split"] == split) & (df["label"] == label)] + for _, row in sub.iterrows(): + uid = str(row["uid"]) + src = os.path.join(AVSBENCH, label, uid) + if not os.path.isdir(src): + continue + picked.append({k: row[k] for k in row.index}) + got += 1 + if got >= need: + break + if got < need: + raise SystemExit( + f"not enough existing {label}/{split} under {AVSBENCH}: need {need}, got {got}" + ) + return picked + + +def _list_v2_uids(z: zipfile.ZipFile) -> list[str]: + uids: set[str] = set() + for name in z.namelist(): + if not name.startswith("v2/") or name.endswith("/"): + continue + parts = name.split("/") + if len(parts) >= 3 and parts[1]: + uids.add(parts[1]) + return sorted(uids) + + +def _extract_v2_from_zip(uids: list[str]) -> None: + allowed = set(uids) + with zipfile.ZipFile(ZIP_PATH, "r") as z: + for info in z.infolist(): + n = info.filename + if not n.startswith("v2/"): + continue + parts = n.split("/") + if len(parts) < 3 or parts[1] not in allowed: + continue + if "/labels_semantic/" in n: + continue + if "/frames/" in n and n.endswith(".jpg"): + pass + elif "/labels_rgb/" in n and n.endswith(".png"): + pass + elif n.endswith("/audio.wav"): + pass + else: + continue + dest = os.path.join(SUBSET_ROOT, "v2", parts[1], *parts[2:]) + os.makedirs(os.path.dirname(dest), exist_ok=True) + with z.open(info, "r") as src, open(dest, "wb") as out: + shutil.copyfileobj(src, out) + + +def _copy_clip(label: str, uid: str) -> None: + src = os.path.join(AVSBENCH, label, uid) + dst = os.path.join(SUBSET_ROOT, label, uid) + if os.path.isdir(dst): + shutil.rmtree(dst) + shutil.copytree(src, dst) + + +def main() -> None: + if not os.path.isfile(FULL_META): + raise SystemExit(f"missing full metadata: {FULL_META}") + if not os.path.isfile(ZIP_PATH): + raise SystemExit(f"missing v2 zip: {ZIP_PATH}") + + shutil.rmtree(SUBSET_ROOT, ignore_errors=True) + os.makedirs(os.path.join(SUBSET_ROOT, "avss_index"), exist_ok=True) + for sub in ("v1s", "v1m", "v2"): + os.makedirs(os.path.join(SUBSET_ROOT, sub), exist_ok=True) + + all_rows: list[dict] = [] + + for label in ("v1s", "v1m"): + rows = _pick_rows_from_full_metadata(label) + assert len(rows) == PER_LABEL + for r in rows: + _copy_clip(label, str(r["uid"])) + all_rows.extend(rows) + + with zipfile.ZipFile(ZIP_PATH, "r") as z: + uids_all = _list_v2_uids(z) + if len(uids_all) < PER_LABEL: + raise SystemExit(f"v2.zip has only {len(uids_all)} clips") + v2_uids = uids_all[:PER_LABEL] + _extract_v2_from_zip(v2_uids) + + for i, uid in enumerate(v2_uids): + split = "train" if i < TRAIN_N else "test" + vid = uid.rsplit("_", 2)[0] if uid.count("_") >= 2 else uid + all_rows.append( + { + "vid": vid, + "uid": uid, + "s_min": 0, + "s_sec": 0, + "a_obj": "v2subset", + "split": split, + "label": "v2", + } + ) + + meta_out = os.path.join(SUBSET_ROOT, "avss_index", "metadata.csv") + with open(meta_out, "w", newline="") as f: + w = csv.DictWriter(f, fieldnames=["vid", "uid", "s_min", "s_sec", "a_obj", "split", "label"]) + w.writeheader() + for r in all_rows: + w.writerow({k: r.get(k, "") for k in w.fieldnames}) + + print("subset data_root:", SUBSET_ROOT) + print("metadata:", meta_out) + print("v2 uids:", v2_uids) + print("rows:", len(all_rows)) + + +if __name__ == "__main__": + main() diff --git a/avs.code/v2.code/tools/mini_debug_train.py b/avs.code/v2.code/tools/mini_debug_train.py new file mode 100644 index 0000000000000000000000000000000000000000..a73f1bdf645441d6c4fa4c8d906b59287016e0be --- /dev/null +++ b/avs.code/v2.code/tools/mini_debug_train.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +"""DDP smoke test: 1 epoch on AVSBench/v2 merge subset (20+20+20 clips). + +Build first:: + + cd /path/to/v2.code && python3 tools/build_avsbench_v2_merge_subset.py + +Then:: + + cd /path/to/v2.code && python3 tools/mini_debug_train.py +""" +from __future__ import annotations + +import os +import sys + +# Avoid MKL + libgomp conflict on some conda stacks before numpy/torch import. +os.environ.setdefault("MKL_THREADING_LAYER", "GNU") +import numpy # noqa: F401, E402 + +_REPO = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +os.chdir(_REPO) +sys.path.insert(0, _REPO) +_WORKSPACE = os.path.dirname(_REPO) +_MERGE_DATA = os.path.join(_WORKSPACE, "AVSBench", "v2") + + +def _patch_config() -> None: + import configs.config as cfg # noqa: E402 + + cfg.C.data_root_path = _MERGE_DATA + cfg.C.saved_dir = os.path.join("/tmp", "v2_mini_debug_ckpt") + os.makedirs(cfg.C.saved_dir, exist_ok=True) + cfg.C.epochs = 1 + cfg.C.batch_size = 1 + cfg.C.num_workers = 0 + cfg.C.wandb_online = False + cfg.C.gpus = 1 + + +if __name__ == "__main__": + if not os.path.isdir(_MERGE_DATA): + raise SystemExit( + f"missing {_MERGE_DATA} — run: python3 {_REPO}/tools/build_avsbench_v2_merge_subset.py" + ) + if not os.path.isfile(os.path.join(_MERGE_DATA, "avss_index", "metadata.csv")): + raise SystemExit(f"missing metadata.csv under {_MERGE_DATA}") + + _patch_config() + + import torch # noqa: E402 + from easydict import EasyDict # noqa: E402 + + from configs.config import C # noqa: E402 + + hyp = EasyDict(dict(C)) + hyp.gpus = 1 + hyp.batch_size = 1 + hyp.epochs = 1 + hyp.num_workers = 0 + hyp.wandb_online = False + hyp.data_root_path = _MERGE_DATA + hyp.saved_dir = os.path.join("/tmp", "v2_mini_debug_ckpt") + os.makedirs(hyp.saved_dir, exist_ok=True) + + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ.setdefault("MASTER_PORT", "9912") + + from main import main as train_main # noqa: E402 + + torch.multiprocessing.spawn(train_main, nprocs=hyp.gpus, args=(hyp.gpus, hyp)) diff --git a/avs.code/v2.code/tools/remap_aural_ckpt_keys.py b/avs.code/v2.code/tools/remap_aural_ckpt_keys.py new file mode 100644 index 0000000000000000000000000000000000000000..cbb8d6086a854b1b0ab011542eafd99f4bf8a3bf --- /dev/null +++ b/avs.code/v2.code/tools/remap_aural_ckpt_keys.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 +""" +Remap legacy checkpoint keys: rename audio_prompter.* to the current AuralFuser layout (aural_fuser.*), +and drop duplicate weights under training_layers / finetuning_layers. + +Usage: + python tools/remap_aural_ckpt_keys.py /path/to/model.pth [--in-place] [--no-backup] + +By default writes _remapped.pth; --in-place overwrites the input (after a .bak backup unless --no-backup). +""" +from __future__ import annotations + +import argparse +import shutil +from pathlib import Path + +import torch + +# Matches AuralFuser ModuleList names (old train_* indices start at 1; new indices are 0-based). +_REPLACEMENTS: list[tuple[str, str]] = [ + ("train_f_patch_embed1", "patch_embeds.0"), + ("train_f_patch_embed2", "patch_embeds.1"), + ("train_f_patch_embed3", "patch_embeds.2"), + ("train_f_a_block1", "fusion_modules.0"), + ("train_f_a_block2", "fusion_modules.1"), + ("train_f_a_block3", "fusion_modules.2"), + ("train_f_block1", "f_blocks.0"), + ("train_f_block2", "f_blocks.1"), + ("train_f_block3", "f_blocks.2"), + ("train_a_block1", "a_blocks.0"), + ("train_a_block2", "a_blocks.1"), + ("train_a_block3", "a_blocks.2"), + ("train_smooth1", "smooth_convs.0"), + ("train_smooth2", "smooth_convs.1"), +] + + +def remap_state_dict(sd: dict) -> dict: + out: dict = {} + dropped = 0 + for k, v in sd.items(): + if k.startswith("audio_prompter."): + if ".training_layers." in k or ".finetuning_layers." in k: + dropped += 1 + continue + nk = k.replace("audio_prompter.", "aural_fuser.", 1) + for old, new in _REPLACEMENTS: + nk = nk.replace(old, new) + out[nk] = v + else: + out[k] = v + if dropped: + print(f"Dropped duplicate keys: {dropped} (training_layers / finetuning_layers)") + return out + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("ckpt", type=Path, help="Input .pth (full-model state_dict)") + ap.add_argument( + "-o", "--output", type=Path, default=None, + help="Output path; default _remapped.pth", + ) + ap.add_argument("--in-place", action="store_true", help="Overwrite input file") + ap.add_argument("--no-backup", action="store_true", help="Skip .bak when using --in-place") + args = ap.parse_args() + + ckpt_path: Path = args.ckpt.resolve() + if not ckpt_path.is_file(): + raise SystemExit(f"File not found: {ckpt_path}") + + print(f"Loading: {ckpt_path}") + sd = torch.load(ckpt_path, map_location="cpu") + if not isinstance(sd, dict): + raise SystemExit("Expected top-level checkpoint to be a state_dict dict") + + n_old_ap = sum(1 for k in sd if k.startswith("audio_prompter.")) + if n_old_ap == 0: + print("Warning: no audio_prompter.* keys found; checkpoint may already be remapped.") + + new_sd = remap_state_dict(sd) + n_af = sum(1 for k in new_sd if k.startswith("aural_fuser.")) + print(f"aural_fuser key count: {n_af}") + + if args.in_place: + out = ckpt_path + if not args.no_backup: + bak = ckpt_path.with_suffix(ckpt_path.suffix + ".bak") + print(f"Backup -> {bak}") + shutil.copy2(ckpt_path, bak) + else: + out = args.output or ckpt_path.with_name(ckpt_path.stem + "_remapped.pth") + + torch.save(new_sd, out) + print(f"Saved: {out} ({len(new_sd)} tensor keys)") + + +if __name__ == "__main__": + main() diff --git a/avs.code/v2.code/trainer/train.py b/avs.code/v2.code/trainer/train.py new file mode 100644 index 0000000000000000000000000000000000000000..71d6d16caa9350c7f114189ef1b68c8f793039d9 --- /dev/null +++ b/avs.code/v2.code/trainer/train.py @@ -0,0 +1,166 @@ +"""Training and validation loop for the AV segmentation model.""" +import numpy +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + + +class Trainer: + """Wraps train/valid steps with optional loss, metrics, and logging.""" + + def __init__(self, hyp_param, loss, tensorboard, metrics): + self.param = hyp_param + self.loss = loss + self.tensorboard = tensorboard + self.metrics = metrics + from loss.training.contrastive_learning import ContrastLoss + self.cl = ContrastLoss(self.param) + + @torch.no_grad() + def valid(self, epoch, dataloader, model, process=''): + """Evaluate foreground IoU / F-score. `process` selects SAM multimask decoding (see branch below).""" + if not isinstance(dataloader, DataLoader): + raise TypeError( + "valid() expects a torch.utils.data.DataLoader (do not pass iter(dataloader) first)." + ) + self.metrics['foreground_iou'].reset() + self.metrics['foreground_f-score'].reset() + dataloader_length = len(dataloader) + tbar = range(dataloader_length) + tbar = tqdm(tbar, ncols=135) if self.param.local_rank <= 0 else tbar + iou_pool = [None] * self.param.gpus + fscore_pool = [None] * self.param.gpus + + data_iter = iter(dataloader) + for batch_index in tbar: + items = next(data_iter) + frame, spect, label, prompt_dicts = items['frame'], items['spectrogram'], items['label'], items['prompts'] + + frame = torch.flatten(frame, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) + spect = torch.flatten(spect, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) + label = torch.flatten(label, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) + + with torch.autocast("cuda", dtype=torch.bfloat16): + outputs, _ = model.module(frame, spect, prompt_dicts, sam_process=True) + logits = torch.cat([torch.cat(i['multistep_pred_multimasks_high_res']) for i in outputs]) + ious_scores = torch.cat([torch.cat(i['multistep_pred_ious']) for i in outputs]) + occ_scores = torch.cat([torch.cat(i['multistep_object_score_logits']) for i in outputs]) + # process: '' = first multimask; iou_select = argmax IoU head; iou_occ_select = + objectness gate + if process == 'iou_select': + ious_scores = torch.argmax(ious_scores, dim=1) + logits = logits[torch.arange(0, frame.shape[0]), ious_scores, ...] + elif process == 'iou_occ_select': + ious_scores = torch.argmax(ious_scores, dim=1) + logits = logits[torch.arange(0, frame.shape[0]), ious_scores, ...] + logits[occ_scores.squeeze() < 0, ...] = 0. + else: + logits = logits[:, 0, ...] + + masks = logits > 0. + foreground_iou_rank = self.metrics['foreground_iou'].calculate_iou(masks.squeeze().long(), + label.squeeze().long(), + get_entire_list=True) + + foreground_f_score_rank = self.metrics['foreground_f-score'].calculate_f_score(logits.squeeze(), + label.squeeze(), + get_entire_list=True) + torch.distributed.all_gather_object(iou_pool, foreground_iou_rank) + torch.distributed.all_gather_object(fscore_pool, foreground_f_score_rank) + foreground_iou = sum([i['foreground_iou'][0].cpu() for i in iou_pool]) / sum( + [i['foreground_iou'][1] for i in iou_pool]) + foreground_f_score = sum([i['foreground_f-score'][0] for i in fscore_pool]) / sum( + [i['foreground_f-score'][1] for i in fscore_pool]) + + if self.param.local_rank <= 0: + tbar.set_description('epoch {} | valid.f_iou {}, valid.f_f-score {}'.format(epoch, + numpy.round( + foreground_iou.cpu().numpy(), + 5), + numpy.round( + foreground_f_score, + 5))) + torch.cuda.empty_cache() + + final_iou = foreground_iou + final_fscore = foreground_f_score + if self.param.local_rank <= 0 and self.tensorboard is not None: + self.tensorboard.upload_wandb_info({"valid.f_iou/{}".format(process): final_iou, + "valid.f_f-score/{}".format(process): final_fscore}) + + def _to_float(x): + if isinstance(x, torch.Tensor): + return float(x.detach().cpu().item()) + return float(x) + + return numpy.round(_to_float(final_iou), 5), numpy.round(_to_float(final_fscore), 5) + + def train(self, epoch, dataloader, model, optimiser): + """One epoch: SAM frozen, AuralFuser + heads trained with composite loss + contrastive term.""" + if not isinstance(dataloader, DataLoader): + raise TypeError( + "train() expects a torch.utils.data.DataLoader (do not pass iter(dataloader) first)." + ) + self.metrics['foreground_iou'].reset() + self.metrics['foreground_f-score'].reset() + + dataloader_length = len(dataloader) + tbar = range(dataloader_length) + tbar = tqdm(tbar, ncols=135) if self.param.local_rank <= 0 else tbar + + data_iter = iter(dataloader) + for batch_index in tbar: + current_index = dataloader_length * epoch + batch_index + items = next(data_iter) + + frame, spect, label, prompt_dicts = items['frame'], items['spectrogram'], items['label'], items['prompts'] + frame = torch.flatten(frame, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) + spect = torch.flatten(spect, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) + label = torch.flatten(label, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) + with torch.autocast("cuda", dtype=torch.bfloat16): + outputs, proj_feats = model(frame, spect, prompt_dicts, sam_process=False) + + # v1s: only first frame is supervised (artifacts). Any sample in the batch may be v1s (shuffle order). + _ids = items['id'] + _id_list = _ids if isinstance(_ids, (list, tuple)) else [_ids] + if any("/v1s/" in str(x) for x in _id_list): + outputs = outputs[0:1] + label = label[0:1, ...] + vision_feats, audio_feats = proj_feats + proj_feats = ([t[0:1] for t in vision_feats], [t[0:1] for t in audio_feats]) + + loss_dict = self.loss(outputs, label.unsqueeze(1)) + cl_loss = self.cl(proj_feats, outputs, label) + + optimiser.zero_grad() + (loss_dict['core_loss'] + cl_loss).backward() + optimiser.step() + + current_lr = self.param.lr * (1 - current_index / (dataloader_length * self.param.epochs)) ** 0.9 + for params_lr in optimiser.param_groups: + names = params_lr.get("name", []) + if names and any("vgg" in n for n in names): + params_lr['lr'] = current_lr * 0.1 + else: + params_lr['lr'] = current_lr + + if self.param.local_rank <= 0: + logits = torch.cat([i['multistep_pred_multimasks_high_res'][0] for i in outputs]) + foreground_iou = self.metrics['foreground_iou'].calculate_iou((logits > 0)[:, 0, ...].long(), + label.long()) + + self.tensorboard.upload_wandb_info({"loss": loss_dict['core_loss'].item(), "f_iou": foreground_iou.item(), + "lr": optimiser.param_groups[0]['lr'], + "loss_dice": loss_dict['loss_dice'], + "loss_focal": loss_dict['loss_mask'], + "loss_contras": cl_loss.item()}) + tbar.set_description('epoch {} | loss {}, f_iou {}'.format(epoch, loss_dict['core_loss'].item(), + foreground_iou.item())) + ''' + if batch_index % 200 == 0: + pred_mask = (logits > 0)[:, 0, ...].long() + n_vis = min(4, frame.shape[0], pred_mask.shape[0], label.shape[0]) + self.tensorboard.upload_wandb_image( + frame[:n_vis], pred_mask[:n_vis], label[:n_vis].long() + ) + ''' + return diff --git a/avs.code/v2.code/utils/data_utils.py b/avs.code/v2.code/utils/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7e7a98f8ec73e6e5dafd1e395b48a98575e5afb1 --- /dev/null +++ b/avs.code/v2.code/utils/data_utils.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch + +from PIL import Image as PILImage + + +class BatchedVideoMetaData: + """ + This class represents metadata about a batch of videos. + Attributes: + unique_objects_identifier: A tensor of shape Bx3 containing unique identifiers for each object in the batch. Index consists of (video_id, obj_id, frame_id) + frame_orig_size: A tensor of shape Bx2 containing the original size of each frame in the batch. + """ + + unique_objects_identifier: torch.LongTensor + frame_orig_size: torch.LongTensor + + +class BatchedVideoDatapoint: + """ + This class represents a batch of videos with associated annotations and metadata. + Attributes: + img_batch: A [TxBxCxHxW] tensor containing the image data for each frame in the batch, where T is the number of frames per video, and B is the number of videos in the batch. + obj_to_frame_idx: A [TxOx2] tensor containing the image_batch index which the object belongs to. O is the number of objects in the batch. + masks: A [TxOxHxW] tensor containing binary masks for each object in the batch. + metadata: An instance of BatchedVideoMetaData containing metadata about the batch. + dict_key: A string key used to identify the batch. + """ + + img_batch: torch.FloatTensor + obj_to_frame_idx: torch.IntTensor + masks: torch.BoolTensor + metadata: BatchedVideoMetaData + + dict_key: str + + def pin_memory(self, device=None): + return self.apply(torch.Tensor.pin_memory, device=device) + + @property + def num_frames(self) -> int: + """ + Returns the number of frames per video. + """ + return self.batch_size[0] + + @property + def num_videos(self) -> int: + """ + Returns the number of videos in the batch. + """ + return self.img_batch.shape[1] + + @property + def flat_obj_to_img_idx(self) -> torch.IntTensor: + """ + Returns a flattened tensor containing the object to img index. + The flat index can be used to access a flattened img_batch of shape [(T*B)xCxHxW] + """ + frame_idx, video_idx = self.obj_to_frame_idx.unbind(dim=-1) + flat_idx = video_idx * self.num_frames + frame_idx + return flat_idx + + @property + def flat_img_batch(self) -> torch.FloatTensor: + """ + Returns a flattened img_batch_tensor of shape [(B*T)xCxHxW] + """ + + return self.img_batch.transpose(0, 1).flatten(0, 1) + + +@dataclass +class Object: + # Id of the object in the media + object_id: int + # Index of the frame in the media (0 if single image) + frame_index: int + segment: Union[torch.Tensor, dict] # RLE dict or binary mask + + +@dataclass +class Frame: + data: Union[torch.Tensor, PILImage.Image] + objects: List[Object] + + +@dataclass +class VideoDatapoint: + """Refers to an image/video and all its annotations""" + + frames: List[Frame] + video_id: int + size: Tuple[int, int] + + +def collate_fn( + batch: List[VideoDatapoint], + dict_key, +) -> BatchedVideoDatapoint: + """ + Args: + batch: A list of VideoDatapoint instances. + dict_key (str): A string key used to identify the batch. + """ + img_batch = [] + for video in batch: + img_batch += [torch.stack([frame.data for frame in video.frames], dim=0)] + + img_batch = torch.stack(img_batch, dim=0).permute((1, 0, 2, 3, 4)) + T = img_batch.shape[0] + # Prepare data structures for sequential processing. Per-frame processing but batched across videos. + step_t_objects_identifier = [[] for _ in range(T)] + step_t_frame_orig_size = [[] for _ in range(T)] + + step_t_masks = [[] for _ in range(T)] + step_t_obj_to_frame_idx = [ + [] for _ in range(T) + ] # List to store frame indices for each time step + + for video_idx, video in enumerate(batch): + orig_video_id = video.video_id + orig_frame_size = video.size + for t, frame in enumerate(video.frames): + objects = frame.objects + for obj in objects: + orig_obj_id = obj.object_id + orig_frame_idx = obj.frame_index + step_t_obj_to_frame_idx[t].append( + torch.tensor([t, video_idx], dtype=torch.int) + ) + step_t_masks[t].append(obj.segment.to(torch.bool)) + step_t_objects_identifier[t].append( + torch.tensor([orig_video_id, orig_obj_id, orig_frame_idx]) + ) + step_t_frame_orig_size[t].append(torch.tensor(orig_frame_size)) + + obj_to_frame_idx = torch.stack( + [ + torch.stack(obj_to_frame_idx, dim=0) + for obj_to_frame_idx in step_t_obj_to_frame_idx + ], + dim=0, + ) + masks = torch.stack([torch.stack(masks, dim=0) for masks in step_t_masks], dim=0) + objects_identifier = torch.stack( + [torch.stack(id, dim=0) for id in step_t_objects_identifier], dim=0 + ) + frame_orig_size = torch.stack( + [torch.stack(id, dim=0) for id in step_t_frame_orig_size], dim=0 + ) + return BatchedVideoDatapoint( + img_batch=img_batch, + obj_to_frame_idx=obj_to_frame_idx, + masks=masks, + metadata=BatchedVideoMetaData( + unique_objects_identifier=objects_identifier, + frame_orig_size=frame_orig_size, + ), + dict_key=dict_key, + batch_size=[T], + ) diff --git a/avs.code/v2.code/utils/foreground_fscore.py b/avs.code/v2.code/utils/foreground_fscore.py new file mode 100644 index 0000000000000000000000000000000000000000..ea20b84d2304ca0bd9981fd1a3c254111e3d0ac4 --- /dev/null +++ b/avs.code/v2.code/utils/foreground_fscore.py @@ -0,0 +1,90 @@ +import numpy +import torch + + +class AverageMeter: + def __init__(self, *keys): + self.__data = dict() + for k in keys: + self.__data[k] = [0.0, 0] + + def add(self, dict): + for k, v in dict.items(): + self.__data[k][0] += v + self.__data[k][1] += 1 + + def get(self, *keys): + if len(keys) == 1: + return self.__data[keys[0]][0] / self.__data[keys[0]][1] + else: + v_list = [self.__data[k][0] / self.__data[k][1] for k in keys] + return tuple(v_list) + + def get_entire_dict_for_ddp_calculation(self): + return self.__data + + def pop(self, key=None): + if key is None: + for k in self.__data.keys(): + self.__data[k] = [0.0, 0] + else: + v = self.get(key) + self.__data[key] = [0.0, 0] + return v + + +class ForegroundFScore(AverageMeter): + def __init__(self, rank): + self.local_rank = rank + super(ForegroundFScore, self).__init__('foreground_f-score') + + def _eval_pr(self, y_pred, y, num, cuda_flag=True): + if cuda_flag: + prec, recall = torch.zeros(num).cuda(self.local_rank), torch.zeros(num).cuda(self.local_rank) + thlist = torch.linspace(0, 1 - 1e-10, num).cuda(self.local_rank) + else: + prec, recall = torch.zeros(num), torch.zeros(num) + thlist = torch.linspace(0, 1 - 1e-10, num) + for i in range(num): + y_temp = (y_pred >= thlist[i]).float() + tp = (y_temp * y).sum() + prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() + 1e-20) + return prec, recall + + def calculate_f_score(self, pred, gt, pr_num=255, get_entire_list=False): + + r""" + param: + pred: size [N x H x W] + gt: size [N x H x W] + output: + iou: size [1] (size_average=True) or [N] (size_average=False) + """ + # print('=> eval [FMeasure]..') + pred = torch.sigmoid(pred) # =======================================[important] + N = pred.size(0) + beta2 = 0.3 + avg_f, img_num = 0.0, 0 + score = torch.zeros(pr_num) + # fLog = open(os.path.join(measure_path, 'FMeasure.txt'), 'w') + # print("{} videos in this batch".format(N)) + + for img_id in range(N): + # examples with totally black GTs are out of consideration + if torch.mean(gt[img_id].float()) == 0.0: + continue + prec, recall = self._eval_pr(pred[img_id], gt[img_id], pr_num) + f_score = (1 + beta2) * prec * recall / (beta2 * prec + recall) + f_score[f_score != f_score] = 0 # for Nan + avg_f += f_score + img_num += 1 + score = avg_f / img_num + # print('score: ', score) + # fLog.close() + self.add({'foreground_f-score': score.max().item()}) + return self.get('foreground_f-score') if not get_entire_list else self.get_entire_dict_for_ddp_calculation() + + def reset(self,): + super(ForegroundFScore, self).__init__('foreground_f-score') + + diff --git a/avs.code/v2.code/utils/foreground_iou.py b/avs.code/v2.code/utils/foreground_iou.py new file mode 100644 index 0000000000000000000000000000000000000000..e01eeb081eee8ebfa1fcb6618d05b9d57c02f817 --- /dev/null +++ b/avs.code/v2.code/utils/foreground_iou.py @@ -0,0 +1,69 @@ +import numpy +import torch + + +class AverageMeter: + def __init__(self, *keys): + self.__data = dict() + for k in keys: + self.__data[k] = [0.0, 0] + + def add(self, dict): + for k, v in dict.items(): + self.__data[k][0] += v + self.__data[k][1] += 1 + + def get(self, *keys): + if len(keys) == 1: + return self.__data[keys[0]][0] / self.__data[keys[0]][1] + else: + v_list = [self.__data[k][0] / self.__data[k][1] for k in keys] + return tuple(v_list) + + def get_entire_dict_for_ddp_calculation(self): + return self.__data + + def pop(self, key=None): + if key is None: + for k in self.__data.keys(): + self.__data[k] = [0.0, 0] + else: + v = self.get(key) + self.__data[key] = [0.0, 0] + return v + + +class ForegroundIoU(AverageMeter): + def __init__(self): + super(ForegroundIoU, self).__init__('foreground_iou') + + def calculate_iou(self, pred, target, eps=1e-7, get_entire_list=False): + r""" + param (both hard mask): + pred: size [N x H x W], type: int + target: size [N x H x W], type: int + output: + iou: size [1] (size_average=True) or [N] (size_average=False) + """ + assert len(pred.shape) == 3 and pred.shape == target.shape, 'shape mismatch.' + assert pred.dtype is torch.long and target.dtype is torch.long, 'type mismatch.' + + N = pred.size(0) + num_pixels = pred.size(-1) * pred.size(-2) + no_obj_flag = (target.sum(2).sum(1) == 0) + + inter = (pred * target).sum(2).sum(1) + union = torch.max(pred, target).sum(2).sum(1) + + inter_no_obj = ((1 - target) * (1 - pred)).sum(2).sum(1) + inter[no_obj_flag] = inter_no_obj[no_obj_flag] + union[no_obj_flag] = num_pixels + + iou = torch.sum(inter / (union+eps)) / N + + self.add({'foreground_iou': iou}) + return self.get('foreground_iou') if not get_entire_list else self.get_entire_dict_for_ddp_calculation() + + def reset(self,): + super(ForegroundIoU, self).__init__('foreground_iou') + diff --git a/avs.code/v2.code/utils/iou.py b/avs.code/v2.code/utils/iou.py new file mode 100644 index 0000000000000000000000000000000000000000..211488b780887a8efd84361bafc6b09bfad4c345 --- /dev/null +++ b/avs.code/v2.code/utils/iou.py @@ -0,0 +1,76 @@ +import torch +import numpy + + +class BinaryMIoU(object): + def __init__(self, ignore_index): + self.num_classes = 2 + self.ignore_index = ignore_index + self.inter, self.union = 0, 0 + self.correct, self.label = 0, 0 + self.iou = numpy.array([0 for _ in range(self.num_classes)]) + self.acc = 0.0 + + def get_metric_results(self, curr_correct_, curr_label_, curr_inter_, curr_union_): + # calculates the overall miou and acc + self.correct = self.correct + curr_correct_ + self.label = self.label + curr_label_ + self.inter = self.inter + curr_inter_ + self.union = self.union + curr_union_ + self.acc = 1.0 * self.correct / (numpy.spacing(1) + self.label) + self.iou = 1.0 * self.inter / (numpy.spacing(1) + self.union) + return numpy.round(self.iou, 4), numpy.round(self.acc, 4) + # if class_list is None: + # return numpy.round(self.iou.mean().item(), 4), \ + # numpy.round(self.acc, 4) + # else: + # return numpy.round(self.iou[class_list].mean().item(), 4), \ + # numpy.round(self.acc, 4) + + @staticmethod + def get_current_image_results(curr_correct_, curr_label_, curr_inter_, curr_union_): + curr_acc = 1.0 * curr_correct_ / (numpy.spacing(1) + curr_label_) + curr_iou = 1.0 * curr_inter_ / (numpy.spacing(1) + curr_union_) + return curr_iou, curr_acc + + def __call__(self, x, y): + curr_correct, curr_label, curr_inter, curr_union = self.calculate_current_sample(x, y) + return (self.get_metric_results(curr_correct, curr_label, curr_inter, curr_union), + self.get_current_image_results(curr_correct, curr_label, curr_inter, curr_union)) + + def calculate_current_sample(self, output, target): + # output => BxCxHxW (logits) + # target => Bx1xHxW + target[target == self.ignore_index] = -1 + correct, labeled = self.batch_pix_accuracy(output.data, target) + inter, union = self.batch_intersection_union(output.data, target, self.num_classes) + return [numpy.round(correct, 5), numpy.round(labeled, 5), numpy.round(inter, 5), numpy.round(union, 5)] + + @ staticmethod + def batch_pix_accuracy(predict, target): + # _, predict = torch.max(output, 1) + + predict = predict.int() + 1 + target = target.int() + 1 + + pixel_labeled = (target > 0).sum() + pixel_correct = ((predict == target) * (target > 0)).sum() + assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled" + return pixel_correct.cpu().numpy(), pixel_labeled.cpu().numpy() + + @ staticmethod + def batch_intersection_union(predict, target, num_class): + # _, predict = torch.max(output, 1) + predict = predict + 1 + target = target + 1 + + predict = predict * (target > 0).long() + intersection = predict * (predict == target).long() + + area_inter = torch.histc(intersection.float(), bins=num_class, max=num_class, min=1) + area_pred = torch.histc(predict.float(), bins=num_class, max=num_class, min=1) + area_lab = torch.histc(target.float(), bins=num_class, max=num_class, min=1) + area_union = area_pred + area_lab - area_inter + assert (area_inter <= area_union).all(), "Intersection area should be smaller than Union area" + return area_inter.cpu().numpy(), area_union.cpu().numpy() + diff --git a/avs.code/v2.code/utils/misc.py b/avs.code/v2.code/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..5eb9d66c31a4b9209b81a5b615386d29f246135c --- /dev/null +++ b/avs.code/v2.code/utils/misc.py @@ -0,0 +1,350 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import warnings +from threading import Thread + +import numpy as np +import torch +from PIL import Image +from tqdm import tqdm + + +def get_sdpa_settings(): + if torch.cuda.is_available(): + old_gpu = torch.cuda.get_device_properties(0).major < 7 + # only use Flash Attention on Ampere (8.0) or newer GPUs + use_flash_attn = torch.cuda.get_device_properties(0).major >= 8 + if not use_flash_attn: + warnings.warn( + "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.", + category=UserWarning, + stacklevel=2, + ) + # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only + # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases) + pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2]) + if pytorch_version < (2, 2): + warnings.warn( + f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. " + "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).", + category=UserWarning, + stacklevel=2, + ) + math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn + else: + old_gpu = True + use_flash_attn = False + math_kernel_on = True + + return old_gpu, use_flash_attn, math_kernel_on + + +def get_connected_components(mask): + """ + Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W). + + Inputs: + - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is + background. + + Outputs: + - labels: A tensor of shape (N, 1, H, W) containing the connected component labels + for foreground pixels and 0 for background pixels. + - counts: A tensor of shape (N, 1, H, W) containing the area of the connected + components for foreground pixels and 0 for background pixels. + """ + from sam2 import _C + + return _C.get_connected_componnets(mask.to(torch.uint8).contiguous()) + + +def mask_to_box(masks: torch.Tensor): + """ + compute bounding box given an input mask + + Inputs: + - masks: [B, 1, H, W] masks, dtype=torch.Tensor + + Returns: + - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor + """ + B, _, h, w = masks.shape + device = masks.device + xs = torch.arange(w, device=device, dtype=torch.int32) + ys = torch.arange(h, device=device, dtype=torch.int32) + grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy") + grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w) + grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w) + min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1) + max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1) + min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1) + max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1) + bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1) + + return bbox_coords + + +def _load_img_as_tensor(img_path, image_size): + img_pil = Image.open(img_path) + img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) + if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images + img_np = img_np / 255.0 + else: + raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}") + img = torch.from_numpy(img_np).permute(2, 0, 1) + video_width, video_height = img_pil.size # the original video size + return img, video_height, video_width + + +class AsyncVideoFrameLoader: + """ + A list of video frames to be load asynchronously without blocking session start. + """ + + def __init__( + self, + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ): + self.img_paths = img_paths + self.image_size = image_size + self.offload_video_to_cpu = offload_video_to_cpu + self.img_mean = img_mean + self.img_std = img_std + # items in `self.images` will be loaded asynchronously + self.images = [None] * len(img_paths) + # catch and raise any exceptions in the async loading thread + self.exception = None + # video_height and video_width be filled when loading the first image + self.video_height = None + self.video_width = None + self.compute_device = compute_device + + # load the first frame to fill video_height and video_width and also + # to cache it (since it's most likely where the user will click) + self.__getitem__(0) + + # load the rest of frames asynchronously without blocking the session start + def _load_frames(): + try: + for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"): + self.__getitem__(n) + except Exception as e: + self.exception = e + + self.thread = Thread(target=_load_frames, daemon=True) + self.thread.start() + + def __getitem__(self, index): + if self.exception is not None: + raise RuntimeError("Failure in frame loading thread") from self.exception + + img = self.images[index] + if img is not None: + return img + + img, video_height, video_width = _load_img_as_tensor( + self.img_paths[index], self.image_size + ) + self.video_height = video_height + self.video_width = video_width + # normalize by mean and std + img -= self.img_mean + img /= self.img_std + if not self.offload_video_to_cpu: + img = img.to(self.compute_device, non_blocking=True) + self.images[index] = img + return img + + def __len__(self): + return len(self.images) + + +def load_video_frames( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + compute_device=torch.device("cuda"), +): + """ + Load the video frames from video_path. The frames are resized to image_size as in + the model and are loaded to GPU if offload_video_to_cpu=False. This is used by the demo. + """ + is_bytes = isinstance(video_path, bytes) + is_str = isinstance(video_path, str) + is_mp4_path = is_str and os.path.splitext(video_path)[-1] in [".mp4", ".MP4"] + if is_bytes or is_mp4_path: + return load_video_frames_from_video_file( + video_path=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + compute_device=compute_device, + ) + elif is_str and os.path.isdir(video_path): + return load_video_frames_from_jpg_images( + video_path=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + async_loading_frames=async_loading_frames, + compute_device=compute_device, + ) + else: + raise NotImplementedError( + "Only MP4 video and JPEG folder are supported at this moment" + ) + + +def load_video_frames_from_jpg_images( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + compute_device=torch.device("cuda"), +): + """ + Load the video frames from a directory of JPEG files (".jpg" format). + + The frames are resized to image_size x image_size and are loaded to GPU if + `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. + + You can load a frame asynchronously by setting `async_loading_frames` to `True`. + """ + if isinstance(video_path, str) and os.path.isdir(video_path): + jpg_folder = video_path + else: + raise NotImplementedError( + "Only JPEG frames are supported at this moment. For video files, you may use " + "ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n" + "```\n" + "ffmpeg -i .mp4 -q:v 2 -start_number 0 /'%05d.jpg'\n" + "```\n" + "where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks " + "ffmpeg to start the JPEG file from 00000.jpg." + ) + + frame_names = [ + p + for p in os.listdir(jpg_folder) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + ] + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + num_frames = len(frame_names) + if num_frames == 0: + raise RuntimeError(f"no images found in {jpg_folder}") + img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + + if async_loading_frames: + lazy_images = AsyncVideoFrameLoader( + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ) + return lazy_images, lazy_images.video_height, lazy_images.video_width + + images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) + for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): + images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) + if not offload_video_to_cpu: + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def load_video_frames_from_video_file( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + compute_device=torch.device("cuda"), +): + """Load the video frames from a video file.""" + import decord + + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + # Get the original video height and width + decord.bridge.set_bridge("torch") + video_height, video_width, _ = decord.VideoReader(video_path).next().shape + # Iterate over all frames in the video + images = [] + for frame in decord.VideoReader(video_path, width=image_size, height=image_size): + images.append(frame.permute(2, 0, 1)) + + images = torch.stack(images, dim=0).float() / 255.0 + if not offload_video_to_cpu: + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def fill_holes_in_mask_scores(mask, max_area): + """ + A post processor to fill small holes in mask scores with area under `max_area`. + """ + # Holes are those connected components in background with area <= self.max_area + # (background regions are those with mask scores <= 0) + assert max_area > 0, "max_area must be positive" + + input_mask = mask + try: + labels, areas = get_connected_components(mask <= 0) + is_hole = (labels > 0) & (areas <= max_area) + # We fill holes with a small positive mask score (0.1) to change them to foreground. + mask = torch.where(is_hole, 0.1, mask) + except Exception as e: + # Skip the post-processing step on removing small holes if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. You can " + "still use SAM 2 and it's OK to ignore the error above, although some post-processing " + "functionality may be limited (which doesn't affect the results in most cases; see " + "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + mask = input_mask + + return mask + + +def concat_points(old_point_inputs, new_points, new_labels): + """Add new points and labels to previous point inputs (add at the end).""" + if old_point_inputs is None: + points, labels = new_points, new_labels + else: + points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1) + labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1) + + return {"point_coords": points, "point_labels": labels} + diff --git a/avs.code/v2.code/utils/tensorboard.py b/avs.code/v2.code/utils/tensorboard.py new file mode 100644 index 0000000000000000000000000000000000000000..75fdceabf8f9bcefb3e3c00aa1136097197afdf5 --- /dev/null +++ b/avs.code/v2.code/utils/tensorboard.py @@ -0,0 +1,140 @@ +import os + +import PIL +import matplotlib.pyplot as plt +import numpy +import torch +import torchvision +try: + import wandb +except ImportError: # pragma: no cover + wandb = None + +# from utils.visualize import show_img + + +color_map = {"background": (0, 0, 0), "longitudinal": (128, 0, 0), "pothole": (0, 128, 0), + "alligator": (128, 128, 0), "transverse": (128, 0, 128), "ignore": (255, 255, 255)} + + +class _DummyWandb: + def log(self, *args, **kwargs): + return None + + +class Tensorboard: + def __init__(self, config): + self._log_images = bool(config.get('wandb_online', False)) + if not self._log_images or wandb is None or not hasattr(wandb, "init"): + self.tensor_board = _DummyWandb() + self._log_images = False + elif config.get('wandb_online', False): + key = config.get('wandb_key') or os.environ.get('WANDB_API_KEY', '') + if key: + os.environ['WANDB_API_KEY'] = key + wandb.login(key=key, relogin=False) + self.tensor_board = wandb.init(project=config['proj_name'], name=config['experiment_name'], + config=config, settings=wandb.Settings(code_dir=".")) + + self.restore_transform = torchvision.transforms.Compose([ + DeNormalize(config['image_mean'], config['image_std']), + torchvision.transforms.ToPILImage()]) + + def upload_wandb_info(self, info_dict): + for i, info in enumerate(info_dict): + self.tensor_board.log({info: info_dict[info]}) + return + + + def upload_wandb_image(self, frames, pseudo_label_from_pred, pseudo_label_from_sam, img_number=4): + if not self._log_images: + return + + def _batched_rgb(t): + """[N,C,H,W] or [C,H,W] float tensor on CPU.""" + if not isinstance(t, torch.Tensor): + t = torch.as_tensor(t) + t = t.detach().cpu().float() + if t.dim() == 3: + return t.unsqueeze(0) + if t.dim() == 4: + return t + raise ValueError("frames must be [C,H,W] or [N,C,H,W], got shape {}".format(tuple(t.shape))) + + def _batched_mask(t): + """[N,H,W] or [N,1,H,W] or [H,W].""" + if not isinstance(t, torch.Tensor): + t = torch.as_tensor(t) + t = t.detach().cpu().float() + while t.dim() > 3: + t = t.squeeze(1) + if t.dim() == 2: + t = t.unsqueeze(0) + if t.dim() != 3: + raise ValueError("masks must be [H,W], [N,H,W] or [N,1,H,W], got shape {}".format(tuple(t.shape))) + return t + + frames = _batched_rgb(frames) + pseudo_label_from_pred = _batched_mask(pseudo_label_from_pred) + pseudo_label_from_sam = _batched_mask(pseudo_label_from_sam) + + n = min(frames.shape[0], pseudo_label_from_pred.shape[0], pseudo_label_from_sam.shape[0], img_number) + frames = frames[:n] + pseudo_label_from_pred = pseudo_label_from_pred[:n] + pseudo_label_from_sam = pseudo_label_from_sam[:n] + + pseudo_label_from_sam = pseudo_label_from_sam.clone() + pseudo_label_from_pred = pseudo_label_from_pred.clone() + pseudo_label_from_sam[pseudo_label_from_sam == 255.] = 0.5 + pseudo_label_from_pred[pseudo_label_from_pred == 255.] = 0.5 + + denorm = self.restore_transform.transforms[0] + image_list = [] + label_list = [] + logits_list = [] + for i in range(n): + fi = frames[i].clone() + if fi.shape[0] == 3: + denorm(fi) + fi.clamp_(0.0, 1.0) + image_list.append(wandb.Image(fi, caption="id {}".format(str(i)))) + # wandb.Image expects torch tensors as [C, H, W] (it permutes CHW→HWC) + ms = pseudo_label_from_sam[i].squeeze() + mp = pseudo_label_from_pred[i].squeeze() + if ms.dim() == 2: + ms = ms.unsqueeze(0) + if mp.dim() == 2: + mp = mp.unsqueeze(0) + label_list.append(wandb.Image(ms, caption="id {}".format(str(i)))) + logits_list.append(wandb.Image(mp, caption="id {}".format(str(i)))) + + self.tensor_board.log({"image": image_list, "label": label_list, "logits": logits_list}) + + def de_normalize(self, image): + return [self.restore_transform(i.detach().cpu()) if (isinstance(i, torch.Tensor) and len(i.shape) == 3) + else colorize_mask(i.detach().cpu().numpy(), self.palette) + for i in image] + + def finish(self): + self.tensor_board.finish() + + +class DeNormalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, tensor): + for t, m, s in zip(tensor, self.mean, self.std): + t.mul_(s).add_(m) + return tensor + + +def colorize_mask(mask, palette): + zero_pad = 256 * 3 - len(palette) + for i in range(zero_pad): + palette.append(0) + # palette[-6:-3] = [183, 65, 14] + new_mask = PIL.Image.fromarray(mask.astype(numpy.uint8)).convert('P') + new_mask.putpalette(palette) + return new_mask diff --git a/avs.code/v2.code/utils/utils.py b/avs.code/v2.code/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e72f27a7e2be77cea271001230195ef79f685351 --- /dev/null +++ b/avs.code/v2.code/utils/utils.py @@ -0,0 +1,119 @@ +"""Optimizer helpers: split learning rates for AuralFuser train_* vs VGG backbone.""" +import torch +import copy +from typing import List, Dict, Set, Any + + +def manipulate_params(cfg, model): + weight_decay_norm = 0 + weight_decay_embed = 0 + defaults = {} + defaults["lr"] = cfg.lr + defaults["weight_decay"] = cfg.weight_decay + + norm_module_types = ( + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + torch.nn.SyncBatchNorm, + torch.nn.GroupNorm, + torch.nn.InstanceNorm1d, + torch.nn.InstanceNorm2d, + torch.nn.InstanceNorm3d, + torch.nn.LayerNorm, + torch.nn.LocalResponseNorm, + ) + + params_training: List[Dict[str, Any]] = [] + params_finetuning: List[Dict[str, Any]] = [] + memo: Set[torch.nn.parameter.Parameter] = set() + + train_prefixes = ( + "patch_embeds", + "f_blocks", + "a_blocks", + "fusion_modules", + "smooth_convs", + "train_proj_v1", + "train_proj_a1", + ) + + for module_name, module in model.named_modules(): + for module_param_name, value in module.named_parameters(recurse=False): + if not value.requires_grad: + continue + # Avoid duplicating parameters + if value in memo: + continue + memo.add(value) + hyperparams = copy.copy(defaults) + if 'vgg' in module_name or 'vgg' in module_param_name: + hyperparams['lr'] *= 0.1 + params_finetuning.append({"params": [value], "name": [module_name], **hyperparams}) + elif ( + 'train' in module_name + or 'train' in module_param_name + or module_name.startswith(train_prefixes) + ): + if ( + "relative_position_bias_table" in module_param_name + or "pos_embed" in module_param_name + ): + hyperparams["weight_decay"] = 0.0 + if isinstance(module, norm_module_types): + hyperparams["weight_decay"] = 0.0 + if isinstance(module, torch.nn.Embedding): + hyperparams["weight_decay"] = 0.0 + params_training.append({"params": [value], "name": [module_name], **hyperparams}) + else: + print('undefined layer type.') + raise NotImplementedError + final_list = params_training + params_finetuning + assert len([p for p in model.parameters() if p.requires_grad]) == len(final_list), 'checksum confirmed not pass.' + return final_list + + +def group_weight(weight_group, module, weight_decay_value, lr): + group_decay = [] + group_no_decay = [] + norm_module_types = ( + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + torch.nn.SyncBatchNorm, + torch.nn.GroupNorm, + torch.nn.InstanceNorm1d, + torch.nn.InstanceNorm2d, + torch.nn.InstanceNorm3d, + torch.nn.LayerNorm, + torch.nn.LocalResponseNorm, + ) + + for m in module.modules(): + if isinstance(m, torch.nn.Linear): + group_decay.append(m.weight) + if m.bias is not None: + group_no_decay.append(m.bias) + elif isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d)): + group_decay.append(m.weight) + if m.bias is not None: + group_no_decay.append(m.bias) + elif isinstance(m, norm_module_types): + if m.weight is not None: + group_no_decay.append(m.weight) + if m.bias is not None: + group_no_decay.append(m.bias) + elif isinstance(m, torch.nn.Parameter): + group_no_decay.append(m) + elif isinstance(m, torch.nn.Embedding): + group_no_decay.append(m) + else: + print('undefined layer type find.') + raise NotImplementedError + + assert len(list(module.parameters())) == len(group_decay) + len( + group_no_decay) + weight_group.append(dict(params=group_decay, weight_deacy=weight_decay_value, lr=lr)) + weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr)) + return weight_group \ No newline at end of file diff --git a/ckpts/avs/v1m/0.75336.pth b/ckpts/avs/v1m/0.75336.pth new file mode 100644 index 0000000000000000000000000000000000000000..0890e540e44aff444c3006b9ac89598efc72026c --- /dev/null +++ b/ckpts/avs/v1m/0.75336.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a9e9007862d95dd2705a1b873bf942fa08a931eab2c8f3556316995516f7c1b5 +size 521935136 diff --git a/ckpts/avs/v1m/nohup.out b/ckpts/avs/v1m/nohup.out new file mode 100644 index 0000000000000000000000000000000000000000..218ec622f8eee3736129877a4197c577e866934f --- /dev/null +++ b/ckpts/avs/v1m/nohup.out @@ -0,0 +1,596 @@ +wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information. +==> Load pretrained VGGish parameters from /root/autodl-tmp/avs/v1/v1m.code/ckpts/vggish-10086976.pth +device: cuda +==> Load pretrained VGGish parameters from /root/autodl-tmp/avs/v1/v1m.code/ckpts/vggish-10086976.pth +device: cuda +==> Load pretrained VGGish parameters from /root/autodl-tmp/avs/v1/v1m.code/ckpts/vggish-10086976.pth +device: cuda +wandb: Currently logged in as: pyedog1976. Use `wandb login --relogin` to force relogin +wandb: Appending key for api.wandb.ai to your netrc file: /root/.netrc +wandb: - Waiting for wandb.init()... wandb: \ Waiting for wandb.init()... wandb: Tracking run with wandb version 0.18.3 +wandb: Run data is saved locally in /root/autodl-tmp/avs/v1/wandb/run-20260421_050129-gzp5dmwi +wandb: Run `wandb offline` to turn off syncing. +wandb: Syncing run v1m-hiera-l +wandb: ⭐️ View project at https://wandb.ai/pyedog1976/AVS-final-report +wandb: 🚀 View run at https://wandb.ai/pyedog1976/AVS-final-report/runs/gzp5dmwi +==> Load pretrained VGGish parameters from /root/autodl-tmp/avs/v1/v1m.code/ckpts/vggish-10086976.pth +device: cuda + 0%| | 0/74 [00:00 [gpus] +./run_ref_train.sh [gpus] +``` +The experiments are implemented by 4 GPUs by default. + +## 🔍 Inference (example) + +```bash +cd avs.code/v2.code +python inference.py --gpus 1 --batch_size 1 --inference_ckpt /absolute/path/to/checkpoint.pth +``` + +## 📊 Training Logs (Reproducibility) + +Some examples of training details, please see [this wandb link](https://wandb.ai/pyedog1976/AVS-final-report/workspace?nw=nwuserpyedog1976). + +In details, after clicking the run (e.g., [v1m-hiera-l](https://wandb.ai/pyedog1976/AVS-final-report/runs/gzp5dmwi/logs?nw=nwuserpyedog1976)), you can checkout: + +1) overall information (e.g., command line, hardware information and training time). +2) training curves and validation visualisation. +3) output logs. + + +## 💾 Checkpoints +We release both checkpoints and training logs in this [Google Drive link](https://drive.google.com/drive/folders/1n0HaCHMn48KaImXvX2mu4qKHUQg4mo9R?usp=sharing). + + diff --git a/docs/installation.md b/docs/installation.md new file mode 100644 index 0000000000000000000000000000000000000000..51f2d513992bfb653e58b7add987468446b3f503 --- /dev/null +++ b/docs/installation.md @@ -0,0 +1,116 @@ +# Installation + +The project is based on Python and PyTorch. We usually run experiments with multi-GPU training. + +Tested runtime: +- Python `3.12.3` +- PyTorch `2.8.0+cu128` + +## 📥 Clone the Git repo + +``` shell +$ https://github.com/yyliu01/AuralSAM2 +$ cd AuralSAM2 +``` + +## 🧩 Install dependencies + +1) create conda env from yaml +```shell +$ conda env create -f docs/auralsam2.yml +``` + +2) activate env +```shell +$ conda activate auralsam2 +``` + +3) install PyTorch (recommended: match tested runtime) +```shell +# CUDA 12.8 (tested): +$ pip install torch==2.8.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 +``` + +4) install python packages (if needed) +```shell +$ pip install -r docs/requirements.txt +``` + +## 🗂️ Prepare dataset + +### AVSBench (`avs.code`) + +1) download and prepare AVSBench under repository root. +2) ensure the dataset root path is: + - `AVSBench/` + - `AVSBench/avss_index/metadata.csv` (and subset folders `v1s/`, `v1m/`, `v2/`) + +### Ref-AVS (`ref-avs.code`) + +1) download and prepare the Ref-AVS (REFAVS) dataset under repository root. +2) ensure the dataset root path is: + - `REFAVS/` + - `REFAVS/metadata.csv` (splits: `train`, `test_s`, `test_u`, `test_n`) + + +### Checkpoints (shared) + +Prepare under repository root: + +- `ckpts/sam_ckpts/sam2_hiera_large.pt` +- `ckpts/vggish-10086976.pth` + +## 🏗️ Workspace structure + +```shell +AuralSAM2/ +├── avs.code/ +│ ├── v1s.code/ +│ ├── v1m.code/ +│ └── v2.code/ +├── ref-avs.code/ +├── scripts/ +│ ├── run_avs_train.sh +│ └── run_ref_train.sh +├── AVSBench/ +│ ├── avss_index +│ │ ├── metadata.csv +│ │ ├── metadata_v1m_man.csv +│ │ └── metadata_v2_man.csv +│ ├── v1m +│ │ ├── 01uIJMwnUvA_0 +│ │ ├── 0WxgIKuetYI_0 +│ │ ... (419 more) +│ ├── v1s +│ │ ├── --FenyW2i_4_5000_10000 +│ │ ├── --ZHUMfueO0_5000_10000 +│ │ ... (4927 more) +│ └── v2 +│ ├── --KCIeTv6PM_14000_24000 +│ ├── --iSerV5DbY_68000_78000 +│ ... (5995 more) +├── REFAVS/ +│ ├── gt_mask +│ │ ├── --KCIeTv6PM_14000_24000 +│ │ ├── --iSerV5DbY_68000_78000 +│ │ ... (~4000 more) +│ ├── media +│ │ ├── --KCIeTv6PM_14000_24000 +│ │ ├── --iSerV5DbY_68000_78000 +│ │ ... (~4300 more) +│ └── metadata.csv +├── ckpts/ +│ ├── sam_ckpts/ +│ │ └── sam2_hiera_large.pt +│ └── vggish-10086976.pth +└── docs/ + ├── installation.md + ├── before_start.md + ├── requirements.txt + └── auralsam2.yml +``` + +## 📝 Notes + +- use `docs/before_start.md` for training and inference commands. +- if wandb is not needed, disable online logging in your config. diff --git a/docs/overview.png b/docs/overview.png new file mode 100644 index 0000000000000000000000000000000000000000..62f8247f972743545f2e8d2b1d8b59742d6d0959 --- /dev/null +++ b/docs/overview.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e707a6d77fa1ffff5416027e6da0126a3d84d77c04d0127caae35722ebf7965 +size 144143 diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1d0d2a626752910d9eaf06cba3e341302075f958 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,159 @@ +absl-py==2.3.1 +annotated-doc==0.0.4 +annotated-types==0.7.0 +antlr4-python3-runtime==4.9.3 +anyio==4.10.0 +argon2-cffi==25.1.0 +argon2-cffi-bindings==25.1.0 +arrow==1.3.0 +asttokens==3.0.0 +async-lru==2.0.5 +attrs==25.3.0 +babel==2.17.0 +beautifulsoup4==4.13.4 +bleach==6.2.0 +click==8.3.2 +comm==0.2.3 +contourpy==1.3.3 +cycler==0.12.1 +debugpy==1.8.16 +decorator==5.2.1 +defusedxml==0.7.1 +easydict==1.13 +executing==2.2.0 +fastjsonschema==2.21.1 +filelock==3.18.0 +fonttools==4.59.0 +fqdn==1.5.1 +fsspec==2025.7.0 +gitdb==4.0.12 +GitPython==3.1.46 +grpcio==1.74.0 +h11==0.16.0 +hf-xet==1.4.3 +httpcore==1.0.9 +httpx==0.28.1 +huggingface_hub==1.11.0 +hydra-core==1.3.2 +iopath==0.1.10 +ipykernel==6.30.1 +ipython==9.4.0 +ipython_pygments_lexers==1.1.1 +ipywidgets==8.1.7 +isoduration==20.11.0 +jedi==0.19.2 +Jinja2==3.1.6 +json5==0.12.0 +jsonpointer==2.1 +jsonschema==4.25.0 +jsonschema-specifications==2025.4.1 +jupyter-events==0.12.0 +jupyter-lsp==2.2.6 +jupyter_client==8.6.3 +jupyter_core==5.8.1 +jupyter_server==2.16.0 +jupyter_server_terminals==0.5.3 +jupyterlab==4.4.5 +jupyterlab-language-pack-zh-CN==4.4.post0 +jupyterlab_pygments==0.3.0 +jupyterlab_server==2.27.3 +jupyterlab_widgets==3.0.15 +kiwisolver==1.4.8 +lark==1.2.2 +llvmlite==0.47.0 +Markdown==3.8.2 +markdown-it-py==4.0.0 +MarkupSafe==3.0.2 +matplotlib==3.10.5 +matplotlib-inline==0.1.7 +mdurl==0.1.2 +mistune==3.1.3 +mpmath==1.3.0 +nbclient==0.10.2 +nbconvert==7.16.6 +nbformat==5.10.4 +nest-asyncio==1.6.0 +networkx==3.5 +notebook_shim==0.2.4 +numba==0.65.0 +numpy<2 +nvidia-cublas-cu12==12.8.4.1 +nvidia-cuda-cupti-cu12==12.8.90 +nvidia-cuda-nvrtc-cu12==12.8.93 +nvidia-cuda-runtime-cu12==12.8.90 +nvidia-cudnn-cu12==9.10.2.21 +nvidia-cufft-cu12==11.3.3.83 +nvidia-cufile-cu12==1.13.1.3 +nvidia-curand-cu12==10.3.9.90 +nvidia-cusolver-cu12==11.7.3.90 +nvidia-cusparse-cu12==12.5.8.93 +nvidia-cusparselt-cu12==0.7.1 +nvidia-nccl-cu12==2.27.3 +nvidia-nvjitlink-cu12==12.8.93 +nvidia-nvtx-cu12==12.8.90 +omegaconf==2.3.0 +overrides==7.7.0 +pandas==3.0.2 +pandocfilters==1.5.1 +parso==0.8.4 +pexpect==4.9.0 +pillow==11.3.0 +portalocker==3.2.0 +prometheus_client==0.22.1 +prompt_toolkit==3.0.51 +protobuf==7.34.1 +psutil==7.0.0 +ptyprocess==0.7.0 +pure_eval==0.2.3 +pydantic==2.13.3 +pydantic_core==2.46.3 +Pygments==2.19.2 +pyparsing==3.2.3 +python-dateutil==2.9.0.post0 +python-json-logger==3.3.0 +PyYAML==6.0.2 +pyzmq==27.0.1 +referencing==0.36.2 +resampy==0.4.3 +rfc3339-validator==0.1.4 +rfc3986-validator==0.1.1 +rfc3987-syntax==1.1.0 +rich==15.0.0 +rpds-py==0.26.0 +safetensors==0.7.0 +Send2Trash==1.8.3 +sentry-sdk==2.58.0 +setuptools==69.5.1 +shellingham==1.5.4 +six==1.17.0 +smmap==5.0.3 +sniffio==1.3.1 +soundfile==0.13.1 +soupsieve==2.7 +stack-data==0.6.3 +supervisor==4.2.5 +sympy==1.14.0 +tensorboard==2.20.0 +tensorboard-data-server==0.7.2 +terminado==0.18.1 +timm==1.0.26 +tinycss2==1.4.0 +tornado==6.5.1 +traitlets==5.14.3 +triton==3.4.0 +typer==0.24.1 +types-python-dateutil==2.9.0.20250708 +typing-inspection==0.4.2 +typing_extensions==4.14.1 +uri-template==1.3.0 +wcwidth==0.2.13 +webcolors==24.11.1 +webencodings==0.5.1 +websocket-client==1.8.0 +Werkzeug==3.1.3 +wheel==0.43.0 +widgetsnbextension==4.0.14 +transformers==5.6.2 +audiomentations==0.39.0 +wandb==0.26.1 + diff --git a/ref-avs.code/configs/auralfuser/architecture.yaml b/ref-avs.code/configs/auralfuser/architecture.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ab4c3d06ca42335ce6bfc8064bbd5cfd44c8080a --- /dev/null +++ b/ref-avs.code/configs/auralfuser/architecture.yaml @@ -0,0 +1,30 @@ +# @package _global_ + +aural_fuser: + patch_cfgs: + - [4, 4] + - [2, 2] + - [1, 1] + f_depths: [3, 6, 12] + block_kw: + dim: 256 + num_heads: 4 + mlp_ratio: 4 + qkv_bias: true + qk_scale: null + drop: 0.1 + attn_drop: 0.1 + drop_path: 0.0 + sr_ratio: 4 + linear: false + one_d_kw: + dim: 256 + num_heads: 4 + mlp_ratio: 4 + qkv_bias: true + qk_scale: null + drop: 0.1 + attn_drop: 0.1 + drop_path: 0.0 + sr_ratio: 4 + linear: false diff --git a/ref-avs.code/configs/config.py b/ref-avs.code/configs/config.py new file mode 100644 index 0000000000000000000000000000000000000000..994ead1b8c6f1bf85c57737f2bb2ae402c297525 --- /dev/null +++ b/ref-avs.code/configs/config.py @@ -0,0 +1,52 @@ +"""Ref-AVS training / inference defaults (paths relative to repo root).""" +import os +import pathlib +import numpy +from easydict import EasyDict + +_CODE_ROOT = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +_WORKSPACE_ROOT = os.path.dirname(os.path.dirname(_CODE_ROOT)) + +C = EasyDict() +config = C +cfg = C + +C.seed = 666 + +C.audio = EasyDict() +C.audio.FREEZE_AUDIO_EXTRACTOR = True +C.audio.PRETRAINED_VGGISH_MODEL_PATH = os.path.join(_WORKSPACE_ROOT, 'ckpts', 'vggish-10086976.pth') +C.audio.PREPROCESS_AUDIO_TO_LOG_MEL = False +C.audio.POSTPROCESS_LOG_MEL_WITH_PCA = False +C.train_vggish = False + +C.root_dir = _CODE_ROOT + +# REFAVS layout: REFAVS/metadata.csv, REFAVS/media//... +C.data_root_path = os.path.join(_WORKSPACE_ROOT, 'REFAVS') +C.backbone_weight = os.path.join(_WORKSPACE_ROOT, 'ckpts', 'sam_ckpts', 'sam2_hiera_large.pt') +C.sam_config_path = os.path.join('sam2', 'sam2_hiera_l.yaml') + +C.num_classes = 2 +C.image_mean = numpy.array([0.485, 0.456, 0.406]) +C.image_std = numpy.array([0.229, 0.224, 0.225]) +C.image_size = 1024 +C.image_embedding_size = int(C.image_size / 16) +C.scale_list = [.5, .75, 1., 1.25, 1.5] +C.ignore_index = 255 + +C.lr = 7.5e-5 +C.batch_size = 8 +C.lr_power = 0.9 +C.momentum = 0.9 +C.weight_decay = 0.05 +C.num_workers = 4 + +# Paste W&B API key here or set WANDB_API_KEY in the environment. +C.wandb_key = "" +C.proj_name = "AVS-final-report" +C.experiment_name = "ref-hiera-l" +C.wandb_online = False + +C.saved_dir = os.path.join(_WORKSPACE_ROOT, 'ckpts', 'exp', C.experiment_name) +pathlib.Path(C.saved_dir).mkdir(parents=True, exist_ok=True) diff --git a/ref-avs.code/configs/sam2/sam2_hiera_b+.yaml b/ref-avs.code/configs/sam2/sam2_hiera_b+.yaml new file mode 100644 index 0000000000000000000000000000000000000000..52e0f10732134149f6a994be063d11fd7591c430 --- /dev/null +++ b/ref-avs.code/configs/sam2/sam2_hiera_b+.yaml @@ -0,0 +1,114 @@ +# @package _global_ + +# Model +model: + _target_: model.visual.sam2.organised_sam2_train.SAM2Train + image_encoder: + _target_: model.visual.sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: model.visual.sam2.modeling.backbones.hieradet.Hiera + embed_dim: 112 + num_heads: 2 + neck: + _target_: model.visual.sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [896, 448, 224, 112] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: model.visual.sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: model.visual.sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: model.visual.sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: model.visual.sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: model.visual.sam2.modeling.memory_encoder.Fuser + layer: + _target_: model.visual.sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False + diff --git a/ref-avs.code/configs/sam2/sam2_hiera_l.yaml b/ref-avs.code/configs/sam2/sam2_hiera_l.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8478b3d4b8b16d8b22f6555cf7b1f00231d7fd59 --- /dev/null +++ b/ref-avs.code/configs/sam2/sam2_hiera_l.yaml @@ -0,0 +1,117 @@ +# @package _global_ + +# Model +model: + _target_: model.visual.sam2.organised_sam2_train.SAM2Train + image_encoder: + _target_: model.visual.sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: model.visual.sam2.modeling.backbones.hieradet.Hiera + embed_dim: 144 + num_heads: 2 + stages: [2, 6, 36, 4] + global_att_blocks: [23, 33, 43] + window_pos_embed_bkg_spatial_size: [7, 7] + window_spec: [8, 4, 16, 8] + neck: + _target_: model.visual.sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [1152, 576, 288, 144] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: model.visual.sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: model.visual.sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: model.visual.sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: model.visual.sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: model.visual.sam2.modeling.memory_encoder.Fuser + layer: + _target_: model.visual.sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/ref-avs.code/configs/sam2/sam2_hiera_s.yaml b/ref-avs.code/configs/sam2/sam2_hiera_s.yaml new file mode 100644 index 0000000000000000000000000000000000000000..26e5d4d39f7b2892396106005c37c7ffe6c83bc2 --- /dev/null +++ b/ref-avs.code/configs/sam2/sam2_hiera_s.yaml @@ -0,0 +1,116 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 11, 2] + global_att_blocks: [7, 10, 13] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/ref-avs.code/configs/sam2/sam2_hiera_t.yaml b/ref-avs.code/configs/sam2/sam2_hiera_t.yaml new file mode 100644 index 0000000000000000000000000000000000000000..59e605b73c9777b70942538252d27a55ae8a7e1a --- /dev/null +++ b/ref-avs.code/configs/sam2/sam2_hiera_t.yaml @@ -0,0 +1,118 @@ +# @package _global_ + +# Model +model: + _target_: model.visual.sam2.organised_sam2_train.SAM2Train + image_encoder: + _target_: model.visual.sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: model.visual.sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 7, 2] + global_att_blocks: [5, 7, 9] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: model.visual.sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: model.visual.sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: model.visual.sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: model.visual.sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: model.visual.sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: model.visual.sam2.modeling.memory_encoder.Fuser + layer: + _target_: model.visual.sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 224 # 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + # SAM decoder + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: false + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + # HieraT does not currently support compilation, should always be set to False + compile_image_encoder: False diff --git a/ref-avs.code/configs/training/sam2_training_config.yaml b/ref-avs.code/configs/training/sam2_training_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..29df1199d79c6a9031b82e23aa4b40df99064650 --- /dev/null +++ b/ref-avs.code/configs/training/sam2_training_config.yaml @@ -0,0 +1,60 @@ +# @package _global_ + +# Video transforms +train_transforms: + - _target_: dataloader.sam2_dataset.transforms.ComposeAPI + transforms: + - _target_: dataloader.sam2_dataset.transforms.RandomHorizontalFlip + consistent_transform: True + - _target_: dataloader.sam2_dataset.transforms.RandomAffine + degrees: 25 + shear: 20 + image_interpolation: bilinear + consistent_transform: True + - _target_: dataloader.sam2_dataset.transforms.RandomResizeAPI + sizes: 1024 + square: true + consistent_transform: True + - _target_: dataloader.sam2_dataset.transforms.ColorJitter + consistent_transform: True + brightness: 0.1 + contrast: 0.03 + saturation: 0.03 + hue: null + - _target_: dataloader.sam2_dataset.transforms.RandomGrayscale + p: 0.05 + consistent_transform: True + - _target_: dataloader.sam2_dataset.transforms.ColorJitter + consistent_transform: False + brightness: 0.1 + contrast: 0.05 + saturation: 0.05 + hue: null + - _target_: dataloader.sam2_dataset.transforms.ToTensorAPI + - _target_: dataloader.sam2_dataset.transforms.NormalizeAPI + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +loss: + all: + _target_: loss.training.sam2_training_loss.MultiStepMultiMasksAndIous + weight_dict: + loss_mask: 20 + loss_dice: 1 + loss_iou: 1 + loss_class: 1 + supervise_all_iou: true + iou_use_l1_loss: true + pred_obj_scores: true + focal_gamma_obj_score: 0.0 + focal_alpha_obj_score: -1.0 + gpu_num: 4. + +contrastive_learning: + temperature: 0.10 + ignore_idx: 255 + ood_idx: 254 + max_views: 512 + proj_dim: 512 + sample_limits: 64 + total_limits: 15240 diff --git a/ref-avs.code/dataloader/audio/audio_and_text_dataset.py b/ref-avs.code/dataloader/audio/audio_and_text_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8b4a8bc3a4a4efacffa37483c37f0e4f6e3d030e --- /dev/null +++ b/ref-avs.code/dataloader/audio/audio_and_text_dataset.py @@ -0,0 +1,35 @@ +"""Load REFAVS audio (log-mel) and pass through referring expression strings for the text encoder.""" +import os + +import numpy +import soundfile +import torch + +from dataloader.audio.preprocess_vgg.vggish_input import waveform_to_examples + + +class AudioAndText(torch.utils.data.Dataset): + def __init__(self, augmentation, directory_path, split): + self.augmentation = augmentation + self.directory_path = directory_path + self.split = split + + def load_audio_wave(self, file_index, text_expression): + audio_path = os.path.join(self.directory_path, 'media', file_index, 'audio.wav') + wav_data, sample_rate = soundfile.read(audio_path, dtype='int16') + assert wav_data.dtype == numpy.int16, 'Bad sample type: %r' % wav_data.dtype + wav_data = self.augmentation(wav_data, sample_rate, self.split) + if len(wav_data.shape) < 2: + wav_data = wav_data[:, numpy.newaxis] + wav_data = numpy.repeat(wav_data, axis=-1, repeats=2) + + audio_log_mel = torch.cat([ + waveform_to_examples(wav_data[:, 0], sample_rate, True).detach(), + waveform_to_examples(wav_data[:, 1], sample_rate, True).detach(), + ], dim=1) + + # VGGish expects at least 5 temporal segments. + if audio_log_mel.shape[0] < 5: + pad = audio_log_mel[-1].unsqueeze(0).repeat(5 - audio_log_mel.shape[0], 1, 1, 1) + audio_log_mel = torch.cat([audio_log_mel, pad]) + return audio_log_mel, text_expression diff --git a/ref-avs.code/dataloader/audio/audio_augmentation.py b/ref-avs.code/dataloader/audio/audio_augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..850d1577ea2bca4f8ec209edc201fb54968be928 --- /dev/null +++ b/ref-avs.code/dataloader/audio/audio_augmentation.py @@ -0,0 +1,23 @@ +import numpy + + +class Augmentation(object): + """Audio pre-step used by training/inference: int16 waveform -> float in [-1, 1]. + + The previous audiomentations-based transforms were commented out and never applied; + behavior is unchanged: only scaling by 1/32768. + """ + + def __init__(self, mono=True): + self.mono = mono + + def train_aug(self, x_, sr_): + x_ = x_ / 32768.0 + return x_ + + def test_process(self, x_): + x_ = x_ / 32768.0 + return x_ + + def __call__(self, x, sr, split): + return self.train_aug(x, sr) if split == "train" else self.test_process(x) diff --git a/ref-avs.code/dataloader/audio/audio_dataset.py b/ref-avs.code/dataloader/audio/audio_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5c8e8b276e8545aa55ef56295719a0ad2b167106 --- /dev/null +++ b/ref-avs.code/dataloader/audio/audio_dataset.py @@ -0,0 +1,38 @@ +import torch +import numpy +import os +from dataloader.audio.preprocess_vgg.vggish_input import waveform_to_examples +import soundfile + + +class Audio(torch.utils.data.Dataset): + def __init__(self, augmentation, directory_path, split): + # temporarily set no augmentation. + self.augmentation = augmentation + self.directory_path = directory_path + self.split = split + + def load_audio_wave(self, file_index, file_index_mix): + audio_path = os.path.join(file_index, 'audio.wav') + wav_data, sample_rate = soundfile.read(audio_path, dtype='int16') + assert wav_data.dtype == numpy.int16, 'Bad sample type: %r' % wav_data.dtype + + if file_index_mix is not None: + audio_path2 = os.path.join(file_index_mix, 'audio.wav') + wav_data2, _ = soundfile.read(audio_path2, dtype='int16') + mix_lambda = numpy.random.beta(10, 10) + min_length = min(wav_data.shape[0], wav_data2.shape[0]) + wav_data = wav_data[:min_length] * mix_lambda + wav_data2[:min_length] * (1-mix_lambda) + + wav_data = self.augmentation(wav_data, sample_rate, self.split) + audio_log_mel = torch.cat([waveform_to_examples(wav_data[:, 0], sample_rate, True).detach(), + waveform_to_examples(wav_data[:, 1], sample_rate, True).detach()], dim=1) + + # for the vgg preprocess, we will need 5 seconds audio log. + if audio_log_mel.shape[0] < 5: + audio_log_mel = torch.cat([audio_log_mel, + audio_log_mel[-1].unsqueeze(0).repeat(5-audio_log_mel.shape[0], 1, 1, 1)]) + return audio_log_mel + + def __len__(self): + return len(self.audio_list) diff --git a/ref-avs.code/dataloader/audio/preprocess_vgg/mel_features.py b/ref-avs.code/dataloader/audio/preprocess_vgg/mel_features.py new file mode 100644 index 0000000000000000000000000000000000000000..ac58fb5427f772fcced9cbd3cec3373ffbe5908c --- /dev/null +++ b/ref-avs.code/dataloader/audio/preprocess_vgg/mel_features.py @@ -0,0 +1,223 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Defines routines to compute mel spectrogram features from audio waveform.""" + +import numpy as np + + +def frame(data, window_length, hop_length): + """Convert array into a sequence of successive possibly overlapping frames. + + An n-dimensional array of shape (num_samples, ...) is converted into an + (n+1)-D array of shape (num_frames, window_length, ...), where each frame + starts hop_length points after the preceding one. + + This is accomplished using stride_tricks, so the original data is not + copied. However, there is no zero-padding, so any incomplete frames at the + end are not included. + + Args: + data: np.array of dimension N >= 1. + window_length: Number of samples in each frame. + hop_length: Advance (in samples) between each window. + + Returns: + (N+1)-D np.array with as many rows as there are complete frames that can be + extracted. + """ + num_samples = data.shape[0] + num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length)) + shape = (num_frames, window_length) + data.shape[1:] + strides = (data.strides[0] * hop_length,) + data.strides + return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides) + + +def periodic_hann(window_length): + """Calculate a "periodic" Hann window. + + The classic Hann window is defined as a raised cosine that starts and + ends on zero, and where every value appears twice, except the middle + point for an odd-length window. Matlab calls this a "symmetric" window + and np.hanning() returns it. However, for Fourier analysis, this + actually represents just over one cycle of a period N-1 cosine, and + thus is not compactly expressed on a length-N Fourier basis. Instead, + it's better to use a raised cosine that ends just before the final + zero value - i.e. a complete cycle of a period-N cosine. Matlab + calls this a "periodic" window. This routine calculates it. + + Args: + window_length: The number of points in the returned window. + + Returns: + A 1D np.array containing the periodic hann window. + """ + return 0.5 - (0.5 * np.cos(2 * np.pi / window_length * + np.arange(window_length))) + + +def stft_magnitude(signal, fft_length, + hop_length=None, + window_length=None): + """Calculate the short-time Fourier transform magnitude. + + Args: + signal: 1D np.array of the input time-domain signal. + fft_length: Size of the FFT to apply. + hop_length: Advance (in samples) between each frame passed to FFT. + window_length: Length of each block of samples to pass to FFT. + + Returns: + 2D np.array where each row contains the magnitudes of the fft_length/2+1 + unique values of the FFT for the corresponding frame of input samples. + """ + frames = frame(signal, window_length, hop_length) + # Apply frame window to each frame. We use a periodic Hann (cosine of period + # window_length) instead of the symmetric Hann of np.hanning (period + # window_length-1). + window = periodic_hann(window_length) + windowed_frames = frames * window + return np.abs(np.fft.rfft(windowed_frames, int(fft_length))) + + +# Mel spectrum constants and functions. +_MEL_BREAK_FREQUENCY_HERTZ = 700.0 +_MEL_HIGH_FREQUENCY_Q = 1127.0 + + +def hertz_to_mel(frequencies_hertz): + """Convert frequencies to mel scale using HTK formula. + + Args: + frequencies_hertz: Scalar or np.array of frequencies in hertz. + + Returns: + Object of same size as frequencies_hertz containing corresponding values + on the mel scale. + """ + return _MEL_HIGH_FREQUENCY_Q * np.log( + 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ)) + + +def spectrogram_to_mel_matrix(num_mel_bins=20, + num_spectrogram_bins=129, + audio_sample_rate=8000, + lower_edge_hertz=125.0, + upper_edge_hertz=3800.0): + """Return a matrix that can post-multiply spectrogram rows to make mel. + + Returns a np.array matrix A that can be used to post-multiply a matrix S of + spectrogram values (STFT magnitudes) arranged as frames x bins to generate a + "mel spectrogram" M of frames x num_mel_bins. M = S A. + + The classic HTK algorithm exploits the complementarity of adjacent mel bands + to multiply each FFT bin by only one mel weight, then add it, with positive + and negative signs, to the two adjacent mel bands to which that bin + contributes. Here, by expressing this operation as a matrix multiply, we go + from num_fft multiplies per frame (plus around 2*num_fft adds) to around + num_fft^2 multiplies and adds. However, because these are all presumably + accomplished in a single call to np.dot(), it's not clear which approach is + faster in Python. The matrix multiplication has the attraction of being more + general and flexible, and much easier to read. + + Args: + num_mel_bins: How many bands in the resulting mel spectrum. This is + the number of columns in the output matrix. + num_spectrogram_bins: How many bins there are in the source spectrogram + data, which is understood to be fft_size/2 + 1, i.e. the spectrogram + only contains the nonredundant FFT bins. + audio_sample_rate: Samples per second of the audio at the input to the + spectrogram. We need this to figure out the actual frequencies for + each spectrogram bin, which dictates how they are mapped into mel. + lower_edge_hertz: Lower bound on the frequencies to be included in the mel + spectrum. This corresponds to the lower edge of the lowest triangular + band. + upper_edge_hertz: The desired top edge of the highest frequency band. + + Returns: + An np.array with shape (num_spectrogram_bins, num_mel_bins). + + Raises: + ValueError: if frequency edges are incorrectly ordered or out of range. + """ + nyquist_hertz = audio_sample_rate / 2. + if lower_edge_hertz < 0.0: + raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz) + if lower_edge_hertz >= upper_edge_hertz: + raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" % + (lower_edge_hertz, upper_edge_hertz)) + if upper_edge_hertz > nyquist_hertz: + raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" % + (upper_edge_hertz, nyquist_hertz)) + spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins) + spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz) + # The i'th mel band (starting from i=1) has center frequency + # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge + # band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in + # the band_edges_mel arrays. + band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz), + hertz_to_mel(upper_edge_hertz), num_mel_bins + 2) + # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins + # of spectrogram values. + mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins)) + for i in range(num_mel_bins): + lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3] + # Calculate lower and upper slopes for every spectrogram bin. + # Line segments are linear in the *mel* domain, not hertz. + lower_slope = ((spectrogram_bins_mel - lower_edge_mel) / + (center_mel - lower_edge_mel)) + upper_slope = ((upper_edge_mel - spectrogram_bins_mel) / + (upper_edge_mel - center_mel)) + # .. then intersect them with each other and zero. + mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope, + upper_slope)) + # HTK excludes the spectrogram DC bin; make sure it always gets a zero + # coefficient. + mel_weights_matrix[0, :] = 0.0 + return mel_weights_matrix + + +def log_mel_spectrogram(data, + audio_sample_rate=8000, + log_offset=0.0, + window_length_secs=0.025, + hop_length_secs=0.010, + **kwargs): + """Convert waveform to a log magnitude mel-frequency spectrogram. + + Args: + data: 1D np.array of waveform data. + audio_sample_rate: The sampling rate of data. + log_offset: Add this to values when taking log to avoid -Infs. + window_length_secs: Duration of each window to analyze. + hop_length_secs: Advance between successive analysis windows. + **kwargs: Additional arguments to pass to spectrogram_to_mel_matrix. + + Returns: + 2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank + magnitudes for successive frames. + """ + window_length_samples = int(round(audio_sample_rate * window_length_secs)) + hop_length_samples = int(round(audio_sample_rate * hop_length_secs)) + fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0))) + spectrogram = stft_magnitude( + data, + fft_length=fft_length, + hop_length=hop_length_samples, + window_length=window_length_samples) + mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix( + num_spectrogram_bins=spectrogram.shape[1], + audio_sample_rate=audio_sample_rate, **kwargs)) + return np.log(mel_spectrogram + log_offset) diff --git a/ref-avs.code/dataloader/audio/preprocess_vgg/vggish_input.py b/ref-avs.code/dataloader/audio/preprocess_vgg/vggish_input.py new file mode 100644 index 0000000000000000000000000000000000000000..9d58e81bc70a85138980128e033f271998794605 --- /dev/null +++ b/ref-avs.code/dataloader/audio/preprocess_vgg/vggish_input.py @@ -0,0 +1,98 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Compute input examples for VGGish from audio waveform.""" + +# Modification: Return torch tensors rather than numpy arrays +import torch + +import numpy as np +import resampy + +from dataloader.audio.preprocess_vgg import mel_features +from dataloader.audio.preprocess_vgg import vggish_params + +import soundfile as sf + + +def waveform_to_examples(data, sample_rate, return_tensor=True): + """Converts audio waveform into an array of examples for VGGish. + + Args: + data: np.array of either one dimension (mono) or two dimensions + (multi-channel, with the outer dimension representing channels). + Each sample is generally expected to lie in the range [-1.0, +1.0], + although this is not required. + sample_rate: Sample rate of data. + return_tensor: Return data as a Pytorch tensor ready for VGGish + + Returns: + 3-D np.array of shape [num_examples, num_frames, num_bands] which represents + a sequence of examples, each of which contains a patch of log mel + spectrogram, covering num_frames frames of audio and num_bands mel frequency + bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS. + + """ + # Convert to mono. + if len(data.shape) > 1: + data = np.mean(data, axis=1) + # Resample to the rate assumed by VGGish. + if sample_rate != vggish_params.SAMPLE_RATE: + data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE) + + # Compute log mel spectrogram features. + log_mel = mel_features.log_mel_spectrogram( + data, + audio_sample_rate=vggish_params.SAMPLE_RATE, + log_offset=vggish_params.LOG_OFFSET, + window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS, + hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS, + num_mel_bins=vggish_params.NUM_MEL_BINS, + lower_edge_hertz=vggish_params.MEL_MIN_HZ, + upper_edge_hertz=vggish_params.MEL_MAX_HZ) + + # Frame features into examples. + features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS + example_window_length = int(round( + vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate)) + example_hop_length = int(round( + vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate)) + log_mel_examples = mel_features.frame( + log_mel, + window_length=example_window_length, + hop_length=example_hop_length) + + if return_tensor: + log_mel_examples = torch.tensor( + log_mel_examples, requires_grad=True)[:, None, :, :].float() + + return log_mel_examples + + +def wavfile_to_examples(wav_file, return_tensor=True): + """Convenience wrapper around waveform_to_examples() for a common WAV format. + + Args: + wav_file: String path to a file, or a file-like object. The file + is assumed to contain WAV audio data with signed 16-bit PCM samples. + torch: Return data as a Pytorch tensor ready for VGGish + + Returns: + See waveform_to_examples. + """ + wav_data, sr = sf.read(wav_file, dtype='int16') + assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype + samples = wav_data / 32768.0 # Convert to [-1.0, +1.0] + return waveform_to_examples(samples, sr, return_tensor) diff --git a/ref-avs.code/dataloader/audio/preprocess_vgg/vggish_params.py b/ref-avs.code/dataloader/audio/preprocess_vgg/vggish_params.py new file mode 100644 index 0000000000000000000000000000000000000000..526784bceaa4c9c8b8dc2b8f82e0f3d395d4bec2 --- /dev/null +++ b/ref-avs.code/dataloader/audio/preprocess_vgg/vggish_params.py @@ -0,0 +1,53 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Global parameters for the VGGish model. + +See vggish_slim.py for more information. +""" + +# Architectural constants. +NUM_FRAMES = 96 # Frames in input mel-spectrogram patch. +NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch. +EMBEDDING_SIZE = 128 # Size of embedding layer. + +# Hyperparameters used in feature and example generation. +SAMPLE_RATE = 16000 +STFT_WINDOW_LENGTH_SECONDS = 0.025 +STFT_HOP_LENGTH_SECONDS = 0.010 +NUM_MEL_BINS = NUM_BANDS +MEL_MIN_HZ = 125 +MEL_MAX_HZ = 7500 +LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram. +EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames +EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap. + +# Parameters used for embedding postprocessing. +PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors' +PCA_MEANS_NAME = 'pca_means' +QUANTIZE_MIN_VAL = -2.0 +QUANTIZE_MAX_VAL = +2.0 + +# Hyperparameters used in training. +INIT_STDDEV = 0.01 # Standard deviation used to initialize weights. +LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer. +ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer. + +# Names of ops, tensors, and features. +INPUT_OP_NAME = 'vggish/input_features' +INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0' +OUTPUT_OP_NAME = 'vggish/embedding' +OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0' +AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding' diff --git a/ref-avs.code/dataloader/dataset.py b/ref-avs.code/dataloader/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b1fd43f7e17e9978fc5d36ec172796badde6f7f1 --- /dev/null +++ b/ref-avs.code/dataloader/dataset.py @@ -0,0 +1,68 @@ +"""Ref-AVS dataset: frames, masks, log-mel audio, and referring expressions.""" +import os +import numpy +import torch +import pandas + +from dataloader.visual.visual_dataset import Visual +from dataloader.audio.audio_and_text_dataset import AudioAndText + + +class AV(torch.utils.data.Dataset): + """Pairs ``Visual`` with ``AudioAndText`` via REFAVS ``metadata.csv``.""" + + def __init__(self, split, augmentation, param, root_path=''): + self.visual_dataset = Visual( + augmentation['visual'], root_path, split, + param.image_size, param.image_embedding_size, + ) + self.audio_and_text_dataset = AudioAndText(augmentation['audio'], root_path, split) + self.split = split + self.file_path = self.organise_files(self.split, root_path, csv_name_='metadata.csv') + + def __getitem__(self, index): + vid, fid, exp, _ = self.file_path[index] + frame, label, prompts = self.visual_dataset.load_data(vid, fid) + audio_mel, text_feature = self.audio_and_text_dataset.load_audio_wave(vid, exp) + return { + 'frame': frame, + 'label': label, + 'spectrogram': audio_mel, + 'text': text_feature, + 'id': self.file_path[index], + 'prompts': prompts, + } + + def __len__(self): + return len(self.file_path) + + @staticmethod + def organise_files(split_, root_path_, csv_name_): + total_files = pandas.read_csv(os.path.join(root_path_, csv_name_)) + if split_ == 'test_n': + rows = zip( + total_files[total_files['split'] == split_]['uid'], + total_files[total_files['split'] == split_]['fid'], + total_files[total_files['split'] == split_]['exp'], + ) + return [ + [name.rsplit('_', 2)[0], object_id, expression, 0] + for name, object_id, expression in rows + ] + + rows = zip( + total_files[total_files['split'] == split_]['vid'], + total_files[total_files['split'] == split_]['fid'], + total_files[total_files['split'] == split_]['exp'], + ) + file_path = [[vid, fid, expression, 0] for vid, fid, expression in rows] + + if split_ == 'train': + null_uids = list(total_files[total_files['split'] == split_]['uid']) + assert len(null_uids) == len(file_path) + for idx, row in enumerate(file_path): + if 'null_' in null_uids[idx]: + row[0] = null_uids[idx].rsplit('_', 2)[0] + row[-1] = null_uids[idx].rsplit('_', 2)[1] + + return file_path diff --git a/ref-avs.code/dataloader/sam2_dataset/__init__.py b/ref-avs.code/dataloader/sam2_dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/ref-avs.code/dataloader/sam2_dataset/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/ref-avs.code/dataloader/sam2_dataset/sam2_datasets.py b/ref-avs.code/dataloader/sam2_dataset/sam2_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..6deda056bea555fc07ace455ccc62c606a7b81c9 --- /dev/null +++ b/ref-avs.code/dataloader/sam2_dataset/sam2_datasets.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +from typing import Callable, Iterable, List, Optional, Sequence + +import torch + +from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Subset + +from torch.utils.data.distributed import DistributedSampler + + +class MixedDataLoader: + def __init__(self, dataloaders: List[DataLoader], mixing_prob: torch.FloatTensor): + """ + Args: + dataloaders (List[DataLoader]): List of DataLoaders to be mixed. + mixing_prob (torch.FloatTensor): Probability of each dataloader to be sampled from + + """ + assert len(dataloaders) == mixing_prob.shape[0] + self.dataloaders = dataloaders + self.mixing_prob = mixing_prob + # Iterator state + self._iter_dls = None + self._iter_mixing_prob = None + self.random_generator = torch.Generator() + + def __len__(self): + return sum([len(d) for d in self.dataloaders]) + + def __iter__(self): + # Synchronize dataloader seeds + self.random_generator.manual_seed(42) + self._iter_dls = [iter(loader) for loader in self.dataloaders] + self._iter_mixing_prob = self.mixing_prob.clone() + return self + + def __next__(self): + """ + Sample a dataloader to sample from based on mixing probabilities. If one of the dataloaders is exhausted, we continue sampling from the other loaders until all are exhausted. + """ + if self._iter_dls is None: + raise TypeError(f"{type(self).__name__} object is not an iterator") + + while self._iter_mixing_prob.any(): # at least one D-Loader with non-zero prob. + dataset_idx = self._iter_mixing_prob.multinomial( + 1, generator=self.random_generator + ).item() + try: + item = next(self._iter_dls[dataset_idx]) + return item + except StopIteration: + # No more iterations for this dataset, set it's mixing probability to zero and try again. + self._iter_mixing_prob[dataset_idx] = 0 + except Exception as e: + # log and raise any other unexpected error. + logging.error(e) + raise e + + # Exhausted all iterators + raise StopIteration + + +class TorchTrainMixedDataset: + def __init__( + self, + datasets: List[Dataset], + batch_sizes: List[int], + num_workers: int, + shuffle: bool, + pin_memory: bool, + drop_last: bool, + collate_fn: Optional[Callable] = None, + worker_init_fn: Optional[Callable] = None, + phases_per_epoch: int = 1, + dataset_prob: Optional[List[float]] = None, + ) -> None: + """ + Args: + datasets (List[Dataset]): List of Datasets to be mixed. + batch_sizes (List[int]): Batch sizes for each dataset in the list. + num_workers (int): Number of workers per dataloader. + shuffle (bool): Whether or not to shuffle data. + pin_memory (bool): If True, use pinned memory when loading tensors from disk. + drop_last (bool): Whether or not to drop the last batch of data. + collate_fn (Callable): Function to merge a list of samples into a mini-batch. + worker_init_fn (Callable): Function to init each dataloader worker. + phases_per_epoch (int): Number of phases per epoch. + dataset_prob (List[float]): Probability of choosing the dataloader to sample from. Should sum to 1.0 + """ + + self.datasets = datasets + self.batch_sizes = batch_sizes + self.num_workers = num_workers + self.shuffle = shuffle + self.pin_memory = pin_memory + self.drop_last = drop_last + self.collate_fn = collate_fn + self.worker_init_fn = worker_init_fn + assert len(self.datasets) > 0 + for dataset in self.datasets: + assert not isinstance(dataset, IterableDataset), "Not supported" + # `RepeatFactorWrapper` requires calling set_epoch first to get its length + self._set_dataset_epoch(dataset, 0) + self.phases_per_epoch = phases_per_epoch + self.chunks = [None] * len(datasets) + if dataset_prob is None: + # If not provided, assign each dataset a probability proportional to its length. + dataset_lens = [ + (math.floor(len(d) / bs) if drop_last else math.ceil(len(d) / bs)) + for d, bs in zip(datasets, batch_sizes) + ] + total_len = sum(dataset_lens) + dataset_prob = torch.tensor([d_len / total_len for d_len in dataset_lens]) + else: + assert len(dataset_prob) == len(datasets) + dataset_prob = torch.tensor(dataset_prob) + + logging.info(f"Dataset mixing probabilities: {dataset_prob.tolist()}") + assert dataset_prob.sum().item() == 1.0, "Probabilities should sum to 1.0" + self.dataset_prob = dataset_prob + + def _set_dataset_epoch(self, dataset, epoch: int) -> None: + if hasattr(dataset, "epoch"): + dataset.epoch = epoch + if hasattr(dataset, "set_epoch"): + dataset.set_epoch(epoch) + + def get_loader(self, epoch) -> Iterable: + dataloaders = [] + for d_idx, (dataset, batch_size) in enumerate( + zip(self.datasets, self.batch_sizes) + ): + if self.phases_per_epoch > 1: + # Major epoch that looops over entire dataset + # len(main_epoch) == phases_per_epoch * len(epoch) + main_epoch = epoch // self.phases_per_epoch + + # Phase with in the main epoch + local_phase = epoch % self.phases_per_epoch + + # Start of new data-epoch or job is resumed after preemtion. + if local_phase == 0 or self.chunks[d_idx] is None: + # set seed for dataset epoch + # If using RepeatFactorWrapper, this step currectly re-samples indices before chunking. + self._set_dataset_epoch(dataset, main_epoch) + + # Separate random generator for subset sampling + g = torch.Generator() + g.manual_seed(main_epoch) + self.chunks[d_idx] = torch.chunk( + torch.randperm(len(dataset), generator=g), + self.phases_per_epoch, + ) + + dataset = Subset(dataset, self.chunks[d_idx][local_phase]) + else: + self._set_dataset_epoch(dataset, epoch) + + sampler = DistributedSampler(dataset, shuffle=self.shuffle) + sampler.set_epoch(epoch) + + batch_sampler = BatchSampler(sampler, batch_size, drop_last=self.drop_last) + dataloaders.append( + DataLoader( + dataset, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + batch_sampler=batch_sampler, + collate_fn=self.collate_fn, + worker_init_fn=self.worker_init_fn, + ) + ) + return MixedDataLoader(dataloaders, self.dataset_prob) diff --git a/ref-avs.code/dataloader/sam2_dataset/transforms.py b/ref-avs.code/dataloader/sam2_dataset/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..7731e59ba98a5465493e3a9c4b785eb4d4420ca2 --- /dev/null +++ b/ref-avs.code/dataloader/sam2_dataset/transforms.py @@ -0,0 +1,528 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Transforms and data augmentation for both image + bbox. +""" + +import logging + +import random +from typing import Iterable + +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as F +import torchvision.transforms.v2.functional as Fv2 +from PIL import Image as PILImage +# from docutils.nodes import label +import numpy +from torchvision.transforms import InterpolationMode + +# from utils.data_utils import VideoDatapoint + + +def hflip(frames, labels, index): + # print(index) + # print(len(frames), frames[index].size, type(frames[index])) + # print(len(labels), labels[index].size, type(labels[index])) + frames[index] = F.hflip(frames[index]) + labels[index] = F.hflip(labels[index]) + # for obj in frames[index].objects: + # if obj.segment is not None: + # obj.segment = F.hflip(obj.segment) + + return frames, labels + + +def get_size_with_aspect_ratio(image_size, size, max_size=None): + w, h = image_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = max_size * min_original_size / max_original_size + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = int(round(size)) + oh = int(round(size * h / w)) + else: + oh = int(round(size)) + ow = int(round(size * w / h)) + + return (oh, ow) + + +def resize(frames, labels, index, size, max_size=None, square=False, v2=False): + # size can be min_size (scalar) or (w, h) tuple + def get_size(image_size, size, max_size=None): + if isinstance(size, (list, tuple)): + return size[::-1] + else: + return get_size_with_aspect_ratio(image_size, size, max_size) + + if square: + size = size, size + else: + raise NotImplementedError + # cur_size = ( + # frames[index].data.size()[-2:][::-1] + # if v2 + # else frames[index].data.size + # ) + # size = get_size(cur_size, size, max_size) + + # old_size = ( + # frames[index].data.size()[-2:][::-1] + # if v2 + # else frames[index].data.size + # ) + if v2: + frames[index].data = Fv2.resize( + frames[index].data, size, antialias=True + ) + else: + frames[index] = F.resize(frames[index], size) + labels[index] = F.resize(labels[index], size) + # new_size = ( + # frames[index].data.size()[-2:][::-1] + # if v2 + # else frames[index].data.size + # ) + + # for obj in frames[index].objects: + # if obj.segment is not None: + # obj.segment = F.resize(obj.segment[None, None], size).squeeze() + + # h, w = size + # frames[index].size = (h, w) + return frames, labels + + +def pad(frames, index, padding, v2=False): + old_h, old_w = frames[index].size + h, w = old_h, old_w + if len(padding) == 2: + # assumes that we only pad on the bottom right corners + frames[index].data = F.pad( + frames[index].data, (0, 0, padding[0], padding[1]) + ) + h += padding[1] + w += padding[0] + else: + # left, top, right, bottom + frames[index].data = F.pad( + frames[index].data, + (padding[0], padding[1], padding[2], padding[3]), + ) + h += padding[1] + padding[3] + w += padding[0] + padding[2] + + frames[index].size = (h, w) + + for obj in frames[index].objects: + if obj.segment is not None: + if v2: + if len(padding) == 2: + obj.segment = Fv2.pad(obj.segment, (0, 0, padding[0], padding[1])) + else: + obj.segment = Fv2.pad(obj.segment, tuple(padding)) + else: + if len(padding) == 2: + obj.segment = F.pad(obj.segment, (0, 0, padding[0], padding[1])) + else: + obj.segment = F.pad(obj.segment, tuple(padding)) + return frames + + +class RandomHorizontalFlip: + def __init__(self, consistent_transform, p=0.5): + self.p = p + self.consistent_transform = consistent_transform + + def __call__(self, frames, labels, **kwargs): + if self.consistent_transform: + if random.random() < self.p: + for i in range(len(frames)): + frames, labels = hflip(frames, labels, i) + return frames, labels + for i in range(len(frames)): + if random.random() < self.p: + frames, labels = hflip(frames, labels, i) + return frames, labels + + +class RandomResizeAPI: + def __init__( + self, sizes, consistent_transform, max_size=None, square=False, v2=False + ): + if isinstance(sizes, int): + sizes = (sizes,) + assert isinstance(sizes, Iterable) + self.sizes = list(sizes) + self.max_size = max_size + self.square = square + self.consistent_transform = consistent_transform + self.v2 = v2 + + def __call__(self, frames, labels): + if self.consistent_transform: + size = random.choice(self.sizes) + for i in range(len(frames)): + frames, labels = resize( + frames, labels, i, size, self.max_size, square=self.square, v2=self.v2 + ) + return frames, labels + for i in range(len(frames)): + size = random.choice(self.sizes) + frames, labels = resize( + frames, labels, i, size, self.max_size, square=self.square, v2=self.v2 + ) + return frames, labels + + +class ToTensorAPI: + def __init__(self, v2=False): + self.v2 = v2 + + def __call__(self, frames, labels, **kwargs): + for img_idx in range(len(frames)): + if self.v2: + raise NotImplementedError + # frames[img_idx] = Fv2.to_tensor(frames[img_idx]) + else: + frames[img_idx] = F.to_tensor(frames[img_idx]) + labels[img_idx] = torch.tensor(numpy.array(labels[img_idx]), dtype=torch.float) + return frames, labels + + +class NormalizeAPI: + def __init__(self, mean, std, v2=False): + self.mean = mean + self.std = std + self.v2 = v2 + + def __call__(self, frames, labels, **kwargs): + for img_idx in range(len(frames)): + # if self.v2: + # img.data = Fv2.convert_image_dtype(img.data, torch.float32) + # img.data = Fv2.normalize(img.data, mean=self.mean, std=self.std) + # else: + frames[img_idx] = F.normalize(frames[img_idx], mean=self.mean, std=self.std) + + return frames, labels + +''' + + + + + + + + +''' +class ComposeAPI: + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, frames, labels, **kwargs): + for t in self.transforms: + frames, labels = t(frames, labels, **kwargs) + return frames, labels + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string + + +class RandomGrayscale: + def __init__(self, consistent_transform, p=0.5): + self.p = p + self.consistent_transform = consistent_transform + self.Grayscale = T.Grayscale(num_output_channels=3) + + def __call__(self, frames, labels, **kwargs): + if self.consistent_transform: + if random.random() < self.p: + for img_idx in range(len(frames)): + frames[img_idx] = self.Grayscale(frames[img_idx]) + return frames, labels + for img_idx in range(len(frames)): + if random.random() < self.p: + frames[img_idx] = self.Grayscale(frames[img_idx]) + return frames, labels + + +class ColorJitter: + def __init__(self, consistent_transform, brightness, contrast, saturation, hue): + self.consistent_transform = consistent_transform + self.brightness = ( + brightness + if isinstance(brightness, list) + else [max(0, 1 - brightness), 1 + brightness] + ) + self.contrast = ( + contrast + if isinstance(contrast, list) + else [max(0, 1 - contrast), 1 + contrast] + ) + self.saturation = ( + saturation + if isinstance(saturation, list) + else [max(0, 1 - saturation), 1 + saturation] + ) + self.hue = hue if isinstance(hue, list) or hue is None else ([-hue, hue]) + + def __call__(self, frames, labels, **kwargs): + if self.consistent_transform: + # Create a color jitter transformation params + ( + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ) = T.ColorJitter.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) + for img in frames: + if not self.consistent_transform: + ( + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ) = T.ColorJitter.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) + for fn_id in fn_idx: + if fn_id == 0 and brightness_factor is not None: + img = F.adjust_brightness(img, brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + img = F.adjust_contrast(img, contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + img = F.adjust_saturation(img, saturation_factor) + elif fn_id == 3 and hue_factor is not None: + img = F.adjust_hue(img, hue_factor) + return frames, labels + + +class RandomAffine: + def __init__( + self, + degrees, + consistent_transform, + scale=None, + translate=None, + shear=None, + image_mean=(123, 116, 103), + label_fill_value=0., + log_warning=True, + num_tentatives=1, + image_interpolation="bicubic", + ): + """ + The mask is required for this transform. + if consistent_transform if True, then the same random affine is applied to all frames and masks. + """ + self.degrees = degrees if isinstance(degrees, list) else ([-degrees, degrees]) + self.scale = scale + self.shear = ( + shear if isinstance(shear, list) else ([-shear, shear] if shear else None) + ) + self.translate = translate + self.fill_img = image_mean + self.fill_label = label_fill_value + self.consistent_transform = consistent_transform + self.log_warning = log_warning + self.num_tentatives = num_tentatives + assert self.num_tentatives >= 1., 'must have at least one if we utilise the augmentation.' + + if image_interpolation == "bicubic": + self.image_interpolation = InterpolationMode.BICUBIC + elif image_interpolation == "bilinear": + self.image_interpolation = InterpolationMode.BILINEAR + else: + raise NotImplementedError + + def __call__(self, frames, labels, **kwargs): + for _tentative in range(self.num_tentatives): + res_img, res_labels = self.transform_frames(frames, labels) + # if res is not None: + return res_img, res_labels + + # raise NotImplementedError + # if self.log_warning: + # logging.warning( + # f"Skip RandomAffine for zero-area mask in first frame after {self.num_tentatives} tentatives" + # ) + # return frames + + def transform_frames(self, frames, labels): + _, height, width = F.get_dimensions(frames[0]) + img_size = [width, height] + + if self.consistent_transform: + # Create a random affine transformation + affine_params = T.RandomAffine.get_params( + degrees=self.degrees, + translate=self.translate, + scale_ranges=self.scale, + shears=self.shear, + img_size=img_size, + ) + + for img_idx, img in enumerate(frames): + if not self.consistent_transform: + # if not consistent we create a new affine params for every frame&mask pair Create a random affine transformation + affine_params = T.RandomAffine.get_params( + degrees=self.degrees, + translate=self.translate, + scale_ranges=self.scale, + shears=self.shear, + img_size=img_size, + ) + frames[img_idx] = F.affine( + img, + *affine_params, + interpolation=self.image_interpolation, + fill=self.fill_img, + ) + labels[img_idx] = F.affine( + labels[img_idx], + *affine_params, + # default: interpolation='nearest', + fill=self.fill_label, + ) + return frames, labels + + +''' +def random_mosaic_frame( + datapoint, + index, + grid_h, + grid_w, + target_grid_y, + target_grid_x, + should_hflip, +): + # Step 1: downsize the images and paste them into a mosaic + image_data = datapoint.frames[index].data + is_pil = isinstance(image_data, PILImage.Image) + if is_pil: + H_im = image_data.height + W_im = image_data.width + image_data_output = PILImage.new("RGB", (W_im, H_im)) + else: + H_im = image_data.size(-2) + W_im = image_data.size(-1) + image_data_output = torch.zeros_like(image_data) + + downsize_cache = {} + for grid_y in range(grid_h): + for grid_x in range(grid_w): + y_offset_b = grid_y * H_im // grid_h + x_offset_b = grid_x * W_im // grid_w + y_offset_e = (grid_y + 1) * H_im // grid_h + x_offset_e = (grid_x + 1) * W_im // grid_w + H_im_downsize = y_offset_e - y_offset_b + W_im_downsize = x_offset_e - x_offset_b + + if (H_im_downsize, W_im_downsize) in downsize_cache: + image_data_downsize = downsize_cache[(H_im_downsize, W_im_downsize)] + else: + image_data_downsize = F.resize( + image_data, + size=(H_im_downsize, W_im_downsize), + interpolation=InterpolationMode.BILINEAR, + antialias=True, # antialiasing for downsizing + ) + downsize_cache[(H_im_downsize, W_im_downsize)] = image_data_downsize + if should_hflip[grid_y, grid_x].item(): + image_data_downsize = F.hflip(image_data_downsize) + + if is_pil: + image_data_output.paste(image_data_downsize, (x_offset_b, y_offset_b)) + else: + image_data_output[:, y_offset_b:y_offset_e, x_offset_b:x_offset_e] = ( + image_data_downsize + ) + + datapoint.frames[index].data = image_data_output + + # Step 2: downsize the masks and paste them into the target grid of the mosaic + for obj in datapoint.frames[index].objects: + if obj.segment is None: + continue + assert obj.segment.shape == (H_im, W_im) and obj.segment.dtype == torch.uint8 + segment_output = torch.zeros_like(obj.segment) + + target_y_offset_b = target_grid_y * H_im // grid_h + target_x_offset_b = target_grid_x * W_im // grid_w + target_y_offset_e = (target_grid_y + 1) * H_im // grid_h + target_x_offset_e = (target_grid_x + 1) * W_im // grid_w + target_H_im_downsize = target_y_offset_e - target_y_offset_b + target_W_im_downsize = target_x_offset_e - target_x_offset_b + + segment_downsize = F.resize( + obj.segment[None, None], + size=(target_H_im_downsize, target_W_im_downsize), + interpolation=InterpolationMode.BILINEAR, + antialias=True, # antialiasing for downsizing + )[0, 0] + if should_hflip[target_grid_y, target_grid_x].item(): + segment_downsize = F.hflip(segment_downsize[None, None])[0, 0] + + segment_output[ + target_y_offset_b:target_y_offset_e, target_x_offset_b:target_x_offset_e + ] = segment_downsize + obj.segment = segment_output + + return datapoint + + +class RandomMosaicVideoAPI: + def __init__(self, prob=0.15, grid_h=2, grid_w=2, use_random_hflip=False): + self.prob = prob + self.grid_h = grid_h + self.grid_w = grid_w + self.use_random_hflip = use_random_hflip + + def __call__(self, frames, **kwargs): + if random.random() > self.prob: + return datapoint + + # select a random location to place the target mask in the mosaic + target_grid_y = random.randint(0, self.grid_h - 1) + target_grid_x = random.randint(0, self.grid_w - 1) + # whether to flip each grid in the mosaic horizontally + if self.use_random_hflip: + should_hflip = torch.rand(self.grid_h, self.grid_w) < 0.5 + else: + should_hflip = torch.zeros(self.grid_h, self.grid_w, dtype=torch.bool) + for i in range(len(datapoint.frames)): + datapoint = random_mosaic_frame( + datapoint, + i, + grid_h=self.grid_h, + grid_w=self.grid_w, + target_grid_y=target_grid_y, + target_grid_x=target_grid_x, + should_hflip=should_hflip, + ) + + return datapoint +''' \ No newline at end of file diff --git a/ref-avs.code/dataloader/sam2_dataset/utils.py b/ref-avs.code/dataloader/sam2_dataset/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a658df234c3dcf74404f844b5be793b0545485ed --- /dev/null +++ b/ref-avs.code/dataloader/sam2_dataset/utils.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Some wrapping utilities extended from pytorch's to support repeat factor sampling in particular""" + +from typing import Iterable + +import torch +from torch.utils.data import ( + ConcatDataset as TorchConcatDataset, + Dataset, + Subset as TorchSubset, +) + + +class ConcatDataset(TorchConcatDataset): + def __init__(self, datasets: Iterable[Dataset]) -> None: + super(ConcatDataset, self).__init__(datasets) + + self.repeat_factors = torch.cat([d.repeat_factors for d in datasets]) + + def set_epoch(self, epoch: int): + for dataset in self.datasets: + if hasattr(dataset, "epoch"): + dataset.epoch = epoch + if hasattr(dataset, "set_epoch"): + dataset.set_epoch(epoch) + + +class Subset(TorchSubset): + def __init__(self, dataset, indices) -> None: + super(Subset, self).__init__(dataset, indices) + + self.repeat_factors = dataset.repeat_factors[indices] + assert len(indices) == len(self.repeat_factors) + + +# Adapted from Detectron2 +class RepeatFactorWrapper(Dataset): + """ + Thin wrapper around a dataset to implement repeat factor sampling. + The underlying dataset must have a repeat_factors member to indicate the per-image factor. + Set it to uniformly ones to disable repeat factor sampling + """ + + def __init__(self, dataset, seed: int = 0): + self.dataset = dataset + self.epoch_ids = None + self._seed = seed + + # Split into whole number (_int_part) and fractional (_frac_part) parts. + self._int_part = torch.trunc(dataset.repeat_factors) + self._frac_part = dataset.repeat_factors - self._int_part + + def _get_epoch_indices(self, generator): + """ + Create a list of dataset indices (with repeats) to use for one epoch. + + Args: + generator (torch.Generator): pseudo random number generator used for + stochastic rounding. + + Returns: + torch.Tensor: list of dataset indices to use in one epoch. Each index + is repeated based on its calculated repeat factor. + """ + # Since repeat factors are fractional, we use stochastic rounding so + # that the target repeat factor is achieved in expectation over the + # course of training + rands = torch.rand(len(self._frac_part), generator=generator) + rep_factors = self._int_part + (rands < self._frac_part).float() + # Construct a list of indices in which we repeat images as specified + indices = [] + for dataset_index, rep_factor in enumerate(rep_factors): + indices.extend([dataset_index] * int(rep_factor.item())) + return torch.tensor(indices, dtype=torch.int64) + + def __len__(self): + if self.epoch_ids is None: + # Here we raise an error instead of returning original len(self.dataset) avoid + # accidentally using unwrapped length. Otherwise it's error-prone since the + # length changes to `len(self.epoch_ids)`changes after set_epoch is called. + raise RuntimeError("please call set_epoch first to get wrapped length") + # return len(self.dataset) + + return len(self.epoch_ids) + + def set_epoch(self, epoch: int): + g = torch.Generator() + g.manual_seed(self._seed + epoch) + self.epoch_ids = self._get_epoch_indices(g) + if hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(epoch) + + def __getitem__(self, idx): + if self.epoch_ids is None: + raise RuntimeError( + "Repeat ids haven't been computed. Did you forget to call set_epoch?" + ) + + return self.dataset[self.epoch_ids[idx]] diff --git a/ref-avs.code/dataloader/sam2_dataset/vos_dataset.py b/ref-avs.code/dataloader/sam2_dataset/vos_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d1e9d39fe184cf0d86fbf22b5385dc05988cab83 --- /dev/null +++ b/ref-avs.code/dataloader/sam2_dataset/vos_dataset.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import random +from copy import deepcopy + +import numpy as np + +import torch +from iopath.common.file_io import g_pathmgr +from PIL import Image as PILImage +from torchvision.datasets.vision import VisionDataset + +from training.dataset.vos_raw_dataset import VOSRawDataset +from training.dataset.vos_sampler import VOSSampler +from training.dataset.vos_segment_loader import JSONSegmentLoader + +from training.utils.data_utils import Frame, Object, VideoDatapoint + +MAX_RETRIES = 100 + + +class VOSDataset(VisionDataset): + def __init__( + self, + transforms, + training: bool, + video_dataset: VOSRawDataset, + sampler: VOSSampler, + multiplier: int, + always_target=True, + target_segments_available=True, + ): + self._transforms = transforms + self.training = training + self.video_dataset = video_dataset + self.sampler = sampler + + self.repeat_factors = torch.ones(len(self.video_dataset), dtype=torch.float32) + self.repeat_factors *= multiplier + print(f"Raw dataset length = {len(self.video_dataset)}") + + self.curr_epoch = 0 # Used in case data loader behavior changes across epochs + self.always_target = always_target + self.target_segments_available = target_segments_available + + def _get_datapoint(self, idx): + + for retry in range(MAX_RETRIES): + try: + if isinstance(idx, torch.Tensor): + idx = idx.item() + # sample a video + video, segment_loader = self.video_dataset.get_video(idx) + # sample frames and object indices to be used in a datapoint + sampled_frms_and_objs = self.sampler.sample( + video, segment_loader, epoch=self.curr_epoch + ) + break # Succesfully loaded video + except Exception as e: + if self.training: + logging.warning( + f"Loading failed (id={idx}); Retry {retry} with exception: {e}" + ) + idx = random.randrange(0, len(self.video_dataset)) + else: + # Shouldn't fail to load a val video + raise e + + datapoint = self.construct(video, sampled_frms_and_objs, segment_loader) + for transform in self._transforms: + datapoint = transform(datapoint, epoch=self.curr_epoch) + return datapoint + + def construct(self, video, sampled_frms_and_objs, segment_loader): + """ + Constructs a VideoDatapoint sample to pass to transforms + """ + sampled_frames = sampled_frms_and_objs.frames + sampled_object_ids = sampled_frms_and_objs.object_ids + + images = [] + rgb_images = load_images(sampled_frames) + # Iterate over the sampled frames and store their rgb data and object data (bbox, segment) + for frame_idx, frame in enumerate(sampled_frames): + w, h = rgb_images[frame_idx].size + images.append( + Frame( + data=rgb_images[frame_idx], + objects=[], + ) + ) + # We load the gt segments associated with the current frame + if isinstance(segment_loader, JSONSegmentLoader): + segments = segment_loader.load( + frame.frame_idx, obj_ids=sampled_object_ids + ) + else: + segments = segment_loader.load(frame.frame_idx) + for obj_id in sampled_object_ids: + # Extract the segment + if obj_id in segments: + assert ( + segments[obj_id] is not None + ), "None targets are not supported" + # segment is uint8 and remains uint8 throughout the transforms + segment = segments[obj_id].to(torch.uint8) + else: + # There is no target, we either use a zero mask target or drop this object + if not self.always_target: + continue + segment = torch.zeros(h, w, dtype=torch.uint8) + + images[frame_idx].objects.append( + Object( + object_id=obj_id, + frame_index=frame.frame_idx, + segment=segment, + ) + ) + return VideoDatapoint( + frames=images, + video_id=video.video_id, + size=(h, w), + ) + + def __getitem__(self, idx): + return self._get_datapoint(idx) + + def __len__(self): + return len(self.video_dataset) + + +def load_images(frames): + all_images = [] + cache = {} + for frame in frames: + if frame.data is None: + # Load the frame rgb data from file + path = frame.image_path + if path in cache: + all_images.append(deepcopy(all_images[cache[path]])) + continue + with g_pathmgr.open(path, "rb") as fopen: + all_images.append(PILImage.open(fopen).convert("RGB")) + cache[path] = len(all_images) - 1 + else: + # The frame rgb data has already been loaded + # Convert it to a PILImage + all_images.append(tensor_2_PIL(frame.data)) + + return all_images + + +def tensor_2_PIL(data: torch.Tensor) -> PILImage.Image: + data = data.cpu().numpy().transpose((1, 2, 0)) * 255.0 + data = data.astype(np.uint8) + return PILImage.fromarray(data) diff --git a/ref-avs.code/dataloader/sam2_dataset/vos_raw_dataset.py b/ref-avs.code/dataloader/sam2_dataset/vos_raw_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..44fe893717a3e3bd85b043baa33d349b52b4b34e --- /dev/null +++ b/ref-avs.code/dataloader/sam2_dataset/vos_raw_dataset.py @@ -0,0 +1,308 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import glob +import logging +import os +from dataclasses import dataclass + +from typing import List, Optional + +import pandas as pd + +import torch + +from iopath.common.file_io import g_pathmgr + +from omegaconf.listconfig import ListConfig + +from training.dataset.vos_segment_loader import ( + JSONSegmentLoader, + MultiplePNGSegmentLoader, + PalettisedPNGSegmentLoader, + SA1BSegmentLoader, +) + + +@dataclass +class VOSFrame: + frame_idx: int + image_path: str + data: Optional[torch.Tensor] = None + is_conditioning_only: Optional[bool] = False + + +@dataclass +class VOSVideo: + video_name: str + video_id: int + frames: List[VOSFrame] + + def __len__(self): + return len(self.frames) + + +class VOSRawDataset: + def __init__(self): + pass + + def get_video(self, idx): + raise NotImplementedError() + + +class PNGRawDataset(VOSRawDataset): + def __init__( + self, + img_folder, + gt_folder, + file_list_txt=None, + excluded_videos_list_txt=None, + sample_rate=1, + is_palette=True, + single_object_mode=False, + truncate_video=-1, + frames_sampling_mult=False, + ): + self.img_folder = img_folder + self.gt_folder = gt_folder + self.sample_rate = sample_rate + self.is_palette = is_palette + self.single_object_mode = single_object_mode + self.truncate_video = truncate_video + + # Read the subset defined in file_list_txt + if file_list_txt is not None: + with g_pathmgr.open(file_list_txt, "r") as f: + subset = [os.path.splitext(line.strip())[0] for line in f] + else: + subset = os.listdir(self.img_folder) + + # Read and process excluded files if provided + if excluded_videos_list_txt is not None: + with g_pathmgr.open(excluded_videos_list_txt, "r") as f: + excluded_files = [os.path.splitext(line.strip())[0] for line in f] + else: + excluded_files = [] + + # Check if it's not in excluded_files + self.video_names = sorted( + [video_name for video_name in subset if video_name not in excluded_files] + ) + + if self.single_object_mode: + # single object mode + self.video_names = sorted( + [ + os.path.join(video_name, obj) + for video_name in self.video_names + for obj in os.listdir(os.path.join(self.gt_folder, video_name)) + ] + ) + + if frames_sampling_mult: + video_names_mult = [] + for video_name in self.video_names: + num_frames = len(os.listdir(os.path.join(self.img_folder, video_name))) + video_names_mult.extend([video_name] * num_frames) + self.video_names = video_names_mult + + def get_video(self, idx): + """ + Given a VOSVideo object, return the mask tensors. + """ + video_name = self.video_names[idx] + + if self.single_object_mode: + video_frame_root = os.path.join( + self.img_folder, os.path.dirname(video_name) + ) + else: + video_frame_root = os.path.join(self.img_folder, video_name) + + video_mask_root = os.path.join(self.gt_folder, video_name) + + if self.is_palette: + segment_loader = PalettisedPNGSegmentLoader(video_mask_root) + else: + segment_loader = MultiplePNGSegmentLoader( + video_mask_root, self.single_object_mode + ) + + all_frames = sorted(glob.glob(os.path.join(video_frame_root, "*.jpg"))) + if self.truncate_video > 0: + all_frames = all_frames[: self.truncate_video] + frames = [] + for _, fpath in enumerate(all_frames[:: self.sample_rate]): + fid = int(os.path.basename(fpath).split(".")[0]) + frames.append(VOSFrame(fid, image_path=fpath)) + video = VOSVideo(video_name, idx, frames) + return video, segment_loader + + def __len__(self): + return len(self.video_names) + + +class SA1BRawDataset(VOSRawDataset): + def __init__( + self, + img_folder, + gt_folder, + file_list_txt=None, + excluded_videos_list_txt=None, + num_frames=1, + mask_area_frac_thresh=1.1, # no filtering by default + uncertain_iou=-1, # no filtering by default + ): + self.img_folder = img_folder + self.gt_folder = gt_folder + self.num_frames = num_frames + self.mask_area_frac_thresh = mask_area_frac_thresh + self.uncertain_iou = uncertain_iou # stability score + + # Read the subset defined in file_list_txt + if file_list_txt is not None: + with g_pathmgr.open(file_list_txt, "r") as f: + subset = [os.path.splitext(line.strip())[0] for line in f] + else: + subset = os.listdir(self.img_folder) + subset = [ + path.split(".")[0] for path in subset if path.endswith(".jpg") + ] # remove extension + + # Read and process excluded files if provided + if excluded_videos_list_txt is not None: + with g_pathmgr.open(excluded_videos_list_txt, "r") as f: + excluded_files = [os.path.splitext(line.strip())[0] for line in f] + else: + excluded_files = [] + + # Check if it's not in excluded_files and it exists + self.video_names = [ + video_name for video_name in subset if video_name not in excluded_files + ] + + def get_video(self, idx): + """ + Given a VOSVideo object, return the mask tensors. + """ + video_name = self.video_names[idx] + + video_frame_path = os.path.join(self.img_folder, video_name + ".jpg") + video_mask_path = os.path.join(self.gt_folder, video_name + ".json") + + segment_loader = SA1BSegmentLoader( + video_mask_path, + mask_area_frac_thresh=self.mask_area_frac_thresh, + video_frame_path=video_frame_path, + uncertain_iou=self.uncertain_iou, + ) + + frames = [] + for frame_idx in range(self.num_frames): + frames.append(VOSFrame(frame_idx, image_path=video_frame_path)) + video_name = video_name.split("_")[-1] # filename is sa_{int} + # video id needs to be image_id to be able to load correct annotation file during eval + video = VOSVideo(video_name, int(video_name), frames) + return video, segment_loader + + def __len__(self): + return len(self.video_names) + + +class JSONRawDataset(VOSRawDataset): + """ + Dataset where the annotation in the format of SA-V json files + """ + + def __init__( + self, + img_folder, + gt_folder, + file_list_txt=None, + excluded_videos_list_txt=None, + sample_rate=1, + rm_unannotated=True, + ann_every=1, + frames_fps=24, + ): + self.gt_folder = gt_folder + self.img_folder = img_folder + self.sample_rate = sample_rate + self.rm_unannotated = rm_unannotated + self.ann_every = ann_every + self.frames_fps = frames_fps + + # Read and process excluded files if provided + excluded_files = [] + if excluded_videos_list_txt is not None: + if isinstance(excluded_videos_list_txt, str): + excluded_videos_lists = [excluded_videos_list_txt] + elif isinstance(excluded_videos_list_txt, ListConfig): + excluded_videos_lists = list(excluded_videos_list_txt) + else: + raise NotImplementedError + + for excluded_videos_list_txt in excluded_videos_lists: + with open(excluded_videos_list_txt, "r") as f: + excluded_files.extend( + [os.path.splitext(line.strip())[0] for line in f] + ) + excluded_files = set(excluded_files) + + # Read the subset defined in file_list_txt + if file_list_txt is not None: + with g_pathmgr.open(file_list_txt, "r") as f: + subset = [os.path.splitext(line.strip())[0] for line in f] + else: + subset = os.listdir(self.img_folder) + + self.video_names = sorted( + [video_name for video_name in subset if video_name not in excluded_files] + ) + + def get_video(self, video_idx): + """ + Given a VOSVideo object, return the mask tensors. + """ + video_name = self.video_names[video_idx] + video_json_path = os.path.join(self.gt_folder, video_name + "_manual.json") + segment_loader = JSONSegmentLoader( + video_json_path=video_json_path, + ann_every=self.ann_every, + frames_fps=self.frames_fps, + ) + + frame_ids = [ + int(os.path.splitext(frame_name)[0]) + for frame_name in sorted( + os.listdir(os.path.join(self.img_folder, video_name)) + ) + ] + + frames = [ + VOSFrame( + frame_id, + image_path=os.path.join( + self.img_folder, f"{video_name}/%05d.jpg" % (frame_id) + ), + ) + for frame_id in frame_ids[:: self.sample_rate] + ] + + if self.rm_unannotated: + # Eliminate the frames that have not been annotated + valid_frame_ids = [ + i * segment_loader.ann_every + for i, annot in enumerate(segment_loader.frame_annots) + if annot is not None and None not in annot + ] + frames = [f for f in frames if f.frame_idx in valid_frame_ids] + + video = VOSVideo(video_name, video_idx, frames) + return video, segment_loader + + def __len__(self): + return len(self.video_names) diff --git a/ref-avs.code/dataloader/sam2_dataset/vos_sampler.py b/ref-avs.code/dataloader/sam2_dataset/vos_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..1ad84b759d0f66191a84017d17140d128b634ca0 --- /dev/null +++ b/ref-avs.code/dataloader/sam2_dataset/vos_sampler.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import random +from dataclasses import dataclass +from typing import List + +from training.dataset.vos_segment_loader import LazySegments + +MAX_RETRIES = 1000 + + +@dataclass +class SampledFramesAndObjects: + frames: List[int] + object_ids: List[int] + + +class VOSSampler: + def __init__(self, sort_frames=True): + # frames are ordered by frame id when sort_frames is True + self.sort_frames = sort_frames + + def sample(self, video): + raise NotImplementedError() + + +class RandomUniformSampler(VOSSampler): + def __init__( + self, + num_frames, + max_num_objects, + reverse_time_prob=0.0, + ): + self.num_frames = num_frames + self.max_num_objects = max_num_objects + self.reverse_time_prob = reverse_time_prob + + def sample(self, video, segment_loader, epoch=None): + + for retry in range(MAX_RETRIES): + if len(video.frames) < self.num_frames: + raise Exception( + f"Cannot sample {self.num_frames} frames from video {video.video_name} as it only has {len(video.frames)} annotated frames." + ) + start = random.randrange(0, len(video.frames) - self.num_frames + 1) + frames = [video.frames[start + step] for step in range(self.num_frames)] + if random.uniform(0, 1) < self.reverse_time_prob: + # Reverse time + frames = frames[::-1] + + # Get first frame object ids + visible_object_ids = [] + loaded_segms = segment_loader.load(frames[0].frame_idx) + if isinstance(loaded_segms, LazySegments): + # LazySegments for SA1BRawDataset + visible_object_ids = list(loaded_segms.keys()) + else: + for object_id, segment in segment_loader.load( + frames[0].frame_idx + ).items(): + if segment.sum(): + visible_object_ids.append(object_id) + + # First frame needs to have at least a target to track + if len(visible_object_ids) > 0: + break + if retry >= MAX_RETRIES - 1: + raise Exception("No visible objects") + + object_ids = random.sample( + visible_object_ids, + min(len(visible_object_ids), self.max_num_objects), + ) + return SampledFramesAndObjects(frames=frames, object_ids=object_ids) + + +class EvalSampler(VOSSampler): + """ + VOS Sampler for evaluation: sampling all the frames and all the objects in a video + """ + + def __init__( + self, + ): + super().__init__() + + def sample(self, video, segment_loader, epoch=None): + """ + Sampling all the frames and all the objects + """ + if self.sort_frames: + # ordered by frame id + frames = sorted(video.frames, key=lambda x: x.frame_idx) + else: + # use the original order + frames = video.frames + object_ids = segment_loader.load(frames[0].frame_idx).keys() + if len(object_ids) == 0: + raise Exception("First frame of the video has no objects") + + return SampledFramesAndObjects(frames=frames, object_ids=object_ids) diff --git a/ref-avs.code/dataloader/sam2_dataset/vos_segment_loader.py b/ref-avs.code/dataloader/sam2_dataset/vos_segment_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..27e17010cc8b010e103c3ac399689d80da7cfde9 --- /dev/null +++ b/ref-avs.code/dataloader/sam2_dataset/vos_segment_loader.py @@ -0,0 +1,300 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import glob +import json +import os + +import numpy as np +import pandas as pd +import torch + +from PIL import Image as PILImage + +try: + from pycocotools import mask as mask_utils +except: + pass + + +class JSONSegmentLoader: + def __init__(self, video_json_path, ann_every=1, frames_fps=24, valid_obj_ids=None): + # Annotations in the json are provided every ann_every th frame + self.ann_every = ann_every + # Ids of the objects to consider when sampling this video + self.valid_obj_ids = valid_obj_ids + with open(video_json_path, "r") as f: + data = json.load(f) + if isinstance(data, list): + self.frame_annots = data + elif isinstance(data, dict): + masklet_field_name = "masklet" if "masklet" in data else "masks" + self.frame_annots = data[masklet_field_name] + if "fps" in data: + if isinstance(data["fps"], list): + annotations_fps = int(data["fps"][0]) + else: + annotations_fps = int(data["fps"]) + assert frames_fps % annotations_fps == 0 + self.ann_every = frames_fps // annotations_fps + else: + raise NotImplementedError + + def load(self, frame_id, obj_ids=None): + assert frame_id % self.ann_every == 0 + rle_mask = self.frame_annots[frame_id // self.ann_every] + + valid_objs_ids = set(range(len(rle_mask))) + if self.valid_obj_ids is not None: + # Remove the masklets that have been filtered out for this video + valid_objs_ids &= set(self.valid_obj_ids) + if obj_ids is not None: + # Only keep the objects that have been sampled + valid_objs_ids &= set(obj_ids) + valid_objs_ids = sorted(list(valid_objs_ids)) + + # Construct rle_masks_filtered that only contains the rle masks we are interested in + id_2_idx = {} + rle_mask_filtered = [] + for obj_id in valid_objs_ids: + if rle_mask[obj_id] is not None: + id_2_idx[obj_id] = len(rle_mask_filtered) + rle_mask_filtered.append(rle_mask[obj_id]) + else: + id_2_idx[obj_id] = None + + # Decode the masks + raw_segments = torch.from_numpy(mask_utils.decode(rle_mask_filtered)).permute( + 2, 0, 1 + ) # (num_obj, h, w) + segments = {} + for obj_id in valid_objs_ids: + if id_2_idx[obj_id] is None: + segments[obj_id] = None + else: + idx = id_2_idx[obj_id] + segments[obj_id] = raw_segments[idx] + return segments + + def get_valid_obj_frames_ids(self, num_frames_min=None): + # For each object, find all the frames with a valid (not None) mask + num_objects = len(self.frame_annots[0]) + + # The result dict associates each obj_id with the id of its valid frames + res = {obj_id: [] for obj_id in range(num_objects)} + + for annot_idx, annot in enumerate(self.frame_annots): + for obj_id in range(num_objects): + if annot[obj_id] is not None: + res[obj_id].append(int(annot_idx * self.ann_every)) + + if num_frames_min is not None: + # Remove masklets that have less than num_frames_min valid masks + for obj_id, valid_frames in list(res.items()): + if len(valid_frames) < num_frames_min: + res.pop(obj_id) + + return res + + +class PalettisedPNGSegmentLoader: + def __init__(self, video_png_root): + """ + SegmentLoader for datasets with masks stored as palettised PNGs. + video_png_root: the folder contains all the masks stored in png + """ + self.video_png_root = video_png_root + # build a mapping from frame id to their PNG mask path + # note that in some datasets, the PNG paths could have more + # than 5 digits, e.g. "00000000.png" instead of "00000.png" + png_filenames = os.listdir(self.video_png_root) + self.frame_id_to_png_filename = {} + for filename in png_filenames: + frame_id, _ = os.path.splitext(filename) + self.frame_id_to_png_filename[int(frame_id)] = filename + + def load(self, frame_id): + """ + load the single palettised mask from the disk (path: f'{self.video_png_root}/{frame_id:05d}.png') + Args: + frame_id: int, define the mask path + Return: + binary_segments: dict + """ + # check the path + mask_path = os.path.join( + self.video_png_root, self.frame_id_to_png_filename[frame_id] + ) + + # load the mask + masks = PILImage.open(mask_path).convert("P") + masks = np.array(masks) + + object_id = pd.unique(masks.flatten()) + object_id = object_id[object_id != 0] # remove background (0) + + # convert into N binary segmentation masks + binary_segments = {} + for i in object_id: + bs = masks == i + binary_segments[i] = torch.from_numpy(bs) + + return binary_segments + + def __len__(self): + return + + +class MultiplePNGSegmentLoader: + def __init__(self, video_png_root, single_object_mode=False): + """ + video_png_root: the folder contains all the masks stored in png + single_object_mode: whether to load only a single object at a time + """ + self.video_png_root = video_png_root + self.single_object_mode = single_object_mode + # read a mask to know the resolution of the video + if self.single_object_mode: + tmp_mask_path = glob.glob(os.path.join(video_png_root, "*.png"))[0] + else: + tmp_mask_path = glob.glob(os.path.join(video_png_root, "*", "*.png"))[0] + tmp_mask = np.array(PILImage.open(tmp_mask_path)) + self.H = tmp_mask.shape[0] + self.W = tmp_mask.shape[1] + if self.single_object_mode: + self.obj_id = ( + int(video_png_root.split("/")[-1]) + 1 + ) # offset by 1 as bg is 0 + else: + self.obj_id = None + + def load(self, frame_id): + if self.single_object_mode: + return self._load_single_png(frame_id) + else: + return self._load_multiple_pngs(frame_id) + + def _load_single_png(self, frame_id): + """ + load single png from the disk (path: f'{self.obj_id}/{frame_id:05d}.png') + Args: + frame_id: int, define the mask path + Return: + binary_segments: dict + """ + mask_path = os.path.join(self.video_png_root, f"{frame_id:05d}.png") + binary_segments = {} + + if os.path.exists(mask_path): + mask = np.array(PILImage.open(mask_path)) + else: + # if png doesn't exist, empty mask + mask = np.zeros((self.H, self.W), dtype=bool) + binary_segments[self.obj_id] = torch.from_numpy(mask > 0) + return binary_segments + + def _load_multiple_pngs(self, frame_id): + """ + load multiple png masks from the disk (path: f'{obj_id}/{frame_id:05d}.png') + Args: + frame_id: int, define the mask path + Return: + binary_segments: dict + """ + # get the path + all_objects = sorted(glob.glob(os.path.join(self.video_png_root, "*"))) + num_objects = len(all_objects) + assert num_objects > 0 + + # load the masks + binary_segments = {} + for obj_folder in all_objects: + # obj_folder is {video_name}/{obj_id}, obj_id is specified by the name of the folder + obj_id = int(obj_folder.split("/")[-1]) + obj_id = obj_id + 1 # offset 1 as bg is 0 + mask_path = os.path.join(obj_folder, f"{frame_id:05d}.png") + if os.path.exists(mask_path): + mask = np.array(PILImage.open(mask_path)) + else: + mask = np.zeros((self.H, self.W), dtype=bool) + binary_segments[obj_id] = torch.from_numpy(mask > 0) + + return binary_segments + + def __len__(self): + return + + +class LazySegments: + """ + Only decodes segments that are actually used. + """ + + def __init__(self): + self.segments = {} + self.cache = {} + + def __setitem__(self, key, item): + self.segments[key] = item + + def __getitem__(self, key): + if key in self.cache: + return self.cache[key] + rle = self.segments[key] + mask = torch.from_numpy(mask_utils.decode([rle])).permute(2, 0, 1)[0] + self.cache[key] = mask + return mask + + def __contains__(self, key): + return key in self.segments + + def __len__(self): + return len(self.segments) + + def keys(self): + return self.segments.keys() + + +class SA1BSegmentLoader: + def __init__( + self, + video_mask_path, + mask_area_frac_thresh=1.1, + video_frame_path=None, + uncertain_iou=-1, + ): + with open(video_mask_path, "r") as f: + self.frame_annots = json.load(f) + + if mask_area_frac_thresh <= 1.0: + # Lazily read frame + orig_w, orig_h = PILImage.open(video_frame_path).size + area = orig_w * orig_h + + self.frame_annots = self.frame_annots["annotations"] + + rle_masks = [] + for frame_annot in self.frame_annots: + if not frame_annot["area"] > 0: + continue + if ("uncertain_iou" in frame_annot) and ( + frame_annot["uncertain_iou"] < uncertain_iou + ): + # uncertain_iou is stability score + continue + if ( + mask_area_frac_thresh <= 1.0 + and (frame_annot["area"] / area) >= mask_area_frac_thresh + ): + continue + rle_masks.append(frame_annot["segmentation"]) + + self.segments = LazySegments() + for i, rle in enumerate(rle_masks): + self.segments[i] = rle + + def load(self, frame_idx): + return self.segments diff --git a/ref-avs.code/dataloader/visual/visual_augmentation.py b/ref-avs.code/dataloader/visual/visual_augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..5d40aed7c8b8c08d50a46db122e1213bd4878afd --- /dev/null +++ b/ref-avs.code/dataloader/visual/visual_augmentation.py @@ -0,0 +1,140 @@ +import random + +import matplotlib.pyplot as plt +import numpy +import torch +import torchvision.transforms.functional as F +import torchvision.transforms as transforms + + +class Augmentation(object): + def __init__(self, image_mean, image_std, image_width, image_height, scale_list, ignore_index=255): + self.image_size = (image_height, image_width) + # self.image_norm = (image_mean, image_std) + # self.get_crop_pos = transforms.RandomCrop(self.image_size) + self.color_jitter = transforms.ColorJitter(brightness=.5, contrast=.5, saturation=.5, hue=.25) + self.gaussian_blurring = transforms.GaussianBlur((3, 3)) + self.scale_list = scale_list + + self.normalise = transforms.Normalize(mean=image_mean, std=image_std) + self.to_tensor = transforms.ToTensor() + + self.ignore_index = ignore_index + + # self.normalise = transforms.Normalize(mean=image_mean, std=image_std) + + # if setup == "avs" or setup == "avss" or setup == "avss_binary": + # # AVS + # self.scale_list = [.5, .75, 1.] + # self.color_jitter = None + # else: + # # COCO + # # self.scale_list = [.75, 1., 1.25, 1.5, 1.75, 2.] + # self.scale_list = [0.5,0.75,1.0,1.25,1.5,1.75,2.0] + + # def normalise(self, image): + # image = image / 255.0 + # image = image - self.image_norm[0] + # image = image / self.image_norm[1] + # return image + + def resize(self, image_, label_, size=None): + h_, w_ = self.image_size if size is None else size + image_ = F.resize(image_, (h_, w_), transforms.InterpolationMode.BICUBIC) + label_ = F.resize(label_, (h_, w_), transforms.InterpolationMode.NEAREST) + return image_, label_ + + def random_crop_with_padding(self, image_, label_): + w_, h_ = image_.size + if min(h_, w_) < min(self.image_size): + res_w_ = max(self.image_size[0] - w_, 0) + res_h_ = max(self.image_size[1] - h_, 0) + image_ = F.pad(image_, [0, 0, res_w_, res_h_], fill=(numpy.array(self.image_norm[0]) * 255.).tolist()) + # image_ = F.pad(image_, [0, 0, res_w_, res_h_], fill=self.ignore_index) # if error, define the padding value. + label_ = F.pad(label_, [0, 0, res_w_, res_h_], fill=self.ignore_index) + + pos_ = self.get_crop_pos.get_params(image_, self.image_size) + image_ = F.crop(image_, *pos_) + label_ = F.crop(label_, *pos_) + + return image_, label_ + + # @staticmethod + def random_scales(self, image_, label_): + w_, h_ = image_.size + chosen_scale = random.choice(self.scale_list) + w_, h_ = int(w_ * chosen_scale), int(h_ * chosen_scale) + image_ = F.resize(image_, (h_, w_), transforms.InterpolationMode.BICUBIC) + label_ = F.resize(label_, (h_, w_), transforms.InterpolationMode.NEAREST) + return image_, label_ + + @staticmethod + def random_flip_h(image_, label_): + chosen_flip = random.random() > 0.5 + image_ = F.hflip(image_) if chosen_flip else image_ + label_ = F.hflip(label_) if chosen_flip else label_ + return image_, label_ + + def augment_entire_clip(self, x_list, y_list): + degree_ = float(torch.empty(1).uniform_(float(-25.), float(25.)).item()) + shear_ = [float(torch.empty(1).uniform_(float(-20.), float(20.)).item()), + torch.empty(1).uniform_(float(-20.), float(20.)).item()] + dice = random.random() + for index, single_x in enumerate(x_list): + if dice <= 0.1: + single_x = F.rgb_to_grayscale(single_x, num_output_channels=3) + + single_x = F.affine(single_x, angle=degree_, shear=shear_, translate=[0,0], scale=1., + interpolation=transforms.InterpolationMode.BILINEAR, fill=[0., 0., 0.]) + single_y = F.affine(y_list[index], angle=degree_, shear=shear_, translate=[0,0], scale=1., + interpolation=transforms.InterpolationMode.NEAREST, fill=[0.]) + x_list[index] = single_x + y_list[index] = single_y + + return x_list, y_list + + + + + def train_aug(self, x_, y_): + x_, y_ = self.random_flip_h(x_, y_) + # # x, y = self.random_scales(x, y) + x_, y_ = self.resize(x_, y_) + + if self.color_jitter is not None and random.random() < 0.5: + x_ = self.color_jitter(x_) + if self.gaussian_blurring is not None and random.random() < 0.5: + x_ = self.gaussian_blurring(x_) + + # x, y = self.random_crop_with_padding(x, y) + + x_ = self.normalise(self.to_tensor(x_)).type(torch.float32) + # receive pseudo labels. + y_ = torch.tensor(numpy.array(y_)[numpy.newaxis, ...], dtype=torch.float) + return x_, y_ + + def test_process(self, x_, y_): + # x = self.to_tensor(x) + # y = torch.tensor(numpy.asarray(y)).long() + + # following AVSbench setup, we fix image size (224, 224) + x_, y_ = self.resize(x_, y_) + + x_ = self.normalise(self.to_tensor(x_)).type(torch.float32) + y_ = torch.tensor(numpy.array(y_)[numpy.newaxis, ...], dtype=torch.float) + return x_, y_ + + def __call__(self, x, y, split): + return self.train_aug(x, y) if split == "train" \ + else self.test_process(x, y) + + + + + + + + + + + diff --git a/ref-avs.code/dataloader/visual/visual_dataset.py b/ref-avs.code/dataloader/visual/visual_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a103a0234445c1f7ffb238594ef896438cdab8b5 --- /dev/null +++ b/ref-avs.code/dataloader/visual/visual_dataset.py @@ -0,0 +1,112 @@ +import os +import re +import PIL.Image +import matplotlib.pyplot as plt +import numpy +import torch +import pandas +import torchvision + + +class Visual(torch.utils.data.Dataset): + def __init__(self, augmentation, directory_path, split, image_size, image_embedding_size): + self.augment = augmentation + self.directory_path = directory_path + self.split = split + self.image_size = image_size + self.embedding_size = image_embedding_size + + def get_frame_and_label(self, file_prefix, object_id): + # if self.split == 'null': + # frame_path = os.path.join(self.directory_path, 'media_cross', file_prefix, 'frames') + # frame_path = [os.path.join(frame_path, i) for i in os.listdir(frame_path)] + # frame_path.sort(key=lambda x: tuple(map(int, x.split('/')[-1].split("_")[-1].split('.jpg')[0]))) + # # dummy empty label. + # frame = [PIL.Image.open(i) for i in frame_path] + # label = [PIL.Image.new('L', frame[0].size)] * len(frame) + # else: + frame_path = os.path.join(self.directory_path, 'media', file_prefix, 'frames') + label_path = os.path.join(self.directory_path, 'gt_mask', file_prefix, 'fid_{}'.format(str(object_id))) + frame_path = [os.path.join(frame_path, i) for i in os.listdir(frame_path)] + label_path = [os.path.join(label_path, i) for i in os.listdir(label_path)] + frame_path.sort(key=lambda x: tuple(map(int, x.split('/')[-1].split("_")[-1].split('.jpg')[0]))) + label_path.sort(key=lambda x: tuple(map(int, x.split('/')[-1].split("_")[-1].split('.png')[0]))) + frame = [PIL.Image.open(i) for i in frame_path] + label = [PIL.Image.open(i).convert('L') for i in label_path] + return frame, label + + def load_data(self, file_prefix, object_id): + frame, label = self.get_frame_and_label(file_prefix, object_id) + label_idx = torch.tensor(list([1] * 10), dtype=torch.bool) + + prompts = {} + image_batch = [None]*len(frame) + label_batch = [None]*len(frame) + + if self.split == 'train': + # apply sam2 augmentation. + frame, label = self.augment(frame, label) + + for i in range(len(frame)): + if 'test_' in self.split: + # note: there is no augmentation in here. + curr_frame, curr_label = self.augment(frame[i], label[i], split=self.split) + else: + curr_frame, curr_label = frame[i], label[i] + + curr_label[curr_label > 0.] = 1. + image_batch[i], label_batch[i] = curr_frame, curr_label + + # image_batch[i], label_batch[i] = self.augment(frame[i], label[i], split=self.split) + # note: we simply convert the code to binary mask in v1s, v1m; + # to some reason, we failed to load the label in `L' format and had to hardcoding here. + # label_batch[i][label_batch[i] > 0.] = 1. + + # prompts['box_coords'][i], prompts['masks'][i] = self.receive_other_prompts(label_batch[i]) + + # organise the prompts + # prompts.update({'masks': torch.stack(prompts['masks'], dim=0)}) + # prompts.update({'box_coords': torch.stack(prompts['box_coords'], dim=0)}) + # prompts.update({'point_labels': torch.stack(prompts['point_labels'], dim=0)}) + prompts.update({'label_index': label_idx}) + return torch.stack(image_batch, dim=0), torch.stack(label_batch, dim=0), prompts + + def receive_other_prompts(self, y_): + # y_ = torch.zeros_like(y_) + if len(torch.unique(y_)) > 1: + # foreground point + points_foreground = torch.stack(torch.where(y_ > 0)[::-1], dim=0).transpose(1, 0) + + # bbox prompt (left-top corner & right-bottom corner) + bbox_one = torch.min(points_foreground[:, 0]), torch.min(points_foreground[:, 1]) + bbox_fou = torch.max(points_foreground[:, 0]), torch.max(points_foreground[:, 1]) + bbox_coord = torch.tensor(bbox_one + bbox_fou, dtype=torch.float) + bbox_coord = self.transform_coords(bbox_coord, orig_hw=y_.squeeze().shape) + # mask prompt + low_mask = torchvision.transforms.functional.resize(y_.clone(), [self.embedding_size*4, self.embedding_size*4], + torchvision.transforms.InterpolationMode.NEAREST) + else: + # for the pure background situation. + bbox_coord = torch.zeros([4], dtype=torch.float).fill_(float('nan')) + low_mask = torch.zeros([1, self.embedding_size*4, self.embedding_size*4], dtype=torch.float).fill_(float('nan')) + + return bbox_coord, low_mask + + # we transfer the coords to SAM's input resolution (1024, 1024). + def transform_coords(self, coords: torch.Tensor, orig_hw=None) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, + If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + + Returns + Un-normalized coordinates in the range of [0, 1] which is expected by the sam2 model. + """ + h, w = orig_hw + coords = coords.clone().reshape(-1, 2, 2) + coords[..., 0] = coords[..., 0] / w + coords[..., 1] = coords[..., 1] / h + coords = coords * self.image_size # unnormalize coords + return coords.reshape(4) + + + diff --git a/ref-avs.code/inference.py b/ref-avs.code/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..2b7a5f188b2132e112f1173b43d6d357ab9f92e6 --- /dev/null +++ b/ref-avs.code/inference.py @@ -0,0 +1,201 @@ +"""Distributed inference on Ref-AVS (test_s / test_u / test_n); uses Trainer.valid / valid_null like main.py.""" +import os +import pathlib +import argparse +import random + +import numpy +import torch +from easydict import EasyDict + + +_real_mkdir = pathlib.Path.mkdir + + +def _safe_mkdir(self, mode=0o777, parents=False, exist_ok=False): + try: + return _real_mkdir(self, mode, parents, exist_ok=exist_ok) + except PermissionError: + pass + + +pathlib.Path.mkdir = _safe_mkdir + + +def seed_it(seed): + random.seed(seed) + os.environ["PYTHONSEED"] = str(seed) + numpy.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.enabled = True + + +class _DummyTensorboard: + """Minimal Tensorboard stub so Trainer.valid / valid_null run without wandb logging.""" + + def upload_wandb_info(self, info_dict): + pass + + def upload_wandb_image(self, *args, **kwargs): + pass + + +def main(local_rank, ngpus_per_node, hyp_param): + hyp_param.local_rank = local_rank + torch.distributed.init_process_group( + backend='nccl', + init_method='env://', + rank=hyp_param.local_rank, + world_size=hyp_param.gpus, + ) + seed_it(local_rank + hyp_param.seed) + torch.cuda.set_device(hyp_param.local_rank) + + import model.visual.sam2 # noqa: F401 — registers Hydra config store + from hydra import compose + from omegaconf import OmegaConf + + arch_h = compose(config_name='configs/auralfuser/architecture.yaml') + OmegaConf.resolve(arch_h) + hyp_param.aural_fuser = OmegaConf.to_container(arch_h.aural_fuser, resolve=True) + + train_cfg = compose(config_name='configs/training/sam2_training_config.yaml') + OmegaConf.resolve(train_cfg) + hyp_param.contrastive_learning = OmegaConf.to_container(train_cfg.contrastive_learning, resolve=True) + + hyp_param.image_size = 1024 + hyp_param.image_embedding_size = int(hyp_param.image_size / 16) + + from model.mymodel import AVmodel + av_model = AVmodel(hyp_param).cuda(hyp_param.local_rank) + if not hyp_param.inference_ckpt: + raise ValueError("--inference_ckpt is required for inference.") + + ckpt_sd = torch.load(hyp_param.inference_ckpt, map_location="cpu") + if not isinstance(ckpt_sd, dict): + raise TypeError("Checkpoint must be a state_dict dictionary.") + if any(k.startswith("v_model.") or k.startswith("aural_fuser.") for k in ckpt_sd): + av_model.load_state_dict(ckpt_sd, strict=True) + else: + av_model.aural_fuser.load_state_dict(ckpt_sd, strict=True) + + av_model = torch.nn.parallel.DistributedDataParallel( + av_model, device_ids=[hyp_param.local_rank], find_unused_parameters=False, + ) + av_model.eval() + + from dataloader.dataset import AV + from dataloader.visual.visual_augmentation import Augmentation as VisualAugmentation + from dataloader.audio.audio_augmentation import Augmentation as AudioAugmentation + from torch.utils.data import DataLoader, Subset + from torch.utils.data.distributed import DistributedSampler + + visual_aug = VisualAugmentation( + hyp_param.image_mean, hyp_param.image_std, + hyp_param.image_size, hyp_param.image_size, + hyp_param.scale_list, ignore_index=hyp_param.ignore_index, + ) + audio_aug = AudioAugmentation(mono=True) + + max_batches = getattr(hyp_param, "inference_max_batches", 0) or 0 + val_batch_size = getattr(hyp_param, "inference_val_batch_size", 4) + + def _test_loader(split): + ds = AV( + split=split, + augmentation={"visual": visual_aug, "audio": audio_aug}, + param=hyp_param, + root_path=hyp_param.data_root_path, + ) + if max_batches > 0: + n_samples = min(max_batches * val_batch_size, len(ds)) + ds = Subset(ds, range(n_samples)) + sampler = DistributedSampler(ds, shuffle=False) + return DataLoader( + ds, + batch_size=val_batch_size, + sampler=sampler, + num_workers=hyp_param.num_workers, + ) + + test_s_loader = _test_loader('test_s') + test_u_loader = _test_loader('test_u') + test_n_loader = _test_loader('test_n') + + from trainer.train import Trainer + from utils.foreground_iou import ForegroundIoU + from utils.foreground_fscore import ForegroundFScore + from utils.foreground_s import ForegroundS + + metrics = { + "foreground_iou": ForegroundIoU(), + "foreground_f-score": ForegroundFScore(hyp_param.local_rank), + "foreground_s": ForegroundS(), + } + trainer = Trainer(hyp_param, loss=None, tensorboard=_DummyTensorboard(), metrics=metrics) + + test_s_iou, test_s_f = trainer.valid( + epoch=0, dataloader=test_s_loader, model=av_model, process='test_s', + ) + torch.cuda.empty_cache() + + test_u_iou, test_u_f = trainer.valid( + epoch=0, dataloader=test_u_loader, model=av_model, process='test_u', + ) + torch.cuda.empty_cache() + + test_n_s = trainer.valid_null( + epoch=0, dataloader=test_n_loader, model=av_model, process='test_n', + ) + torch.cuda.empty_cache() + + if hyp_param.local_rank <= 0: + print("\n========== Ref-AVS inference (same splits / metrics as training valid) ==========") + print(" test_s f_iou={} f_f-score={}".format(test_s_iou, test_s_f)) + print(" test_u f_iou={} f_f-score={}".format(test_u_iou, test_u_f)) + print(" test_n f_s={}".format(test_n_s)) + print("=======================================================\n") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Ref-AVS inference: test_s / test_u / test_n') + + parser.add_argument('--local_rank', type=int, default=-1, + help='multi-process training for DDP') + parser.add_argument('-g', '--gpus', default=1, type=int, + help='number of gpus per node') + parser.add_argument('--batch_size', default=1, type=int, + help='unused at inference (validation uses inference_val_batch_size)') + parser.add_argument('--epochs', default=80, type=int, help='unused') + parser.add_argument('--lr', default=1e-5, type=float, help='unused') + parser.add_argument('--online', action='store_true', help='unused') + parser.add_argument( + '--inference_ckpt', type=str, required=True, + help='Trained AuralFuser checkpoint (.pth). SAM2 from backbone_weight in configs.', + ) + parser.add_argument('--inference_max_batches', type=int, default=0, + help='0 = full split; >0 = first N batches per split (debug)') + parser.add_argument('--inference_val_batch_size', type=int, default=4, + help='Validation batch size (default 4, same as main.py _test_loader)') + + args = parser.parse_args() + + from configs.config import C + args = EasyDict({**C, **vars(args)}) + + _repo = pathlib.Path(__file__).resolve().parent + _workspace = _repo.parent + args.data_root_path = str(_workspace / 'REFAVS') + args.backbone_weight = str(_workspace / 'ckpts' / 'sam_ckpts' / 'sam2_hiera_large.pt') + args.audio.PRETRAINED_VGGISH_MODEL_PATH = str(_workspace / 'ckpts' / 'vggish-10086976.pth') + args.saved_dir = '/tmp/ref_avs_infer_ckpt' + pathlib.Path(args.saved_dir).mkdir(parents=True, exist_ok=True) + + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '9902' + + torch.multiprocessing.spawn(main, nprocs=args.gpus, args=(args.gpus, args)) diff --git a/ref-avs.code/loss/misc.py b/ref-avs.code/loss/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..874d9805b482f52bbffc1be620e36e0cffc07c46 --- /dev/null +++ b/ref-avs.code/loss/misc.py @@ -0,0 +1,111 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +from typing import List, Optional + +import torch +import torch.distributed as dist +import torchvision +from torch import Tensor + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], : img.shape[2]] = False + else: + raise ValueError("not supported") + return NestedTensor(tensor, mask) + + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max( + torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) + ).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True diff --git a/ref-avs.code/loss/training/contrastive_learning.py b/ref-avs.code/loss/training/contrastive_learning.py new file mode 100644 index 0000000000000000000000000000000000000000..b45e48c52727cdb0d0f3e678b91e1ced349a5bda --- /dev/null +++ b/ref-avs.code/loss/training/contrastive_learning.py @@ -0,0 +1,545 @@ +"""Contrastive loss used during SAM2 + fusion training (config from Hydra `contrastive_learning`, tmp.code style).""" +import torch +from abc import ABC +import torch.nn as nn + + +class ContrastLoss(nn.Module, ABC): + def __init__(self, hyp_param): + super(ContrastLoss, self).__init__() + self.param = hyp_param + _defaults = { + "temperature": 0.10, + "ignore_idx": 255, + "ood_idx": 254, + "max_views": 512, + "proj_dim": 512, + "sample_limits": 64, + "total_limits": 15240, + } + _raw = getattr(hyp_param, "contrastive_learning", None) or {} + _cfg = {**_defaults, **_raw} + self.temperature = _cfg["temperature"] + self.ignore_idx = _cfg["ignore_idx"] + self.ood_idx = _cfg["ood_idx"] + self.max_views = _cfg["max_views"] + self.proj_dim = _cfg["proj_dim"] + self.sample_limits = _cfg["sample_limits"] + self.total_limits = _cfg["total_limits"] + + def select_class_wise_samples(self, embeddings, audio_embeddings, predictions, masks, batch_idx): + embedding_sample_list = [] + label_list = [] + embedding_sample_list_a = [] + label_list_a = [] + class_index_list = torch.unique(masks) + # means not silence + if len(class_index_list) > 1: + for class_index in class_index_list[1:]: + embedding_sample_list_a.append(audio_embeddings.unsqueeze(0)) + label_list_a.append(class_index.unsqueeze(0) + batch_idx * 1e3) + else: + embedding_sample_list_a.append(audio_embeddings.unsqueeze(0)) + label_list_a.append(torch.zeros([1], device=embeddings.device) + batch_idx * 1e3) + + # contras_list = [] + # contras_label_list = [] + sample_limits = self.sample_limits + # we only have 0, 1 + embeddings = embeddings.permute(1, 0) + for class_index in class_index_list: + hard_indices = embeddings[((masks != predictions) & (masks == class_index)).nonzero()] + easy_indices = embeddings[((masks == predictions) & (masks == class_index)).nonzero()] + + hard_indices_num, easy_indices_num = hard_indices.shape[0], easy_indices.shape[0] + + # the number that is selected to the contrastive learning. + selective_num_hard = min(sample_limits, hard_indices_num) + selective_num_easy = min(sample_limits, easy_indices_num) + + if (selective_num_hard + selective_num_easy) < sample_limits * 2: + if selective_num_hard > selective_num_easy: + selective_num_hard += sample_limits * 2 - selective_num_easy + else: + selective_num_easy += sample_limits * 2 - selective_num_hard + + # skip if contains too limited samples. + # if selective_num_hard < 10 and selective_num_easy < 10: + # continue + hard_chosen_indices = torch.randperm(hard_indices_num)[:selective_num_hard] + embedding_sample_list.append(hard_indices[hard_chosen_indices]) + label_list.append(masks[hard_chosen_indices] + batch_idx * 1e3) + + # add negative features to list. + easy_chosen_indices = torch.randperm(easy_indices_num)[:selective_num_easy] + embedding_sample_list.append(easy_indices[easy_chosen_indices]) + label_list.append(masks[easy_chosen_indices] + batch_idx * 1e3) + return embedding_sample_list, label_list, embedding_sample_list_a, label_list_a + + def forward_audio_visual(self, visual_embeddings, audio_embeddings, masks, predictions): + masks = masks.flatten(start_dim=1) + predictions = predictions.flatten(start_dim=1) + visual_embeddings = visual_embeddings.flatten(start_dim=-2) + + visual_embedding_sample_list = [] + visual_label_list = [] + audio_embedding_sample_list = [] + audio_label_list = [] + + for frame_idx in range(masks.shape[0]): + current_vision_feats = visual_embeddings[frame_idx] + current_masks = masks[frame_idx] + current_predictions = predictions[frame_idx] + current_audio_feats = audio_embeddings[frame_idx] + for layer_idx in range(3): + (selected_vision_embeddings, selected_vision_labels, + selected_audio_embeddings, selected_audio_labels) = self.select_class_wise_samples(current_vision_feats[layer_idx], + current_audio_feats[layer_idx], + current_predictions, + current_masks, + 0) + + visual_embedding_sample_list += selected_vision_embeddings + visual_label_list += selected_vision_labels + + audio_embedding_sample_list += selected_audio_embeddings + audio_label_list += selected_audio_labels + + if len(visual_embedding_sample_list) == 0: return 0. + visual_embedding_sample_list = torch.cat(visual_embedding_sample_list, dim=0).squeeze() + visual_label_list = torch.cat(visual_label_list, dim=0).unsqueeze(-1) + audio_embedding_sample_list = torch.cat(audio_embedding_sample_list, dim=0).squeeze() + audio_label_list = torch.cat(audio_label_list).unsqueeze(1) + + # print(visual_embedding_sample_list.shape, visual_label_list.shape) + # print(audio_embedding_sample_list.shape, audio_label_list.shape) + # exit(1) + total_limits = self.total_limits + if visual_embedding_sample_list.shape[0] > total_limits: + rand_index = torch.randperm(visual_embedding_sample_list.shape[0])[total_limits] + visual_embedding_sample_list = visual_embedding_sample_list[:rand_index] + visual_label_list = visual_label_list[:rand_index] + loss = self.info_nce(visual_embedding_sample_list, visual_label_list, audio_embedding_sample_list, + audio_label_list) + return loss + + + # proof the q-project CAN BE the projector head of the contrastive learning. + # At the moment, I do believe the ATTENTION is the another format of the contrastive learning. + # First experiment: ignore the sound, only work on the projected vision mask. + def forward(self, embeddings, output_dicts, masks): + predictions = torch.cat([i['multistep_pred_masks'] for i in output_dicts]) + predictions = torch.nn.functional.interpolate(predictions, size=(int(self.param.image_size/16), int(self.param.image_size/16)), + mode='bilinear', align_corners=False).squeeze(1) + masks = torch.nn.functional.interpolate(masks.unsqueeze(1), size=(int(self.param.image_size/16), int(self.param.image_size/16)), + mode='nearest').squeeze(1) + visual_embeddings, audio_embeddings = embeddings + # if len(predictions.shape) < 3 and len(masks.shape) < 3: + # predictions = predictions.unsqueeze(0) + # masks = masks.unsqueeze(0) + + visual_embeddings = torch.cat([torch.cat([visual_embeddings[0][i].unsqueeze(0), + visual_embeddings[1][i].unsqueeze(0), + visual_embeddings[2][i].unsqueeze(0)]).unsqueeze(0) + for i in range(masks.shape[0])]) + audio_embeddings = torch.cat([torch.cat([audio_embeddings[0][i].unsqueeze(0), + audio_embeddings[1][i].unsqueeze(0), + audio_embeddings[2][i].unsqueeze(0)]).unsqueeze(0) + for i in range(masks.shape[0])]) + + # dict_keys(['point_inputs', 'mask_inputs', 'multistep_pred_masks', 'multistep_pred_masks_high_res', + # 'multistep_pred_multimasks', 'multistep_pred_multimasks_high_res', 'multistep_pred_ious', + # 'multistep_point_inputs', 'multistep_object_score_logits', 'pred_masks', 'pred_masks_high_res', + # 'maskmem_features', 'maskmem_pos_enc']) + return self.forward_audio_visual(visual_embeddings, audio_embeddings.squeeze(-1), masks, predictions) + + # def forward_visual_only(self, visual_embeddings, masks, predictions): + # masks = masks.flatten(start_dim=1) + # predictions = predictions.flatten(start_dim=1) + # visual_embeddings = visual_embeddings.flatten(start_dim=-2) + # + # visual_embedding_sample_list = [] + # visual_label_list = [] + # audio_embedding_sample_list = [] + # audio_label_list = [] + # + # for frame_idx in range(masks.shape[0]): + # current_vision_feats = visual_embeddings[frame_idx] + # current_masks = masks[frame_idx] + # current_predictions = predictions[frame_idx] + # for layer_idx in range(3): + # current_select_embeddings, current_select_labels = self.select_class_wise_samples(current_vision_feats[layer_idx], + # None, + # current_predictions, + # current_masks, + # frame_idx) + # visual_embedding_sample_list += current_select_embeddings + # visual_label_list += current_select_labels + # + # + # + # if len(embedding_sample_list) == 0: return 0. + # embedding_sample_list = torch.cat(embedding_sample_list, dim=0).squeeze() + # label_list = torch.cat(label_list, dim=0).unsqueeze(-1) + # total_limits = 15240 + # if embedding_sample_list.shape[0] > total_limits: + # rand_index = torch.randperm(embedding_sample_list.shape[0])[total_limits] + # embedding_sample_list = embedding_sample_list[:rand_index] + # label_list = label_list[:rand_index] + # loss = self.info_nce(embedding_sample_list, label_list, embedding_sample_list, + # label_list) + # return loss + + + """ + # embeddings_size = (int(self.param.image_size/16), int(self.param.image_size/16)) + # masks = torch.nn.functional.interpolate(masks.float(), embeddings_size, mode='nearest') + # masks = masks.flatten(start_dim=1) + # predictions = torch.nn.functional.interpolate(predictions.float(), embeddings_size, mode='nearest') + # predictions = predictions.flatten(start_dim=1) + # + # embedding_sample_list = [] + # label_list = [] + # contras_sample_list = [] + # contras_label_list = [] + + # temp3. + # embedding_visual, embedding_audio = embeddings + # embedding_visual = torch.nn.functional.normalize(embedding_visual, p=2, dim=1) + # embedding_audio = torch.nn.functional.normalize(embedding_audio, p=2, dim=1) + # embedding_visual = embedding_visual.reshape(self.param.batch_size, int(embedding_visual.shape[0]/self.param.batch_size), + # *embedding_visual.shape[-2:]) + # + # embedding_audio = embedding_audio.reshape(self.param.batch_size, int(embedding_audio.shape[0]/self.param.batch_size), + # *embedding_audio.shape[-2:]) + # masks = masks.reshape(self.param.batch_size, int(masks.shape[0]/self.param.batch_size), + # masks.shape[-1]) + # predictions = predictions.reshape(self.param.batch_size, int(predictions.shape[0]/self.param.batch_size), + # predictions.shape[-1]) + # + # for batch_idx in range(masks.shape[0]): + # current_video_clip_embed = embedding_visual[batch_idx] + # current_video_clip_masks = masks[batch_idx] + # current_video_clip_preds = predictions[batch_idx] + # current_audio_clip_embed = embedding_audio[batch_idx] + # # print(current_video_clip_embed.shape, current_audio_clip_embed.shape, current_video_clip_masks.shape, current_video_clip_preds.shape) + # # exit(1) + # for sample_idx in range(masks.shape[1]): + # current_vision_feats = current_video_clip_embed[batch_idx] + # current_audio_feats = current_audio_clip_embed[batch_idx] + # current_masks = current_video_clip_masks[batch_idx] + # current_predictions = current_video_clip_preds[batch_idx] + # current_select_embeddings, current_select_labels = self.select_class_wise_samples(current_vision_feats, + # current_audio_feats, + # current_predictions, + # current_masks, + # batch_idx) + + # temp2. + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) + + embeddings = embeddings.reshape(self.param.batch_size, int(embeddings.shape[0]/self.param.batch_size), + *embeddings.shape[-2:]) + masks = masks.reshape(self.param.batch_size, int(masks.shape[0]/self.param.batch_size), + masks.shape[-1]) + predictions = predictions.reshape(self.param.batch_size, int(predictions.shape[0]/self.param.batch_size), + predictions.shape[-1]) + + for batch_idx in range(masks.shape[0]): + current_video_clip_embed = embeddings[batch_idx] + current_video_clip_masks = masks[batch_idx] + current_video_clip_preds = predictions[batch_idx] + # current_audio_clip_feats = + for sample_idx in range(masks.shape[1]): + current_vision_feats = current_video_clip_embed[batch_idx] + current_masks = current_video_clip_masks[batch_idx] + current_predictions = current_video_clip_preds[batch_idx] + current_select_embeddings, current_select_labels = self.select_class_wise_samples(current_vision_feats, + current_predictions, + current_masks, + batch_idx) + embedding_sample_list += current_select_embeddings + label_list += current_select_labels + # hard_indices = current_vision_feats[(current_masks != current_predictions).nonzero()] + # easy_indices = current_vision_feats[(current_masks == current_predictions).nonzero()] + # + # hard_indices_num, easy_indices_num = hard_indices.shape[0], easy_indices.shape[0] + # + # # the number that is selected to the contrastive learning. + # selective_num_hard = min(sample_limits, hard_indices_num) + # selective_num_easy = min(sample_limits, easy_indices_num) + # # skip if contains too limited samples. + # if selective_num_hard < 10 or selective_num_easy < 10: + # continue + # + # hard_chosen_indices = torch.randperm(hard_indices_num)[:selective_num_hard] + # embedding_sample_list.append(hard_indices[hard_chosen_indices]) + # label_list.append(current_masks[hard_chosen_indices] + batch_idx * 1e3) + # + # # add negative features to list. + # easy_chosen_indices = torch.randperm(easy_indices_num)[:selective_num_easy] + # embedding_sample_list.append(easy_indices[easy_chosen_indices]) + # label_list.append(current_masks[easy_chosen_indices] + batch_idx * 1e3) + + if len(embedding_sample_list) == 0: return 0. + embedding_sample_list = torch.cat(embedding_sample_list, dim=0).squeeze() + label_list = torch.cat(label_list, dim=0).unsqueeze(-1) + total_limits = self.total_limits + if embedding_sample_list.shape[0] > total_limits: + rand_index = torch.randperm(embedding_sample_list.shape[0])[total_limits] + embedding_sample_list = embedding_sample_list[:rand_index] + label_list = label_list[:rand_index] + loss = self.info_nce(embedding_sample_list, label_list, embedding_sample_list, + label_list) + + # temp. + # sample_limits = 500 + # for batch_idx in range(masks.shape[0]): + # # go through 3 layers embeddings. + # for j in range(len(embeddings)): + # current_vision_feats_list = embeddings[j] + # current_vision_feats = torch.nn.functional.normalize(current_vision_feats_list[batch_idx], p=2, dim=1) + # current_masks = masks[batch_idx] + # positive_indices = current_vision_feats[current_masks > 0, ...] + # negative_indices = current_vision_feats[current_masks == 0, ...] + # positive_indices_num, negative_indices_num = positive_indices.shape[0], negative_indices.shape[0] + # + # # the number that is selected to the contrastive learning. + # selective_num = min(sample_limits, positive_indices_num, negative_indices_num) + # if selective_num < 50: continue # skip if contains too limited samples. + # + # embedding_sample_list.append(positive_indices[torch.randperm(positive_indices_num)[:selective_num]]) + # label_list.append(torch.tensor([batch_idx + (self.param.local_rank * 100)] * selective_num, + # device=positive_indices.device)) + # + # # add negative features to list. + # negative_sample_list.append(negative_indices[torch.randperm(negative_indices_num)[:selective_num]]) + # negative_label_list.append(torch.tensor([-1] * selective_num, device=negative_indices.device)) + # + # if len(embedding_sample_list) == 0: return 0. + # embedding_sample_list = torch.cat(embedding_sample_list, dim=0) + # negative_sample_list = torch.cat(negative_sample_list, dim=0) + # label_list = torch.cat(label_list) + # negative_label_list = torch.cat(negative_label_list) + # + # loss = self.info_nce(embedding_sample_list, label_list.unsqueeze(-1), + # torch.cat([embedding_sample_list, negative_sample_list], dim=0), + # torch.cat([label_list, negative_label_list]).unsqueeze(-1)) + + # output_list_embeddings = [torch.zeros_like(embedding_sample_list) for _ in range(torch.distributed.get_world_size())] + # output_list_labels = [torch.zeros_like(label_list) for _ in range(torch.distributed.get_world_size())] + # + # torch.distributed.all_gather(output_list_embeddings, embedding_sample_list) + # torch.distributed.all_gather(output_list_labels, label_list) + # + # output_list_embeddings = torch.cat(output_list_embeddings) + # output_list_labels = torch.cat(output_list_labels, dim=1) + # loss = self.info_nce(output_list_embeddings, output_list_labels, output_list_embeddings, output_list_labels) + return loss + """ + # q_max. + # def forward(self, embeddings, masks): + # # for single-sounding obj. only, with first idx mask. + # masks = torch.nn.functional.interpolate(masks.float(), (64, 64), mode='bilinear', align_corners=False) + # masks = masks.flatten(start_dim=1) + # # embedding_sample_list = torch.zeros([masks.shape[0], 128]).to(self.param.local_rank) + # embedding_sample_list = [] + # label_list = [] + # + # negative_sample_list = [] + # negative_label_list = [] + # sample_limits = 20 + # for batch_idx in range(masks.shape[0]): + # # go through 3 layers embeddings. + # for j in range(len(embeddings)): + # current_vision_feats_list, current_audio_feats_list = embeddings[j] + # current_audio_feats = torch.nn.functional.normalize(current_audio_feats_list[batch_idx], p=2, dim=1) + # current_vision_feats = torch.nn.functional.normalize(current_vision_feats_list[batch_idx], p=2, dim=1) + # current_masks = masks[batch_idx] + # + # # add following features to list. + # embedding_sample_list.append(current_vision_feats[current_masks > 0, ...].max(dim=0)[0].unsqueeze(0)) + # label_list.append(batch_idx + (self.param.local_rank * 100)) + # + # embedding_sample_list.append(current_audio_feats) + # label_list.append(batch_idx + (self.param.local_rank * 100)) + # + # # add negative features to list. + # negative_num = min(current_vision_feats[current_masks == 0, ...].shape[0], sample_limits) + # if negative_num < 5: continue # skip if contains too limited samples. + # rand_idx = torch.randperm(current_vision_feats[current_masks == 0, ...].shape[0])[:negative_num] + # negative_sample_list.append(current_vision_feats[current_masks == 0, ...][rand_idx]) + # negative_label_list.append(torch.tensor([-1] * negative_num, device=current_vision_feats.device)) + # + # embedding_sample_list = torch.cat(embedding_sample_list) + # label_list = torch.tensor(label_list, device=masks.device) + # negative_sample_list = torch.cat(negative_sample_list, dim=0) + # negative_label_list = torch.cat(negative_label_list) + # + # loss = self.info_nce(embedding_sample_list, label_list.unsqueeze(-1), + # torch.cat([embedding_sample_list, negative_sample_list], dim=0), + # torch.cat([label_list, negative_label_list]).unsqueeze(-1)) + # + # # output_list_embeddings = [torch.zeros_like(embedding_sample_list) for _ in range(torch.distributed.get_world_size())] + # # output_list_labels = [torch.zeros_like(label_list) for _ in range(torch.distributed.get_world_size())] + # # + # # torch.distributed.all_gather(output_list_embeddings, embedding_sample_list) + # # torch.distributed.all_gather(output_list_labels, label_list) + # # + # # output_list_embeddings = torch.cat(output_list_embeddings) + # # output_list_labels = torch.cat(output_list_labels, dim=1) + # # loss = self.info_nce(output_list_embeddings, output_list_labels, output_list_embeddings, output_list_labels) + # return loss + + # attention mean. + # def forward(self, embeddings): + # embedding_sample_list = [] + # label_list = [] + # for layer_embeddings in embeddings: + # embedding_sample_list.append(torch.nn.functional.normalize(layer_embeddings, p=2, dim=1)) + # # currently we only utilise single frame. + # label_list.append(torch.tensor(list(range(0, 1 + 1)) * self.param.batch_size) + (self.param.local_rank * 100)) + # embedding_sample_list = torch.cat(embedding_sample_list).cuda(self.param.local_rank) + # label_list = torch.cat(label_list).cuda(self.param.local_rank).unsqueeze(0) + # + # """ + # all gather implementation. + # """ + # """ + # output_list_embeddings = [torch.zeros_like(embedding_sample_list) for _ in range(torch.distributed.get_world_size())] + # output_list_labels = [torch.zeros_like(label_list) for _ in range(torch.distributed.get_world_size())] + # + # torch.distributed.all_gather(output_list_embeddings, embedding_sample_list) + # torch.distributed.all_gather(output_list_labels, label_list) + # + # output_list_embeddings = torch.cat(output_list_embeddings) + # output_list_labels = torch.cat(output_list_labels, dim=1) + # loss = self.info_nce(output_list_embeddings, output_list_labels, output_list_embeddings, output_list_labels) + # """ + # loss = self.info_nce(embedding_sample_list, label_list, embedding_sample_list, label_list) + # # frame_token_semantic_attn = torch.nn.functional.normalize(frame_token_semantic_attn.squeeze(), p=2, dim=1) + # # audio_token_attn = torch.nn.functional.normalize(audio_token_attn, p=2, dim=1) + # # city_gt = torch.nn.functional.interpolate(city_gt.unsqueeze(1).float(), size=city_proj.shape[2:], + # # mode='nearest').squeeze().long() + # # + # # ood_gt = torch.nn.functional.interpolate(ood_gt.unsqueeze(1).float(), size=ood_proj.shape[2:], + # # mode='nearest').squeeze().long() + # # + # # # normalise the embed results + # # city_proj = torch.nn.functional.normalize(city_proj, p=2, dim=1) + # # ood_proj = torch.nn.functional.normalize(ood_proj, p=2, dim=1) + # + # # randomly extract embed samples within a batch + # # anchor_embeds, anchor_labels, contrs_embeds, contrs_labels = self.extraction_samples(city_proj, city_gt, + # # ood_proj, ood_gt) + # # + # # # calculate the CoroCL + # # loss = self.info_nce(anchors_=anchor_embeds, a_labels_=anchor_labels.unsqueeze(1), contras_=contrs_embeds, + # # c_labels_=contrs_labels.unsqueeze(1)) if anchor_embeds.nelement() > 0 else \ + # # torch.tensor([.0], device=city_proj.device) + # + # return loss + @staticmethod + def manipulate_cover_mask(a_label, current_mask): + # shifting current visual index value + # background:=1, foreground:=2. + a_label = a_label + 1 + visual_mask = torch.matmul(a_label, torch.transpose(a_label, 0, 1)) + # kicked out the positive value in same visual class. + current_mask[:visual_mask.shape[1], :visual_mask.shape[0]][visual_mask == 1.] = 0 + current_mask[:visual_mask.shape[1], :visual_mask.shape[0]][visual_mask == 4.] = 0 + + return current_mask + + # The implementation of cross-image contrastive learning is based on: + # https://github.com/tfzhou/ContrastiveSeg/blob/287e5d3069ce6d7a1517ddf98e004c00f23f8f99/lib/loss/loss_contrast.py + def info_nce(self, anchors_, a_labels_, contras_, c_labels_): + c_labels_ = torch.cat([a_labels_, c_labels_]) + contras_ = torch.cat([anchors_, contras_]) + # calculates the binary mask: same category => 1, different categories => 0 + mask = torch.eq(a_labels_, torch.transpose(c_labels_, 0, 1)).float() + + # calculates the dot product + anchor_dot_contrast = torch.div(torch.matmul(anchors_, torch.transpose(contras_, 0, 1)), + self.temperature) + + # for numerical stability + logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) + logits = anchor_dot_contrast - logits_max.detach() + + # calculates the negative mask + neg_mask = 1 - mask + + # avoid the self duplicate issue + mask = self.manipulate_cover_mask(a_label=a_labels_, current_mask=mask) + mask = mask.fill_diagonal_(0.) + + # sum the negative odot results + neg_logits = torch.exp(logits) * neg_mask + neg_logits = neg_logits.sum(1, keepdim=True) + + exp_logits = torch.exp(logits) + + # log_prob -> log(exp(x))-log(exp(x) + exp(y)) + # log_prob -> log{exp(x)/[exp(x)+exp(y)]} + log_prob = logits - torch.log(exp_logits + neg_logits) + # log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) + + # calculate the info-nce based on the positive samples (under same categories) + mask_pos_pairs = mask.sum(1) + mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs) + # mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs.sum(1) + mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs + assert not torch.isnan(mean_log_prob_pos).any(), print(torch.isnan(log_prob).any()) + return - mean_log_prob_pos.mean() + + # def extraction_samples(self, city_embd, city_label, ood_embd, ood_label): + # # reformat the matrix + # city_embd = city_embd.flatten(start_dim=2).permute(0, 2, 1) + # city_label = city_label.flatten(start_dim=1) + # ood_embd = ood_embd.flatten(start_dim=2).permute(0, 2, 1) + # ood_label = ood_label.flatten(start_dim=1) + # + # # define different types of embeds + # city_positive = city_embd[city_label == self.ood_idx] + # city_negative = city_embd[(city_label != self.ood_idx) & (city_label != self.ignore_idx)] + # ood_positive = ood_embd[ood_label == self.ood_idx] + # ood_negative = ood_embd[(ood_label != self.ood_idx) & (ood_label != self.ignore_idx)] + # + # # define the number of choice + # sample_num = int(min(self.max_views, city_positive.shape[0], ood_positive.shape[0], + # city_negative.shape[0], ood_negative.shape[0])) + # + # # randomly extract the anchor set with {city_ood, city_inlier} + # city_positive_anchor = city_positive[torch.randperm(city_positive.shape[0])][:sample_num] + # city_negative_anchor = city_negative[torch.randperm(city_negative.shape[0])][:sample_num] + # + # anchor_embed = torch.cat([city_positive_anchor, city_negative_anchor], dim=0) + # + # anchor_label = torch.cat([torch.empty(city_positive_anchor.shape[0], + # device=city_positive_anchor.device).fill_(1.), + # torch.empty(city_negative_anchor.shape[0], + # device=city_negative_anchor.device).fill_(0.)]) + # + # # randomly extract the contras set with {city_ood, city_inlier, coco_ood, coco_inlier} + # city_positive_contras = city_positive_anchor.clone() + # city_negative_contras = city_negative_anchor.clone() + # ood_positive_contras = ood_positive[torch.randperm(ood_positive.shape[0])][:sample_num] + # ood_negative_contras = ood_negative[torch.randperm(ood_negative.shape[0])][:sample_num] + # + # contrs_embed = torch.cat([city_positive_contras, city_negative_contras, + # ood_positive_contras, ood_negative_contras], dim=0) + # + # contrs_label = torch.cat([torch.empty(city_positive_contras.shape[0], + # device=city_positive_contras.device).fill_(1.), + # torch.empty(city_negative_contras.shape[0], + # device=city_negative_contras.device).fill_(0.), + # torch.empty(ood_positive_contras.shape[0], + # device=ood_positive_contras.device).fill_(1.), + # torch.empty(ood_negative_contras.shape[0], + # device=ood_negative_contras.device).fill_(0.)]) + # + # return anchor_embed, anchor_label, contrs_embed, contrs_label + + + diff --git a/ref-avs.code/loss/training/sam2_training_loss.py b/ref-avs.code/loss/training/sam2_training_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..0862f6e54b675f5fc8b59a7454931fb58df45d58 --- /dev/null +++ b/ref-avs.code/loss/training/sam2_training_loss.py @@ -0,0 +1,221 @@ +"""SAM2 multi-step mask + IoU + objectness loss (Hydra `_target_`: `MultiStepMultiMasksAndIous`).""" +from collections import defaultdict +from typing import Dict, List + +import torch +import torch.distributed +import torch.nn as nn +import torch.nn.functional as F + +CORE_LOSS_KEY = "core_loss" + + +def dice_loss(inputs, targets, num_objects, loss_on_multimask=False): + inputs = inputs.sigmoid() + if loss_on_multimask: + assert inputs.dim() == 4 and targets.dim() == 4 + inputs = inputs.flatten(2) + targets = targets.flatten(2) + numerator = 2 * (inputs * targets).sum(-1) + else: + inputs = inputs.flatten(1) + numerator = 2 * (inputs * targets).sum(1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + if loss_on_multimask: + return loss / num_objects + return loss.sum() / num_objects + + +def sigmoid_focal_loss( + inputs, + targets, + num_objects, + alpha: float = 0.25, + gamma: float = 2, + loss_on_multimask=False, +): + prob = inputs.sigmoid() + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + if loss_on_multimask: + assert loss.dim() == 4 + return loss.flatten(2).mean(-1) / num_objects + return loss.mean(1).sum() / num_objects + + +def iou_loss( + inputs, targets, pred_ious, num_objects, loss_on_multimask=False, use_l1_loss=False +): + assert inputs.dim() == 4 and targets.dim() == 4 + pred_mask = inputs.flatten(2) > 0 + gt_mask = targets.flatten(2) > 0 + area_i = torch.sum(pred_mask & gt_mask, dim=-1).float() + area_u = torch.sum(pred_mask | gt_mask, dim=-1).float() + actual_ious = area_i / torch.clamp(area_u, min=1.0) + + if use_l1_loss: + loss = F.l1_loss(pred_ious, actual_ious, reduction="none") + else: + loss = F.mse_loss(pred_ious, actual_ious, reduction="none") + if loss_on_multimask: + return loss / num_objects + return loss.sum() / num_objects + + +class MultiStepMultiMasksAndIous(nn.Module): + def __init__( + self, + weight_dict, + focal_alpha=0.25, + focal_gamma=2, + supervise_all_iou=False, + iou_use_l1_loss=False, + pred_obj_scores=False, + focal_gamma_obj_score=0.0, + focal_alpha_obj_score=-1, + gpu_num=1, + ): + super().__init__() + self.weight_dict = weight_dict + self.focal_alpha = focal_alpha + self.focal_gamma = focal_gamma + self.world_size = gpu_num + assert "loss_mask" in self.weight_dict + assert "loss_dice" in self.weight_dict + assert "loss_iou" in self.weight_dict + if "loss_class" not in self.weight_dict: + self.weight_dict["loss_class"] = 0.0 + + self.focal_alpha_obj_score = focal_alpha_obj_score + self.focal_gamma_obj_score = focal_gamma_obj_score + self.supervise_all_iou = supervise_all_iou + self.iou_use_l1_loss = iou_use_l1_loss + self.pred_obj_scores = pred_obj_scores + + def forward(self, outs_batch: List[Dict], targets_batch: torch.Tensor): + assert len(outs_batch) == len(targets_batch) + num_objects = torch.tensor( + targets_batch.shape[1], device=targets_batch.device, dtype=torch.float + ) + torch.distributed.all_reduce(num_objects) + num_objects = torch.clamp(num_objects / self.world_size, min=1).item() + + losses = defaultdict(int) + for outs, targets in zip(outs_batch, targets_batch): + cur_losses = self._forward(outs, targets, num_objects) + for k, v in cur_losses.items(): + losses[k] += v + return losses + + def _forward(self, outputs: Dict, targets: torch.Tensor, num_objects): + target_masks = targets.unsqueeze(1).float() + assert target_masks.dim() == 4 + + src_masks_list = outputs["multistep_pred_multimasks_high_res"] + ious_list = outputs["multistep_pred_ious"] + object_score_logits_list = outputs["multistep_object_score_logits"] + assert len(src_masks_list) == len(ious_list) + assert len(object_score_logits_list) == len(ious_list) + + losses = {"loss_mask": 0, "loss_dice": 0, "loss_iou": 0, "loss_class": 0} + for src_masks, ious, object_score_logits in zip( + src_masks_list, ious_list, object_score_logits_list + ): + self._update_losses( + losses, src_masks, target_masks, ious, num_objects, object_score_logits + ) + losses[CORE_LOSS_KEY] = self.reduce_loss(losses) + return losses + + def _update_losses( + self, losses, src_masks, target_masks, ious, num_objects, object_score_logits + ): + target_masks = target_masks.expand_as(src_masks) + loss_multimask = sigmoid_focal_loss( + src_masks, + target_masks, + num_objects, + alpha=self.focal_alpha, + gamma=self.focal_gamma, + loss_on_multimask=True, + ) + loss_multidice = dice_loss( + src_masks, target_masks, num_objects, loss_on_multimask=True + ) + if not self.pred_obj_scores: + loss_class = torch.tensor( + 0.0, dtype=loss_multimask.dtype, device=loss_multimask.device + ) + target_obj = torch.ones( + loss_multimask.shape[0], + 1, + dtype=loss_multimask.dtype, + device=loss_multimask.device, + ) + else: + target_obj = torch.any((target_masks[:, 0] > 0).flatten(1), dim=-1)[ + ..., None + ].float() + loss_class = sigmoid_focal_loss( + object_score_logits, + target_obj, + num_objects, + alpha=self.focal_alpha_obj_score, + gamma=self.focal_gamma_obj_score, + ) + + loss_multiiou = iou_loss( + src_masks, + target_masks, + ious, + num_objects, + loss_on_multimask=True, + use_l1_loss=self.iou_use_l1_loss, + ) + assert loss_multimask.dim() == 2 + assert loss_multidice.dim() == 2 + assert loss_multiiou.dim() == 2 + if loss_multimask.size(1) > 1: + loss_combo = ( + loss_multimask * self.weight_dict["loss_mask"] + + loss_multidice * self.weight_dict["loss_dice"] + ) + best_loss_inds = torch.argmin(loss_combo, dim=-1) + batch_inds = torch.arange(loss_combo.size(0), device=loss_combo.device) + + loss_mask = loss_multimask[batch_inds, best_loss_inds].unsqueeze(1) + loss_dice = loss_multidice[batch_inds, best_loss_inds].unsqueeze(1) + if self.supervise_all_iou: + loss_iou = loss_multiiou.mean(dim=-1).unsqueeze(1) + else: + loss_iou = loss_multiiou[batch_inds, best_loss_inds].unsqueeze(1) + else: + loss_mask = loss_multimask + loss_dice = loss_multidice + loss_iou = loss_multiiou + + loss_mask = loss_mask * target_obj + loss_dice = loss_dice * target_obj + loss_iou = loss_iou * target_obj + + losses["loss_mask"] += loss_mask.sum() + losses["loss_dice"] += loss_dice.sum() + losses["loss_iou"] += loss_iou.sum() + losses["loss_class"] += loss_class + + def reduce_loss(self, losses): + reduced_loss = 0.0 + for loss_key, weight in self.weight_dict.items(): + if loss_key not in losses: + raise ValueError(f"{type(self)} doesn't compute {loss_key}") + if weight != 0: + reduced_loss += losses[loss_key] * weight + return reduced_loss + diff --git a/ref-avs.code/main.py b/ref-avs.code/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e61f125fb4c8d787142732df7e0a82985d60dfe0 --- /dev/null +++ b/ref-avs.code/main.py @@ -0,0 +1,157 @@ +"""DDP training: frozen SAM2 + text, trainable AuralFuser (Ref-AVS).""" +import os +import argparse +import random + +import numpy +import torch +from easydict import EasyDict + + +def seed_it(seed): + os.environ["PYTHONSEED"] = str(seed) + random.seed(seed) + numpy.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.enabled = True + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def main(local_rank, ngpus_per_node, hyp_param): + hyp_param.local_rank = local_rank + torch.distributed.init_process_group( + backend='nccl', init_method='env://', + rank=local_rank, world_size=hyp_param.gpus, + ) + seed_it(local_rank + hyp_param.seed) + torch.cuda.set_device(local_rank) + + import model.visual.sam2 # noqa: F401 — registers Hydra config store + + from hydra import compose + from hydra.utils import instantiate + from omegaconf import OmegaConf + + cfg = compose(config_name='configs/training/sam2_training_config.yaml') + OmegaConf.resolve(cfg) + hyp_param.contrastive_learning = OmegaConf.to_container(cfg.contrastive_learning, resolve=True) + + arch_h = compose(config_name='configs/auralfuser/architecture.yaml') + OmegaConf.resolve(arch_h) + hyp_param.aural_fuser = OmegaConf.to_container(arch_h.aural_fuser, resolve=True) + + hyp_param.image_size = 1024 + hyp_param.image_embedding_size = int(hyp_param.image_size / 16) + + from model.mymodel import AVmodel + av_model = AVmodel(hyp_param).cuda(local_rank) + av_model = torch.nn.parallel.DistributedDataParallel( + av_model, device_ids=[local_rank], find_unused_parameters=True, + ) + + from utils.utils import manipulate_params + optimiser = torch.optim.AdamW(manipulate_params(hyp_param, av_model.module.aural_fuser), betas=(0.9, 0.999)) + + from dataloader.dataset import AV + from dataloader.visual.visual_augmentation import Augmentation as VisualAugmentation + from dataloader.audio.audio_augmentation import Augmentation as AudioAugmentation + from torch.utils.data.distributed import DistributedSampler + + compose_api = instantiate(cfg.train_transforms, _recursive_=True)[0] + audio_aug = AudioAugmentation(mono=True) + train_dataset = AV( + split='train', + augmentation={"visual": compose_api, "audio": audio_aug}, + param=hyp_param, + root_path=hyp_param.data_root_path, + ) + + visual_aug = VisualAugmentation( + hyp_param.image_mean, hyp_param.image_std, + hyp_param.image_size, hyp_param.image_size, + hyp_param.scale_list, ignore_index=hyp_param.ignore_index, + ) + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=hyp_param.batch_size, + sampler=DistributedSampler(train_dataset, shuffle=True), + num_workers=hyp_param.num_workers, + drop_last=True, + ) + + def _test_loader(split): + ds = AV(split=split, augmentation={"visual": visual_aug, "audio": audio_aug}, + param=hyp_param, root_path=hyp_param.data_root_path) + return torch.utils.data.DataLoader( + ds, batch_size=4, + sampler=DistributedSampler(ds, shuffle=False), + num_workers=hyp_param.num_workers, + ) + + test_s_loader = _test_loader('test_s') + test_u_loader = _test_loader('test_u') + test_n_loader = _test_loader('test_n') + + criterion = instantiate(cfg.loss, _recursive_=True)['all'] + + from utils.tensorboard import Tensorboard + tensorboard = Tensorboard(config=hyp_param) if local_rank <= 0 else None + + from trainer.train import Trainer + from utils.foreground_iou import ForegroundIoU + from utils.foreground_fscore import ForegroundFScore + from utils.foreground_s import ForegroundS + metrics = { + "foreground_iou": ForegroundIoU(), + "foreground_f-score": ForegroundFScore(0 if local_rank <= 0 else local_rank), + "foreground_s": ForegroundS(), + } + trainer = Trainer(hyp_param, loss=criterion, tensorboard=tensorboard, metrics=metrics) + + test_s_best, test_u_best = 0.2, 0.2 + for epoch in range(hyp_param.epochs + 1): + av_model.train() + av_model.module.freeze_sam_parameters() + train_loader.sampler.set_epoch(epoch) + trainer.train(epoch=epoch, dataloader=train_loader, model=av_model, optimiser=optimiser) + + torch.distributed.barrier() + torch.cuda.empty_cache() + + av_model.eval() + test_s, _ = trainer.valid(epoch=epoch, dataloader=test_s_loader, model=av_model, process='test_s') + test_u, _ = trainer.valid(epoch=epoch, dataloader=test_u_loader, model=av_model, process='test_u') + trainer.valid_null(epoch=epoch, dataloader=test_n_loader, model=av_model, process='test_n') + + if local_rank <= 0 and (test_s > test_s_best or test_u > test_u_best): + test_s_best = max(test_s, test_s_best) + test_u_best = max(test_u, test_u_best) + torch.save( + av_model.module.aural_fuser.state_dict(), + os.path.join( + hyp_param.saved_dir, + f's({float(test_s)})_u({float(test_u)}).pth', + ), + ) + + torch.distributed.barrier() + torch.cuda.empty_cache() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Ref-AVS training') + parser.add_argument('--local_rank', type=int, default=-1) + parser.add_argument('-g', '--gpus', default=1, type=int) + parser.add_argument('--batch_size', default=1, type=int) + parser.add_argument('--epochs', default=80, type=int) + parser.add_argument('--lr', default=5e-4, type=float) + args = parser.parse_args() + + from configs.config import C + args = EasyDict({**C, **vars(args)}) + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '9901' + torch.multiprocessing.spawn(main, nprocs=args.gpus, args=(args.gpus, args)) diff --git a/ref-avs.code/model/audio/torchvggish/mel_features.py b/ref-avs.code/model/audio/torchvggish/mel_features.py new file mode 100644 index 0000000000000000000000000000000000000000..ac58fb5427f772fcced9cbd3cec3373ffbe5908c --- /dev/null +++ b/ref-avs.code/model/audio/torchvggish/mel_features.py @@ -0,0 +1,223 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Defines routines to compute mel spectrogram features from audio waveform.""" + +import numpy as np + + +def frame(data, window_length, hop_length): + """Convert array into a sequence of successive possibly overlapping frames. + + An n-dimensional array of shape (num_samples, ...) is converted into an + (n+1)-D array of shape (num_frames, window_length, ...), where each frame + starts hop_length points after the preceding one. + + This is accomplished using stride_tricks, so the original data is not + copied. However, there is no zero-padding, so any incomplete frames at the + end are not included. + + Args: + data: np.array of dimension N >= 1. + window_length: Number of samples in each frame. + hop_length: Advance (in samples) between each window. + + Returns: + (N+1)-D np.array with as many rows as there are complete frames that can be + extracted. + """ + num_samples = data.shape[0] + num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length)) + shape = (num_frames, window_length) + data.shape[1:] + strides = (data.strides[0] * hop_length,) + data.strides + return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides) + + +def periodic_hann(window_length): + """Calculate a "periodic" Hann window. + + The classic Hann window is defined as a raised cosine that starts and + ends on zero, and where every value appears twice, except the middle + point for an odd-length window. Matlab calls this a "symmetric" window + and np.hanning() returns it. However, for Fourier analysis, this + actually represents just over one cycle of a period N-1 cosine, and + thus is not compactly expressed on a length-N Fourier basis. Instead, + it's better to use a raised cosine that ends just before the final + zero value - i.e. a complete cycle of a period-N cosine. Matlab + calls this a "periodic" window. This routine calculates it. + + Args: + window_length: The number of points in the returned window. + + Returns: + A 1D np.array containing the periodic hann window. + """ + return 0.5 - (0.5 * np.cos(2 * np.pi / window_length * + np.arange(window_length))) + + +def stft_magnitude(signal, fft_length, + hop_length=None, + window_length=None): + """Calculate the short-time Fourier transform magnitude. + + Args: + signal: 1D np.array of the input time-domain signal. + fft_length: Size of the FFT to apply. + hop_length: Advance (in samples) between each frame passed to FFT. + window_length: Length of each block of samples to pass to FFT. + + Returns: + 2D np.array where each row contains the magnitudes of the fft_length/2+1 + unique values of the FFT for the corresponding frame of input samples. + """ + frames = frame(signal, window_length, hop_length) + # Apply frame window to each frame. We use a periodic Hann (cosine of period + # window_length) instead of the symmetric Hann of np.hanning (period + # window_length-1). + window = periodic_hann(window_length) + windowed_frames = frames * window + return np.abs(np.fft.rfft(windowed_frames, int(fft_length))) + + +# Mel spectrum constants and functions. +_MEL_BREAK_FREQUENCY_HERTZ = 700.0 +_MEL_HIGH_FREQUENCY_Q = 1127.0 + + +def hertz_to_mel(frequencies_hertz): + """Convert frequencies to mel scale using HTK formula. + + Args: + frequencies_hertz: Scalar or np.array of frequencies in hertz. + + Returns: + Object of same size as frequencies_hertz containing corresponding values + on the mel scale. + """ + return _MEL_HIGH_FREQUENCY_Q * np.log( + 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ)) + + +def spectrogram_to_mel_matrix(num_mel_bins=20, + num_spectrogram_bins=129, + audio_sample_rate=8000, + lower_edge_hertz=125.0, + upper_edge_hertz=3800.0): + """Return a matrix that can post-multiply spectrogram rows to make mel. + + Returns a np.array matrix A that can be used to post-multiply a matrix S of + spectrogram values (STFT magnitudes) arranged as frames x bins to generate a + "mel spectrogram" M of frames x num_mel_bins. M = S A. + + The classic HTK algorithm exploits the complementarity of adjacent mel bands + to multiply each FFT bin by only one mel weight, then add it, with positive + and negative signs, to the two adjacent mel bands to which that bin + contributes. Here, by expressing this operation as a matrix multiply, we go + from num_fft multiplies per frame (plus around 2*num_fft adds) to around + num_fft^2 multiplies and adds. However, because these are all presumably + accomplished in a single call to np.dot(), it's not clear which approach is + faster in Python. The matrix multiplication has the attraction of being more + general and flexible, and much easier to read. + + Args: + num_mel_bins: How many bands in the resulting mel spectrum. This is + the number of columns in the output matrix. + num_spectrogram_bins: How many bins there are in the source spectrogram + data, which is understood to be fft_size/2 + 1, i.e. the spectrogram + only contains the nonredundant FFT bins. + audio_sample_rate: Samples per second of the audio at the input to the + spectrogram. We need this to figure out the actual frequencies for + each spectrogram bin, which dictates how they are mapped into mel. + lower_edge_hertz: Lower bound on the frequencies to be included in the mel + spectrum. This corresponds to the lower edge of the lowest triangular + band. + upper_edge_hertz: The desired top edge of the highest frequency band. + + Returns: + An np.array with shape (num_spectrogram_bins, num_mel_bins). + + Raises: + ValueError: if frequency edges are incorrectly ordered or out of range. + """ + nyquist_hertz = audio_sample_rate / 2. + if lower_edge_hertz < 0.0: + raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz) + if lower_edge_hertz >= upper_edge_hertz: + raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" % + (lower_edge_hertz, upper_edge_hertz)) + if upper_edge_hertz > nyquist_hertz: + raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" % + (upper_edge_hertz, nyquist_hertz)) + spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins) + spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz) + # The i'th mel band (starting from i=1) has center frequency + # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge + # band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in + # the band_edges_mel arrays. + band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz), + hertz_to_mel(upper_edge_hertz), num_mel_bins + 2) + # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins + # of spectrogram values. + mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins)) + for i in range(num_mel_bins): + lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3] + # Calculate lower and upper slopes for every spectrogram bin. + # Line segments are linear in the *mel* domain, not hertz. + lower_slope = ((spectrogram_bins_mel - lower_edge_mel) / + (center_mel - lower_edge_mel)) + upper_slope = ((upper_edge_mel - spectrogram_bins_mel) / + (upper_edge_mel - center_mel)) + # .. then intersect them with each other and zero. + mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope, + upper_slope)) + # HTK excludes the spectrogram DC bin; make sure it always gets a zero + # coefficient. + mel_weights_matrix[0, :] = 0.0 + return mel_weights_matrix + + +def log_mel_spectrogram(data, + audio_sample_rate=8000, + log_offset=0.0, + window_length_secs=0.025, + hop_length_secs=0.010, + **kwargs): + """Convert waveform to a log magnitude mel-frequency spectrogram. + + Args: + data: 1D np.array of waveform data. + audio_sample_rate: The sampling rate of data. + log_offset: Add this to values when taking log to avoid -Infs. + window_length_secs: Duration of each window to analyze. + hop_length_secs: Advance between successive analysis windows. + **kwargs: Additional arguments to pass to spectrogram_to_mel_matrix. + + Returns: + 2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank + magnitudes for successive frames. + """ + window_length_samples = int(round(audio_sample_rate * window_length_secs)) + hop_length_samples = int(round(audio_sample_rate * hop_length_secs)) + fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0))) + spectrogram = stft_magnitude( + data, + fft_length=fft_length, + hop_length=hop_length_samples, + window_length=window_length_samples) + mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix( + num_spectrogram_bins=spectrogram.shape[1], + audio_sample_rate=audio_sample_rate, **kwargs)) + return np.log(mel_spectrogram + log_offset) diff --git a/ref-avs.code/model/audio/torchvggish/vggish.py b/ref-avs.code/model/audio/torchvggish/vggish.py new file mode 100644 index 0000000000000000000000000000000000000000..f01c22867c713bfd8713eee5665120b92602761d --- /dev/null +++ b/ref-avs.code/model/audio/torchvggish/vggish.py @@ -0,0 +1,193 @@ +import numpy as np +import torch +import torch.nn as nn +from torch import hub + +from . import vggish_input, vggish_params + + +class VGG(nn.Module): + def __init__(self, features): + super(VGG, self).__init__() + self.features = features + self.embeddings = nn.Sequential( + nn.Linear(512 * 4 * 6, 4096), + nn.ReLU(True), + nn.Linear(4096, 4096), + nn.ReLU(True), + nn.Linear(4096, 128), + nn.ReLU(True)) + + def forward(self, x): + x = self.features(x) + + # Transpose the output from features to + # remain compatible with vggish embeddings + x = torch.transpose(x, 1, 3) + x = torch.transpose(x, 1, 2) + x = x.contiguous() + x = x.view(x.size(0), -1) + + return self.embeddings(x) + + +class Postprocessor(nn.Module): + """Post-processes VGGish embeddings. Returns a torch.Tensor instead of a + numpy array in order to preserve the gradient. + + "The initial release of AudioSet included 128-D VGGish embeddings for each + segment of AudioSet. These released embeddings were produced by applying + a PCA transformation (technically, a whitening transform is included as well) + and 8-bit quantization to the raw embedding output from VGGish, in order to + stay compatible with the YouTube-8M project which provides visual embeddings + in the same format for a large set of YouTube videos. This class implements + the same PCA (with whitening) and quantization transformations." + """ + + def __init__(self): + """Constructs a postprocessor.""" + super(Postprocessor, self).__init__() + # Create empty matrix, for user's state_dict to load + self.pca_eigen_vectors = torch.empty( + (vggish_params.EMBEDDING_SIZE, vggish_params.EMBEDDING_SIZE,), + dtype=torch.float, + ) + self.pca_means = torch.empty( + (vggish_params.EMBEDDING_SIZE, 1), dtype=torch.float + ) + + self.pca_eigen_vectors = nn.Parameter(self.pca_eigen_vectors, requires_grad=False) + self.pca_means = nn.Parameter(self.pca_means, requires_grad=False) + + def postprocess(self, embeddings_batch): + """Applies tensor postprocessing to a batch of embeddings. + + Args: + embeddings_batch: An tensor of shape [batch_size, embedding_size] + containing output from the embedding layer of VGGish. + + Returns: + A tensor of the same shape as the input, containing the PCA-transformed, + quantized, and clipped version of the input. + """ + assert len(embeddings_batch.shape) == 2, "Expected 2-d batch, got %r" % ( + embeddings_batch.shape, + ) + assert ( + embeddings_batch.shape[1] == vggish_params.EMBEDDING_SIZE + ), "Bad batch shape: %r" % (embeddings_batch.shape,) + + # Apply PCA. + # - Embeddings come in as [batch_size, embedding_size]. + # - Transpose to [embedding_size, batch_size]. + # - Subtract pca_means column vector from each column. + # - Premultiply by PCA matrix of shape [output_dims, input_dims] + # where both are are equal to embedding_size in our case. + # - Transpose result back to [batch_size, embedding_size]. + pca_applied = torch.mm(self.pca_eigen_vectors, (embeddings_batch.t() - self.pca_means)).t() + + # Quantize by: + # - clipping to [min, max] range + clipped_embeddings = torch.clamp( + pca_applied, vggish_params.QUANTIZE_MIN_VAL, vggish_params.QUANTIZE_MAX_VAL + ) + # - convert to 8-bit in range [0.0, 255.0] + quantized_embeddings = torch.round( + (clipped_embeddings - vggish_params.QUANTIZE_MIN_VAL) + * ( + 255.0 + / (vggish_params.QUANTIZE_MAX_VAL - vggish_params.QUANTIZE_MIN_VAL) + ) + ) + return torch.squeeze(quantized_embeddings) + + def forward(self, x): + return self.postprocess(x) + + +def make_layers(): + layers = [] + in_channels = 1 + for v in [64, "M", 128, "M", 256, 256, "M", 512, 512, "M"]: + if v == "M": + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = v + return nn.Sequential(*layers) + + +def _vgg(): + return VGG(make_layers()) + + +# def _spectrogram(): +# config = dict( +# sr=16000, +# n_fft=400, +# n_mels=64, +# hop_length=160, +# window="hann", +# center=False, +# pad_mode="reflect", +# htk=True, +# fmin=125, +# fmax=7500, +# output_format='Magnitude', +# # device=device, +# ) +# return Spectrogram.MelSpectrogram(**config) + + +class VGGish(VGG): + def __init__(self, cfg, device=None): + super().__init__(make_layers()) + if cfg.FREEZE_AUDIO_EXTRACTOR: + state_dict = torch.load(cfg.PRETRAINED_VGGISH_MODEL_PATH) + super().load_state_dict(state_dict) + print(f'==> Load pretrained VGGish parameters from {cfg.PRETRAINED_VGGISH_MODEL_PATH}') + + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print("device: ", device) + self.device = device + + self.preprocess = cfg.PREPROCESS_AUDIO_TO_LOG_MEL + self.postprocess = cfg.POSTPROCESS_LOG_MEL_WITH_PCA + if self.postprocess: + self.pproc = Postprocessor() + if cfg.FREEZE_AUDIO_EXTRACTOR: + state_dict = torch.load(cfg.PRETRAINED_PCA_PARAMS_PATH) + # TODO: Convert the state_dict to torch + state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME] = torch.as_tensor( + state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME], dtype=torch.float + ) + state_dict[vggish_params.PCA_MEANS_NAME] = torch.as_tensor( + state_dict[vggish_params.PCA_MEANS_NAME].reshape(-1, 1), dtype=torch.float + ) + self.pproc.load_state_dict(state_dict) + self.to(self.device) + + def forward(self, x): + if self.preprocess: + print(">>> pre processing...") + x = self._preprocess(x) + x = x.to(self.device) + x = VGG.forward(self, x) + if self.postprocess: + print(">>> post processing...") + x = self._postprocess(x) + return x + + def _preprocess(self, x): + # if isinstance(x, np.ndarray): + # x = vggish_input.waveform_to_examples(x, fs) + if isinstance(x, str): + x = vggish_input.wavfile_to_examples(x) + else: + raise AttributeError + return x + + def _postprocess(self, x): + return self.pproc(x) diff --git a/ref-avs.code/model/audio/torchvggish/vggish_input.py b/ref-avs.code/model/audio/torchvggish/vggish_input.py new file mode 100644 index 0000000000000000000000000000000000000000..ede228b1fb630180f1f49244355d373fb3300f03 --- /dev/null +++ b/ref-avs.code/model/audio/torchvggish/vggish_input.py @@ -0,0 +1,98 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Compute input examples for VGGish from audio waveform.""" + +# Modification: Return torch tensors rather than numpy arrays +import torch + +import numpy as np +import resampy + +from . import mel_features +from . import vggish_params + +import soundfile as sf + + +def waveform_to_examples(data, sample_rate, return_tensor=True): + """Converts audio waveform into an array of examples for VGGish. + + Args: + data: np.array of either one dimension (mono) or two dimensions + (multi-channel, with the outer dimension representing channels). + Each sample is generally expected to lie in the range [-1.0, +1.0], + although this is not required. + sample_rate: Sample rate of data. + return_tensor: Return data as a Pytorch tensor ready for VGGish + + Returns: + 3-D np.array of shape [num_examples, num_frames, num_bands] which represents + a sequence of examples, each of which contains a patch of log mel + spectrogram, covering num_frames frames of audio and num_bands mel frequency + bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS. + + """ + # Convert to mono. + if len(data.shape) > 1: + data = np.mean(data, axis=1) + # Resample to the rate assumed by VGGish. + if sample_rate != vggish_params.SAMPLE_RATE: + data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE) + + # Compute log mel spectrogram features. + log_mel = mel_features.log_mel_spectrogram( + data, + audio_sample_rate=vggish_params.SAMPLE_RATE, + log_offset=vggish_params.LOG_OFFSET, + window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS, + hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS, + num_mel_bins=vggish_params.NUM_MEL_BINS, + lower_edge_hertz=vggish_params.MEL_MIN_HZ, + upper_edge_hertz=vggish_params.MEL_MAX_HZ) + + # Frame features into examples. + features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS + example_window_length = int(round( + vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate)) + example_hop_length = int(round( + vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate)) + log_mel_examples = mel_features.frame( + log_mel, + window_length=example_window_length, + hop_length=example_hop_length) + + if return_tensor: + log_mel_examples = torch.tensor( + log_mel_examples, requires_grad=True)[:, None, :, :].float() + + return log_mel_examples + + +def wavfile_to_examples(wav_file, return_tensor=True): + """Convenience wrapper around waveform_to_examples() for a common WAV format. + + Args: + wav_file: String path to a file, or a file-like object. The file + is assumed to contain WAV audio data with signed 16-bit PCM samples. + torch: Return data as a Pytorch tensor ready for VGGish + + Returns: + See waveform_to_examples. + """ + wav_data, sr = sf.read(wav_file, dtype='int16') + assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype + samples = wav_data / 32768.0 # Convert to [-1.0, +1.0] + return waveform_to_examples(samples, sr, return_tensor) diff --git a/ref-avs.code/model/audio/torchvggish/vggish_params.py b/ref-avs.code/model/audio/torchvggish/vggish_params.py new file mode 100644 index 0000000000000000000000000000000000000000..526784bceaa4c9c8b8dc2b8f82e0f3d395d4bec2 --- /dev/null +++ b/ref-avs.code/model/audio/torchvggish/vggish_params.py @@ -0,0 +1,53 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Global parameters for the VGGish model. + +See vggish_slim.py for more information. +""" + +# Architectural constants. +NUM_FRAMES = 96 # Frames in input mel-spectrogram patch. +NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch. +EMBEDDING_SIZE = 128 # Size of embedding layer. + +# Hyperparameters used in feature and example generation. +SAMPLE_RATE = 16000 +STFT_WINDOW_LENGTH_SECONDS = 0.025 +STFT_HOP_LENGTH_SECONDS = 0.010 +NUM_MEL_BINS = NUM_BANDS +MEL_MIN_HZ = 125 +MEL_MAX_HZ = 7500 +LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram. +EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames +EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap. + +# Parameters used for embedding postprocessing. +PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors' +PCA_MEANS_NAME = 'pca_means' +QUANTIZE_MIN_VAL = -2.0 +QUANTIZE_MAX_VAL = +2.0 + +# Hyperparameters used in training. +INIT_STDDEV = 0.01 # Standard deviation used to initialize weights. +LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer. +ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer. + +# Names of ops, tensors, and features. +INPUT_OP_NAME = 'vggish/input_features' +INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0' +OUTPUT_OP_NAME = 'vggish/embedding' +OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0' +AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding' diff --git a/ref-avs.code/model/aural_fuser.py b/ref-avs.code/model/aural_fuser.py new file mode 100644 index 0000000000000000000000000000000000000000..893b868c2ae451a725f45497b6db5abc3a610c82 --- /dev/null +++ b/ref-avs.code/model/aural_fuser.py @@ -0,0 +1,742 @@ +import torch +import torch.nn as nn +from model.audio.torchvggish import vggish +from timm.models.layers import DropPath, trunc_normal_ +import math + +from model.visual.sam2.modeling.position_encoding import PositionEmbeddingSine + + +class ProjectionHead(nn.Module): + def __init__(self, dim_in, proj_dim=256, norm_act=nn.BatchNorm2d, conv_layer=nn.Conv2d): + super(ProjectionHead, self).__init__() + self.proj = nn.Sequential( + nn.Linear(dim_in, proj_dim), + nn.GELU(), + nn.LayerNorm(proj_dim), + nn.Linear(proj_dim, proj_dim), + ) + + def forward(self, x): + return torch.nn.functional.normalize(self.proj(x), p=2, dim=1) + + +class AuralFuser(torch.nn.Module): + """Fuses VGGish audio, RoBERTa text, and SAM2 FPN maps via patch embeds, fusion blocks, and projection heads.""" + + def __init__(self, hyp_param): + self.hyp_param = hyp_param + super().__init__() + self.vgg = vggish.VGGish(self.hyp_param.audio) + if not getattr(self.hyp_param, "train_vggish", False): + for p in self.vgg.parameters(): + p.requires_grad = False + + self.position_encoding_func = PositionEmbeddingSine(num_pos_feats=256, normalize=True, scale=None, + temperature=10000) + + if not hasattr(self.hyp_param, "aural_fuser") or self.hyp_param.aural_fuser is None: + raise ValueError( + "hyp_param.aural_fuser is missing; load it with Hydra compose before constructing AuralFuser." + ) + arch_cfg = self.hyp_param.aural_fuser + + _patch_cfgs = [tuple(i) for i in arch_cfg["patch_cfgs"]] + _f_depths = arch_cfg["f_depths"] + _block_kw = dict(arch_cfg["block_kw"]) + _block_kw["norm_layer"] = nn.LayerNorm + _one_d_kw = dict(arch_cfg["one_d_kw"]) + _one_d_kw["norm_layer"] = nn.LayerNorm + + self.patch_embeds = nn.ModuleList( + nn.Conv2d(256, 256, kernel_size=k, stride=s) for k, s in _patch_cfgs + ) + + self.f_blocks = nn.ModuleList( + nn.ModuleList([Block(**_block_kw) for _ in range(n)]) for n in _f_depths + ) + + self.a_blocks = nn.ModuleList( + nn.ModuleList([OneDBlock(**_one_d_kw) for _ in range(3)]) for _ in range(3) + ) + + self.fusion_modules = nn.ModuleList( + TPAVIModuleDIY(in_channels=256, mode='dot') for _ in range(3) + ) + self.smooth_convs = nn.ModuleList( + nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0) for _ in range(2) + ) + + self.train_proj_v1 = ProjectionHead(dim_in=256, proj_dim=128) + self.train_proj_a1 = ProjectionHead(dim_in=256, proj_dim=128) + + self.text_proj = nn.Sequential( + nn.Linear(768, 1024), + nn.GELU(), + nn.Linear(1024, 256), + ) + + @staticmethod + def positionalencoding1d(d_model, length): + if d_model % 2 != 0: + raise ValueError("Cannot use sin/cos positional encoding with " + "odd dim (got dim={:d})".format(d_model)) + pe = torch.zeros(length, d_model) + position = torch.arange(0, length).unsqueeze(1) + div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) * + -(math.log(10000.0) / d_model))) + pe[:, 0::2] = torch.sin(position.float() * div_term) + pe[:, 1::2] = torch.cos(position.float() * div_term) + + return pe + + def forward(self, feature_dicts, spect=None, text=None): + image_embed_shape = [self.hyp_param.image_embedding_size] * 2 + H, W = image_embed_shape[0], image_embed_shape[1] + d = torch.cat( + [ + self.vgg(spect[:, 0, ...].unsqueeze(1)), + self.vgg(spect[:, 1, ...].unsqueeze(1)), + ], + dim=-1, + ) + text = self.text_proj(text) + d = torch.cat([d, text.squeeze()]) + + length = d.shape[-1] + fix_audio_pos = self.positionalencoding1d(length, 1).squeeze().to(spect.device) + fpn = list(feature_dicts["backbone_fpn"]) + patch_embeds = list(self.patch_embeds) + f_blocks = list(self.f_blocks) + a_blocks = list(self.a_blocks) + tpavi = list(self.fusion_modules) + smooths = [None, self.smooth_convs[0], self.smooth_convs[1]] + + feats = [None, None, None] + d_outputs = [] + vis_attn_feats = [] + + for i in range(3): + x = fpn[i] + x = patch_embeds[i](x) + x_pos = self.position_encoding_func(x) + x = x.flatten(2).permute(0, 2, 1) + x_pos = x_pos.flatten(2).permute(0, 2, 1) + + if i == 0: + x = x + x_pos + d = d + fix_audio_pos + else: + x = x + feats[i - 1] + x = smooths[i]( + x.permute(0, 2, 1).reshape(x.shape[0], 256, H, W) + ).flatten(2).permute(0, 2, 1) + x = x + x_pos + d = d + fix_audio_pos + + for blks in f_blocks[i]: + x = blks(x, H, W, x_pos) + for blks in a_blocks[i]: + d = blks(d, fix_audio_pos) + + x = x + x_pos + d = d + fix_audio_pos + x, d_out, x_attn, _ = tpavi[i](x, H, W, x_pos, d, length) + d = d_out + feats[i] = x + d_outputs.append(d_out) + vis_attn_feats.append(x_attn) + + a, b, c = feats + d1, d2, d3 = d_outputs + a_attn, b_attn, c_attn = vis_attn_feats + + feature_residual = [a, b, c] + audio_out = [d1, d2, d3] + + proj_feature_out = [ + [ + self.train_proj_v1(a_attn.flatten(start_dim=2).permute(0, 2, 1)).reshape( + -1, *image_embed_shape, 128 + ).permute(0, 3, 1, 2), + self.train_proj_v1(b_attn.flatten(start_dim=2).permute(0, 2, 1)).reshape( + -1, *image_embed_shape, 128 + ).permute(0, 3, 1, 2), + self.train_proj_v1(c_attn.flatten(start_dim=2).permute(0, 2, 1)).reshape( + -1, *image_embed_shape, 128 + ).permute(0, 3, 1, 2), + ], + [ + self.train_proj_a1(d1[:10]).unsqueeze(-1), + self.train_proj_a1(d2[:10]).unsqueeze(-1), + self.train_proj_a1(d3[:10]).unsqueeze(-1), + ], + ] + + return feature_residual, audio_out, proj_feature_out + + +class TPAVIModuleDIY(nn.Module): + def __init__(self, in_channels, inter_channels=None, mode='dot', + dimension=3): + """ + args: + in_channels: original channel size (1024 in the paper) + inter_channels: channel size inside the block if not specifed reduced to half (512 in the paper) + mode: supports Gaussian, Embedded Gaussian, Dot Product, and Concatenation + dimension: can be 1 (temporal), 2 (spatial), 3 (spatiotemporal) + bn_layer: whether to add batch norm + """ + super(TPAVIModuleDIY, self).__init__() + assert mode == 'dot', print('... following original paper.') + self.mode = mode + self.dimension = dimension + + self.in_channels = in_channels + self.inter_channels = inter_channels + + self.inter_channels = in_channels // 2 + + self.align_channel = nn.Conv1d(256, in_channels, kernel_size=1) + self.align_channel_back = nn.Conv1d(in_channels, 128, kernel_size=1) + + self.norm_layer = nn.LayerNorm(in_channels) + + if dimension == 3: + conv_nd = nn.Conv3d + bn = nn.BatchNorm3d + elif dimension == 2: + conv_nd = nn.Conv2d + bn = nn.BatchNorm2d + else: + conv_nd = nn.Conv1d + bn = nn.BatchNorm1d + + self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) + + self.W_z = nn.Sequential( + conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1), + bn(self.in_channels) + ) + nn.init.constant_(self.W_z[1].weight, 0) + nn.init.constant_(self.W_z[1].bias, 0) + + self.W_z2 = nn.Sequential( + nn.Conv1d(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1), + nn.BatchNorm1d(self.in_channels) + ) + nn.init.constant_(self.W_z2[1].weight, 0) + nn.init.constant_(self.W_z2[1].bias, 0) + self.norm_layer2 = nn.LayerNorm(self.in_channels) + + self.q_frame = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) + self.k_frame = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) + self.v_frame = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) + + self.q_audio = nn.Conv1d(self.in_channels, self.inter_channels, kernel_size=1) + self.k_audio = nn.Conv1d(self.in_channels, self.inter_channels, kernel_size=1) + self.v_audio = nn.Conv1d(self.in_channels, self.inter_channels, kernel_size=1) + + + def forward(self, frame, H_x, W_x, tmp1, audio, tmp2): + """ + args: + x: (N, C, T, H, W) for dimension=3; (N, C, H, W) for dimension 2; (N, C, T) for dimension 1 + audio: (N, T, C) + """ + frame = frame.permute(0, 2, 1) + frame = frame.reshape(frame.shape[0], frame.shape[1], H_x, W_x) + + frame = frame.unsqueeze(2) + audio = self.align_channel(audio.unsqueeze(-1)) + + + batch_size = frame.size(0) + audio_batch_size = audio.size(0) + q_frame = self.q_frame(frame).reshape(1, -1, self.inter_channels) # [bs, 4096, 128] + k_frame = self.k_frame(frame).reshape(1, -1, self.inter_channels) # [bs, 4096, 128] + v_frame = self.v_frame(frame).reshape(1, -1, self.inter_channels) # [bs, 4096, 128] + + q_audio = self.q_audio(audio).reshape(1, -1, self.inter_channels) # [bs, 1, 128] + k_audio = self.k_audio(audio).reshape(1, -1, self.inter_channels) # [bs, 1, 128] + v_audio = self.v_audio(audio).reshape(1, -1, self.inter_channels) # [bs, 1, 128] + + f = torch.matmul(q_frame, k_audio.mT) # [bs, 4096, 1] + f_normalise = f / f.size(1) # [bs, THW, THW] + + frame_attn = torch.matmul(f_normalise, v_audio) # [bs, THW, C] + + frame_attn = frame_attn.permute(0, 2, 1).contiguous() # [bs, C, THW] + frame_attn = frame_attn.view(batch_size, self.inter_channels, *frame.size()[2:]) # + frame_attn = self.W_z(frame_attn) # [bs, C, T, H, W] + frame = frame_attn + frame # # [bs, C, T, H, W] + + frame = frame.permute(0, 2, 3, 4, 1) # [bs, T, H, W, C] + frame = self.norm_layer(frame) + frame = frame.permute(0, 4, 1, 2, 3) # [bs, C, T, H, W] + frame = frame.squeeze().flatten(start_dim=2).permute(0, 2, 1) + + a = torch.matmul(q_audio, k_frame.mT) # [bs, THW, THW] + a_normalise = a / a.size(-1) + + audio_attn = torch.matmul(a_normalise, v_frame) + audio_attn = audio_attn.permute(0, 2, 1).contiguous() # [bs, C, THW] + + audio_attn = audio_attn.view(audio_batch_size, self.inter_channels).unsqueeze(-1) + audio_attn = self.W_z2(audio_attn) # [bs, C, T, H, W] + + audio = audio_attn + audio + + audio = self.norm_layer2(audio.squeeze()).squeeze() + + return frame, audio, frame_attn, audio_attn + + +class OneDBlock(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = OneDAttention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = OneDMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, + linear=linear) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, pos): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class OneDAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, + linear=False): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.linear = linear + self.sr_ratio = sr_ratio + if not linear: + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.pool = nn.AdaptiveAvgPool2d(7) + self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1) + self.norm = nn.LayerNorm(dim) + self.act = nn.GELU() + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = x.unsqueeze(0) + B, N, C = x.shape + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + k, v = kv[0], kv[1] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + x = x.squeeze() + return x + + +class OneDMlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + self.linear = linear + + if self.linear: + self.relu = nn.ReLU(inplace=True) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = self.fc1(x) + if self.linear: + x = self.relu(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class TwoWayBlock(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False): + super().__init__() + self.norm1_f = norm_layer(dim) + self.norm1_a = norm_layer(dim) + self.attn = TwoWayCrossAttention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2_f = norm_layer(dim) + self.norm2_a = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_f = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, linear=linear) + self.mlp_a = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, linear=linear) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x_f, H_f, W_f, x_f_pos, x_a, H_a, W_a, x_a_pos): + x_f_1, x_a_1 = self.attn(self.norm1_f(x_f + x_f_pos), H_f, W_f, self.norm1_a(x_a + x_a_pos), H_a, W_a) + x_f, x_a = x_f + self.drop_path(x_f_1), x_a + self.drop_path(x_a_1) + + x_f_2, x_a_2 = self.mlp_f(self.norm2_f(x_f), H_f, W_f), self.mlp_a(self.norm2_a(x_a), H_a, W_a) + x_f, x_a = x_f + self.drop_path(x_f_2), x_a + self.drop_path(x_a_2) + return x_f, x_a + + +class TwoWayCrossAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, + linear=False): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.linear = linear + self.sr_ratio = sr_ratio + for i in ['frame', 'audio']: + setattr(self, i + '_q', nn.Linear(dim, dim, bias=qkv_bias)) + setattr(self, i + '_kv', nn.Linear(dim, dim, bias=qkv_bias)) + setattr(self, i + '_attn_drop', nn.Dropout(attn_drop)) + setattr(self, i + '_proj', nn.Linear(dim, dim, bias=qkv_bias)) + setattr(self, i + '_proj_drop', nn.Dropout(proj_drop)) + if not linear: + if sr_ratio > 1: + setattr(self, i + '_sr', nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)) + setattr(self, i + '_norm', nn.LayerNorm(dim)) + else: + setattr(self, i + '_pool', nn.AdaptiveAvgPool2d(7)) + setattr(self, i + '_sr', nn.Conv2d(dim, dim, kernel_size=1, stride=1)) + setattr(self, i + '_norm', nn.LayerNorm(dim)) + setattr(self, i + '_act', nn.GELU()) + + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x_f, H_f, W_f, x_a, H_a, W_a): + B_f, N_f, C_f = x_f.shape + B_a, N_a, C_a = x_a.shape + q_f = self.frame_q(x_f).reshape(B_f, N_f, self.num_heads, C_f // self.num_heads).permute(0, 2, 1, 3) + q_a = self.audio_q(x_a).reshape(B_a, N_a, self.num_heads, C_a // self.num_heads).permute(0, 2, 1, 3) + + if not self.linear: + if self.sr_ratio > 1: + x_f = x_f.permute(0, 2, 1).reshape(B_f, C_f, H_f, W_f) + x_f = self.frame_sr(x_f).reshape(B_f, C_f, -1).permute(0, 2, 1) + x_f = self.frame_norm(x_f) + kv_f = self.frame_kv(x_f).reshape(B_f, -1, 2, self.num_heads, C_f // self.num_heads).permute(2, 0, 3, 1, + 4) + + x_a = x_a.permute(0, 2, 1).reshape(B_a, C_a, H_a, W_a) + x_a = self.audio_sr(x_a).reshape(B_a, C_a, -1).permute(0, 2, 1) + x_a = self.audio_norm(x_a) + kv_a = self.audio_kv(x_a).reshape(B_a, -1, 2, self.num_heads, C_f // self.num_heads).permute(2, 0, 3, 1, + 4) + + else: + kv_f = self.frame_kv(x_f).reshape(B_f, -1, 2, self.num_heads, C_f // self.num_heads).permute(2, 0, 3, 1, + 4) + kv_a = self.kv(x_a).reshape(B_a, -1, 2, self.num_heads, C_a // self.num_heads).permute(2, 0, 3, 1, 4) + else: + raise NotImplementedError + + k_f, v_f = kv_f[0], kv_f[1] + k_a, v_a = kv_a[0], kv_a[1] + + attn_a = (q_a @ k_f.transpose(-2, -1)) * self.scale + attn_a = attn_a.softmax(dim=-1) + attn_a = self.audio_attn_drop(attn_a) + x_a = (attn_a @ v_f).transpose(1, 2).reshape(B_a, N_a, C_a) + x_a = self.audio_proj(x_a) + x_a = self.audio_proj_drop(x_a) + + attn_f = (q_f @ k_a.transpose(-2, -1)) * self.scale + attn_f = attn_f.softmax(dim=-1) + attn_f = self.frame_attn_drop(attn_f) + x_f = (attn_f @ v_a).transpose(1, 2).reshape(B_f, N_f, C_f) + x_f = self.frame_proj(x_f) + x_f = self.frame_proj_drop(x_f) + + return x_f, x_a + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, linear=linear) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W, pos): + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) + + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, + linear=False): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.linear = linear + self.sr_ratio = sr_ratio + if not linear: + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.pool = nn.AdaptiveAvgPool2d(7) + self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1) + self.norm = nn.LayerNorm(dim) + self.act = nn.GELU() + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + B, N, C = x.shape + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + if not self.linear: + if self.sr_ratio > 1: + x_ = x.permute(0, 2, 1).reshape(B, C, H, W) + x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) + x_ = self.norm(x_) + kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + else: + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + else: + x_ = x.permute(0, 2, 1).reshape(B, C, H, W) + x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1) + x_ = self.norm(x_) + x_ = self.act(x_) + kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + self.linear = linear + + if self.linear: + self.relu = nn.ReLU(inplace=True) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + if self.linear: + x = self.relu(x) + x = self.dwconv(x, H, W) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2) + return x diff --git a/ref-avs.code/model/mymodel.py b/ref-avs.code/model/mymodel.py new file mode 100644 index 0000000000000000000000000000000000000000..7ae9338b84242637a60e868fd8ab3c27f7f65991 --- /dev/null +++ b/ref-avs.code/model/mymodel.py @@ -0,0 +1,94 @@ +"""End-to-end Ref-AVS: SAM2 visual backbone + AuralFuser fusion + tracking head. + +Orchestration follows ``avs.code/v1m.code/model/mymodel.py``. +""" +import torch + +from model.visual.sam2.build_sam import build_sam2_visual_predictor +from model.visual.sam2.utils.transforms import SAM2Transforms +from model.aural_fuser import AuralFuser +from transformers import AutoTokenizer, AutoModel + + +class AVmodel(torch.nn.Module): + """SAM2 + audio/text fusion (``aural_fuser``) + SAM2 tracking decoder.""" + + def __init__(self, param, mask_threshold=0.0, max_hole_area=0.0, max_sprinkle_area=0.0): + super().__init__() + self.param = param + self.mask_threshold = mask_threshold + self._bb_feat_sizes = [ + (int(self.param.image_size / 4), int(self.param.image_size / 4)), + (int(self.param.image_size / 8), int(self.param.image_size / 8)), + (int(self.param.image_size / 16), int(self.param.image_size / 16)), + ] + + self.v_model = build_sam2_visual_predictor( + self.param.sam_config_path, + self.param.backbone_weight, + apply_postprocessing=True, + mode='train', + hydra_overrides_extra=["++model.image_size={}".format(self.param.image_size)], + ) + self._transforms = SAM2Transforms( + resolution=self.v_model.image_size, + mask_threshold=mask_threshold, + max_hole_area=max_hole_area, + max_sprinkle_area=max_sprinkle_area, + ) + self.aural_fuser = AuralFuser(hyp_param=self.param) + self.text_tokenizer = AutoTokenizer.from_pretrained('distilbert/distilroberta-base') + self.t_model = AutoModel.from_pretrained('distilbert/distilroberta-base') + + def _encode_text(self, prompts): + """RoBERTa embeddings for referring expressions (frozen at train time).""" + enc = self.text_tokenizer( + *prompts, + max_length=25, + padding="max_length", + truncation=True, + return_tensors="pt", + ) + enc['input_ids'] = enc['input_ids'].cuda(self.param.local_rank, non_blocking=True) + enc['attention_mask'] = enc['attention_mask'].cuda(self.param.local_rank, non_blocking=True) + with torch.no_grad(): + return self.t_model(**enc).last_hidden_state + + def forward_frame(self, frame_): + """Single-frame SAM2 image encoder pass (same helper pattern as v1m).""" + frame = torch.nn.functional.interpolate( + frame_, (self.param.image_size, self.param.image_size), + antialias=True, align_corners=False, mode='bilinear', + ) + return self.v_model.image_encoder(frame) + + def forward(self, frames, spect, prompts, sam_process=False): + """Fuse audio+text into FPN, then run SAM2 tracking without box/mask prompts.""" + text_feats = self._encode_text(prompts) + backbone_feats = self.v_model.forward_image(frames, pre_compute=False) + audio_residual_feats = self.aural_fuser(backbone_feats, spect, text_feats) + visual_resfeats, audio_resfeats, proj_feats = audio_residual_feats + + map_res = visual_resfeats[::-1] + vec_res = audio_resfeats[::-1] + av_feats = (map_res, vec_res) + + backbone_feats = self.v_model.precompute_high_res_features(backbone_feats) + backbone_feats = self.v_model.dont_prepare_prompt_inputs( + backbone_feats, + num_frames=frames.shape[0], + condition_frame=int(frames.shape[0] / 2), + ) + outputs = self.v_model.forward_tracking_wo_prompt(backbone_feats, audio_res=av_feats) + return outputs, proj_feats + + @property + def device(self) -> torch.device: + return self.v_model.device + + def freeze_sam_parameters(self): + """Freeze SAM2 and text backbone; only ``aural_fuser`` is trained.""" + self.v_model.eval() + self.t_model.eval() + for _, parameter in self.v_model.named_parameters(): + parameter.requires_grad = False diff --git a/ref-avs.code/model/visual/sam2/__init__.py b/ref-avs.code/model/visual/sam2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ec9293b35f206027249490ec25cd9a4dd326332 --- /dev/null +++ b/ref-avs.code/model/visual/sam2/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from hydra import initialize_config_module +from hydra.core.global_hydra import GlobalHydra + +if not GlobalHydra.instance().is_initialized(): + initialize_config_module("model.visual.sam2", version_base="1.2") diff --git a/ref-avs.code/model/visual/sam2/automatic_mask_generator.py b/ref-avs.code/model/visual/sam2/automatic_mask_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..065e469e27c2d3af40d51d072031e828692c799b --- /dev/null +++ b/ref-avs.code/model/visual/sam2/automatic_mask_generator.py @@ -0,0 +1,454 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +from torchvision.ops.boxes import batched_nms, box_area # type: ignore + +from sam2.modeling.sam2_base import SAM2Base +from sam2.sam2_image_predictor import SAM2ImagePredictor +from sam2.utils.amg import ( + area_from_rle, + batch_iterator, + batched_mask_to_box, + box_xyxy_to_xywh, + build_all_layer_point_grids, + calculate_stability_score, + coco_encode_rle, + generate_crop_boxes, + is_box_near_crop_edge, + mask_to_rle_pytorch, + MaskData, + remove_small_regions, + rle_to_mask, + uncrop_boxes_xyxy, + uncrop_masks, + uncrop_points, +) + + +class SAM2AutomaticMaskGenerator: + def __init__( + self, + model: SAM2Base, + points_per_side: Optional[int] = 32, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.8, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + mask_threshold: float = 0.0, + box_nms_thresh: float = 0.7, + crop_n_layers: int = 0, + crop_nms_thresh: float = 0.7, + crop_overlap_ratio: float = 512 / 1500, + crop_n_points_downscale_factor: int = 1, + point_grids: Optional[List[np.ndarray]] = None, + min_mask_region_area: int = 0, + output_mode: str = "binary_mask", + use_m2m: bool = False, + multimask_output: bool = True, + **kwargs, + ) -> None: + """ + Using a SAM 2 model, generates masks for the entire image. + Generates a grid of point prompts over the image, then filters + low quality and duplicate masks. The default settings are chosen + for SAM 2 with a HieraL backbone. + + Arguments: + model (Sam): The SAM 2 model to use for mask prediction. + points_per_side (int or None): The number of points to be sampled + along one side of the image. The total number of points is + points_per_side**2. If None, 'point_grids' must provide explicit + point sampling. + points_per_batch (int): Sets the number of points run simultaneously + by the model. Higher numbers may be faster but use more GPU memory. + pred_iou_thresh (float): A filtering threshold in [0,1], using the + model's predicted mask quality. + stability_score_thresh (float): A filtering threshold in [0,1], using + the stability of the mask under changes to the cutoff used to binarize + the model's mask predictions. + stability_score_offset (float): The amount to shift the cutoff when + calculated the stability score. + mask_threshold (float): Threshold for binarizing the mask logits + box_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks. + crop_n_layers (int): If >0, mask prediction will be run again on + crops of the image. Sets the number of layers to run, where each + layer has 2**i_layer number of image crops. + crop_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks between different crops. + crop_overlap_ratio (float): Sets the degree to which crops overlap. + In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + crop_n_points_downscale_factor (int): The number of points-per-side + sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + point_grids (list(np.ndarray) or None): A list over explicit grids + of points used for sampling, normalized to [0,1]. The nth grid in the + list is used in the nth crop layer. Exclusive with points_per_side. + min_mask_region_area (int): If >0, postprocessing will be applied + to remove disconnected regions and holes in masks with area smaller + than min_mask_region_area. Requires opencv. + output_mode (str): The form masks are returned in. Can be 'binary_mask', + 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. + For large resolutions, 'binary_mask' may consume large amounts of + memory. + use_m2m (bool): Whether to add a one step refinement using previous mask predictions. + multimask_output (bool): Whether to output multimask at each point of the grid. + """ + + assert (points_per_side is None) != ( + point_grids is None + ), "Exactly one of points_per_side or point_grid must be provided." + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layers, + crop_n_points_downscale_factor, + ) + elif point_grids is not None: + self.point_grids = point_grids + else: + raise ValueError("Can't have both points_per_side and point_grid be None.") + + assert output_mode in [ + "binary_mask", + "uncompressed_rle", + "coco_rle", + ], f"Unknown output_mode {output_mode}." + if output_mode == "coco_rle": + try: + from pycocotools import mask as mask_utils # type: ignore # noqa: F401 + except ImportError as e: + print("Please install pycocotools") + raise e + + self.predictor = SAM2ImagePredictor( + model, + max_hole_area=min_mask_region_area, + max_sprinkle_area=min_mask_region_area, + ) + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.mask_threshold = mask_threshold + self.box_nms_thresh = box_nms_thresh + self.crop_n_layers = crop_n_layers + self.crop_nms_thresh = crop_nms_thresh + self.crop_overlap_ratio = crop_overlap_ratio + self.crop_n_points_downscale_factor = crop_n_points_downscale_factor + self.min_mask_region_area = min_mask_region_area + self.output_mode = output_mode + self.use_m2m = use_m2m + self.multimask_output = multimask_output + + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator": + """ + Load a pretrained model from the Hugging Face hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2AutomaticMaskGenerator): The loaded model. + """ + from sam2.build_sam import build_sam2_hf + + sam_model = build_sam2_hf(model_id, **kwargs) + return cls(sam_model, **kwargs) + + @torch.no_grad() + def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: + """ + Generates masks for the given image. + + Arguments: + image (np.ndarray): The image to generate masks for, in HWC uint8 format. + + Returns: + list(dict(str, any)): A list over records for masks. Each record is + a dict containing the following keys: + segmentation (dict(str, any) or np.ndarray): The mask. If + output_mode='binary_mask', is an array of shape HW. Otherwise, + is a dictionary containing the RLE. + bbox (list(float)): The box around the mask, in XYWH format. + area (int): The area in pixels of the mask. + predicted_iou (float): The model's own prediction of the mask's + quality. This is filtered by the pred_iou_thresh parameter. + point_coords (list(list(float))): The point coordinates input + to the model to generate this mask. + stability_score (float): A measure of the mask's quality. This + is filtered on using the stability_score_thresh parameter. + crop_box (list(float)): The crop of the image used to generate + the mask, given in XYWH format. + """ + + # Generate masks + mask_data = self._generate_masks(image) + + # Encode masks + if self.output_mode == "coco_rle": + mask_data["segmentations"] = [ + coco_encode_rle(rle) for rle in mask_data["rles"] + ] + elif self.output_mode == "binary_mask": + mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] + else: + mask_data["segmentations"] = mask_data["rles"] + + # Write mask records + curr_anns = [] + for idx in range(len(mask_data["segmentations"])): + ann = { + "segmentation": mask_data["segmentations"][idx], + "area": area_from_rle(mask_data["rles"][idx]), + "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), + "predicted_iou": mask_data["iou_preds"][idx].item(), + "point_coords": [mask_data["points"][idx].tolist()], + "stability_score": mask_data["stability_score"][idx].item(), + "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), + } + curr_anns.append(ann) + + return curr_anns + + def _generate_masks(self, image: np.ndarray) -> MaskData: + orig_size = image.shape[:2] + crop_boxes, layer_idxs = generate_crop_boxes( + orig_size, self.crop_n_layers, self.crop_overlap_ratio + ) + + # Iterate over image crops + data = MaskData() + for crop_box, layer_idx in zip(crop_boxes, layer_idxs): + crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) + data.cat(crop_data) + + # Remove duplicate masks between crops + if len(crop_boxes) > 1: + # Prefer masks from smaller crops + scores = 1 / box_area(data["crop_boxes"]) + scores = scores.to(data["boxes"].device) + keep_by_nms = batched_nms( + data["boxes"].float(), + scores, + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.crop_nms_thresh, + ) + data.filter(keep_by_nms) + data.to_numpy() + return data + + def _process_crop( + self, + image: np.ndarray, + crop_box: List[int], + crop_layer_idx: int, + orig_size: Tuple[int, ...], + ) -> MaskData: + # Crop the image and calculate embeddings + x0, y0, x1, y1 = crop_box + cropped_im = image[y0:y1, x0:x1, :] + cropped_im_size = cropped_im.shape[:2] + self.predictor.set_image(cropped_im) + + # Get points for this crop + points_scale = np.array(cropped_im_size)[None, ::-1] + points_for_image = self.point_grids[crop_layer_idx] * points_scale + + # Generate masks for this crop in batches + data = MaskData() + for (points,) in batch_iterator(self.points_per_batch, points_for_image): + batch_data = self._process_batch( + points, cropped_im_size, crop_box, orig_size, normalize=True + ) + data.cat(batch_data) + del batch_data + self.predictor.reset_predictor() + + # Remove duplicates within this crop. + keep_by_nms = batched_nms( + data["boxes"].float(), + data["iou_preds"], + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.box_nms_thresh, + ) + data.filter(keep_by_nms) + + # Return to the original image frame + data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) + data["points"] = uncrop_points(data["points"], crop_box) + data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) + + return data + + def _process_batch( + self, + points: np.ndarray, + im_size: Tuple[int, ...], + crop_box: List[int], + orig_size: Tuple[int, ...], + normalize=False, + ) -> MaskData: + orig_h, orig_w = orig_size + + # Run model on this batch + points = torch.as_tensor( + points, dtype=torch.float32, device=self.predictor.device + ) + in_points = self.predictor._transforms.transform_coords( + points, normalize=normalize, orig_hw=im_size + ) + in_labels = torch.ones( + in_points.shape[0], dtype=torch.int, device=in_points.device + ) + masks, iou_preds, low_res_masks = self.predictor._predict( + in_points[:, None, :], + in_labels[:, None], + multimask_output=self.multimask_output, + return_logits=True, + ) + + # Serialize predictions and store in MaskData + data = MaskData( + masks=masks.flatten(0, 1), + iou_preds=iou_preds.flatten(0, 1), + points=points.repeat_interleave(masks.shape[1], dim=0), + low_res_masks=low_res_masks.flatten(0, 1), + ) + del masks + + if not self.use_m2m: + # Filter by predicted IoU + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + # Calculate and filter by stability score + data["stability_score"] = calculate_stability_score( + data["masks"], self.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + else: + # One step refinement using previous mask predictions + in_points = self.predictor._transforms.transform_coords( + data["points"], normalize=normalize, orig_hw=im_size + ) + labels = torch.ones( + in_points.shape[0], dtype=torch.int, device=in_points.device + ) + masks, ious = self.refine_with_m2m( + in_points, labels, data["low_res_masks"], self.points_per_batch + ) + data["masks"] = masks.squeeze(1) + data["iou_preds"] = ious.squeeze(1) + + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + data["stability_score"] = calculate_stability_score( + data["masks"], self.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + + # Threshold masks and calculate boxes + data["masks"] = data["masks"] > self.mask_threshold + data["boxes"] = batched_mask_to_box(data["masks"]) + + # Filter boxes that touch crop boundaries + keep_mask = ~is_box_near_crop_edge( + data["boxes"], crop_box, [0, 0, orig_w, orig_h] + ) + if not torch.all(keep_mask): + data.filter(keep_mask) + + # Compress to RLE + data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) + data["rles"] = mask_to_rle_pytorch(data["masks"]) + del data["masks"] + + return data + + @staticmethod + def postprocess_small_regions( + mask_data: MaskData, min_area: int, nms_thresh: float + ) -> MaskData: + """ + Removes small disconnected regions and holes in masks, then reruns + box NMS to remove any new duplicates. + + Edits mask_data in place. + + Requires open-cv as a dependency. + """ + if len(mask_data["rles"]) == 0: + return mask_data + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for rle in mask_data["rles"]: + mask = rle_to_mask(rle) + + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and score=1 to unchanged masks + # so NMS will prefer ones that didn't need postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(masks) + keep_by_nms = batched_nms( + boxes.float(), + torch.as_tensor(scores), + torch.zeros_like(boxes[:, 0]), # categories + iou_threshold=nms_thresh, + ) + + # Only recalculate RLEs for masks that have changed + for i_mask in keep_by_nms: + if scores[i_mask] == 0.0: + mask_torch = masks[i_mask].unsqueeze(0) + mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] + mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly + mask_data.filter(keep_by_nms) + + return mask_data + + def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch): + new_masks = [] + new_iou_preds = [] + + for cur_points, cur_point_labels, low_res_mask in batch_iterator( + points_per_batch, points, point_labels, low_res_masks + ): + best_masks, best_iou_preds, _ = self.predictor._predict( + cur_points[:, None, :], + cur_point_labels[:, None], + mask_input=low_res_mask[:, None, :], + multimask_output=False, + return_logits=True, + ) + new_masks.append(best_masks) + new_iou_preds.append(best_iou_preds) + masks = torch.cat(new_masks, dim=0) + return masks, torch.cat(new_iou_preds, dim=0) diff --git a/ref-avs.code/model/visual/sam2/build_sam.py b/ref-avs.code/model/visual/sam2/build_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..dfc2589e3d6f19ea831bb0825f1e1521ae9aa9e6 --- /dev/null +++ b/ref-avs.code/model/visual/sam2/build_sam.py @@ -0,0 +1,171 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os + +import torch +from hydra import compose +from hydra.utils import instantiate +from omegaconf import OmegaConf +''' +import sam2 + +# Check if the user is running Python from the parent directory of the sam2 repo +# (i.e. the directory where this repo is cloned into) -- this is not supported since +# it could shadow the sam2 package and cause issues. +if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")): + # If the user has "sam2/sam2" in their path, they are likey importing the repo itself + # as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory). + # This typically happens because the user is running Python from the parent directory + # that contains the sam2 repo they cloned. + raise RuntimeError( + "You're likely running Python from the parent directory of the sam2 repository " + "(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). " + "This is not supported since the `sam2` Python package could be shadowed by the " + "repository name (the repository is also named `sam2` and contains the Python package " + "in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir " + "rather than its parent dir, or from your home directory) after installing SAM 2." + ) +''' + +HF_MODEL_ID_TO_FILENAMES = { + "facebook/sam2-hiera-tiny": ( + "configs/sam2/sam2_hiera_t.yaml", + "sam2_hiera_tiny.pt", + ), + "facebook/sam2-hiera-small": ( + "configs/sam2/sam2_hiera_s.yaml", + "sam2_hiera_small.pt", + ), + "facebook/sam2-hiera-base-plus": ( + "configs/sam2/sam2_hiera_b+.yaml", + "sam2_hiera_base_plus.pt", + ), + "facebook/sam2-hiera-large": ( + "configs/sam2/sam2_hiera_l.yaml", + "sam2_hiera_large.pt", + ), + "facebook/sam2.1-hiera-tiny": ( + "configs/sam2.1/sam2.1_hiera_t.yaml", + "sam2.1_hiera_tiny.pt", + ), + "facebook/sam2.1-hiera-small": ( + "configs/sam2.1/sam2.1_hiera_s.yaml", + "sam2.1_hiera_small.pt", + ), + "facebook/sam2.1-hiera-base-plus": ( + "configs/sam2.1/sam2.1_hiera_b+.yaml", + "sam2.1_hiera_base_plus.pt", + ), + "facebook/sam2.1-hiera-large": ( + "configs/sam2.1/sam2.1_hiera_l.yaml", + "sam2.1_hiera_large.pt", + ), +} + + +def build_sam2( + config_file, + ckpt_path=None, + device="cuda", + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, + **kwargs, +): + + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + # dynamically fall back to multi-mask if the single mask is not stable + "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + ] + # Read config and init model + cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) + OmegaConf.resolve(cfg) + model = instantiate(cfg.model, _recursive_=True) + _load_checkpoint(model, ckpt_path) + model = model.to(device) + if mode == "eval": + model.eval() + return model + + +def build_sam2_visual_predictor( + config_file, + ckpt_path=None, + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, + **kwargs, +): + # visual + hydra_overrides = [] + # "++model._target_=model.visual.sam2.organised_sam2_train.SAM2Train", + # ] + # hydra_overrides = [ + # "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", + # ] + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + + # dynamically fall back to multi-mask if the single mask is not stable + # "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + # "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + # "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + + # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking + "++model.binarize_mask_from_pts_for_mem_enc=true", + # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) + # "++model.fill_hole_area=8", + ] + hydra_overrides.extend(hydra_overrides_extra) + + # Read config and init model + cfg = compose(config_name=config_file, overrides=hydra_overrides) + OmegaConf.resolve(cfg) + model = instantiate(cfg.model, _recursive_=True) + _load_checkpoint(model, ckpt_path) + if mode == "eval": + model.eval() + return model + + +def _hf_download(model_id): + from huggingface_hub import hf_hub_download + + config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id] + ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) + return config_name, ckpt_path + + +def build_sam2_hf(model_id, **kwargs): + config_name, ckpt_path = _hf_download(model_id) + return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs) + + +# def build_sam2_video_predictor_hf(model_id, **kwargs): +# config_name, ckpt_path = _hf_download(model_id) +# return build_sam2_video_predictor( +# config_file=config_name, ckpt_path=ckpt_path, **kwargs +# ) + + +def _load_checkpoint(model, ckpt_path): + if ckpt_path is not None: + sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"] + missing_keys, unexpected_keys = model.load_state_dict(sd) + if missing_keys: + logging.error(missing_keys) + raise RuntimeError() + if unexpected_keys: + logging.error(unexpected_keys) + raise RuntimeError() + logging.info("Loaded checkpoint sucessfully") diff --git a/ref-avs.code/model/visual/sam2/configs/auralfuser/architecture.yaml b/ref-avs.code/model/visual/sam2/configs/auralfuser/architecture.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ab4c3d06ca42335ce6bfc8064bbd5cfd44c8080a --- /dev/null +++ b/ref-avs.code/model/visual/sam2/configs/auralfuser/architecture.yaml @@ -0,0 +1,30 @@ +# @package _global_ + +aural_fuser: + patch_cfgs: + - [4, 4] + - [2, 2] + - [1, 1] + f_depths: [3, 6, 12] + block_kw: + dim: 256 + num_heads: 4 + mlp_ratio: 4 + qkv_bias: true + qk_scale: null + drop: 0.1 + attn_drop: 0.1 + drop_path: 0.0 + sr_ratio: 4 + linear: false + one_d_kw: + dim: 256 + num_heads: 4 + mlp_ratio: 4 + qkv_bias: true + qk_scale: null + drop: 0.1 + attn_drop: 0.1 + drop_path: 0.0 + sr_ratio: 4 + linear: false diff --git a/ref-avs.code/model/visual/sam2/configs/sam2/sam2_hiera_l.yaml b/ref-avs.code/model/visual/sam2/configs/sam2/sam2_hiera_l.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8478b3d4b8b16d8b22f6555cf7b1f00231d7fd59 --- /dev/null +++ b/ref-avs.code/model/visual/sam2/configs/sam2/sam2_hiera_l.yaml @@ -0,0 +1,117 @@ +# @package _global_ + +# Model +model: + _target_: model.visual.sam2.organised_sam2_train.SAM2Train + image_encoder: + _target_: model.visual.sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: model.visual.sam2.modeling.backbones.hieradet.Hiera + embed_dim: 144 + num_heads: 2 + stages: [2, 6, 36, 4] + global_att_blocks: [23, 33, 43] + window_pos_embed_bkg_spatial_size: [7, 7] + window_spec: [8, 4, 16, 8] + neck: + _target_: model.visual.sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [1152, 576, 288, 144] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: model.visual.sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: model.visual.sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: model.visual.sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: model.visual.sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: model.visual.sam2.modeling.memory_encoder.Fuser + layer: + _target_: model.visual.sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/ref-avs.code/model/visual/sam2/configs/training/sam2_training_config.yaml b/ref-avs.code/model/visual/sam2/configs/training/sam2_training_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..29df1199d79c6a9031b82e23aa4b40df99064650 --- /dev/null +++ b/ref-avs.code/model/visual/sam2/configs/training/sam2_training_config.yaml @@ -0,0 +1,60 @@ +# @package _global_ + +# Video transforms +train_transforms: + - _target_: dataloader.sam2_dataset.transforms.ComposeAPI + transforms: + - _target_: dataloader.sam2_dataset.transforms.RandomHorizontalFlip + consistent_transform: True + - _target_: dataloader.sam2_dataset.transforms.RandomAffine + degrees: 25 + shear: 20 + image_interpolation: bilinear + consistent_transform: True + - _target_: dataloader.sam2_dataset.transforms.RandomResizeAPI + sizes: 1024 + square: true + consistent_transform: True + - _target_: dataloader.sam2_dataset.transforms.ColorJitter + consistent_transform: True + brightness: 0.1 + contrast: 0.03 + saturation: 0.03 + hue: null + - _target_: dataloader.sam2_dataset.transforms.RandomGrayscale + p: 0.05 + consistent_transform: True + - _target_: dataloader.sam2_dataset.transforms.ColorJitter + consistent_transform: False + brightness: 0.1 + contrast: 0.05 + saturation: 0.05 + hue: null + - _target_: dataloader.sam2_dataset.transforms.ToTensorAPI + - _target_: dataloader.sam2_dataset.transforms.NormalizeAPI + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +loss: + all: + _target_: loss.training.sam2_training_loss.MultiStepMultiMasksAndIous + weight_dict: + loss_mask: 20 + loss_dice: 1 + loss_iou: 1 + loss_class: 1 + supervise_all_iou: true + iou_use_l1_loss: true + pred_obj_scores: true + focal_gamma_obj_score: 0.0 + focal_alpha_obj_score: -1.0 + gpu_num: 4. + +contrastive_learning: + temperature: 0.10 + ignore_idx: 255 + ood_idx: 254 + max_views: 512 + proj_dim: 512 + sample_limits: 64 + total_limits: 15240 diff --git a/ref-avs.code/model/visual/sam2/modeling/backbones/hieradet.py b/ref-avs.code/model/visual/sam2/modeling/backbones/hieradet.py new file mode 100644 index 0000000000000000000000000000000000000000..3fb6633c9c752cbefe2fc6043c81fb79bc659465 --- /dev/null +++ b/ref-avs.code/model/visual/sam2/modeling/backbones/hieradet.py @@ -0,0 +1,317 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from functools import partial +from typing import List, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from iopath.common.file_io import g_pathmgr + +from model.visual.sam2.modeling.backbones.utils import ( + PatchEmbed, + window_partition, + window_unpartition, +) + +from model.visual.sam2.modeling.sam2_utils import DropPath, MLP + + +def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: + if pool is None: + return x + # (B, H, W, C) -> (B, C, H, W) + x = x.permute(0, 3, 1, 2) + x = pool(x) + # (B, C, H', W') -> (B, H', W', C) + x = x.permute(0, 2, 3, 1) + if norm: + x = norm(x) + + return x + + +class MultiScaleAttention(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + q_pool: nn.Module = None, + ): + super().__init__() + + self.dim = dim + self.dim_out = dim_out + self.num_heads = num_heads + self.q_pool = q_pool + self.qkv = nn.Linear(dim, dim_out * 3) + self.proj = nn.Linear(dim_out, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (B, H * W, 3, nHead, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) + # q, k, v with shape (B, H * W, nheads, C) + q, k, v = torch.unbind(qkv, 2) + + # Q pooling (for downsample at stage changes) + if self.q_pool: + q = do_pool(q.reshape(B, H, W, -1), self.q_pool) + H, W = q.shape[1:3] # downsampled shape + q = q.reshape(B, H * W, self.num_heads, -1) + + # Torch's SDPA expects [B, nheads, H*W, C] so we transpose + x = F.scaled_dot_product_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + ) + # Transpose back + x = x.transpose(1, 2) + x = x.reshape(B, H, W, -1) + + x = self.proj(x) + + return x + + +class MultiScaleBlock(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + mlp_ratio: float = 4.0, + drop_path: float = 0.0, + norm_layer: Union[nn.Module, str] = "LayerNorm", + q_stride: Tuple[int, int] = None, + act_layer: nn.Module = nn.GELU, + window_size: int = 0, + ): + super().__init__() + + if isinstance(norm_layer, str): + norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) + + self.dim = dim + self.dim_out = dim_out + self.norm1 = norm_layer(dim) + + self.window_size = window_size + + self.pool, self.q_stride = None, q_stride + if self.q_stride: + self.pool = nn.MaxPool2d( + kernel_size=q_stride, stride=q_stride, ceil_mode=False + ) + + self.attn = MultiScaleAttention( + dim, + dim_out, + num_heads=num_heads, + q_pool=self.pool, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim_out) + self.mlp = MLP( + dim_out, + int(dim_out * mlp_ratio), + dim_out, + num_layers=2, + activation=act_layer, + ) + + if dim != dim_out: + self.proj = nn.Linear(dim, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x # B, H, W, C + x = self.norm1(x) + + # Skip connection + if self.dim != self.dim_out: + shortcut = do_pool(self.proj(x), self.pool) + + # Window partition + window_size = self.window_size + if window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, window_size) + + # Window Attention + Q Pooling (if stage change) + x = self.attn(x) + if self.q_stride: + # Shapes have changed due to Q pooling + window_size = self.window_size // self.q_stride[0] + H, W = shortcut.shape[1:3] + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + pad_hw = (H + pad_h, W + pad_w) + + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, window_size, pad_hw, (H, W)) + + x = shortcut + self.drop_path(x) + # MLP + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Hiera(nn.Module): + """ + Reference: https://arxiv.org/abs/2306.00989 + """ + + def __init__( + self, + embed_dim: int = 96, # initial embed dim + num_heads: int = 1, # initial number of heads + drop_path_rate: float = 0.0, # stochastic depth + q_pool: int = 3, # number of q_pool stages + q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages + stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage + dim_mul: float = 2.0, # dim_mul factor at stage shift + head_mul: float = 2.0, # head_mul factor at stage shift + window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), + # window size per stage, when not using global att. + window_spec: Tuple[int, ...] = ( + 8, + 4, + 14, + 7, + ), + # global attn in these blocks + global_att_blocks: Tuple[int, ...] = ( + 12, + 16, + 20, + ), + weights_path=None, + return_interm_layers=True, # return feats from every stage + ): + super().__init__() + + assert len(stages) == len(window_spec) + self.window_spec = window_spec + + depth = sum(stages) + self.q_stride = q_stride + self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] + assert 0 <= q_pool <= len(self.stage_ends[:-1]) + self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] + self.return_interm_layers = return_interm_layers + + self.patch_embed = PatchEmbed( + embed_dim=embed_dim, + ) + # Which blocks have global att? + self.global_att_blocks = global_att_blocks + + # Windowed positional embedding (https://arxiv.org/abs/2311.05613) + self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size + self.pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size) + ) + self.pos_embed_window = nn.Parameter( + torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]) + ) + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + cur_stage = 1 + self.blocks = nn.ModuleList() + + for i in range(depth): + dim_out = embed_dim + # lags by a block, so first block of + # next stage uses an initial window size + # of previous stage and final window size of current stage + window_size = self.window_spec[cur_stage - 1] + + if self.global_att_blocks is not None: + window_size = 0 if i in self.global_att_blocks else window_size + + if i - 1 in self.stage_ends: + dim_out = int(embed_dim * dim_mul) + num_heads = int(num_heads * head_mul) + cur_stage += 1 + + block = MultiScaleBlock( + dim=embed_dim, + dim_out=dim_out, + num_heads=num_heads, + drop_path=dpr[i], + q_stride=self.q_stride if i in self.q_pool_blocks else None, + window_size=window_size, + ) + + embed_dim = dim_out + self.blocks.append(block) + + self.channel_list = ( + [self.blocks[i].dim_out for i in self.stage_ends[::-1]] + if return_interm_layers + else [self.blocks[-1].dim_out] + ) + + if weights_path is not None: + with g_pathmgr.open(weights_path, "rb") as f: + chkpt = torch.load(f, map_location="cpu") + logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False)) + + def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: + h, w = hw + window_embed = self.pos_embed_window + pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") + pos_embed = pos_embed + window_embed.tile( + [x // y for x, y in zip(pos_embed.shape, window_embed.shape)] + ) + pos_embed = pos_embed.permute(0, 2, 3, 1) + return pos_embed + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + x = self.patch_embed(x) + # x: (B, H, W, C) + + # Add pos embed + x = x + self._get_pos_embed(x.shape[1:3]) + + outputs = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if (i == self.stage_ends[-1]) or ( + i in self.stage_ends and self.return_interm_layers + ): + feats = x.permute(0, 3, 1, 2) + outputs.append(feats) + + return outputs + + def get_layer_id(self, layer_name): + # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 + num_layers = self.get_num_layers() + + if layer_name.find("rel_pos") != -1: + return num_layers + 1 + elif layer_name.find("pos_embed") != -1: + return 0 + elif layer_name.find("patch_embed") != -1: + return 0 + elif layer_name.find("blocks") != -1: + return int(layer_name.split("blocks")[1].split(".")[1]) + 1 + else: + return num_layers + 1 + + def get_num_layers(self) -> int: + return len(self.blocks) diff --git a/ref-avs.code/model/visual/sam2/modeling/backbones/image_encoder.py b/ref-avs.code/model/visual/sam2/modeling/backbones/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..37e9266bc98596e97ca303118c910ed24f6cee2c --- /dev/null +++ b/ref-avs.code/model/visual/sam2/modeling/backbones/image_encoder.py @@ -0,0 +1,134 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ImageEncoder(nn.Module): + def __init__( + self, + trunk: nn.Module, + neck: nn.Module, + scalp: int = 0, + ): + super().__init__() + self.trunk = trunk + self.neck = neck + self.scalp = scalp + assert ( + self.trunk.channel_list == self.neck.backbone_channel_list + ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" + + def forward(self, sample: torch.Tensor): + # Forward through backbone + features, pos = self.neck(self.trunk(sample)) + if self.scalp > 0: + # Discard the lowest resolution features + features, pos = features[: -self.scalp], pos[: -self.scalp] + + src = features[-1] + output = { + "vision_features": src, + "vision_pos_enc": pos, + "backbone_fpn": features, + } + return output + + +class FpnNeck(nn.Module): + """ + A modified variant of Feature Pyramid Network (FPN) neck + (we remove output conv and also do bicubic interpolation similar to ViT + pos embed interpolation) + """ + + def __init__( + self, + position_encoding: nn.Module, + d_model: int, + backbone_channel_list: List[int], + kernel_size: int = 1, + stride: int = 1, + padding: int = 0, + fpn_interp_model: str = "bilinear", + fuse_type: str = "sum", + fpn_top_down_levels: Optional[List[int]] = None, + ): + """Initialize the neck + :param trunk: the backbone + :param position_encoding: the positional encoding to use + :param d_model: the dimension of the model + :param neck_norm: the normalization to use + """ + super().__init__() + self.position_encoding = position_encoding + self.convs = nn.ModuleList() + self.backbone_channel_list = backbone_channel_list + self.d_model = d_model + for dim in backbone_channel_list: + current = nn.Sequential() + current.add_module( + "conv", + nn.Conv2d( + in_channels=dim, + out_channels=d_model, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ), + ) + + self.convs.append(current) + self.fpn_interp_model = fpn_interp_model + assert fuse_type in ["sum", "avg"] + self.fuse_type = fuse_type + + # levels to have top-down features in its outputs + # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 + # have top-down propagation, while outputs of level 0 and level 1 have only + # lateral features from the same backbone level. + if fpn_top_down_levels is None: + # default is to have top-down features on all levels + fpn_top_down_levels = range(len(self.convs)) + self.fpn_top_down_levels = list(fpn_top_down_levels) + + def forward(self, xs: List[torch.Tensor]): + + out = [None] * len(self.convs) + pos = [None] * len(self.convs) + assert len(xs) == len(self.convs) + # fpn forward pass + # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py + prev_features = None + # forward in top-down order (from low to high resolution) + n = len(self.convs) - 1 + for i in range(n, -1, -1): + x = xs[i] + lateral_features = self.convs[n - i](x) + if i in self.fpn_top_down_levels and prev_features is not None: + top_down_features = F.interpolate( + prev_features.to(dtype=torch.float32), + scale_factor=2.0, + mode=self.fpn_interp_model, + align_corners=( + None if self.fpn_interp_model == "nearest" else False + ), + antialias=False, + ) + prev_features = lateral_features + top_down_features + if self.fuse_type == "avg": + prev_features /= 2 + else: + prev_features = lateral_features + x_out = prev_features + out[i] = x_out + pos[i] = self.position_encoding(x_out).to(x_out.dtype) + + return out, pos diff --git a/ref-avs.code/model/visual/sam2/modeling/backbones/utils.py b/ref-avs.code/model/visual/sam2/modeling/backbones/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..32d55c7545f064de133a5ff0200ba1ece9b504b7 --- /dev/null +++ b/ref-avs.code/model/visual/sam2/modeling/backbones/utils.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Some utilities for backbones, in particular for windowing""" + +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def window_partition(x, window_size): + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows, (Hp, Wp) + + +def window_unpartition(windows, window_size, pad_hw, hw): + """ + Window unpartition into original sequences and removing padding. + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view( + B, Hp // window_size, Wp // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, ...] = (7, 7), + stride: Tuple[int, ...] = (4, 4), + padding: Tuple[int, ...] = (3, 3), + in_chans: int = 3, + embed_dim: int = 768, + ): + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/ref-avs.code/model/visual/sam2/modeling/memory_attention.py b/ref-avs.code/model/visual/sam2/modeling/memory_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..11f4ccb1904f022c18f8a02b9590a66bd57bb8f1 --- /dev/null +++ b/ref-avs.code/model/visual/sam2/modeling/memory_attention.py @@ -0,0 +1,169 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from torch import nn, Tensor + +from model.visual.sam2.modeling.sam.transformer import RoPEAttention + +from model.visual.sam2.modeling.sam2_utils import get_activation_fn, get_clones + + +class MemoryAttentionLayer(nn.Module): + + def __init__( + self, + activation: str, + cross_attention: nn.Module, + d_model: int, + dim_feedforward: int, + dropout: float, + pos_enc_at_attn: bool, + pos_enc_at_cross_attn_keys: bool, + pos_enc_at_cross_attn_queries: bool, + self_attention: nn.Module, + ): + super().__init__() + self.d_model = d_model + self.dim_feedforward = dim_feedforward + self.dropout_value = dropout + self.self_attn = self_attention + self.cross_attn_image = cross_attention + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation_str = activation + self.activation = get_activation_fn(activation) + + # Where to add pos enc + self.pos_enc_at_attn = pos_enc_at_attn + self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries + self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys + + def _forward_sa(self, tgt, query_pos): + # Self-Attention + tgt2 = self.norm1(tgt) + q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 + tgt2 = self.self_attn(q, k, v=tgt2) + tgt = tgt + self.dropout1(tgt2) + return tgt + + def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): + kwds = {} + if num_k_exclude_rope > 0: + assert isinstance(self.cross_attn_image, RoPEAttention) + kwds = {"num_k_exclude_rope": num_k_exclude_rope} + + # Cross-Attention + tgt2 = self.norm2(tgt) + tgt2 = self.cross_attn_image( + q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, + k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, + v=memory, + **kwds, + ) + tgt = tgt + self.dropout2(tgt2) + return tgt + + def forward( + self, + tgt, + memory, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + num_k_exclude_rope: int = 0, + ) -> torch.Tensor: + + # Self-Attn, Cross-Attn + tgt = self._forward_sa(tgt, query_pos) + tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) + # MLP + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + +class MemoryAttention(nn.Module): + def __init__( + self, + d_model: int, + pos_enc_at_input: bool, + layer: nn.Module, + num_layers: int, + batch_first: bool = True, # Do layers expect batch first input? + ): + super().__init__() + self.d_model = d_model + self.layers = get_clones(layer, num_layers) + self.num_layers = num_layers + self.norm = nn.LayerNorm(d_model) + self.pos_enc_at_input = pos_enc_at_input + self.batch_first = batch_first + + def forward( + self, + curr: torch.Tensor, # self-attention inputs + memory: torch.Tensor, # cross-attention inputs + curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs + memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs + num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* + ): + if isinstance(curr, list): + assert isinstance(curr_pos, list) + assert len(curr) == len(curr_pos) == 1 + curr, curr_pos = ( + curr[0], + curr_pos[0], + ) + + assert ( + curr.shape[1] == memory.shape[1] + ), "Batch size must be the same for curr and memory" + + output = curr + if self.pos_enc_at_input and curr_pos is not None: + output = output + 0.1 * curr_pos + + if self.batch_first: + # Convert to batch first + output = output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + memory = memory.transpose(0, 1) + memory_pos = memory_pos.transpose(0, 1) + + for layer in self.layers: + kwds = {} + if isinstance(layer.cross_attn_image, RoPEAttention): + kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} + + output = layer( + tgt=output, + memory=memory, + pos=memory_pos, + query_pos=curr_pos, + **kwds, + ) + normed_output = self.norm(output) + + if self.batch_first: + # Convert back to seq first + normed_output = normed_output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + + return normed_output diff --git a/ref-avs.code/model/visual/sam2/modeling/memory_encoder.py b/ref-avs.code/model/visual/sam2/modeling/memory_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..7e1143cc0d5774ff96108203e404f678f14b0a23 --- /dev/null +++ b/ref-avs.code/model/visual/sam2/modeling/memory_encoder.py @@ -0,0 +1,181 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model.visual.sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d + + +class MaskDownSampler(nn.Module): + """ + Progressively downsample a mask by total_stride, each time by stride. + Note that LayerNorm is applied per *token*, like in ViT. + + With each downsample (by a factor stride**2), channel capacity increases by the same factor. + In the end, we linearly project to embed_dim channels. + """ + + def __init__( + self, + embed_dim=256, + kernel_size=4, + stride=4, + padding=0, + total_stride=16, + activation=nn.GELU, + ): + super().__init__() + num_layers = int(math.log2(total_stride) // math.log2(stride)) + assert stride**num_layers == total_stride + self.encoder = nn.Sequential() + mask_in_chans, mask_out_chans = 1, 1 + for _ in range(num_layers): + mask_out_chans = mask_in_chans * (stride**2) + self.encoder.append( + nn.Conv2d( + mask_in_chans, + mask_out_chans, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + ) + self.encoder.append(LayerNorm2d(mask_out_chans)) + self.encoder.append(activation()) + mask_in_chans = mask_out_chans + + self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) + + def forward(self, x): + return self.encoder(x) + + +# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) +class CXBlock(nn.Module): + r"""ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__( + self, + dim, + kernel_size=7, + padding=3, + drop_path=0.0, + layer_scale_init_value=1e-6, + use_dwconv=True, + ): + super().__init__() + self.dwconv = nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=padding, + groups=dim if use_dwconv else 1, + ) # depthwise conv + self.norm = LayerNorm2d(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, 4 * dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = self.norm(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class Fuser(nn.Module): + def __init__(self, layer, num_layers, dim=None, input_projection=False): + super().__init__() + self.proj = nn.Identity() + self.layers = get_clones(layer, num_layers) + + if input_projection: + assert dim is not None + self.proj = nn.Conv2d(dim, dim, kernel_size=1) + + def forward(self, x): + # normally x: (N, C, H, W) + x = self.proj(x) + for layer in self.layers: + x = layer(x) + return x + + +class MemoryEncoder(nn.Module): + def __init__( + self, + out_dim, + mask_downsampler, + fuser, + position_encoding, + in_dim=256, # in_dim of pix_feats + ): + super().__init__() + + self.mask_downsampler = mask_downsampler + + self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) + self.fuser = fuser + self.position_encoding = position_encoding + self.out_proj = nn.Identity() + if out_dim != in_dim: + self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) + + def forward( + self, + pix_feat: torch.Tensor, + masks: torch.Tensor, + skip_mask_sigmoid: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + ## Process masks + # sigmoid, so that less domain shift from gt masks which are bool + if not skip_mask_sigmoid: + masks = F.sigmoid(masks) + masks = self.mask_downsampler(masks) + + ## Fuse pix_feats and downsampled masks + # in case the visual features are on CPU, cast them to CUDA + pix_feat = pix_feat.to(masks.device) + + x = self.pix_feat_proj(pix_feat) + x = x + masks + x = self.fuser(x) + x = self.out_proj(x) + + pos = self.position_encoding(x).to(x.dtype) + + return {"vision_features": x, "vision_pos_enc": [pos]} diff --git a/ref-avs.code/model/visual/sam2/modeling/position_encoding.py b/ref-avs.code/model/visual/sam2/modeling/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..52ac22674d5d4fdd9e83b6bdf034bff56d04bc0d --- /dev/null +++ b/ref-avs.code/model/visual/sam2/modeling/position_encoding.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Any, Optional, Tuple + +import numpy as np + +import torch +from torch import nn + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention Is All You Need paper, generalized to work on images. + """ + + def __init__( + self, + num_pos_feats, + temperature: int = 10000, + normalize: bool = True, + scale: Optional[float] = None, + ): + super().__init__() + assert num_pos_feats % 2 == 0, "Expecting even model width" + self.num_pos_feats = num_pos_feats // 2 + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + self.cache = {} + + def _encode_xy(self, x, y): + # The positions are expected to be normalized + assert len(x) == len(y) and x.ndim == y.ndim == 1 + x_embed = x * self.scale + y_embed = y * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, None] / dim_t + pos_y = y_embed[:, None] / dim_t + pos_x = torch.stack( + (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2 + ).flatten(1) + pos_y = torch.stack( + (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2 + ).flatten(1) + return pos_x, pos_y + + @torch.no_grad() + def encode_boxes(self, x, y, w, h): + pos_x, pos_y = self._encode_xy(x, y) + pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) + return pos + + encode = encode_boxes # Backwards compatibility + + @torch.no_grad() + def encode_points(self, x, y, labels): + (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape + assert bx == by and nx == ny and bx == bl and nx == nl + pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) + pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) + pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) + return pos + + @torch.no_grad() + def forward(self, x: torch.Tensor): + cache_key = (x.shape[-2], x.shape[-1]) + if cache_key in self.cache: + return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) + y_embed = ( + torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) + .view(1, -1, 1) + .repeat(x.shape[0], 1, x.shape[-1]) + ) + x_embed = ( + torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) + .view(1, 1, -1) + .repeat(x.shape[0], x.shape[-2], 1) + ) + + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + self.cache[cache_key] = pos[0] + return pos + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C + + +# Rotary Positional Encoding, adapted from: +# 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py +# 2. https://github.com/naver-ai/rope-vit +# 3. https://github.com/lucidrains/rotary-embedding-torch + + +def init_t_xy(end_x: int, end_y: int): + t = torch.arange(end_x * end_y, dtype=torch.float32) + t_x = (t % end_x).float() + t_y = torch.div(t, end_x, rounding_mode="floor").float() + return t_x, t_y + + +def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): + freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + + t_x, t_y = init_t_xy(end_x, end_y) + freqs_x = torch.outer(t_x, freqs_x) + freqs_y = torch.outer(t_y, freqs_y) + freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) + freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) + return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) + shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_enc( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, + repeat_freqs_k: bool = False, +): + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = ( + torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + if xk.shape[-2] != 0 + else None + ) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + if xk_ is None: + # no keys to rotate, due to dropout + return xq_out.type_as(xq).to(xq.device), xk + # repeat freqs along seq_len dim to match k seq_len + if repeat_freqs_k: + r = xk_.shape[-2] // xq_.shape[-2] + if freqs_cis.is_cuda: + freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) + else: + # torch.repeat on complex numbers may not be supported on non-CUDA devices + # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten + freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) diff --git a/ref-avs.code/model/visual/sam2/modeling/sam/mask_decoder.py b/ref-avs.code/model/visual/sam2/modeling/sam/mask_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..078f21cc2ec41805eebec677e6e27771335deaa4 --- /dev/null +++ b/ref-avs.code/model/visual/sam2/modeling/sam/mask_decoder.py @@ -0,0 +1,300 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn + +from model.visual.sam2.modeling.sam2_utils import LayerNorm2d, MLP + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + use_high_res_features: bool = False, + iou_prediction_use_sigmoid=False, + dynamic_multimask_via_stability=False, + dynamic_multimask_stability_delta=0.05, + dynamic_multimask_stability_thresh=0.98, + pred_obj_scores: bool = False, + pred_obj_scores_mlp: bool = False, + use_multimask_token_for_obj_ptr: bool = False, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.pred_obj_scores = pred_obj_scores + if self.pred_obj_scores: + self.obj_score_token = nn.Embedding(1, transformer_dim) + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d( + transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 + ), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d( + transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 + ), + activation(), + ) + self.use_high_res_features = use_high_res_features + if use_high_res_features: + self.conv_s0 = nn.Conv2d( + transformer_dim, transformer_dim // 8, kernel_size=1, stride=1 + ) + self.conv_s1 = nn.Conv2d( + transformer_dim, transformer_dim // 4, kernel_size=1, stride=1 + ) + + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, + iou_head_hidden_dim, + self.num_mask_tokens, + iou_head_depth, + sigmoid_output=iou_prediction_use_sigmoid, + ) + if self.pred_obj_scores: + self.pred_obj_score_head = nn.Linear(transformer_dim, 1) + if pred_obj_scores_mlp: + self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3) + + # When outputting a single mask, optionally we can dynamically fall back to the best + # multimask output token if the single mask output token gives low stability scores. + self.dynamic_multimask_via_stability = dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + audio_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + torch.Tensor: batched SAM token for mask output + """ + masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + repeat_image=repeat_image, + high_res_features=high_res_features, + audio_res_features_=audio_res_features + ) + + # Select the correct mask or masks for output + if multimask_output: + masks = masks[:, 1:, :, :] + iou_pred = iou_pred[:, 1:] + elif self.dynamic_multimask_via_stability and not self.training: + masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) + else: + masks = masks[:, 0:1, :, :] + iou_pred = iou_pred[:, 0:1] + + + if multimask_output and self.use_multimask_token_for_obj_ptr: + sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape + else: + # Take the mask output token. Here we *always* use the token for single mask output. + # At test time, even if we track after 1-click (and using multimask_output=True), + # we still take the single mask token here. The rationale is that we always track + # after multiple clicks during training, so the past tokens seen during training + # are always the single mask token (and we'll let it be the object-memory token). + sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape + + # Prepare output + return masks, iou_pred, sam_tokens_out, object_score_logits + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + audio_res_features_: Optional[List[torch.Tensor]] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + s = 0 + if self.pred_obj_scores: + output_tokens = torch.cat( + [ + self.obj_score_token.weight, + self.iou_token.weight, + self.mask_tokens.weight, + ], + dim=0, + ) + s = 1 + else: + output_tokens = torch.cat( + [self.iou_token.weight, self.mask_tokens.weight], dim=0 + ) + output_tokens = output_tokens.unsqueeze(0).expand( + sparse_prompt_embeddings.size(0), -1, -1 + ) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + if repeat_image: + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + else: + assert image_embeddings.shape[0] == tokens.shape[0] + src = image_embeddings + src = src + dense_prompt_embeddings + assert ( + image_pe.size(0) == 1 + ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens, audio_res_features_) + iou_token_out = hs[:, s, :] + mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + + if not self.use_high_res_features: + upscaled_embedding = self.output_upscaling(src) + else: + dc1, ln1, act1, dc2, act2 = self.output_upscaling + feat_s0, feat_s1 = high_res_features + upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) + upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) + + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append( + self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) + ) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + if self.pred_obj_scores: + assert s == 1 + object_score_logits = self.pred_obj_score_head(hs[:, 0, :]) + else: + # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1 + object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1) + + return masks, iou_pred, mask_tokens_out, object_score_logits + + def _get_stability_scores(self, mask_logits): + """ + Compute stability scores of the mask logits based on the IoU between upper and + lower thresholds. + """ + mask_logits = mask_logits.flatten(-2) + stability_delta = self.dynamic_multimask_stability_delta + area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() + area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() + stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) + return stability_scores + + def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): + """ + When outputting a single mask, if the stability score from the current single-mask + output (based on output token 0) falls below a threshold, we instead select from + multi-mask outputs (based on output token 1~3) the mask with the highest predicted + IoU score. This is intended to ensure a valid mask for both clicking and tracking. + """ + # The best mask from multimask output tokens (1~3) + multimask_logits = all_mask_logits[:, 1:, :, :] + multimask_iou_scores = all_iou_scores[:, 1:] + best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) + batch_inds = torch.arange( + multimask_iou_scores.size(0), device=all_iou_scores.device + ) + best_multimask_logits = multimask_logits[batch_inds, best_scores_inds] + best_multimask_logits = best_multimask_logits.unsqueeze(1) + best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds] + best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1) + + # The mask from singlemask output token 0 and its stability score + singlemask_logits = all_mask_logits[:, 0:1, :, :] + singlemask_iou_scores = all_iou_scores[:, 0:1] + stability_scores = self._get_stability_scores(singlemask_logits) + is_stable = stability_scores >= self.dynamic_multimask_stability_thresh + + # Dynamically fall back to best multimask output upon low stability scores. + mask_logits_out = torch.where( + is_stable[..., None, None].expand_as(singlemask_logits), + singlemask_logits, + best_multimask_logits, + ) + iou_scores_out = torch.where( + is_stable.expand_as(singlemask_iou_scores), + singlemask_iou_scores, + best_multimask_iou_scores, + ) + return mask_logits_out, iou_scores_out diff --git a/ref-avs.code/model/visual/sam2/modeling/sam/prompt_encoder.py b/ref-avs.code/model/visual/sam2/modeling/sam/prompt_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..038cebcc072ae7c0f3f83061061be3edba04d0f8 --- /dev/null +++ b/ref-avs.code/model/visual/sam2/modeling/sam/prompt_encoder.py @@ -0,0 +1,188 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple, Type + +import torch +from torch import nn + +from model.visual.sam2.modeling.position_encoding import PositionEmbeddingRandom + +from model.visual.sam2.modeling.sam2_utils import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [ + nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) + ] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = ( + 4 * image_embedding_size[0], + 4 * image_embedding_size[1], + ) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords( + points, self.input_image_size + ) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + point_embedding[labels == 2] += self.point_embeddings[2].weight + point_embedding[labels == 3] += self.point_embeddings[3].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords( + coords, self.input_image_size + ) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + # we only utilise sounding as prompt. + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty( + (bs, 0, self.embed_dim), device=self._get_device() + ) + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + ''' + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + ''' + return sparse_embeddings, dense_embeddings + diff --git a/ref-avs.code/model/visual/sam2/modeling/sam/transformer.py b/ref-avs.code/model/visual/sam2/modeling/sam/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..31916550afeccb66f4427cee7ec4a7a2d66913a5 --- /dev/null +++ b/ref-avs.code/model/visual/sam2/modeling/sam/transformer.py @@ -0,0 +1,367 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import math +import warnings +from functools import partial +from typing import Tuple, Type + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from model.visual.sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis +from model.visual.sam2.modeling.sam2_utils import MLP +from model.visual.sam2.utils.misc import get_sdpa_settings + +warnings.simplefilter(action="ignore", category=FutureWarning) +# Check whether Flash Attention is available (and use it by default) +OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings() +# A fallback setting to allow all available kernels if Flash Attention fails +ALLOW_ALL_KERNELS = False + + +def sdp_kernel_context(dropout_p): + """ + Get the context for the attention scaled dot-product kernel. We use Flash Attention + by default, but fall back to all available kernels if Flash Attention fails. + """ + if ALLOW_ALL_KERNELS: + return contextlib.nullcontext() + + return torch.backends.cuda.sdp_kernel( + enable_flash=USE_FLASH_ATTN, + # if Flash attention kernel is off, then math kernel needs to be enabled + enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON, + enable_mem_efficient=OLD_GPU, + ) + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + audio_res: [], + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + visual_res, audio_res = audio_res + + # Prepare queries + queries = point_embedding + keys = image_embedding + # Apply transformer blocks and final layernorm + for i, layer in enumerate(self.layers): + keys = keys + visual_res[i] + queries[:, 2:6] = queries[:, 2:6] + audio_res[i] + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + queries[:, 2:6] = queries[:, 2:6] + audio_res[-1] + keys = keys + visual_res[-1] + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLP( + embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation + ) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + dropout: float = 0.0, + kv_in_dim: int = None, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert ( + self.internal_dim % num_heads == 0 + ), "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + self.dropout_p = dropout + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + try: + with sdp_kernel_context(dropout_p): + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + except Exception as e: + # Fall back to all kernels if the Flash attention kernel fails + warnings.warn( + f"Flash Attention kernel failed due to: {e}\nFalling back to all available " + f"kernels for scaled_dot_product_attention (which may have a slower speed).", + category=UserWarning, + stacklevel=2, + ) + global ALLOW_ALL_KERNELS + ALLOW_ALL_KERNELS = True + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out + + +class RoPEAttention(Attention): + """Attention with rotary position encoding.""" + + def __init__( + self, + *args, + rope_theta=10000.0, + # whether to repeat q rope to match k length + # this is needed for cross-attention to memories + rope_k_repeat=False, + feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.compute_cis = partial( + compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta + ) + freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) + self.freqs_cis = freqs_cis + self.rope_k_repeat = rope_k_repeat + + def forward( + self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0 + ) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Apply rotary position encoding + w = h = math.sqrt(q.shape[-2]) + self.freqs_cis = self.freqs_cis.to(q.device) + if self.freqs_cis.shape[0] != q.shape[-2]: + self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) + if q.shape[-2] != k.shape[-2]: + assert self.rope_k_repeat + + num_k_rope = k.size(-2) - num_k_exclude_rope + q, k[:, :, :num_k_rope] = apply_rotary_enc( + q, + k[:, :, :num_k_rope], + freqs_cis=self.freqs_cis, + repeat_freqs_k=self.rope_k_repeat, + ) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + try: + with sdp_kernel_context(dropout_p): + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + except Exception as e: + # Fall back to all kernels if the Flash attention kernel fails + warnings.warn( + f"Flash Attention kernel failed due to: {e}\nFalling back to all available " + f"kernels for scaled_dot_product_attention (which may have a slower speed).", + category=UserWarning, + stacklevel=2, + ) + global ALLOW_ALL_KERNELS + ALLOW_ALL_KERNELS = True + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/ref-avs.code/model/visual/sam2/modeling/sam2_base.py b/ref-avs.code/model/visual/sam2/modeling/sam2_base.py new file mode 100644 index 0000000000000000000000000000000000000000..fa1f740376b8cd48b18ab2988de9e51c6b36b429 --- /dev/null +++ b/ref-avs.code/model/visual/sam2/modeling/sam2_base.py @@ -0,0 +1,943 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed +import torch.nn.functional as F + +from torch.nn.init import trunc_normal_ + +from model.visual.sam2.modeling.sam.mask_decoder import MaskDecoder +from model.visual.sam2.modeling.sam.prompt_encoder import PromptEncoder +from model.visual.sam2.modeling.sam.transformer import TwoWayTransformer +from model.visual.sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames + +# a large negative value as a placeholder score for missing objects +NO_OBJ_SCORE = -1024.0 + + +class SAM2Base(torch.nn.Module): + def __init__( + self, + image_encoder, + memory_attention, + memory_encoder, + num_maskmem=7, # default 1 input frame + 6 previous frames + image_size=512, + backbone_stride=16, # stride of the image backbone output + sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob + sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob + # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks + binarize_mask_from_pts_for_mem_enc=False, + use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder + # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit, + # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model + # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM. + max_cond_frames_in_attn=-1, + # on the first frame, whether to directly add the no-memory embedding to the image feature + # (instead of using the transformer encoder) + directly_add_no_mem_embed=False, + # whether to use high-resolution feature maps in the SAM mask decoder + use_high_res_features_in_sam=False, + # whether to output multiple (3) masks for the first click on initial conditioning frames + multimask_output_in_sam=False, + # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`; + # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points) + multimask_min_pt_num=1, + multimask_max_pt_num=1, + # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`) + multimask_output_for_tracking=False, + # Whether to use multimask tokens for obj ptr; Only relevant when both + # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True + use_multimask_token_for_obj_ptr: bool = False, + # whether to use sigmoid to restrict ious prediction to [0-1] + iou_prediction_use_sigmoid=False, + # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5). + # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of + # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame. + memory_temporal_stride_for_eval=1, + # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks) + non_overlap_masks_for_mem_enc=False, + # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder=False, + # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`) + max_obj_ptrs_in_encoder=16, + # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`) + add_tpos_enc_to_obj_ptrs=True, + # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference + # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) + proj_tpos_enc_in_obj_ptrs=False, + # whether to use signed distance (instead of unsigned absolute distance) in the temporal positional encoding in the object pointers + # (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) + use_signed_tpos_enc_to_obj_ptrs=False, + # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation + # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking) + only_obj_ptrs_in_the_past_for_eval=False, + # Whether to predict if there is an object in the frame + pred_obj_scores: bool = False, + # Whether to use an MLP to predict object scores + pred_obj_scores_mlp: bool = False, + # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True; + # Whether to have a fixed no obj pointer when there is no object present + # or to use it as an additive embedding with obj_ptr produced by decoder + fixed_no_obj_ptr: bool = False, + # Soft no object, i.e. mix in no_obj_ptr softly, + # hope to make recovery easier if there is a mistake and mitigate accumulation of errors + soft_no_obj_ptr: bool = False, + use_mlp_for_obj_ptr_proj: bool = False, + # add no obj embedding to spatial frames + no_obj_embed_spatial: bool = False, + # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class. + sam_mask_decoder_extra_args=None, + compile_image_encoder: bool = False, + ): + super().__init__() + + # Part 1: the image backbone + self.image_encoder = image_encoder + # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting + self.use_high_res_features_in_sam = use_high_res_features_in_sam + self.num_feature_levels = 3 if use_high_res_features_in_sam else 1 + self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder + self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder + if use_obj_ptrs_in_encoder: + # A conv layer to downsample the mask prompt to stride 4 (the same stride as + # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale, + # so that it can be fed into the SAM mask decoder to generate a pointer. + self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4) + self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs + if proj_tpos_enc_in_obj_ptrs: + assert add_tpos_enc_to_obj_ptrs # these options need to be used together + self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs + self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs + self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval + + # Part 2: memory attention to condition current frame's visual features + # with memories (and obj ptrs) from past frames + self.memory_attention = memory_attention + + #### this is for Version 2.0 + # self.hidden_dim = memory_attention.d_model + #### this is for Version 2.1 + # self.hidden_dim = image_encoder.neck.d_model + self.hidden_dim = 256 # well, it is always 256 anyway. + + # Part 3: memory encoder for the previous frame's outputs + self.memory_encoder = memory_encoder + self.mem_dim = self.hidden_dim + if hasattr(self.memory_encoder, "out_proj") and hasattr( + self.memory_encoder.out_proj, "weight" + ): + # if there is compression of memories along channel dim + self.mem_dim = self.memory_encoder.out_proj.weight.shape[0] + self.num_maskmem = num_maskmem # Number of memories accessible + # Temporal encoding of the memories + self.maskmem_tpos_enc = torch.nn.Parameter( + torch.zeros(num_maskmem, 1, 1, self.mem_dim) + ) + trunc_normal_(self.maskmem_tpos_enc, std=0.02) + # a single token to indicate no memory embedding from previous frames + self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + trunc_normal_(self.no_mem_embed, std=0.02) + trunc_normal_(self.no_mem_pos_enc, std=0.02) + self.directly_add_no_mem_embed = directly_add_no_mem_embed + # Apply sigmoid to the output raw mask logits (to turn them from + # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder + self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc + self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc + self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc + self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc + self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval + # On frames with mask input, whether to directly output the input mask without + # using a SAM prompt encoder + mask decoder + self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam + self.multimask_output_in_sam = multimask_output_in_sam + self.multimask_min_pt_num = multimask_min_pt_num + self.multimask_max_pt_num = multimask_max_pt_num + self.multimask_output_for_tracking = multimask_output_for_tracking + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid + + # Part 4: SAM-style prompt encoder (for both mask and point inputs) + # and SAM-style mask decoder for the final mask output + self.image_size = image_size + self.backbone_stride = backbone_stride + self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args + self.pred_obj_scores = pred_obj_scores + self.pred_obj_scores_mlp = pred_obj_scores_mlp + self.fixed_no_obj_ptr = fixed_no_obj_ptr + self.soft_no_obj_ptr = soft_no_obj_ptr + if self.fixed_no_obj_ptr: + assert self.pred_obj_scores + assert self.use_obj_ptrs_in_encoder + if self.pred_obj_scores and self.use_obj_ptrs_in_encoder: + self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) + trunc_normal_(self.no_obj_ptr, std=0.02) + self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj + self.no_obj_embed_spatial = None + if no_obj_embed_spatial: + self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim)) + trunc_normal_(self.no_obj_embed_spatial, std=0.02) + + self._build_sam_heads() + self.max_cond_frames_in_attn = max_cond_frames_in_attn + + # Model compilation + if compile_image_encoder: + # Compile the forward function (not the full module) to allow loading checkpoints. + print( + "Image encoder compilation is enabled. First forward pass will be slow." + ) + self.image_encoder.forward = torch.compile( + self.image_encoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + + ### we fix the use_mask_input_as_output_without_sam to be turned off. + self.use_mask_input_as_output_without_sam = False + + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, *args, **kwargs): + raise NotImplementedError( + "Please use the corresponding methods in SAM2VideoPredictor for inference or SAM2Train for training/fine-tuning" + "See notebooks/video_predictor_example.ipynb for an inference example." + ) + + def _build_sam_heads(self): + """Build SAM-style prompt encoder and mask decoder.""" + self.sam_prompt_embed_dim = self.hidden_dim + self.sam_image_embedding_size = self.image_size // self.backbone_stride + + # build PromptEncoder and MaskDecoder from SAM + # (their hyperparameters like `mask_in_chans=16` are from SAM code) + self.sam_prompt_encoder = PromptEncoder( + embed_dim=self.sam_prompt_embed_dim, + image_embedding_size=( + self.sam_image_embedding_size, + self.sam_image_embedding_size, + ), + input_image_size=(self.image_size, self.image_size), + mask_in_chans=16, + ) + self.sam_mask_decoder = MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=self.sam_prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=self.sam_prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + use_high_res_features=self.use_high_res_features_in_sam, + iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, + pred_obj_scores=self.pred_obj_scores, + pred_obj_scores_mlp=self.pred_obj_scores_mlp, + use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, + **(self.sam_mask_decoder_extra_args or {}), + ) + if self.use_obj_ptrs_in_encoder: + # a linear projection on SAM output tokens to turn them into object pointers + self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim) + if self.use_mlp_for_obj_ptr_proj: + self.obj_ptr_proj = MLP( + self.hidden_dim, self.hidden_dim, self.hidden_dim, 3 + ) + else: + self.obj_ptr_proj = torch.nn.Identity() + if self.proj_tpos_enc_in_obj_ptrs: + # a linear projection on temporal positional encoding in object pointers to + # avoid potential interference with spatial positional encoding + self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim) + else: + self.obj_ptr_tpos_proj = torch.nn.Identity() + + def _forward_sam_heads( + self, + backbone_features, + point_inputs=None, + mask_inputs=None, + high_res_features=None, + multimask_output=False, + audio_res=None + ): + """ + Forward SAM prompt encoders and mask heads. + + Inputs: + - backbone_features: image features of [B, C, H, W] shape + - point_inputs: a dictionary with "point_coords" and "point_labels", where + 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the + absolute pixel-unit coordinate in (x, y) format of the P input points + 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means + positive clicks, 0 means negative clicks, and -1 means padding + - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the + same spatial size as the image. + - high_res_features: either 1) None or 2) or a list of length 2 containing + two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively, + which will be used as high-resolution feature maps for SAM decoder. + - multimask_output: if it's True, we output 3 candidate masks and their 3 + corresponding IoU estimates, and if it's False, we output only 1 mask and + its corresponding IoU estimate. + + Outputs: + - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if + `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM + output mask logits (before sigmoid) for the low-resolution masks, with 4x + the resolution (1/4 stride) of the input backbone_features. + - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3 + if `multimask_output=True` and M = 1 if `multimask_output=False`), + upsampled from the low-resolution masks, with shape size as the image + (stride is 1 pixel). + - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1 + if `multimask_output=False`), the estimated IoU of each output mask. + - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `low_res_multimasks`. + - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `high_res_multimasks`. + - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted + based on the output token from the SAM mask decoder. + """ + B = backbone_features.size(0) + device = backbone_features.device + assert backbone_features.size(1) == self.sam_prompt_embed_dim + assert backbone_features.size(2) == self.sam_image_embedding_size + assert backbone_features.size(3) == self.sam_image_embedding_size + + ''' + # a) Handle point prompts + if point_inputs is not None: + sam_point_coords = point_inputs["point_coords"] + sam_point_labels = point_inputs["point_labels"] + assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B + raise NotImplementedError + else: + # If no points are provide, pad with an empty point (with label -1) + sam_point_coords = torch.zeros(B, 1, 2, device=device) + sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) + + # b) Handle mask prompts + if mask_inputs is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) + if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: + sam_mask_prompt = F.interpolate( + mask_inputs.float(), + size=self.sam_prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + else: + sam_mask_prompt = mask_inputs + raise NotImplementedError + else: + # Otherwise, simply feed None (and SAM's prompt encoder will add + # a learned `no_mask_embed` to indicate no mask input in this case). + sam_mask_prompt = None + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( + points=(sam_point_coords, sam_point_labels), + boxes=None, + masks=sam_mask_prompt, + ) + ''' + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( + points=None, + boxes=None, + masks=None, + ) + + ( + low_res_multimasks, + ious, + sam_output_tokens, + object_score_logits, + ) = self.sam_mask_decoder( + image_embeddings=backbone_features, + image_pe=self.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=False, # the image is already batched + high_res_features=high_res_features, + audio_res_features=audio_res + ) + ''' + if self.pred_obj_scores: + is_obj_appearing = object_score_logits > 0 + + # Mask used for spatial memories is always a *hard* choice between obj and no obj, + # consistent with the actual mask prediction + low_res_multimasks = torch.where( + is_obj_appearing[:, None, None], + low_res_multimasks, + NO_OBJ_SCORE, + ) + ''' + # convert masks from possibly bfloat16 (or float16) to float32 + # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) + low_res_multimasks = low_res_multimasks.float() + high_res_multimasks = F.interpolate( + low_res_multimasks, + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + sam_output_token = sam_output_tokens[:, 0] + if multimask_output: + # comment this line temporarily. + # take the best mask prediction (with the highest IoU estimation) + best_iou_inds = torch.argmax(ious, dim=-1) + batch_inds = torch.arange(B, device=device) + low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + if sam_output_tokens.size(1) > 1: + sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] + ''' + low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks + ''' + else: + low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks + + # Extract object pointer from the SAM output token (with occlusion handling) + obj_ptr = self.obj_ptr_proj(sam_output_token) + + # don't train occlusion at the moment, command temporarily. + if self.pred_obj_scores: + is_obj_appearing = object_score_logits > 0 + # Allow *soft* no obj ptr, unlike for masks + if self.soft_no_obj_ptr: + lambda_is_obj_appearing = object_score_logits.sigmoid() + else: + lambda_is_obj_appearing = is_obj_appearing.float() + + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + return ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): + """ + Directly turn binary `mask_inputs` into a output mask logits without using SAM. + (same input and output shapes as in _forward_sam_heads above). + """ + # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid). + out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05 + mask_inputs_float = mask_inputs.float() + high_res_masks = mask_inputs_float * out_scale + out_bias + low_res_masks = F.interpolate( + high_res_masks, + size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + # a dummy IoU prediction of all 1's under mask input + ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float() + if not self.use_obj_ptrs_in_encoder: + # all zeros as a dummy object pointer (of shape [B, C]) + obj_ptr = torch.zeros( + mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device + ) + else: + # produce an object pointer using the SAM decoder from the mask input + _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads( + backbone_features=backbone_features, + mask_inputs=self.mask_downsample(mask_inputs_float), + high_res_features=high_res_features, + ) + # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem; + # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying + # on the object_scores from the SAM decoder. + is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1) + is_obj_appearing = is_obj_appearing[..., None] + lambda_is_obj_appearing = is_obj_appearing.float() + object_score_logits = out_scale * lambda_is_obj_appearing + out_bias + if self.pred_obj_scores: + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_masks, + high_res_masks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def precompute_high_res_features(self, backbone_out): + if self.use_high_res_features_in_sam: + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0( + backbone_out["backbone_fpn"][0] + ) + backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1( + backbone_out["backbone_fpn"][1] + ) + return backbone_out + + def forward_image(self, img_batch: torch.Tensor, pre_compute=True): + """Get the image feature on the input batch.""" + backbone_out = self.image_encoder(img_batch) + return backbone_out if not pre_compute else self.precompute_high_res_features(backbone_out) + + def _prepare_backbone_features(self, backbone_out): + """Prepare and flatten visual features.""" + backbone_out = backbone_out.copy() + assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) + assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels + + feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :] + vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :] + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] + # flatten NxCxHxW to HWxNxC + vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] + vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] + + return backbone_out, vision_feats, vision_pos_embeds, feat_sizes + + def _prepare_memory_conditioned_features( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + ): + """Fuse the current frame's visual feature map with previous memory.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + device = current_vision_feats[-1].device + # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images. + # In this case, we skip the fusion with any memory. + if self.num_maskmem == 0: # Disable memory and skip fusion + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + return pix_feat + + num_obj_ptr_tokens = 0 + tpos_sign_mul = -1 if track_in_reverse else 1 + # Step 1: condition the visual features of the current frame on previous memories + if not is_init_cond_frame: + # Retrieve the memories encoded with the maskmem backbone + to_cat_memory, to_cat_memory_pos_embed = [], [] + # Add conditioning frames's output first (all cond frames have t_pos=0 for + # when getting temporal positional embedding below) + assert len(output_dict["cond_frame_outputs"]) > 0 + # Select a maximum number of temporally closest cond frames for cross attention + cond_outputs = output_dict["cond_frame_outputs"] + selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames( + frame_idx, cond_outputs, self.max_cond_frames_in_attn + ) + t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()] + # for t_pos in range(1, min(self.num_maskmem, frame_idx)): + # out = output_dict["non_cond_frame_outputs"].get(t_pos, None) + # t_pos_and_prevs.append((t_pos, out)) + # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory + # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 + # We also allow taking the memory frame non-consecutively (with stride>1), in which case + # we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame. + stride = 1 if self.training else self.memory_temporal_stride_for_eval + + for t_pos in range(1, self.num_maskmem): + t_rel = self.num_maskmem - t_pos # how many frames before current frame + if t_rel == 1: + # for t_rel == 1, we take the last frame (regardless of r) + if not track_in_reverse: + # the frame immediately before this frame (i.e. frame_idx - 1) + prev_frame_idx = frame_idx - t_rel + else: + # the frame immediately after this frame (i.e. frame_idx + 1) + prev_frame_idx = frame_idx + t_rel + else: + # for t_rel >= 2, we take the memory frame from every r-th frames + if not track_in_reverse: + # first find the nearest frame among every r-th frames before this frame + # for r=1, this would be (frame_idx - 2) + prev_frame_idx = ((frame_idx - 2) // stride) * stride + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx - (t_rel - 2) * stride + else: + # first find the nearest frame among every r-th frames after this frame + # for r=1, this would be (frame_idx + 2) + prev_frame_idx = -(-(frame_idx + 2) // stride) * stride + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx + (t_rel - 2) * stride + out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) + if out is None: + # If an unselected conditioning frame is among the last (self.num_maskmem - 1) + # frames, we still attend to it as if it's a non-conditioning frame. + out = unselected_cond_outputs.get(prev_frame_idx, None) + t_pos_and_prevs.append((t_pos, out)) + + for t_pos, prev in t_pos_and_prevs: + if prev is None: + continue # skip padding frames + # "maskmem_features" might have been offloaded to CPU in demo use cases, + # so we load it back to GPU (it's a no-op if it's already on GPU). + feats = prev["maskmem_features"].to(device, non_blocking=True) + to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) + # Spatial positional encoding (it might have been offloaded to CPU in eval) + maskmem_enc = prev["maskmem_pos_enc"][-1].to(device) + maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) + # Temporal positional encoding + maskmem_enc = ( + maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1] + ) + to_cat_memory_pos_embed.append(maskmem_enc) + # Construct the list of past object pointers + if self.use_obj_ptrs_in_encoder: + max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder) + # First add those object pointers from selected conditioning frames + # (optionally, only include object pointers in the past during evaluation) + if not self.training and self.only_obj_ptrs_in_the_past_for_eval: + ptr_cond_outputs = { + t: out + for t, out in selected_cond_outputs.items() + if (t >= frame_idx if track_in_reverse else t <= frame_idx) + } + else: + ptr_cond_outputs = selected_cond_outputs + pos_and_ptrs = [ + # Temporal pos encoding contains how far away each pointer is from current frame + ( + ( + (frame_idx - t) * tpos_sign_mul + if self.use_signed_tpos_enc_to_obj_ptrs + else abs(frame_idx - t) + ), + out["obj_ptr"], + ) + for t, out in ptr_cond_outputs.items() + ] + # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame + for t_diff in range(1, max_obj_ptrs_in_encoder): + t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff + if t < 0 or (num_frames is not None and t >= num_frames): + break + out = output_dict["non_cond_frame_outputs"].get( + t, unselected_cond_outputs.get(t, None) + ) + if out is not None: + pos_and_ptrs.append((t_diff, out["obj_ptr"])) + # If we have at least one object pointer, add them to the across attention + if len(pos_and_ptrs) > 0: + pos_list, ptrs_list = zip(*pos_and_ptrs) + # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape + obj_ptrs = torch.stack(ptrs_list, dim=0) + # a temporal positional embedding based on how far each object pointer is from + # the current frame (sine embedding normalized by the max pointer num). + # default false. + if self.add_tpos_enc_to_obj_ptrs: + t_diff_max = max_obj_ptrs_in_encoder - 1 + tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim + obj_pos = torch.tensor(pos_list, device=device) + obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) + obj_pos = self.obj_ptr_tpos_proj(obj_pos) + obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) + else: + obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim) + if self.mem_dim < C: + # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C + obj_ptrs = obj_ptrs.reshape( + -1, B, C // self.mem_dim, self.mem_dim + ) + obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1) + obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0) + to_cat_memory.append(obj_ptrs) + to_cat_memory_pos_embed.append(obj_pos) + num_obj_ptr_tokens = obj_ptrs.shape[0] + else: + num_obj_ptr_tokens = 0 + else: + # for initial conditioning frames, encode them without using any previous memory + if self.directly_add_no_mem_embed: + # directly add no-mem embedding (instead of using the transformer encoder) + pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + # Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder) + # the Following lines will never be triggered. + raise NotImplementedError + to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] + to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] + + # Step 2: Concatenate the memories and forward through the transformer encoder + memory = torch.cat(to_cat_memory, dim=0) + memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0) + + pix_feat_with_mem = self.memory_attention( + curr=current_vision_feats, + curr_pos=current_vision_pos_embeds, + memory=memory, + memory_pos=memory_pos_embed, + num_obj_ptr_tokens=num_obj_ptr_tokens, + ) + # reshape the output (HW)BC => BCHW + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + def _encode_new_memory( + self, + current_vision_feats, + feat_sizes, + pred_masks_high_res, + object_score_logits, + is_mask_from_pts, + ): + """Encode the current image and its prediction into a memory feature.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + # top-level feature, (HW)BC => BCHW + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + if self.non_overlap_masks_for_mem_enc and not self.training: + # optionally, apply non-overlapping constraints to the masks (it's applied + # in the batch dimension and should only be used during eval, where all + # the objects come from the same video under batch size 1). + pred_masks_high_res = self._apply_non_overlapping_constraints( + pred_masks_high_res + ) + raise NotImplementedError + # scale the raw mask logits with a temperature before applying sigmoid + binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts + if binarize and not self.training: + mask_for_mem = (pred_masks_high_res > 0).float() + else: + # apply sigmoid on the raw mask logits to turn them into range (0, 1) + mask_for_mem = torch.sigmoid(pred_masks_high_res) + # apply scale and bias terms to the sigmoid probabilities + if self.sigmoid_scale_for_mem_enc != 1.0: + mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc + if self.sigmoid_bias_for_mem_enc != 0.0: + mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc + maskmem_out = self.memory_encoder( + pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied + ) + maskmem_features = maskmem_out["vision_features"] + maskmem_pos_enc = maskmem_out["vision_pos_enc"] + # add a no-object embedding to the spatial memory to indicate that the frame + # is predicted to be occluded (i.e. no object is appearing in the frame) + if self.no_obj_embed_spatial is not None: + is_obj_appearing = (object_score_logits > 0).float() + maskmem_features += ( + 1 - is_obj_appearing[..., None, None] + ) * self.no_obj_embed_spatial[..., None, None].expand( + *maskmem_features.shape + ) + # it will be used in sam2.1 + # raise NotImplementedError + + return maskmem_features, maskmem_pos_enc + + def _track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ): + current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(current_vision_feats) > 1: + high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) + ] + else: + high_res_features = None + if mask_inputs is not None and self.use_mask_input_as_output_without_sam: + # When use_mask_input_as_output_without_sam=True, we directly output the mask input + # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. + pix_feat = current_vision_feats[-1].permute(1, 2, 0) + pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + sam_outputs = self._use_mask_as_output( + pix_feat, high_res_features, mask_inputs + ) + else: + # fused the visual feature with previous memory features in the memory bank + pix_feat = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats[-1:], + current_vision_pos_embeds=current_vision_pos_embeds[-1:], + feat_sizes=feat_sizes[-1:], + output_dict=output_dict, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + ) + # apply SAM-style segmentation head + # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, + # e.g. in demo where such logits come from earlier interaction instead of correction sampling + # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) + if prev_sam_mask_logits is not None: + assert point_inputs is not None and mask_inputs is None + mask_inputs = prev_sam_mask_logits + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + + return current_out, sam_outputs, high_res_features, pix_feat + + def _encode_memory_in_output( + self, + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ): + if run_mem_encoder and self.num_maskmem > 0: + high_res_masks_for_mem_enc = high_res_masks + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks_for_mem_enc, + object_score_logits=object_score_logits, + is_mask_from_pts=(point_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + + def track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + # Whether to run the memory encoder on the predicted masks. Sometimes we might want + # to skip the memory encoder with `run_mem_encoder=False`. For example, + # in demo we might call `track_step` multiple times for each user click, + # and only encode the memory when the user finalizes their clicks. And in ablation + # settings like SAM training on static images, we don't need the memory encoder. + run_mem_encoder=True, + # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). + prev_sam_mask_logits=None, + ): + current_out, sam_outputs, _, _ = self._track_step( + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ) + + ( + _, + _, + _, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = sam_outputs + + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + if not self.training: + # Only add this in inference (to avoid unused param in activation checkpointing; + # it's mainly used in the demo to encode spatial memories w/ consolidated masks) + current_out["object_score_logits"] = object_score_logits + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + self._encode_memory_in_output( + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ) + + return current_out + + def _use_multimask(self, is_init_cond_frame, point_inputs): + """Whether to use multimask output in the SAM head.""" + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) + multimask_output = ( + self.multimask_output_in_sam + and (is_init_cond_frame or self.multimask_output_for_tracking) + and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) + ) + return multimask_output + + def _apply_non_overlapping_constraints(self, pred_masks): + """ + Apply non-overlapping constraints to the object scores in pred_masks. Here we + keep only the highest scoring object at each spatial location in pred_masks. + """ + batch_size = pred_masks.size(0) + if batch_size == 1: + return pred_masks + + device = pred_masks.device + # "max_obj_inds": object index of the object with the highest score at each location + max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) + # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks` + batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] + keep = max_obj_inds == batch_obj_inds + # suppress overlapping regions' scores below -10.0 so that the foreground regions + # don't overlap (here sigmoid(-10.0)=4.5398e-05) + pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) + return pred_masks diff --git a/ref-avs.code/model/visual/sam2/modeling/sam2_utils.py b/ref-avs.code/model/visual/sam2/modeling/sam2_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..19133558dd657bbcf67f851011d45bd4999cab0a --- /dev/null +++ b/ref-avs.code/model/visual/sam2/modeling/sam2_utils.py @@ -0,0 +1,323 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import copy +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model.visual.sam2.utils.misc import mask_to_box + + +def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): + """ + Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs` + that are temporally closest to the current frame at `frame_idx`. Here, we take + - a) the closest conditioning frame before `frame_idx` (if any); + - b) the closest conditioning frame after `frame_idx` (if any); + - c) any other temporally closest conditioning frames until reaching a total + of `max_cond_frame_num` conditioning frames. + + Outputs: + - selected_outputs: selected items (keys & values) from `cond_frame_outputs`. + - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`. + """ + if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: + selected_outputs = cond_frame_outputs + unselected_outputs = {} + else: + assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" + selected_outputs = {} + + # the closest conditioning frame before `frame_idx` (if any) + idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None) + if idx_before is not None: + selected_outputs[idx_before] = cond_frame_outputs[idx_before] + + # the closest conditioning frame after `frame_idx` (if any) + idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None) + if idx_after is not None: + selected_outputs[idx_after] = cond_frame_outputs[idx_after] + + # add other temporally closest conditioning frames until reaching a total + # of `max_cond_frame_num` conditioning frames. + num_remain = max_cond_frame_num - len(selected_outputs) + inds_remain = sorted( + (t for t in cond_frame_outputs if t not in selected_outputs), + key=lambda x: abs(x - frame_idx), + )[:num_remain] + selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) + unselected_outputs = { + t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs + } + + return selected_outputs, unselected_outputs + + +def get_1d_sine_pe(pos_inds, dim, temperature=10000): + """ + Get 1D sine positional embedding as in the original Transformer paper. + """ + pe_dim = dim // 2 + dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) + dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) + + pos_embed = pos_inds.unsqueeze(-1) / dim_t + pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) + return pos_embed + + +def get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") + + +def get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +class DropPath(nn.Module): + # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py + def __init__(self, drop_prob=0.0, scale_by_keep=True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and self.scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + activation: nn.Module = nn.ReLU, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + self.act = activation() + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +def sample_box_points( + masks: torch.Tensor, + noise: float = 0.1, # SAM default + noise_bound: int = 20, # SAM default + top_left_label: int = 2, + bottom_right_label: int = 3, +) -> Tuple[np.array, np.array]: + """ + Sample a noised version of the top left and bottom right corners of a given `bbox` + + Inputs: + - masks: [B, 1, H,W] boxes, dtype=torch.Tensor + - noise: noise as a fraction of box width and height, dtype=float + - noise_bound: maximum amount of noise (in pure pixesl), dtype=int + + Returns: + - box_coords: [B, num_pt, 2], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.float + - box_labels: [B, num_pt], label 2 is reserverd for top left and 3 for bottom right corners, dtype=torch.int32 + """ + device = masks.device + box_coords = mask_to_box(masks) + B, _, H, W = masks.shape + box_labels = torch.tensor( + [top_left_label, bottom_right_label], dtype=torch.int, device=device + ).repeat(B) + if noise > 0.0: + if not isinstance(noise_bound, torch.Tensor): + noise_bound = torch.tensor(noise_bound, device=device) + bbox_w = box_coords[..., 2] - box_coords[..., 0] + bbox_h = box_coords[..., 3] - box_coords[..., 1] + max_dx = torch.min(bbox_w * noise, noise_bound) + max_dy = torch.min(bbox_h * noise, noise_bound) + box_noise = 2 * torch.rand(B, 1, 4, device=device) - 1 + box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1) + + box_coords = box_coords + box_noise + img_bounds = ( + torch.tensor([W, H, W, H], device=device) - 1 + ) # uncentered pixel coords + box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping + + box_coords = box_coords.reshape(-1, 2, 2) # always 2 points + box_labels = box_labels.reshape(-1, 2) + return box_coords, box_labels + + +def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1): + """ + Sample `num_pt` random points (along with their labels) independently from the error regions. + + Inputs: + - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool + - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None + - num_pt: int, number of points to sample independently for each of the B error maps + + Outputs: + - points: [B, num_pt, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point + - labels: [B, num_pt], dtype=torch.int32, where 1 means positive clicks and 0 means + negative clicks + """ + if pred_masks is None: # if pred_masks is not provided, treat it as empty + pred_masks = torch.zeros_like(gt_masks) + assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1 + assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape + assert num_pt >= 0 + + B, _, H_im, W_im = gt_masks.shape + device = gt_masks.device + + # false positive region, a new point sampled in this region should have + # negative label to correct the FP error + fp_masks = ~gt_masks & pred_masks + # false negative region, a new point sampled in this region should have + # positive label to correct the FN error + fn_masks = gt_masks & ~pred_masks + # whether the prediction completely match the ground-truth on each mask + all_correct = torch.all((gt_masks == pred_masks).flatten(2), dim=2) + all_correct = all_correct[..., None, None] + + # channel 0 is FP map, while channel 1 is FN map + pts_noise = torch.rand(B, num_pt, H_im, W_im, 2, device=device) + # sample a negative new click from FP region or a positive new click + # from FN region, depend on where the maximum falls, + # and in case the predictions are all correct (no FP or FN), we just + # sample a negative click from the background region + pts_noise[..., 0] *= fp_masks | (all_correct & ~gt_masks) + pts_noise[..., 1] *= fn_masks + pts_idx = pts_noise.flatten(2).argmax(dim=2) + labels = (pts_idx % 2).to(torch.int32) + pts_idx = pts_idx // 2 + pts_x = pts_idx % W_im + pts_y = pts_idx // W_im + points = torch.stack([pts_x, pts_y], dim=2).to(torch.float) + return points, labels + + +def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True): + """ + Sample 1 random point (along with its label) from the center of each error region, + that is, the point with the largest distance to the boundary of each error region. + This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py + + Inputs: + - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool + - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None + - padding: if True, pad with boundary of 1 px for distance transform + + Outputs: + - points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point + - labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks + """ + import cv2 + + if pred_masks is None: + pred_masks = torch.zeros_like(gt_masks) + assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1 + assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape + + B, _, _, W_im = gt_masks.shape + device = gt_masks.device + + # false positive region, a new point sampled in this region should have + # negative label to correct the FP error + fp_masks = ~gt_masks & pred_masks + # false negative region, a new point sampled in this region should have + # positive label to correct the FN error + fn_masks = gt_masks & ~pred_masks + + fp_masks = fp_masks.cpu().numpy() + fn_masks = fn_masks.cpu().numpy() + points = torch.zeros(B, 1, 2, dtype=torch.float) + labels = torch.ones(B, 1, dtype=torch.int32) + for b in range(B): + fn_mask = fn_masks[b, 0] + fp_mask = fp_masks[b, 0] + if padding: + fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant") + fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant") + # compute the distance of each point in FN/FP region to its boundary + fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0) + fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0) + if padding: + fn_mask_dt = fn_mask_dt[1:-1, 1:-1] + fp_mask_dt = fp_mask_dt[1:-1, 1:-1] + + # take the point in FN/FP region with the largest distance to its boundary + fn_mask_dt_flat = fn_mask_dt.reshape(-1) + fp_mask_dt_flat = fp_mask_dt.reshape(-1) + fn_argmax = np.argmax(fn_mask_dt_flat) + fp_argmax = np.argmax(fp_mask_dt_flat) + is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax] + pt_idx = fn_argmax if is_positive else fp_argmax + points[b, 0, 0] = pt_idx % W_im # x + points[b, 0, 1] = pt_idx // W_im # y + labels[b, 0] = int(is_positive) + + points = points.to(device) + labels = labels.to(device) + return points, labels + + +def get_next_point(gt_masks, pred_masks, method): + if method == "uniform": + return sample_random_points_from_errors(gt_masks, pred_masks) + elif method == "center": + return sample_one_point_from_error_center(gt_masks, pred_masks) + else: + raise ValueError(f"unknown sampling method {method}") diff --git a/ref-avs.code/model/visual/sam2/organised_sam2_train.py b/ref-avs.code/model/visual/sam2/organised_sam2_train.py new file mode 100644 index 0000000000000000000000000000000000000000..49814159b96732aadacf0d04bed9b346cc663678 --- /dev/null +++ b/ref-avs.code/model/visual/sam2/organised_sam2_train.py @@ -0,0 +1,812 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import numpy as np +import torch +import torch.distributed +from model.visual.sam2.modeling.sam2_base import SAM2Base +from model.visual.sam2.modeling.sam2_utils import ( + get_1d_sine_pe, + get_next_point, + sample_box_points, + select_closest_cond_frames, +) + +from utils.misc import concat_points + +from utils.data_utils import BatchedVideoDatapoint + + +class SAM2Train(SAM2Base): + def __init__( + self, + image_encoder, + memory_attention=None, + memory_encoder=None, + prob_to_use_pt_input_for_train=0.0, + prob_to_use_pt_input_for_eval=0.0, + prob_to_use_box_input_for_train=0.0, + prob_to_use_box_input_for_eval=0.0, + # if it is greater than 1, we interactive point sampling in the 1st frame and other randomly selected frames + num_frames_to_correct_for_train=1, # default: only iteratively sample on first frame + num_frames_to_correct_for_eval=1, # default: only iteratively sample on first frame + rand_frames_to_correct_for_train=False, + rand_frames_to_correct_for_eval=False, + # how many frames to use as initial conditioning frames (for both point input and mask input; the first frame is always used as an initial conditioning frame) + # - if `rand_init_cond_frames` below is True, we randomly sample 1~num_init_cond_frames initial conditioning frames + # - otherwise we sample a fixed number of num_init_cond_frames initial conditioning frames + # note: for point input, we sample correction points on all such initial conditioning frames, and we require that `num_frames_to_correct` >= `num_init_cond_frames`; + # these are initial conditioning frames because as we track the video, more conditioning frames might be added + # when a frame receives correction clicks under point input if `add_all_frames_to_correct_as_cond=True` + num_init_cond_frames_for_train=1, # default: only use the first frame as initial conditioning frame + num_init_cond_frames_for_eval=1, # default: only use the first frame as initial conditioning frame + rand_init_cond_frames_for_train=True, # default: random 1~num_init_cond_frames_for_train cond frames (to be constent w/ previous TA data loader) + rand_init_cond_frames_for_eval=False, + # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click + # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames + add_all_frames_to_correct_as_cond=False, + # how many additional correction points to sample (on each frame selected to be corrected) + # note that the first frame receives an initial input click (in addition to any correction clicks) + num_correction_pt_per_frame=7, + # method for point sampling during evaluation + # "uniform" (sample uniformly from error region) or "center" (use the point with the largest distance to error region boundary) + # default to "center" to be consistent with evaluation in the SAM paper + pt_sampling_for_eval="center", + # During training, we optionally allow sampling the correction points from GT regions + # instead of the prediction error regions with a small probability. This might allow the + # model to overfit less to the error regions in training datasets + prob_to_sample_from_gt_for_train=0.0, + use_act_ckpt_iterative_pt_sampling=False, + # whether to forward image features per frame (as it's being tracked) during evaluation, instead of forwarding image features + # of all frames at once. This avoids backbone OOM errors on very long videos in evaluation, but could be slightly slower. + forward_backbone_per_frame_for_eval=False, + freeze_image_encoder=False, + **kwargs, + ): + super().__init__(image_encoder, memory_attention, memory_encoder, **kwargs) + self.use_act_ckpt_iterative_pt_sampling = use_act_ckpt_iterative_pt_sampling + self.forward_backbone_per_frame_for_eval = forward_backbone_per_frame_for_eval + + # Point sampler and conditioning frames + self.prob_to_use_pt_input_for_train = prob_to_use_pt_input_for_train + self.prob_to_use_box_input_for_train = prob_to_use_box_input_for_train + self.prob_to_use_pt_input_for_eval = prob_to_use_pt_input_for_eval + self.prob_to_use_box_input_for_eval = prob_to_use_box_input_for_eval + if prob_to_use_pt_input_for_train > 0 or prob_to_use_pt_input_for_eval > 0: + logging.info( + f"Training with points (sampled from masks) as inputs with p={prob_to_use_pt_input_for_train}" + ) + assert num_frames_to_correct_for_train >= num_init_cond_frames_for_train + assert num_frames_to_correct_for_eval >= num_init_cond_frames_for_eval + + self.num_frames_to_correct_for_train = num_frames_to_correct_for_train + self.num_frames_to_correct_for_eval = num_frames_to_correct_for_eval + self.rand_frames_to_correct_for_train = rand_frames_to_correct_for_train + self.rand_frames_to_correct_for_eval = rand_frames_to_correct_for_eval + # Initial multi-conditioning frames + self.num_init_cond_frames_for_train = num_init_cond_frames_for_train + self.num_init_cond_frames_for_eval = num_init_cond_frames_for_eval + self.rand_init_cond_frames_for_train = rand_init_cond_frames_for_train + self.rand_init_cond_frames_for_eval = rand_init_cond_frames_for_eval + self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond + self.num_correction_pt_per_frame = num_correction_pt_per_frame + self.pt_sampling_for_eval = pt_sampling_for_eval + self.prob_to_sample_from_gt_for_train = prob_to_sample_from_gt_for_train + # A random number generator with a fixed initial seed across GPUs + self.rng = np.random.default_rng(seed=42) + if freeze_image_encoder: + for p in self.image_encoder.parameters(): + p.requires_grad = False + + + def forward(self, input: BatchedVideoDatapoint): + if self.training or not self.forward_backbone_per_frame_for_eval: + # precompute image features on all frames before tracking + backbone_out = self.forward_image(input.flat_img_batch) + else: + # defer image feature computation on a frame until it's being tracked + backbone_out = {"backbone_fpn": None, "vision_pos_enc": None} + backbone_out = self.prepare_prompt_inputs(backbone_out, input) + previous_stages_out = self.forward_tracking(backbone_out, input) + + return previous_stages_out + + def _prepare_backbone_features_per_frame(self, img_batch, img_ids): + """Compute the image backbone features on the fly for the given img_ids.""" + # Only forward backbone on unique image ids to avoid repetitive computation + # (if `img_ids` has only one element, it's already unique so we skip this step). + if img_ids.numel() > 1: + unique_img_ids, inv_ids = torch.unique(img_ids, return_inverse=True) + else: + unique_img_ids, inv_ids = img_ids, None + + # Compute the image features on those unique image ids + image = img_batch[unique_img_ids] + backbone_out = self.forward_image(image) + ( + _, + vision_feats, + vision_pos_embeds, + feat_sizes, + ) = self._prepare_backbone_features(backbone_out) + ''' + vision_feats + torch.Size([65536, 5, 32]) + torch.Size([16384, 5, 64]) + torch.Size([4096, 5, 256]) + ''' + # Inverse-map image features for `unique_img_ids` to the final image features + # for the original input `img_ids`. + if inv_ids is not None: + image = image[inv_ids] + vision_feats = [x[:, inv_ids] for x in vision_feats] + vision_pos_embeds = [x[:, inv_ids] for x in vision_pos_embeds] + + return image, vision_feats, vision_pos_embeds, feat_sizes + + @staticmethod + def dont_prepare_prompt_inputs(backbone_out, num_frames=5, condition_frame=0): + backbone_out["gt_masks_per_frame"] = {} + backbone_out["num_frames"] = num_frames + backbone_out["use_pt_input"] = False + # always start from the first frame. + backbone_out["init_cond_frames"] = [condition_frame] + backbone_out["frames_not_in_init_cond"] = [i for i in range(0, num_frames) if i != condition_frame] + # backbone_out["init_cond_frames"] = [] + # backbone_out["frames_not_in_init_cond"] = [i for i in range(0, num_frames)] + + backbone_out["mask_inputs_per_frame"] = {} + backbone_out["point_inputs_per_frame"] = {} + backbone_out["frames_to_add_correction_pt"] = [] + return backbone_out + + def prepare_prompt_inputs(self, backbone_out, input, start_frame_idx=0): + """ + Prepare input mask, point or box prompts. Optionally, we allow tracking from + a custom `start_frame_idx` to the end of the video (for evaluation purposes). + """ + # Load the ground-truth masks on all frames (so that we can later + # sample correction points from them) + # gt_masks_per_frame = { + # stage_id: targets.segments.unsqueeze(1) # [B, 1, H_im, W_im] + # for stage_id, targets in enumerate(input.find_targets) + # } + gt_masks_per_frame = { + stage_id: masks.unsqueeze(1) # [B, 1, H_im, W_im] + for stage_id, masks in enumerate(input.masks) + } + # gt_masks_per_frame = input.masks.unsqueeze(2) # [T,B,1,H_im,W_im] keep everything in tensor form + backbone_out["gt_masks_per_frame"] = gt_masks_per_frame + num_frames = input.num_frames + backbone_out["num_frames"] = num_frames + + # Randomly decide whether to use point inputs or mask inputs + if self.training: + prob_to_use_pt_input = self.prob_to_use_pt_input_for_train + prob_to_use_box_input = self.prob_to_use_box_input_for_train + num_frames_to_correct = self.num_frames_to_correct_for_train + rand_frames_to_correct = self.rand_frames_to_correct_for_train + num_init_cond_frames = self.num_init_cond_frames_for_train + rand_init_cond_frames = self.rand_init_cond_frames_for_train + else: + prob_to_use_pt_input = self.prob_to_use_pt_input_for_eval + prob_to_use_box_input = self.prob_to_use_box_input_for_eval + num_frames_to_correct = self.num_frames_to_correct_for_eval + rand_frames_to_correct = self.rand_frames_to_correct_for_eval + num_init_cond_frames = self.num_init_cond_frames_for_eval + rand_init_cond_frames = self.rand_init_cond_frames_for_eval + if num_frames == 1: + # here we handle a special case for mixing video + SAM on image training, + # where we force using point input for the SAM task on static images + prob_to_use_pt_input = 1.0 + num_frames_to_correct = 1 + num_init_cond_frames = 1 + assert num_init_cond_frames >= 1 + # (here `self.rng.random()` returns value in range 0.0 <= X < 1.0) + use_pt_input = self.rng.random() < prob_to_use_pt_input + if rand_init_cond_frames and num_init_cond_frames > 1: + # randomly select 1 to `num_init_cond_frames` frames as initial conditioning frames + num_init_cond_frames = self.rng.integers( + 1, num_init_cond_frames, endpoint=True + ) + if ( + use_pt_input + and rand_frames_to_correct + and num_frames_to_correct > num_init_cond_frames + ): + # randomly select `num_init_cond_frames` to `num_frames_to_correct` frames to sample + # correction clicks (only for the case of point input) + num_frames_to_correct = self.rng.integers( + num_init_cond_frames, num_frames_to_correct, endpoint=True + ) + backbone_out["use_pt_input"] = use_pt_input + + # Sample initial conditioning frames + if num_init_cond_frames == 1: + init_cond_frames = [start_frame_idx] # starting frame + else: + # starting frame + randomly selected remaining frames (without replacement) + init_cond_frames = [start_frame_idx] + self.rng.choice( + range(start_frame_idx + 1, num_frames), + num_init_cond_frames - 1, + replace=False, + ).tolist() + backbone_out["init_cond_frames"] = init_cond_frames + backbone_out["frames_not_in_init_cond"] = [ + t for t in range(start_frame_idx, num_frames) if t not in init_cond_frames + ] + # Prepare mask or point inputs on initial conditioning frames + backbone_out["mask_inputs_per_frame"] = {} # {frame_idx: } + backbone_out["point_inputs_per_frame"] = {} # {frame_idx: } + for t in init_cond_frames: + if not use_pt_input: + backbone_out["mask_inputs_per_frame"][t] = gt_masks_per_frame[t] + else: + # During training # P(box) = prob_to_use_pt_input * prob_to_use_box_input + use_box_input = self.rng.random() < prob_to_use_box_input + if use_box_input: + points, labels = sample_box_points( + gt_masks_per_frame[t], + ) + else: + # (here we only sample **one initial point** on initial conditioning frames from the + # ground-truth mask; we may sample more correction points on the fly) + points, labels = get_next_point( + gt_masks=gt_masks_per_frame[t], + pred_masks=None, + method=( + "uniform" if self.training else self.pt_sampling_for_eval + ), + ) + + point_inputs = {"point_coords": points, "point_labels": labels} + backbone_out["point_inputs_per_frame"][t] = point_inputs + + # Sample frames where we will add correction clicks on the fly + # based on the error between prediction and ground-truth masks + if not use_pt_input: + # no correction points will be sampled when using mask inputs + frames_to_add_correction_pt = [] + elif num_frames_to_correct == num_init_cond_frames: + frames_to_add_correction_pt = init_cond_frames + else: + assert num_frames_to_correct > num_init_cond_frames + # initial cond frame + randomly selected remaining frames (without replacement) + extra_num = num_frames_to_correct - num_init_cond_frames + frames_to_add_correction_pt = ( + init_cond_frames + + self.rng.choice( + backbone_out["frames_not_in_init_cond"], extra_num, replace=False + ).tolist() + ) + backbone_out["frames_to_add_correction_pt"] = frames_to_add_correction_pt + + return backbone_out + + def forward_tracking_wo_prompt(self, backbone_out, audio_res=None, return_dict=False): + # img_feats_already_computed = True. + """Forward video tracking on each frame (and sample correction clicks).""" + # Prepare the backbone features + # - vision_feats and vision_pos_embeds are in (HW)BC format + ( + _, + vision_feats, + vision_pos_embeds, + feat_sizes, + ) = self._prepare_backbone_features(backbone_out) + + # Starting the stage loop + num_frames = backbone_out["num_frames"] + init_cond_frames = backbone_out["init_cond_frames"] + frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"] + # first process all the initial conditioning frames to encode them as memory, + # and then conditioning on them to track the remaining frames + processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"] + output_dict = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + + av_v_feats, av_a_feats = audio_res + for stage_id in processing_order: + # Get the image features for the current frames + img_ids = stage_id + # Retrieve image features according to img_ids (if they are already computed). + current_vision_feats = [x[:, img_ids].unsqueeze(1) for x in vision_feats] # add unsqueeze to maintain single sample. + current_vision_pos_embeds = [x[:, img_ids].unsqueeze(1) for x in vision_pos_embeds] # add unsqueeze to maintain single sample. + current_av_v_feats = [x[img_ids] for x in av_v_feats] + current_av_a_feats = [x[img_ids] for x in av_a_feats] + + # Get output masks based on this frame's prompts and previous memory + current_out = self.track_step_wo_prompt( + frame_idx=stage_id, + is_init_cond_frame=stage_id in init_cond_frames, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=None, # backbone_out["point_inputs_per_frame"].get(stage_id, None), + mask_inputs=None, # backbone_out["mask_inputs_per_frame"].get(stage_id, None), + gt_masks=None, # backbone_out["gt_masks_per_frame"].get(stage_id, None), + frames_to_add_correction_pt=None, # frames_to_add_correction_pt, + output_dict=output_dict, + num_frames=num_frames, + audio_res=(current_av_v_feats, current_av_a_feats), + ) + # Append the output, depending on whether it's a conditioning frame + add_output_as_cond_frame = stage_id in init_cond_frames or ( + self.add_all_frames_to_correct_as_cond + and stage_id in frames_to_add_correction_pt + ) + if add_output_as_cond_frame: + output_dict["cond_frame_outputs"][stage_id] = current_out + else: + output_dict["non_cond_frame_outputs"][stage_id] = current_out + + if return_dict: + return output_dict + # turn `output_dict` into a list for loss function + all_frame_outputs = {} + all_frame_outputs.update(output_dict["cond_frame_outputs"]) + all_frame_outputs.update(output_dict["non_cond_frame_outputs"]) + all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)] + # Make DDP happy with activation checkpointing by removing unused keys + all_frame_outputs = [ + {k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs + ] + + + return all_frame_outputs + + def track_step_wo_prompt( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + run_mem_encoder=True, # Whether to run the memory encoder on the predicted masks. + prev_sam_mask_logits=None, # The previously predicted SAM mask logits. + frames_to_add_correction_pt=None, + gt_masks=None, + audio_res=None, + ): + if frames_to_add_correction_pt is None: + frames_to_add_correction_pt = [] + + current_out, sam_outputs, high_res_features, pix_feat = self._track_step_wo_prompt( + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + audio_res + ) + + ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = sam_outputs + current_out["multistep_pred_masks"] = low_res_masks + current_out["multistep_pred_masks_high_res"] = high_res_masks + current_out["multistep_pred_multimasks"] = [low_res_multimasks] + current_out["multistep_pred_multimasks_high_res"] = [high_res_multimasks] + current_out["multistep_pred_ious"] = [ious] + current_out["multistep_point_inputs"] = [point_inputs] + current_out["multistep_object_score_logits"] = [object_score_logits] + + ''' + # Optionally, sample correction points iteratively to correct the mask + if frame_idx in frames_to_add_correction_pt: + point_inputs, final_sam_outputs = self._iter_correct_pt_sampling( + is_init_cond_frame, + point_inputs, + gt_masks, + high_res_features, + pix_feat, + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + object_score_logits, + current_out, + ) + ( + _, + _, + _, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = final_sam_outputs + ''' + # Use the final prediction (after all correction steps for output and eval) + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + + self._encode_memory_in_output( + current_vision_feats, + feat_sizes, + 666., # point_inputs, + run_mem_encoder, + # we follow SAM2 predictor, if we have multiple masks output, we only utilise the first one to perform + # the memory rope attention. + high_res_masks, #[:, 0:1, ...], + # high_res_masks if high_res_masks.shape[1] <= 1 else high_res_masks[:, 0:1, ...], + object_score_logits, + current_out, + ) + return current_out + + def _track_step_wo_prompt( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + audio_res=None + ): + current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(current_vision_feats) > 1: + high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) + ] + else: + high_res_features = None + if mask_inputs is not None and self.use_mask_input_as_output_without_sam: # False + # When use_mask_input_as_output_without_sam=True, we directly output the mask input + # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. + pix_feat = current_vision_feats[-1].permute(1, 2, 0) + pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + sam_outputs = self._use_mask_as_output( + pix_feat, high_res_features, mask_inputs + ) + else: + # fused the visual feature with previous memory features in the memory bank + pix_feat = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats[-1:], + current_vision_pos_embeds=current_vision_pos_embeds[-1:], + feat_sizes=feat_sizes[-1:], + output_dict=output_dict, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + ) + # current_vision_feats[-1] = current_vision_feats[-1] + self.no_mem_embed + # pix_feat = current_vision_feats[-1].permute(1, 2, 0) + # pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + + # we do not apply any prompts except audio. + ''' + # apply SAM-style segmentation head + # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, + # e.g. in demo where such logits come from earlier interaction instead of correction sampling + # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) + # if prev_sam_mask_logits is not None: + # assert point_inputs is not None and mask_inputs is None + # mask_inputs = prev_sam_mask_logits + + ## comment this line, as we don't use points as prompts. + # multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + ''' + + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=True, + audio_res=audio_res + ) + + return current_out, sam_outputs, high_res_features, pix_feat + + def forward_tracking( + self, backbone_out, input: BatchedVideoDatapoint, return_dict=False + ): + """Forward video tracking on each frame (and sample correction clicks).""" + img_feats_already_computed = backbone_out["backbone_fpn"] is not None + if img_feats_already_computed: + # Prepare the backbone features + # - vision_feats and vision_pos_embeds are in (HW)BC format + ( + _, + vision_feats, + vision_pos_embeds, + feat_sizes, + ) = self._prepare_backbone_features(backbone_out) + + # Starting the stage loop + num_frames = backbone_out["num_frames"] + init_cond_frames = backbone_out["init_cond_frames"] + frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"] + # first process all the initial conditioning frames to encode them as memory, + # and then conditioning on them to track the remaining frames + processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"] + output_dict = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + for stage_id in processing_order: + # Get the image features for the current frames + # img_ids = input.find_inputs[stage_id].img_ids + img_ids = input.flat_obj_to_img_idx[stage_id] + if img_feats_already_computed: + # Retrieve image features according to img_ids (if they are already computed). + current_vision_feats = [x[:, img_ids] for x in vision_feats] + current_vision_pos_embeds = [x[:, img_ids] for x in vision_pos_embeds] + else: + # Otherwise, compute the image features on the fly for the given img_ids + # (this might be used for evaluation on long videos to avoid backbone OOM). + ( + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._prepare_backbone_features_per_frame( + input.flat_img_batch, img_ids + ) + + # Get output masks based on this frame's prompts and previous memory + current_out = self.track_step( + frame_idx=stage_id, + is_init_cond_frame=stage_id in init_cond_frames, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=backbone_out["point_inputs_per_frame"].get(stage_id, None), + mask_inputs=backbone_out["mask_inputs_per_frame"].get(stage_id, None), + gt_masks=backbone_out["gt_masks_per_frame"].get(stage_id, None), + frames_to_add_correction_pt=frames_to_add_correction_pt, + output_dict=output_dict, + num_frames=num_frames, + ) + # Append the output, depending on whether it's a conditioning frame + add_output_as_cond_frame = stage_id in init_cond_frames or ( + self.add_all_frames_to_correct_as_cond + and stage_id in frames_to_add_correction_pt + ) + if add_output_as_cond_frame: + output_dict["cond_frame_outputs"][stage_id] = current_out + else: + output_dict["non_cond_frame_outputs"][stage_id] = current_out + + if return_dict: + return output_dict + # turn `output_dict` into a list for loss function + all_frame_outputs = {} + all_frame_outputs.update(output_dict["cond_frame_outputs"]) + all_frame_outputs.update(output_dict["non_cond_frame_outputs"]) + all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)] + # Make DDP happy with activation checkpointing by removing unused keys + all_frame_outputs = [ + {k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs + ] + + return all_frame_outputs + + def track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + run_mem_encoder=True, # Whether to run the memory encoder on the predicted masks. + prev_sam_mask_logits=None, # The previously predicted SAM mask logits. + frames_to_add_correction_pt=None, + gt_masks=None, + ): + if frames_to_add_correction_pt is None: + frames_to_add_correction_pt = [] + current_out, sam_outputs, high_res_features, pix_feat = self._track_step( + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ) + + ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = sam_outputs + + current_out["multistep_pred_masks"] = low_res_masks + current_out["multistep_pred_masks_high_res"] = high_res_masks + current_out["multistep_pred_multimasks"] = [low_res_multimasks] + current_out["multistep_pred_multimasks_high_res"] = [high_res_multimasks] + current_out["multistep_pred_ious"] = [ious] + current_out["multistep_point_inputs"] = [point_inputs] + current_out["multistep_object_score_logits"] = [object_score_logits] + + # Optionally, sample correction points iteratively to correct the mask + if frame_idx in frames_to_add_correction_pt: + point_inputs, final_sam_outputs = self._iter_correct_pt_sampling( + is_init_cond_frame, + point_inputs, + gt_masks, + high_res_features, + pix_feat, + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + object_score_logits, + current_out, + ) + ( + _, + _, + _, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = final_sam_outputs + + # Use the final prediction (after all correction steps for output and eval) + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + self._encode_memory_in_output( + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ) + return current_out + + def _iter_correct_pt_sampling( + self, + is_init_cond_frame, + point_inputs, + gt_masks, + high_res_features, + pix_feat_with_mem, + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + object_score_logits, + current_out, + ): + + assert gt_masks is not None + all_pred_masks = [low_res_masks] + all_pred_high_res_masks = [high_res_masks] + all_pred_multimasks = [low_res_multimasks] + all_pred_high_res_multimasks = [high_res_multimasks] + all_pred_ious = [ious] + all_point_inputs = [point_inputs] + all_object_score_logits = [object_score_logits] + for _ in range(self.num_correction_pt_per_frame): + # sample a new point from the error between prediction and ground-truth + # (with a small probability, directly sample from GT masks instead of errors) + if self.training and self.prob_to_sample_from_gt_for_train > 0: + sample_from_gt = ( + self.rng.random() < self.prob_to_sample_from_gt_for_train + ) + else: + sample_from_gt = False + # if `pred_for_new_pt` is None, only GT masks will be used for point sampling + pred_for_new_pt = None if sample_from_gt else (high_res_masks > 0) + new_points, new_labels = get_next_point( + gt_masks=gt_masks, + pred_masks=pred_for_new_pt, + method="uniform" if self.training else self.pt_sampling_for_eval, + ) + point_inputs = concat_points(point_inputs, new_points, new_labels) + # Feed the mask logits of the previous SAM outputs in the next SAM decoder step. + # For tracking, this means that when the user adds a correction click, we also feed + # the tracking output mask logits along with the click as input to the SAM decoder. + mask_inputs = low_res_masks + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + if self.use_act_ckpt_iterative_pt_sampling and not multimask_output: + sam_outputs = torch.utils.checkpoint.checkpoint( + self._forward_sam_heads, + backbone_features=pix_feat_with_mem, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + use_reentrant=False, + ) + else: + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat_with_mem, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + _, + object_score_logits, + ) = sam_outputs + all_pred_masks.append(low_res_masks) + all_pred_high_res_masks.append(high_res_masks) + all_pred_multimasks.append(low_res_multimasks) + all_pred_high_res_multimasks.append(high_res_multimasks) + all_pred_ious.append(ious) + all_point_inputs.append(point_inputs) + all_object_score_logits.append(object_score_logits) + + # Concatenate the masks along channel (to compute losses on all of them, + # using `MultiStepIteractiveMasks`) + current_out["multistep_pred_masks"] = torch.cat(all_pred_masks, dim=1) + current_out["multistep_pred_masks_high_res"] = torch.cat( + all_pred_high_res_masks, dim=1 + ) + current_out["multistep_pred_multimasks"] = all_pred_multimasks + current_out["multistep_pred_multimasks_high_res"] = all_pred_high_res_multimasks + current_out["multistep_pred_ious"] = all_pred_ious + current_out["multistep_point_inputs"] = all_point_inputs + current_out["multistep_object_score_logits"] = all_object_score_logits + + return point_inputs, sam_outputs diff --git a/ref-avs.code/model/visual/sam2/sam2/sam2_hiera_l.yaml b/ref-avs.code/model/visual/sam2/sam2/sam2_hiera_l.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7256f72aa72da25f04c7d25d3ab884f6a29b7cf7 --- /dev/null +++ b/ref-avs.code/model/visual/sam2/sam2/sam2_hiera_l.yaml @@ -0,0 +1,3 @@ +# @package _global_ +defaults: + - /configs/sam2/sam2_hiera_l diff --git a/ref-avs.code/model/visual/sam2/sam2_image_predictor.py b/ref-avs.code/model/visual/sam2/sam2_image_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..41ce53af5924504c07216df52b2d2eefaeec7ae9 --- /dev/null +++ b/ref-avs.code/model/visual/sam2/sam2_image_predictor.py @@ -0,0 +1,466 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from PIL.Image import Image + +from sam2.modeling.sam2_base import SAM2Base + +from sam2.utils.transforms import SAM2Transforms + + +class SAM2ImagePredictor: + def __init__( + self, + sam_model: SAM2Base, + mask_threshold=0.0, + max_hole_area=0.0, + max_sprinkle_area=0.0, + **kwargs, + ) -> None: + """ + Uses SAM-2 to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam-2): The model to use for mask prediction. + mask_threshold (float): The threshold to use when converting mask logits + to binary masks. Masks are thresholded at 0 by default. + max_hole_area (int): If max_hole_area > 0, we fill small holes in up to + the maximum area of max_hole_area in low_res_masks. + max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to + the maximum area of max_sprinkle_area in low_res_masks. + """ + super().__init__() + self.model = sam_model + self._transforms = SAM2Transforms( + resolution=self.model.image_size, + mask_threshold=mask_threshold, + max_hole_area=max_hole_area, + max_sprinkle_area=max_sprinkle_area, + ) + + # Predictor state + self._is_image_set = False + self._features = None + self._orig_hw = None + # Whether the predictor is set for single image or a batch of images + self._is_batch = False + + # Predictor config + self.mask_threshold = mask_threshold + + # Spatial dim for backbone feature maps + self._bb_feat_sizes = [ + (256, 256), + (128, 128), + (64, 64), + ] + + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor": + """ + Load a pretrained model from the Hugging Face hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2ImagePredictor): The loaded model. + """ + from sam2.build_sam import build_sam2_hf + + sam_model = build_sam2_hf(model_id, **kwargs) + return cls(sam_model, **kwargs) + + @torch.no_grad() + def set_image( + self, + image: Union[np.ndarray, Image], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image + with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + self.reset_predictor() + # Transform the image to the form expected by the model + if isinstance(image, np.ndarray): + logging.info("For numpy array image, we assume (HxWxC) format") + self._orig_hw = [image.shape[:2]] + elif isinstance(image, Image): + w, h = image.size + self._orig_hw = [(h, w)] + else: + raise NotImplementedError("Image format not supported") + + input_image = self._transforms(image) + input_image = input_image[None, ...].to(self.device) + + assert ( + len(input_image.shape) == 4 and input_image.shape[1] == 3 + ), f"input_image must be of size 1x3xHxW, got {input_image.shape}" + logging.info("Computing image embeddings for the provided image...") + backbone_out = self.model.forward_image(input_image) + _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) + # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos + if self.model.directly_add_no_mem_embed: + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + + feats = [ + feat.permute(1, 2, 0).view(1, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + self._is_image_set = True + logging.info("Image embeddings computed.") + + @torch.no_grad() + def set_image_batch( + self, + image_list: List[Union[np.ndarray]], + ) -> None: + """ + Calculates the image embeddings for the provided image batch, allowing + masks to be predicted with the 'predict_batch' method. + + Arguments: + image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray + with pixel values in [0, 255]. + """ + self.reset_predictor() + assert isinstance(image_list, list) + self._orig_hw = [] + for image in image_list: + assert isinstance( + image, np.ndarray + ), "Images are expected to be an np.ndarray in RGB format, and of shape HWC" + self._orig_hw.append(image.shape[:2]) + # Transform the image to the form expected by the model + img_batch = self._transforms.forward_batch(image_list) + img_batch = img_batch.to(self.device) + batch_size = img_batch.shape[0] + assert ( + len(img_batch.shape) == 4 and img_batch.shape[1] == 3 + ), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}" + logging.info("Computing image embeddings for the provided images...") + backbone_out = self.model.forward_image(img_batch) + _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) + # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos + if self.model.directly_add_no_mem_embed: + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + + feats = [ + feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + self._is_image_set = True + self._is_batch = True + logging.info("Image embeddings computed.") + + def predict_batch( + self, + point_coords_batch: List[np.ndarray] = None, + point_labels_batch: List[np.ndarray] = None, + box_batch: List[np.ndarray] = None, + mask_input_batch: List[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + normalize_coords=True, + ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: + """This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images. + It returns a tuple of lists of masks, ious, and low_res_masks_logits. + """ + assert self._is_batch, "This function should only be used when in batched mode" + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image_batch(...) before mask prediction." + ) + num_images = len(self._features["image_embed"]) + all_masks = [] + all_ious = [] + all_low_res_masks = [] + for img_idx in range(num_images): + # Transform input prompts + point_coords = ( + point_coords_batch[img_idx] if point_coords_batch is not None else None + ) + point_labels = ( + point_labels_batch[img_idx] if point_labels_batch is not None else None + ) + box = box_batch[img_idx] if box_batch is not None else None + mask_input = ( + mask_input_batch[img_idx] if mask_input_batch is not None else None + ) + mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts( + point_coords, + point_labels, + box, + mask_input, + normalize_coords, + img_idx=img_idx, + ) + masks, iou_predictions, low_res_masks = self._predict( + unnorm_coords, + labels, + unnorm_box, + mask_input, + multimask_output, + return_logits=return_logits, + img_idx=img_idx, + ) + masks_np = masks.squeeze(0).float().detach().cpu().numpy() + iou_predictions_np = ( + iou_predictions.squeeze(0).float().detach().cpu().numpy() + ) + low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy() + all_masks.append(masks_np) + all_ious.append(iou_predictions_np) + all_low_res_masks.append(low_res_masks_np) + + return all_masks, all_ious, all_low_res_masks + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + normalize_coords=True, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + # Transform input prompts + + mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts( + point_coords, point_labels, box, mask_input, normalize_coords + ) + + masks, iou_predictions, low_res_masks = self._predict( + unnorm_coords, + labels, + unnorm_box, + mask_input, + multimask_output, + return_logits=return_logits, + ) + + masks_np = masks.squeeze(0).float().detach().cpu().numpy() + iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy() + low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy() + return masks_np, iou_predictions_np, low_res_masks_np + + def _prep_prompts( + self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1 + ): + + unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = torch.as_tensor( + point_coords, dtype=torch.float, device=self.device + ) + unnorm_coords = self._transforms.transform_coords( + point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx] + ) + labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + if len(unnorm_coords.shape) == 2: + unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...] + if box is not None: + box = torch.as_tensor(box, dtype=torch.float, device=self.device) + unnorm_box = self._transforms.transform_boxes( + box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx] + ) # Bx2x2 + if mask_logits is not None: + mask_input = torch.as_tensor( + mask_logits, dtype=torch.float, device=self.device + ) + if len(mask_input.shape) == 3: + mask_input = mask_input[None, :, :, :] + return mask_input, unnorm_coords, labels, unnorm_box + + @torch.no_grad() + def _predict( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + img_idx: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using SAM2Transforms. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + boxes (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + if point_coords is not None: + concat_points = (point_coords, point_labels) + else: + concat_points = None + + # Embed prompts + if boxes is not None: + box_coords = boxes.reshape(-1, 2, 2) + box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device) + box_labels = box_labels.repeat(boxes.size(0), 1) + # we merge "boxes" and "points" into a single "concat_points" input (where + # boxes are added at the beginning) to sam_prompt_encoder + if concat_points is not None: + concat_coords = torch.cat([box_coords, concat_points[0]], dim=1) + concat_labels = torch.cat([box_labels, concat_points[1]], dim=1) + concat_points = (concat_coords, concat_labels) + else: + concat_points = (box_coords, box_labels) + + sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( + points=concat_points, + boxes=None, + masks=mask_input, + ) + + # Predict masks + batched_mode = ( + concat_points is not None and concat_points[0].shape[0] > 1 + ) # multi object prediction + high_res_features = [ + feat_level[img_idx].unsqueeze(0) + for feat_level in self._features["high_res_feats"] + ] + low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder( + image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0), + image_pe=self.model.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=batched_mode, + high_res_features=high_res_features, + ) + + # Upscale the masks to the original image resolution + masks = self._transforms.postprocess_masks( + low_res_masks, self._orig_hw[img_idx] + ) + low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0) + if not return_logits: + masks = masks > self.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert ( + self._features is not None + ), "Features must exist if an image has been set." + return self._features["image_embed"] + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_predictor(self) -> None: + """ + Resets the image embeddings and other state variables. + """ + self._is_image_set = False + self._features = None + self._orig_hw = None + self._is_batch = False diff --git a/ref-avs.code/model/visual/sam2/sam2_video_predictor.py b/ref-avs.code/model/visual/sam2/sam2_video_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..c7e01ccf972491904b013526333826b337354db1 --- /dev/null +++ b/ref-avs.code/model/visual/sam2/sam2_video_predictor.py @@ -0,0 +1,1172 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import warnings +from collections import OrderedDict + +import torch + +from tqdm import tqdm + +from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base +from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames + + +class SAM2VideoPredictor(SAM2Base): + """The predictor class to handle user interactions and manage inference states.""" + + def __init__( + self, + fill_hole_area=0, + # whether to apply non-overlapping constraints on the output object masks + non_overlap_masks=False, + # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks; + # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True) + clear_non_cond_mem_around_input=False, + # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True). + clear_non_cond_mem_for_multi_obj=False, + # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click + # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames + add_all_frames_to_correct_as_cond=False, + **kwargs, + ): + super().__init__(**kwargs) + self.fill_hole_area = fill_hole_area + self.non_overlap_masks = non_overlap_masks + self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input + self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj + self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond + + @torch.inference_mode() + def init_state( + self, + video_path, + offload_video_to_cpu=False, + offload_state_to_cpu=False, + async_loading_frames=False, + ): + """Initialize an inference state.""" + compute_device = self.device # device of the model + images, video_height, video_width = load_video_frames( + video_path=video_path, + image_size=self.image_size, + offload_video_to_cpu=offload_video_to_cpu, + async_loading_frames=async_loading_frames, + compute_device=compute_device, + ) + inference_state = {} + inference_state["images"] = images + inference_state["num_frames"] = len(images) + # whether to offload the video frames to CPU memory + # turning on this option saves the GPU memory with only a very small overhead + inference_state["offload_video_to_cpu"] = offload_video_to_cpu + # whether to offload the inference state to CPU memory + # turning on this option saves the GPU memory at the cost of a lower tracking fps + # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object + # and from 24 to 21 when tracking two objects) + inference_state["offload_state_to_cpu"] = offload_state_to_cpu + # the original video height and width, used for resizing final output scores + inference_state["video_height"] = video_height + inference_state["video_width"] = video_width + inference_state["device"] = compute_device + if offload_state_to_cpu: + inference_state["storage_device"] = torch.device("cpu") + else: + inference_state["storage_device"] = compute_device + # inputs on each frame + inference_state["point_inputs_per_obj"] = {} + inference_state["mask_inputs_per_obj"] = {} + # visual features on a small number of recently visited frames for quick interactions + inference_state["cached_features"] = {} + # values that don't change across frames (so we only need to hold one copy of them) + inference_state["constants"] = {} + # mapping between client-side object id and model-side object index + inference_state["obj_id_to_idx"] = OrderedDict() + inference_state["obj_idx_to_id"] = OrderedDict() + inference_state["obj_ids"] = [] + # A storage to hold the model's tracking results and states on each frame + inference_state["output_dict"] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + inference_state["output_dict_per_obj"] = {} + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + inference_state["temp_output_dict_per_obj"] = {} + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + inference_state["consolidated_frame_inds"] = { + "cond_frame_outputs": set(), # set containing frame indices + "non_cond_frame_outputs": set(), # set containing frame indices + } + # metadata for each tracking frame (e.g. which direction it's tracked) + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"] = {} + # Warm up the visual backbone and cache the image feature on frame 0 + self._get_image_feature(inference_state, frame_idx=0, batch_size=1) + return inference_state + + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor": + """ + Load a pretrained model from the Hugging Face hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2VideoPredictor): The loaded model. + """ + from sam2.build_sam import build_sam2_video_predictor_hf + + sam_model = build_sam2_video_predictor_hf(model_id, **kwargs) + return sam_model + + def _obj_id_to_idx(self, inference_state, obj_id): + """Map client-side object id to model-side object index.""" + obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None) + if obj_idx is not None: + return obj_idx + + # This is a new object id not sent to the server before. We only allow adding + # new objects *before* the tracking starts. + allow_new_object = not inference_state["tracking_has_started"] + if allow_new_object: + # get the next object slot + obj_idx = len(inference_state["obj_id_to_idx"]) + inference_state["obj_id_to_idx"][obj_id] = obj_idx + inference_state["obj_idx_to_id"][obj_idx] = obj_id + inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"]) + # set up input and output structures for this object + inference_state["point_inputs_per_obj"][obj_idx] = {} + inference_state["mask_inputs_per_obj"][obj_idx] = {} + inference_state["output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + inference_state["temp_output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + return obj_idx + else: + raise RuntimeError( + f"Cannot add new object id {obj_id} after tracking starts. " + f"All existing object ids: {inference_state['obj_ids']}. " + f"Please call 'reset_state' to restart from scratch." + ) + + def _obj_idx_to_id(self, inference_state, obj_idx): + """Map model-side object index to client-side object id.""" + return inference_state["obj_idx_to_id"][obj_idx] + + def _get_obj_num(self, inference_state): + """Get the total number of unique object ids received so far in this session.""" + return len(inference_state["obj_idx_to_id"]) + + @torch.inference_mode() + def add_new_points_or_box( + self, + inference_state, + frame_idx, + obj_id, + points=None, + labels=None, + clear_old_points=True, + normalize_coords=True, + box=None, + ): + """Add new points to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if (points is not None) != (labels is not None): + raise ValueError("points and labels must be provided together") + if points is None and box is None: + raise ValueError("at least one of points or box must be provided as input") + + if points is None: + points = torch.zeros(0, 2, dtype=torch.float32) + elif not isinstance(points, torch.Tensor): + points = torch.tensor(points, dtype=torch.float32) + if labels is None: + labels = torch.zeros(0, dtype=torch.int32) + elif not isinstance(labels, torch.Tensor): + labels = torch.tensor(labels, dtype=torch.int32) + if points.dim() == 2: + points = points.unsqueeze(0) # add batch dimension + if labels.dim() == 1: + labels = labels.unsqueeze(0) # add batch dimension + + # If `box` is provided, we add it as the first two points with labels 2 and 3 + # along with the user-provided points (consistent with how SAM 2 is trained). + if box is not None: + if not clear_old_points: + raise ValueError( + "cannot add box without clearing old points, since " + "box prompt must be provided before any point prompt " + "(please use clear_old_points=True instead)" + ) + if inference_state["tracking_has_started"]: + warnings.warn( + "You are adding a box after tracking starts. SAM 2 may not always be " + "able to incorporate a box prompt for *refinement*. If you intend to " + "use box prompt as an *initial* input before tracking, please call " + "'reset_state' on the inference state to restart from scratch.", + category=UserWarning, + stacklevel=2, + ) + if not isinstance(box, torch.Tensor): + box = torch.tensor(box, dtype=torch.float32, device=points.device) + box_coords = box.reshape(1, 2, 2) + box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device) + box_labels = box_labels.reshape(1, 2) + points = torch.cat([box_coords, points], dim=1) + labels = torch.cat([box_labels, labels], dim=1) + + if normalize_coords: + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + points = points / torch.tensor([video_W, video_H]).to(points.device) + # scale the (normalized) coordinates by the model's internal image size + points = points * self.image_size + points = points.to(inference_state["device"]) + labels = labels.to(inference_state["device"]) + + if not clear_old_points: + point_inputs = point_inputs_per_frame.get(frame_idx, None) + else: + point_inputs = None + point_inputs = concat_points(point_inputs, points, labels) + + point_inputs_per_frame[frame_idx] = point_inputs + mask_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Get any previously predicted mask logits on this object and feed it along with + # the new clicks into the SAM mask decoder. + prev_sam_mask_logits = None + # lookup temporary output dict first, which contains the most recent output + # (if not found, then lookup conditioning and non-conditioning frame output) + prev_out = obj_temp_output_dict[storage_key].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + + if prev_out is not None and prev_out["pred_masks"] is not None: + device = inference_state["device"] + prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True) + # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. + prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0) + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=point_inputs, + mask_inputs=None, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + def add_new_points(self, *args, **kwargs): + """Deprecated method. Please use `add_new_points_or_box` instead.""" + return self.add_new_points_or_box(*args, **kwargs) + + @torch.inference_mode() + def add_new_mask( + self, + inference_state, + frame_idx, + obj_id, + mask, + ): + """Add new mask to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if not isinstance(mask, torch.Tensor): + mask = torch.tensor(mask, dtype=torch.bool) + assert mask.dim() == 2 + mask_H, mask_W = mask.shape + mask_inputs_orig = mask[None, None] # add batch and channel dimension + mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"]) + + # resize the mask if it doesn't match the model's image size + if mask_H != self.image_size or mask_W != self.image_size: + mask_inputs = torch.nn.functional.interpolate( + mask_inputs_orig, + size=(self.image_size, self.image_size), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + mask_inputs = (mask_inputs >= 0.5).float() + else: + mask_inputs = mask_inputs_orig + + mask_inputs_per_frame[frame_idx] = mask_inputs + point_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=None, + mask_inputs=mask_inputs, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + def _get_orig_video_res_output(self, inference_state, any_res_masks): + """ + Resize the object scores to the original video resolution (video_res_masks) + and apply non-overlapping constraints for final output. + """ + device = inference_state["device"] + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + any_res_masks = any_res_masks.to(device, non_blocking=True) + if any_res_masks.shape[-2:] == (video_H, video_W): + video_res_masks = any_res_masks + else: + video_res_masks = torch.nn.functional.interpolate( + any_res_masks, + size=(video_H, video_W), + mode="bilinear", + align_corners=False, + ) + if self.non_overlap_masks: + video_res_masks = self._apply_non_overlapping_constraints(video_res_masks) + return any_res_masks, video_res_masks + + def _consolidate_temp_output_across_obj( + self, + inference_state, + frame_idx, + is_cond, + run_mem_encoder, + consolidate_at_video_res=False, + ): + """ + Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on + a frame into a single output for all objects, including + 1) fill any missing objects either from `output_dict_per_obj` (if they exist in + `output_dict_per_obj` for this frame) or leave them as placeholder values + (if they don't exist in `output_dict_per_obj` for this frame); + 2) if specified, rerun memory encoder after apply non-overlapping constraints + on the object scores. + """ + batch_size = self._get_obj_num(inference_state) + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Optionally, we allow consolidating the temporary outputs at the original + # video resolution (to provide a better editing experience for mask prompts). + if consolidate_at_video_res: + assert not run_mem_encoder, "memory encoder cannot run at video resolution" + consolidated_H = inference_state["video_height"] + consolidated_W = inference_state["video_width"] + consolidated_mask_key = "pred_masks_video_res" + else: + consolidated_H = consolidated_W = self.image_size // 4 + consolidated_mask_key = "pred_masks" + + # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc" + # will be added when rerunning the memory encoder after applying non-overlapping + # constraints to object scores. Its "pred_masks" are prefilled with a large + # negative value (NO_OBJ_SCORE) to represent missing objects. + consolidated_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + consolidated_mask_key: torch.full( + size=(batch_size, 1, consolidated_H, consolidated_W), + fill_value=NO_OBJ_SCORE, + dtype=torch.float32, + device=inference_state["storage_device"], + ), + "obj_ptr": torch.full( + size=(batch_size, self.hidden_dim), + fill_value=NO_OBJ_SCORE, + dtype=torch.float32, + device=inference_state["device"], + ), + "object_score_logits": torch.full( + size=(batch_size, 1), + # default to 10.0 for object_score_logits, i.e. assuming the object is + # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder` + fill_value=10.0, + dtype=torch.float32, + device=inference_state["device"], + ), + } + empty_mask_ptr = None + for obj_idx in range(batch_size): + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + out = obj_temp_output_dict[storage_key].get(frame_idx, None) + # If the object doesn't appear in "temp_output_dict_per_obj" on this frame, + # we fall back and look up its previous output in "output_dict_per_obj". + # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in + # "output_dict_per_obj" to find a previous output for this object. + if out is None: + out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None) + if out is None: + out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None) + # If the object doesn't appear in "output_dict_per_obj" either, we skip it + # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE + # placeholder above) and set its object pointer to be a dummy pointer. + if out is None: + # Fill in dummy object pointers for those objects without any inputs or + # tracking outcomes on this frame (only do it under `run_mem_encoder=True`, + # i.e. when we need to build the memory for tracking). + if run_mem_encoder: + if empty_mask_ptr is None: + empty_mask_ptr = self._get_empty_mask_ptr( + inference_state, frame_idx + ) + # fill object pointer with a dummy pointer (based on an empty mask) + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr + continue + # Add the temporary object output mask to consolidated output mask + obj_mask = out["pred_masks"] + consolidated_pred_masks = consolidated_out[consolidated_mask_key] + if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]: + consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask + else: + # Resize first if temporary object mask has a different resolution + resized_obj_mask = torch.nn.functional.interpolate( + obj_mask, + size=consolidated_pred_masks.shape[-2:], + mode="bilinear", + align_corners=False, + ) + consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"] + consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[ + "object_score_logits" + ] + + # Optionally, apply non-overlapping constraints on the consolidated scores + # and rerun the memory encoder + if run_mem_encoder: + device = inference_state["device"] + high_res_masks = torch.nn.functional.interpolate( + consolidated_out["pred_masks"].to(device, non_blocking=True), + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + if self.non_overlap_masks_for_mem_enc: + high_res_masks = self._apply_non_overlapping_constraints(high_res_masks) + maskmem_features, maskmem_pos_enc = self._run_memory_encoder( + inference_state=inference_state, + frame_idx=frame_idx, + batch_size=batch_size, + high_res_masks=high_res_masks, + object_score_logits=consolidated_out["object_score_logits"], + is_mask_from_pts=True, # these frames are what the user interacted with + ) + consolidated_out["maskmem_features"] = maskmem_features + consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc + + return consolidated_out + + def _get_empty_mask_ptr(self, inference_state, frame_idx): + """Get a dummy object pointer based on an empty mask on the current frame.""" + # A dummy (empty) mask with a single object + batch_size = 1 + mask_inputs = torch.zeros( + (batch_size, 1, self.image_size, self.image_size), + dtype=torch.float32, + device=inference_state["device"], + ) + + # Retrieve correct image features + ( + _, + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._get_image_feature(inference_state, frame_idx, batch_size) + + # Feed the empty mask and image feature above to get a dummy object pointer + current_out = self.track_step( + frame_idx=frame_idx, + is_init_cond_frame=True, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=None, + mask_inputs=mask_inputs, + output_dict={}, + num_frames=inference_state["num_frames"], + track_in_reverse=False, + run_mem_encoder=False, + prev_sam_mask_logits=None, + ) + return current_out["obj_ptr"] + + @torch.inference_mode() + def propagate_in_video_preflight(self, inference_state): + """Prepare inference_state and consolidate temporary outputs before tracking.""" + # Tracking has started and we don't allow adding new objects until session is reset. + inference_state["tracking_has_started"] = True + batch_size = self._get_obj_num(inference_state) + + # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and + # add them into "output_dict". + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + output_dict = inference_state["output_dict"] + # "consolidated_frame_inds" contains indices of those frames where consolidated + # temporary outputs have been added (either in this call or any previous calls + # to `propagate_in_video_preflight`). + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + for is_cond in [False, True]: + # Separately consolidate conditioning and non-conditioning temp outputs + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Find all the frames that contain temporary outputs for any objects + # (these should be the frames that have just received clicks for mask inputs + # via `add_new_points_or_box` or `add_new_mask`) + temp_frame_inds = set() + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) + consolidated_frame_inds[storage_key].update(temp_frame_inds) + # consolidate the temporary output across all objects on this frame + for frame_idx in temp_frame_inds: + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True + ) + # merge them into "output_dict" and also create per-object slices + output_dict[storage_key][frame_idx] = consolidated_out + self._add_output_per_object( + inference_state, frame_idx, consolidated_out, storage_key + ) + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + if clear_non_cond_mem: + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + + # clear temporary outputs in `temp_output_dict_per_obj` + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + obj_temp_output_dict[storage_key].clear() + + # edge case: if an output is added to "cond_frame_outputs", we remove any prior + # output on the same frame in "non_cond_frame_outputs" + for frame_idx in output_dict["cond_frame_outputs"]: + output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for obj_output_dict in inference_state["output_dict_per_obj"].values(): + for frame_idx in obj_output_dict["cond_frame_outputs"]: + obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + assert frame_idx in output_dict["cond_frame_outputs"] + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + + # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames + # with either points or mask inputs (which should be true under a correct workflow). + all_consolidated_frame_inds = ( + consolidated_frame_inds["cond_frame_outputs"] + | consolidated_frame_inds["non_cond_frame_outputs"] + ) + input_frames_inds = set() + for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values(): + input_frames_inds.update(point_inputs_per_frame.keys()) + for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values(): + input_frames_inds.update(mask_inputs_per_frame.keys()) + assert all_consolidated_frame_inds == input_frames_inds + + @torch.inference_mode() + def propagate_in_video( + self, + inference_state, + start_frame_idx=None, + max_frame_num_to_track=None, + reverse=False, + ): + """Propagate the input points across frames to track in the entire video.""" + self.propagate_in_video_preflight(inference_state) + + output_dict = inference_state["output_dict"] + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + obj_ids = inference_state["obj_ids"] + num_frames = inference_state["num_frames"] + batch_size = self._get_obj_num(inference_state) + if len(output_dict["cond_frame_outputs"]) == 0: + raise RuntimeError("No points are provided; please add points first") + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + + # set start index, end index, and processing order + if start_frame_idx is None: + # default: start from the earliest frame with input points + start_frame_idx = min(output_dict["cond_frame_outputs"]) + if max_frame_num_to_track is None: + # default: track all the frames in the video + max_frame_num_to_track = num_frames + if reverse: + end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0) + if start_frame_idx > 0: + processing_order = range(start_frame_idx, end_frame_idx - 1, -1) + else: + processing_order = [] # skip reverse tracking if starting from frame 0 + else: + end_frame_idx = min( + start_frame_idx + max_frame_num_to_track, num_frames - 1 + ) + processing_order = range(start_frame_idx, end_frame_idx + 1) + + for frame_idx in tqdm(processing_order, desc="propagate in video"): + # We skip those frames already in consolidated outputs (these are frames + # that received input clicks or mask). Note that we cannot directly run + # batched forward on them via `_run_single_frame_inference` because the + # number of clicks on each object might be different. + if frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + storage_key = "cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + if clear_non_cond_mem: + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]: + storage_key = "non_cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + else: + storage_key = "non_cond_frame_outputs" + current_out, pred_masks = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=output_dict, + frame_idx=frame_idx, + batch_size=batch_size, + is_init_cond_frame=False, + point_inputs=None, + mask_inputs=None, + reverse=reverse, + run_mem_encoder=True, + ) + output_dict[storage_key][frame_idx] = current_out + # Create slices of per-object outputs for subsequent interaction with each + # individual object after tracking. + self._add_output_per_object( + inference_state, frame_idx, current_out, storage_key + ) + inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse} + + # Resize the output mask to the original video resolution (we directly use + # the mask scores on GPU for output to avoid any CPU conversion in between) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, pred_masks + ) + yield frame_idx, obj_ids, video_res_masks + + def _add_output_per_object( + self, inference_state, frame_idx, current_out, storage_key + ): + """ + Split a multi-object output into per-object output slices and add them into + `output_dict_per_obj`. The resulting slices share the same tensor storage. + """ + maskmem_features = current_out["maskmem_features"] + assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor) + + maskmem_pos_enc = current_out["maskmem_pos_enc"] + assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list) + + output_dict_per_obj = inference_state["output_dict_per_obj"] + for obj_idx, obj_output_dict in output_dict_per_obj.items(): + obj_slice = slice(obj_idx, obj_idx + 1) + obj_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + "pred_masks": current_out["pred_masks"][obj_slice], + "obj_ptr": current_out["obj_ptr"][obj_slice], + "object_score_logits": current_out["object_score_logits"][obj_slice], + } + if maskmem_features is not None: + obj_out["maskmem_features"] = maskmem_features[obj_slice] + if maskmem_pos_enc is not None: + obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc] + obj_output_dict[storage_key][frame_idx] = obj_out + + @torch.inference_mode() + def clear_all_prompts_in_frame( + self, inference_state, frame_idx, obj_id, need_output=True + ): + """Remove all input points or mask in a specific frame for a given object.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + + # Clear the conditioning information on the given frame + inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None) + inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None) + + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None) + temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None) + + # Check and see if there are still any inputs left on this frame + batch_size = self._get_obj_num(inference_state) + frame_has_input = False + for obj_idx2 in range(batch_size): + if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]: + frame_has_input = True + break + if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]: + frame_has_input = True + break + + # If this frame has no remaining inputs for any objects, we further clear its + # conditioning frame status + if not frame_has_input: + output_dict = inference_state["output_dict"] + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx) + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + # Remove the frame's conditioning output (possibly downgrading it to non-conditioning) + out = output_dict["cond_frame_outputs"].pop(frame_idx, None) + if out is not None: + # The frame is not a conditioning frame anymore since it's not receiving inputs, + # so we "downgrade" its output (if exists) to a non-conditioning frame output. + output_dict["non_cond_frame_outputs"][frame_idx] = out + inference_state["frames_already_tracked"].pop(frame_idx, None) + # Similarly, do it for the sliced output on each object. + for obj_idx2 in range(batch_size): + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2] + obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None) + if obj_out is not None: + obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out + + # If all the conditioning frames have been removed, we also clear the tracking outputs + if len(output_dict["cond_frame_outputs"]) == 0: + self._reset_tracking_results(inference_state) + + if not need_output: + return + # Finally, output updated masks per object (after removing the inputs above) + obj_ids = inference_state["obj_ids"] + is_cond = any( + frame_idx in obj_temp_output_dict["cond_frame_outputs"] + for obj_temp_output_dict in temp_output_dict_per_obj.values() + ) + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + @torch.inference_mode() + def reset_state(self, inference_state): + """Remove all input points or mask in all frames throughout the video.""" + self._reset_tracking_results(inference_state) + # Remove all object ids + inference_state["obj_id_to_idx"].clear() + inference_state["obj_idx_to_id"].clear() + inference_state["obj_ids"].clear() + inference_state["point_inputs_per_obj"].clear() + inference_state["mask_inputs_per_obj"].clear() + inference_state["output_dict_per_obj"].clear() + inference_state["temp_output_dict_per_obj"].clear() + + def _reset_tracking_results(self, inference_state): + """Reset all tracking inputs and results across the videos.""" + for v in inference_state["point_inputs_per_obj"].values(): + v.clear() + for v in inference_state["mask_inputs_per_obj"].values(): + v.clear() + for v in inference_state["output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + for v in inference_state["temp_output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + inference_state["output_dict"]["cond_frame_outputs"].clear() + inference_state["output_dict"]["non_cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear() + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"].clear() + + def _get_image_feature(self, inference_state, frame_idx, batch_size): + """Compute the image features on a given frame.""" + # Look up in the cache first + image, backbone_out = inference_state["cached_features"].get( + frame_idx, (None, None) + ) + if backbone_out is None: + # Cache miss -- we will run inference on a single image + device = inference_state["device"] + image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0) + backbone_out = self.forward_image(image) + # Cache the most recent frame's feature (for repeated interactions with + # a frame; we can use an LRU cache for more frames in the future). + inference_state["cached_features"] = {frame_idx: (image, backbone_out)} + + # expand the features to have the same dimension as the number of objects + expanded_image = image.expand(batch_size, -1, -1, -1) + expanded_backbone_out = { + "backbone_fpn": backbone_out["backbone_fpn"].copy(), + "vision_pos_enc": backbone_out["vision_pos_enc"].copy(), + } + for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]): + expanded_backbone_out["backbone_fpn"][i] = feat.expand( + batch_size, -1, -1, -1 + ) + for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]): + pos = pos.expand(batch_size, -1, -1, -1) + expanded_backbone_out["vision_pos_enc"][i] = pos + + features = self._prepare_backbone_features(expanded_backbone_out) + features = (expanded_image,) + features + return features + + def _run_single_frame_inference( + self, + inference_state, + output_dict, + frame_idx, + batch_size, + is_init_cond_frame, + point_inputs, + mask_inputs, + reverse, + run_mem_encoder, + prev_sam_mask_logits=None, + ): + """Run tracking on a single frame based on current inputs and previous memory.""" + # Retrieve correct image features + ( + _, + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._get_image_feature(inference_state, frame_idx, batch_size) + + # point and mask should not appear as input simultaneously on the same frame + assert point_inputs is None or mask_inputs is None + current_out = self.track_step( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + output_dict=output_dict, + num_frames=inference_state["num_frames"], + track_in_reverse=reverse, + run_mem_encoder=run_mem_encoder, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + maskmem_features = current_out["maskmem_features"] + if maskmem_features is not None: + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + pred_masks_gpu = current_out["pred_masks"] + # potentially fill holes in the predicted masks + if self.fill_hole_area > 0: + pred_masks_gpu = fill_holes_in_mask_scores( + pred_masks_gpu, self.fill_hole_area + ) + pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out) + # object pointer is a small tensor, so we always keep it on GPU memory for fast access + obj_ptr = current_out["obj_ptr"] + object_score_logits = current_out["object_score_logits"] + # make a compact version of this frame's output to reduce the state size + compact_current_out = { + "maskmem_features": maskmem_features, + "maskmem_pos_enc": maskmem_pos_enc, + "pred_masks": pred_masks, + "obj_ptr": obj_ptr, + "object_score_logits": object_score_logits, + } + return compact_current_out, pred_masks_gpu + + def _run_memory_encoder( + self, + inference_state, + frame_idx, + batch_size, + high_res_masks, + object_score_logits, + is_mask_from_pts, + ): + """ + Run the memory encoder on `high_res_masks`. This is usually after applying + non-overlapping constraints to object scores. Since their scores changed, their + memory also need to be computed again with the memory encoder. + """ + # Retrieve correct image features + _, _, current_vision_feats, _, feat_sizes = self._get_image_feature( + inference_state, frame_idx, batch_size + ) + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks, + object_score_logits=object_score_logits, + is_mask_from_pts=is_mask_from_pts, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc( + inference_state, {"maskmem_pos_enc": maskmem_pos_enc} + ) + return maskmem_features, maskmem_pos_enc + + def _get_maskmem_pos_enc(self, inference_state, current_out): + """ + `maskmem_pos_enc` is the same across frames and objects, so we cache it as + a constant in the inference session to reduce session storage size. + """ + model_constants = inference_state["constants"] + # "out_maskmem_pos_enc" should be either a list of tensors or None + out_maskmem_pos_enc = current_out["maskmem_pos_enc"] + if out_maskmem_pos_enc is not None: + if "maskmem_pos_enc" not in model_constants: + assert isinstance(out_maskmem_pos_enc, list) + # only take the slice for one object, since it's same across objects + maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc] + model_constants["maskmem_pos_enc"] = maskmem_pos_enc + else: + maskmem_pos_enc = model_constants["maskmem_pos_enc"] + # expand the cached maskmem_pos_enc to the actual batch size + batch_size = out_maskmem_pos_enc[0].size(0) + expanded_maskmem_pos_enc = [ + x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc + ] + else: + expanded_maskmem_pos_enc = None + return expanded_maskmem_pos_enc + + @torch.inference_mode() + def remove_object(self, inference_state, obj_id, strict=False, need_output=True): + """ + Remove an object id from the tracking state. If strict is True, we check whether + the object id actually exists and raise an error if it doesn't exist. + """ + old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None) + updated_frames = [] + # Check whether this object_id to remove actually exists and possibly raise an error. + if old_obj_idx_to_rm is None: + if not strict: + return inference_state["obj_ids"], updated_frames + raise RuntimeError( + f"Cannot remove object id {obj_id} as it doesn't exist. " + f"All existing object ids: {inference_state['obj_ids']}." + ) + + # If this is the only remaining object id, we simply reset the state. + if len(inference_state["obj_id_to_idx"]) == 1: + self.reset_state(inference_state) + return inference_state["obj_ids"], updated_frames + + # There are still remaining objects after removing this object id. In this case, + # we need to delete the object storage from inference state tensors. + # Step 0: clear the input on those frames where this object id has point or mask input + # (note that this step is required as it might downgrade conditioning frames to + # non-conditioning ones) + obj_input_frames_inds = set() + obj_input_frames_inds.update( + inference_state["point_inputs_per_obj"][old_obj_idx_to_rm] + ) + obj_input_frames_inds.update( + inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm] + ) + for frame_idx in obj_input_frames_inds: + self.clear_all_prompts_in_frame( + inference_state, frame_idx, obj_id, need_output=False + ) + + # Step 1: Update the object id mapping (note that it must be done after Step 0, + # since Step 0 still requires the old object id mappings in inference_state) + old_obj_ids = inference_state["obj_ids"] + old_obj_inds = list(range(len(old_obj_ids))) + remain_old_obj_inds = old_obj_inds.copy() + remain_old_obj_inds.remove(old_obj_idx_to_rm) + new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds] + new_obj_inds = list(range(len(new_obj_ids))) + # build new mappings + old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds)) + inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds)) + inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids)) + inference_state["obj_ids"] = new_obj_ids + + # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys. + # (note that "consolidated_frame_inds" doesn't need to be updated in this step as + # it's already handled in Step 0) + def _map_keys(container): + new_kvs = [] + for k in old_obj_inds: + v = container.pop(k) + if k in old_idx_to_new_idx: + new_kvs.append((old_idx_to_new_idx[k], v)) + container.update(new_kvs) + + _map_keys(inference_state["point_inputs_per_obj"]) + _map_keys(inference_state["mask_inputs_per_obj"]) + _map_keys(inference_state["output_dict_per_obj"]) + _map_keys(inference_state["temp_output_dict_per_obj"]) + + # Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices. + def _slice_state(output_dict, storage_key): + for frame_idx, out in output_dict[storage_key].items(): + out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds] + out["maskmem_pos_enc"] = [ + x[remain_old_obj_inds] for x in out["maskmem_pos_enc"] + ] + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out) + out["pred_masks"] = out["pred_masks"][remain_old_obj_inds] + out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds] + out["object_score_logits"] = out["object_score_logits"][ + remain_old_obj_inds + ] + # also update the per-object slices + self._add_output_per_object( + inference_state, frame_idx, out, storage_key + ) + + _slice_state(inference_state["output_dict"], "cond_frame_outputs") + _slice_state(inference_state["output_dict"], "non_cond_frame_outputs") + + # Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which + # could show an updated mask for objects previously occluded by the object being removed + if need_output: + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + for frame_idx in obj_input_frames_inds: + is_cond = any( + frame_idx in obj_temp_output_dict["cond_frame_outputs"] + for obj_temp_output_dict in temp_output_dict_per_obj.values() + ) + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + updated_frames.append((frame_idx, video_res_masks)) + + return inference_state["obj_ids"], updated_frames + + def _clear_non_cond_mem_around_input(self, inference_state, frame_idx): + """ + Remove the non-conditioning memory around the input frame. When users provide + correction clicks, the surrounding frames' non-conditioning memories can still + contain outdated object appearance information and could confuse the model. + + This method clears those non-conditioning memories surrounding the interacted + frame to avoid giving the model both old and new information about the object. + """ + r = self.memory_temporal_stride_for_eval + frame_idx_begin = frame_idx - r * self.num_maskmem + frame_idx_end = frame_idx + r * self.num_maskmem + output_dict = inference_state["output_dict"] + non_cond_frame_outputs = output_dict["non_cond_frame_outputs"] + for t in range(frame_idx_begin, frame_idx_end + 1): + non_cond_frame_outputs.pop(t, None) + for obj_output_dict in inference_state["output_dict_per_obj"].values(): + obj_output_dict["non_cond_frame_outputs"].pop(t, None) diff --git a/ref-avs.code/model/visual/sam2/utils/amg.py b/ref-avs.code/model/visual/sam2/utils/amg.py new file mode 100644 index 0000000000000000000000000000000000000000..986842960cf5deca00614b7b1cde1ab77dad7e6e --- /dev/null +++ b/ref-avs.code/model/visual/sam2/utils/amg.py @@ -0,0 +1,348 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, Generator, ItemsView, List, Tuple + +import numpy as np +import torch + +# Very lightly adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/utils/amg.py + + +class MaskData: + """ + A structure for storing masks and their related data in batched format. + Implements basic filtering and concatenation. + """ + + def __init__(self, **kwargs) -> None: + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats = dict(**kwargs) + + def __setitem__(self, key: str, item: Any) -> None: + assert isinstance( + item, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats[key] = item + + def __delitem__(self, key: str) -> None: + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + return self._stats.items() + + def filter(self, keep: torch.Tensor) -> None: + for k, v in self._stats.items(): + if v is None: + self._stats[k] = None + elif isinstance(v, torch.Tensor): + self._stats[k] = v[torch.as_tensor(keep, device=v.device)] + elif isinstance(v, np.ndarray): + self._stats[k] = v[keep.detach().cpu().numpy()] + elif isinstance(v, list) and keep.dtype == torch.bool: + self._stats[k] = [a for i, a in enumerate(v) if keep[i]] + elif isinstance(v, list): + self._stats[k] = [v[i] for i in keep] + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def cat(self, new_stats: "MaskData") -> None: + for k, v in new_stats.items(): + if k not in self._stats or self._stats[k] is None: + self._stats[k] = deepcopy(v) + elif isinstance(v, torch.Tensor): + self._stats[k] = torch.cat([self._stats[k], v], dim=0) + elif isinstance(v, np.ndarray): + self._stats[k] = np.concatenate([self._stats[k], v], axis=0) + elif isinstance(v, list): + self._stats[k] = self._stats[k] + deepcopy(v) + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def to_numpy(self) -> None: + for k, v in self._stats.items(): + if isinstance(v, torch.Tensor): + self._stats[k] = v.float().detach().cpu().numpy() + + +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + boxes = uncrop_boxes_xyxy(boxes, crop_box).float() + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: + box_xywh = deepcopy(box_xyxy) + box_xywh[2] = box_xywh[2] - box_xywh[0] + box_xywh[3] = box_xywh[3] - box_xywh[1] + return box_xywh + + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + assert len(args) > 0 and all( + len(a) == len(args[0]) for a in args + ), "Batched iteration must have inputs of all the same size." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] + + +def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """ + Encodes masks to an uncompressed RLE, in the format expected by + pycoco tools. + """ + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(b): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + cur_idxs = torch.cat( + [ + torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), + cur_idxs + 1, + torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), + ] + ) + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if tensor[i, 0] == 0 else [0] + counts.extend(btw_idxs.detach().cpu().tolist()) + out.append({"size": [h, w], "counts": counts}) + return out + + +def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + h, w = rle["size"] + mask = np.empty(h * w, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity ^= True + mask = mask.reshape(w, h) + return mask.transpose() # Put in C order + + +def area_from_rle(rle: Dict[str, Any]) -> int: + return sum(rle["counts"][1::2]) + + +def calculate_stability_score( + masks: torch.Tensor, mask_threshold: float, threshold_offset: float +) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. The stability + score is the IoU between the binary masks obtained by thresholding + the predicted mask logits at high and low values. + """ + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + unions = ( + (masks > (mask_threshold - threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + return intersections / unions + + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + +def build_all_layer_point_grids( + n_per_side: int, n_layers: int, scale_per_layer: int +) -> List[np.ndarray]: + """Generates point grids for all crop layers.""" + points_by_layer = [] + for i in range(n_layers + 1): + n_points = int(n_per_side / (scale_per_layer**i)) + points_by_layer.append(build_point_grid(n_points)) + return points_by_layer + + +def generate_crop_boxes( + im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer + has (2**i)**2 boxes for the ith layer. + """ + crop_boxes, layer_idxs = [], [] + im_h, im_w = im_size + short_side = min(im_h, im_w) + + # Original image + crop_boxes.append([0, 0, im_w, im_h]) + layer_idxs.append(0) + + def crop_len(orig_len, n_crops, overlap): + return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) + + for i_layer in range(n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_w = crop_len(im_w, n_crops_per_side, overlap) + crop_h = crop_len(im_h, n_crops_per_side, overlap) + + crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] + + # Crops in XYWH format + for x0, y0 in product(crop_box_x0, crop_box_y0): + box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + return boxes + offset + + +def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0]], device=points.device) + # Check if points has a channel dimension + if len(points.shape) == 3: + offset = offset.unsqueeze(1) + return points + offset + + +def uncrop_masks( + masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int +) -> torch.Tensor: + x0, y0, x1, y1 = crop_box + if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) + pad = (x0, pad_x - x0, y0, pad_y - y0) + return torch.nn.functional.pad(masks, pad, value=0) + + +def remove_small_regions( + mask: np.ndarray, area_thresh: float, mode: str +) -> Tuple[np.ndarray, bool]: + """ + Removes small disconnected regions and holes in a mask. Returns the + mask and an indicator of if the mask has been modified. + """ + import cv2 # type: ignore + + assert mode in ["holes", "islands"] + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if len(small_regions) == 0: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels] + # If every region is below threshold, keep largest + if len(fill_labels) == 0: + fill_labels = [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + + +def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: + from pycocotools import mask as mask_utils # type: ignore + + h, w = uncompressed_rle["size"] + rle = mask_utils.frPyObjects(uncompressed_rle, h, w) + rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json + return rle + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates boxes in XYXY format around masks. Return [0,0,0,0] for + an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + if len(shape) > 2: + masks = masks.flatten(0, -3) + else: + masks = masks.unsqueeze(0) + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + if len(shape) > 2: + out = out.reshape(*shape[:-2], 4) + else: + out = out[0] + + return out diff --git a/ref-avs.code/model/visual/sam2/utils/misc.py b/ref-avs.code/model/visual/sam2/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..b65ee825732ff85137805be650edd4cbe8e6f6d4 --- /dev/null +++ b/ref-avs.code/model/visual/sam2/utils/misc.py @@ -0,0 +1,349 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import warnings +from threading import Thread + +import numpy as np +import torch +from PIL import Image +from tqdm import tqdm + + +def get_sdpa_settings(): + if torch.cuda.is_available(): + old_gpu = torch.cuda.get_device_properties(0).major < 7 + # only use Flash Attention on Ampere (8.0) or newer GPUs + use_flash_attn = torch.cuda.get_device_properties(0).major >= 8 + if not use_flash_attn: + warnings.warn( + "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.", + category=UserWarning, + stacklevel=2, + ) + # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only + # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases) + pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2]) + if pytorch_version < (2, 2): + warnings.warn( + f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. " + "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).", + category=UserWarning, + stacklevel=2, + ) + math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn + else: + old_gpu = True + use_flash_attn = False + math_kernel_on = True + + return old_gpu, use_flash_attn, math_kernel_on + + +def get_connected_components(mask): + """ + Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W). + + Inputs: + - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is + background. + + Outputs: + - labels: A tensor of shape (N, 1, H, W) containing the connected component labels + for foreground pixels and 0 for background pixels. + - counts: A tensor of shape (N, 1, H, W) containing the area of the connected + components for foreground pixels and 0 for background pixels. + """ + from sam2 import _C + + return _C.get_connected_componnets(mask.to(torch.uint8).contiguous()) + + +def mask_to_box(masks: torch.Tensor): + """ + compute bounding box given an input mask + + Inputs: + - masks: [B, 1, H, W] masks, dtype=torch.Tensor + + Returns: + - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor + """ + B, _, h, w = masks.shape + device = masks.device + xs = torch.arange(w, device=device, dtype=torch.int32) + ys = torch.arange(h, device=device, dtype=torch.int32) + grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy") + grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w) + grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w) + min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1) + max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1) + min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1) + max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1) + bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1) + + return bbox_coords + + +def _load_img_as_tensor(img_path, image_size): + img_pil = Image.open(img_path) + img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) + if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images + img_np = img_np / 255.0 + else: + raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}") + img = torch.from_numpy(img_np).permute(2, 0, 1) + video_width, video_height = img_pil.size # the original video size + return img, video_height, video_width + + +class AsyncVideoFrameLoader: + """ + A list of video frames to be load asynchronously without blocking session start. + """ + + def __init__( + self, + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ): + self.img_paths = img_paths + self.image_size = image_size + self.offload_video_to_cpu = offload_video_to_cpu + self.img_mean = img_mean + self.img_std = img_std + # items in `self.images` will be loaded asynchronously + self.images = [None] * len(img_paths) + # catch and raise any exceptions in the async loading thread + self.exception = None + # video_height and video_width be filled when loading the first image + self.video_height = None + self.video_width = None + self.compute_device = compute_device + + # load the first frame to fill video_height and video_width and also + # to cache it (since it's most likely where the user will click) + self.__getitem__(0) + + # load the rest of frames asynchronously without blocking the session start + def _load_frames(): + try: + for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"): + self.__getitem__(n) + except Exception as e: + self.exception = e + + self.thread = Thread(target=_load_frames, daemon=True) + self.thread.start() + + def __getitem__(self, index): + if self.exception is not None: + raise RuntimeError("Failure in frame loading thread") from self.exception + + img = self.images[index] + if img is not None: + return img + + img, video_height, video_width = _load_img_as_tensor( + self.img_paths[index], self.image_size + ) + self.video_height = video_height + self.video_width = video_width + # normalize by mean and std + img -= self.img_mean + img /= self.img_std + if not self.offload_video_to_cpu: + img = img.to(self.compute_device, non_blocking=True) + self.images[index] = img + return img + + def __len__(self): + return len(self.images) + + +def load_video_frames( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + compute_device=torch.device("cuda"), +): + """ + Load the video frames from video_path. The frames are resized to image_size as in + the model and are loaded to GPU if offload_video_to_cpu=False. This is used by the demo. + """ + is_bytes = isinstance(video_path, bytes) + is_str = isinstance(video_path, str) + is_mp4_path = is_str and os.path.splitext(video_path)[-1] in [".mp4", ".MP4"] + if is_bytes or is_mp4_path: + return load_video_frames_from_video_file( + video_path=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + compute_device=compute_device, + ) + elif is_str and os.path.isdir(video_path): + return load_video_frames_from_jpg_images( + video_path=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + async_loading_frames=async_loading_frames, + compute_device=compute_device, + ) + else: + raise NotImplementedError( + "Only MP4 video and JPEG folder are supported at this moment" + ) + + +def load_video_frames_from_jpg_images( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + compute_device=torch.device("cuda"), +): + """ + Load the video frames from a directory of JPEG files (".jpg" format). + + The frames are resized to image_size x image_size and are loaded to GPU if + `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. + + You can load a frame asynchronously by setting `async_loading_frames` to `True`. + """ + if isinstance(video_path, str) and os.path.isdir(video_path): + jpg_folder = video_path + else: + raise NotImplementedError( + "Only JPEG frames are supported at this moment. For video files, you may use " + "ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n" + "```\n" + "ffmpeg -i .mp4 -q:v 2 -start_number 0 /'%05d.jpg'\n" + "```\n" + "where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks " + "ffmpeg to start the JPEG file from 00000.jpg." + ) + + frame_names = [ + p + for p in os.listdir(jpg_folder) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + ] + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + num_frames = len(frame_names) + if num_frames == 0: + raise RuntimeError(f"no images found in {jpg_folder}") + img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + + if async_loading_frames: + lazy_images = AsyncVideoFrameLoader( + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ) + return lazy_images, lazy_images.video_height, lazy_images.video_width + + images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) + for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): + images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) + if not offload_video_to_cpu: + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def load_video_frames_from_video_file( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + compute_device=torch.device("cuda"), +): + """Load the video frames from a video file.""" + import decord + + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + # Get the original video height and width + decord.bridge.set_bridge("torch") + video_height, video_width, _ = decord.VideoReader(video_path).next().shape + # Iterate over all frames in the video + images = [] + for frame in decord.VideoReader(video_path, width=image_size, height=image_size): + images.append(frame.permute(2, 0, 1)) + + images = torch.stack(images, dim=0).float() / 255.0 + if not offload_video_to_cpu: + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def fill_holes_in_mask_scores(mask, max_area): + """ + A post processor to fill small holes in mask scores with area under `max_area`. + """ + # Holes are those connected components in background with area <= self.max_area + # (background regions are those with mask scores <= 0) + assert max_area > 0, "max_area must be positive" + + input_mask = mask + try: + labels, areas = get_connected_components(mask <= 0) + is_hole = (labels > 0) & (areas <= max_area) + # We fill holes with a small positive mask score (0.1) to change them to foreground. + mask = torch.where(is_hole, 0.1, mask) + except Exception as e: + # Skip the post-processing step on removing small holes if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. You can " + "still use SAM 2 and it's OK to ignore the error above, although some post-processing " + "functionality may be limited (which doesn't affect the results in most cases; see " + "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + mask = input_mask + + return mask + + +def concat_points(old_point_inputs, new_points, new_labels): + """Add new points and labels to previous point inputs (add at the end).""" + if old_point_inputs is None: + points, labels = new_points, new_labels + else: + points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1) + labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1) + + return {"point_coords": points, "point_labels": labels} diff --git a/ref-avs.code/model/visual/sam2/utils/transforms.py b/ref-avs.code/model/visual/sam2/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6d4fa6a3e4d2e2a0dde7f87e4991daff338467c4 --- /dev/null +++ b/ref-avs.code/model/visual/sam2/utils/transforms.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import Normalize, Resize, ToTensor + + +class SAM2Transforms(nn.Module): + def __init__( + self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0 + ): + """ + Transforms for SAM2. + """ + super().__init__() + self.resolution = resolution + self.mask_threshold = mask_threshold + self.max_hole_area = max_hole_area + self.max_sprinkle_area = max_sprinkle_area + self.mean = [0.485, 0.456, 0.406] + self.std = [0.229, 0.224, 0.225] + self.to_tensor = ToTensor() + self.transforms = torch.jit.script( + nn.Sequential( + Resize((self.resolution, self.resolution)), + Normalize(self.mean, self.std), + ) + ) + + def __call__(self, x): + x = self.to_tensor(x) + return self.transforms(x) + + def forward_batch(self, img_list): + img_batch = [self.transforms(self.to_tensor(img)) for img in img_list] + img_batch = torch.stack(img_batch, dim=0) + return img_batch + + def transform_coords( + self, coords: torch.Tensor, normalize=False, orig_hw=None + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, + If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + + Returns + Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model. + """ + if normalize: + assert orig_hw is not None + h, w = orig_hw + coords = coords.clone() + coords[..., 0] = coords[..., 0] / w + coords[..., 1] = coords[..., 1] / h + + coords = coords * self.resolution # unnormalize coords + return coords + + def transform_boxes( + self, boxes: torch.Tensor, normalize=False, orig_hw=None + ) -> torch.Tensor: + """ + Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates, + if the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + """ + boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) + return boxes + + def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: + """ + Perform PostProcessing on output masks. + """ + from model.visual.sam2.utils.misc import get_connected_components + + masks = masks.float() + input_masks = masks + mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image + try: + if self.max_hole_area > 0: + # Holes are those connected components in background with area <= self.fill_hole_area + # (background regions are those with mask scores <= self.mask_threshold) + labels, areas = get_connected_components( + mask_flat <= self.mask_threshold + ) + is_hole = (labels > 0) & (areas <= self.max_hole_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with a small positive mask score (10.0) to change them to foreground. + masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) + + if self.max_sprinkle_area > 0: + labels, areas = get_connected_components( + mask_flat > self.mask_threshold + ) + is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with negative mask score (-10.0) to change them to background. + masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) + except Exception as e: + # Skip the post-processing step if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. You can " + "still use SAM 2 and it's OK to ignore the error above, although some post-processing " + "functionality may be limited (which doesn't affect the results in most cases; see " + "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + masks = input_masks + + masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) + return masks diff --git a/ref-avs.code/requirements.txt b/ref-avs.code/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3f016b081293b4373ec8f982984f99f828375d0b --- /dev/null +++ b/ref-avs.code/requirements.txt @@ -0,0 +1,4 @@ +# Incremental updates only (newly installed in this session) +transformers==5.6.2 +audiomentations==0.39.0 +wandb==0.26.1 diff --git a/ref-avs.code/tools/remap_aural_ckpt_keys.py b/ref-avs.code/tools/remap_aural_ckpt_keys.py new file mode 100644 index 0000000000000000000000000000000000000000..8f96ba921661cc71eabec797d9b7cceceab482fc --- /dev/null +++ b/ref-avs.code/tools/remap_aural_ckpt_keys.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 +""" +Remap legacy Ref-AVS / AVS checkpoints to the current AuralFuser key layout. + +Supports: + - Full-model ckpt with ``aural_fuser.training_layers.*`` / ``finetuning_layers.*`` + (old PromptAudio ModuleList layout) + - ``audio_prompter.*`` + ``train_*`` names (older AVS exports) + - Already-remapped ``patch_embeds.*``, ``f_blocks.*``, etc. (passed through) + +Usage (from repo root or ref-avs.code): + + python ref-avs.code/tools/remap_aural_ckpt_keys.py \\ + ckpts/exp/ref-hiera-l/s\\(0.59\\)_u\\(0.68\\).pth \\ + -o ckpts/exp/ref-hiera-l/remapped.pth + +Then inference: + + python ref-avs.code/inference.py --gpus 1 \\ + --inference_ckpt ckpts/exp/ref-hiera-l/remapped.pth +""" +from __future__ import annotations + +import argparse +import re +import shutil +from pathlib import Path + +import torch + +# Old ``training_layers`` append order in legacy PromptAudio.__init__ +_TRAINING_LAYERS_INDEX_MAP: dict[int, str] = { + 0: "patch_embeds.0", + 1: "patch_embeds.1", + 2: "patch_embeds.2", + 3: "f_blocks.0", + 4: "a_blocks.0", + 5: "fusion_modules.0", + 6: "f_blocks.1", + 7: "a_blocks.1", + 8: "fusion_modules.1", + 9: "f_blocks.2", + 10: "a_blocks.2", + 11: "fusion_modules.2", + 12: "smooth_convs.0", + 13: "smooth_convs.1", + 14: "train_proj_v1", + 15: "train_proj_a1", + 16: "text_proj", +} + +# Flat ``train_*`` renames (audio_prompter / some aural_fuser exports) +_FLAT_REPLACEMENTS: list[tuple[str, str]] = [ + ("train_f_patch_embed1", "patch_embeds.0"), + ("train_f_patch_embed2", "patch_embeds.1"), + ("train_f_patch_embed3", "patch_embeds.2"), + ("train_f_a_block1", "fusion_modules.0"), + ("train_f_a_block2", "fusion_modules.1"), + ("train_f_a_block3", "fusion_modules.2"), + ("train_f_block1", "f_blocks.0"), + ("train_f_block2", "f_blocks.1"), + ("train_f_block3", "f_blocks.2"), + ("train_a_block1", "a_blocks.0"), + ("train_a_block2", "a_blocks.1"), + ("train_a_block3", "a_blocks.2"), + ("train_smooth1", "smooth_convs.0"), + ("train_smooth2", "smooth_convs.1"), +] + +_RE_TRAINING_LAYER = re.compile(r"^(?P(?:aural_fuser|audio_prompter))\.training_layers\.(\d+)\.(?P.+)$") +_RE_FINETUNING_LAYER = re.compile( + r"^(?P(?:aural_fuser|audio_prompter))\.finetuning_layers\.0\.(?P.+)$" +) + + +def _apply_flat_renames(key: str) -> str: + for old, new in _FLAT_REPLACEMENTS: + key = key.replace(old, new) + return key + + +def _remap_key(key: str) -> str | None: + """Return new key, or None to drop the entry.""" + m = _RE_FINETUNING_LAYER.match(key) + if m: + prefix = "aural_fuser" if m.group("prefix") == "audio_prompter" else m.group("prefix") + return f"{prefix}.vgg.{m.group('rest')}" + + m = _RE_TRAINING_LAYER.match(key) + if m: + prefix = "aural_fuser" if m.group("prefix") == "audio_prompter" else m.group("prefix") + idx = int(m.group(2)) + rest = m.group("rest") + target = _TRAINING_LAYERS_INDEX_MAP.get(idx) + if target is None: + return None + return f"{prefix}.{target}.{rest}" + + if key.startswith("audio_prompter."): + if ".training_layers." in key or ".finetuning_layers." in key: + return None + key = key.replace("audio_prompter.", "aural_fuser.", 1) + return _apply_flat_renames(key) + + if ".training_layers." in key or ".finetuning_layers." in key: + return None + + if key.startswith("aural_fuser."): + return _apply_flat_renames(key) + + return key + + +def remap_state_dict(sd: dict) -> dict: + out: dict = {} + dropped = 0 + remapped = 0 + skip_finetuning = any(k.startswith("aural_fuser.vgg.") for k in sd) + for k, v in sd.items(): + if skip_finetuning and "finetuning_layers." in k: + dropped += 1 + continue + nk = _remap_key(k) + if nk is None: + dropped += 1 + continue + if nk != k: + remapped += 1 + if nk in out: + dropped += 1 + continue + out[nk] = v + print(f"Remapped keys: {remapped}, dropped: {dropped}") + return out + + +def _summarize(sd: dict) -> None: + prefixes = ( + "v_model.", + "aural_fuser.patch_embeds", + "aural_fuser.f_blocks", + "aural_fuser.vgg", + "aural_fuser.text_proj", + "t_model.", + ) + for p in prefixes: + n = sum(1 for k in sd if k.startswith(p)) + if n: + print(f" {p}* -> {n} keys") + legacy = sum( + 1 for k in sd + if "training_layers" in k or "finetuning_layers" in k or "train_f_patch" in k + ) + if legacy: + print(f" WARNING: {legacy} legacy keys remain") + + +def main() -> None: + ap = argparse.ArgumentParser(description="Remap legacy AuralFuser / full-model checkpoint keys") + ap.add_argument("ckpt", type=Path, help="Input .pth state_dict") + ap.add_argument("-o", "--output", type=Path, default=None, help="Output .pth (default: _remapped.pth)") + ap.add_argument("--in-place", action="store_true", help="Overwrite input (creates .bak unless --no-backup)") + ap.add_argument("--no-backup", action="store_true") + ap.add_argument( + "--aural-fuser-only", action="store_true", + help="Keep only aural_fuser.* (for aural_fuser-only inference ckpt)", + ) + args = ap.parse_args() + + ckpt_path = args.ckpt.resolve() + if not ckpt_path.is_file(): + raise SystemExit(f"File not found: {ckpt_path}") + + print(f"Loading: {ckpt_path}") + sd = torch.load(ckpt_path, map_location="cpu", weights_only=False) + if not isinstance(sd, dict): + raise SystemExit("Expected top-level checkpoint to be a state_dict dict") + + n_legacy = sum( + 1 for k in sd + if "training_layers." in k or "finetuning_layers." in k + ) + if n_legacy == 0: + print("Note: no training_layers / finetuning_layers keys; file may already be remapped.") + + new_sd = remap_state_dict(sd) + if args.aural_fuser_only: + stripped = {} + for k, v in new_sd.items(): + if not k.startswith("aural_fuser."): + continue + stripped[k[len("aural_fuser."):]] = v + new_sd = stripped + print(f"aural-fuser-only (no prefix, for inference.py): {len(new_sd)} keys") + + print("Summary:") + _summarize(new_sd) + + if args.in_place: + out = ckpt_path + if not args.no_backup: + bak = ckpt_path.with_suffix(ckpt_path.suffix + ".bak") + print(f"Backup -> {bak}") + shutil.copy2(ckpt_path, bak) + else: + out = args.output or ckpt_path.with_name(ckpt_path.suffix.replace(".pth", "") + "_remapped.pth") + + torch.save(new_sd, out) + print(f"Saved: {out} ({len(new_sd)} tensor keys)") + + +if __name__ == "__main__": + main() diff --git a/ref-avs.code/trainer/train.py b/ref-avs.code/trainer/train.py new file mode 100644 index 0000000000000000000000000000000000000000..c044c352671931106e2aa7db6a040d4bb4496c26 --- /dev/null +++ b/ref-avs.code/trainer/train.py @@ -0,0 +1,224 @@ +"""Training and validation for Ref-AVS (text + audio + SAM2 multimask decoding).""" +import numpy +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +_DECODE_MODES = frozenset({'', 'iou_select', 'iou_occ_select'}) + + +def _decode_mode_and_wandb_tag(process): + """Match tmp.code: `process` is decode mode for known strings; else Ref split tag + default decode.""" + if process in _DECODE_MODES: + return process, process + return 'iou_select', process + + +class Trainer: + """Train / valid / null-valid steps with composite loss, contrastive term, and metrics.""" + + def __init__(self, hyp_param, loss, tensorboard, metrics): + self.param = hyp_param + self.loss = loss + self.tensorboard = tensorboard + self.metrics = metrics + from loss.training.contrastive_learning import ContrastLoss + self.cl = ContrastLoss(self.param) + + @torch.no_grad() + def valid_null(self, epoch, dataloader, model, process='test_n'): + if not isinstance(dataloader, DataLoader): + raise TypeError("valid_null() expects a torch.utils.data.DataLoader (do not pass iter(dataloader) first).") + decode_mode, wandb_tag = _decode_mode_and_wandb_tag(process) + self.metrics['foreground_s'].reset() + dataloader_length = len(dataloader) + tbar = range(dataloader_length) + tbar = tqdm(tbar, ncols=135) if self.param.local_rank <= 0 else tbar + p_pool = [None] * self.param.gpus + n_pool = [None] * self.param.gpus + + data_iter = iter(dataloader) + for _ in tbar: + items = next(data_iter) + frame, spect, prompt_dicts = items['frame'], items['spectrogram'], items['text'] + logits = [] + for frame_, spect_, prompt_dicts_ in zip(frame, spect, prompt_dicts): + frame_ = frame_.cuda(self.param.local_rank, non_blocking=True) + spect_ = spect_.cuda(self.param.local_rank, non_blocking=True) + prompt_dicts_ = [prompt_dicts_] + with torch.autocast("cuda", dtype=torch.bfloat16): + outputs, _ = model.module(frame_, spect_, prompt_dicts_, sam_process=False) + + logits_ = torch.cat([torch.cat(i['multistep_pred_multimasks_high_res']) for i in outputs]) + ious_scores = torch.cat([torch.cat(i['multistep_pred_ious']) for i in outputs]) + occ_scores = torch.cat([torch.cat(i['multistep_object_score_logits']) for i in outputs]) + if decode_mode == 'iou_select': + ious_scores = torch.argmax(ious_scores, dim=1) + logits_ = logits_[torch.arange(0, frame_.shape[0]), ious_scores, ...] + elif decode_mode == 'iou_occ_select': + ious_scores = torch.argmax(ious_scores, dim=1) + logits_ = logits_[torch.arange(0, frame_.shape[0]), ious_scores, ...] + logits_[occ_scores.squeeze() < 0, ...] = 0. + else: + logits_ = logits_[:, 0, ...] + logits.append(logits_) + + logits = torch.cat(logits).reshape(frame.shape[0], -1, self.param.image_size, self.param.image_size) + if len(logits.shape) == 3: + logits = logits.unsqueeze(1) + + foreground_s = self.metrics['foreground_s'].metric_s_for_null(logits, get_entire_list=True) + torch.distributed.all_gather_object(p_pool, foreground_s['foreground_p']) + torch.distributed.all_gather_object(n_pool, foreground_s['foreground_n']) + foreground_s = sum([i[0].cpu() for i in p_pool]) / sum([i[0] for i in n_pool]) + + if self.param.local_rank <= 0: + tbar.set_description( + 'epoch {} | valid.null_s {}'.format(epoch, numpy.round(foreground_s, 5)), + ) + torch.cuda.empty_cache() + + final_s = foreground_s + if self.param.local_rank <= 0 and self.tensorboard is not None: + self.tensorboard.upload_wandb_info({"valid.f_s/{}".format(wandb_tag): final_s}) + + return numpy.round(final_s, 5) + + @torch.no_grad() + def valid(self, epoch, dataloader, model, process='iou_select'): + """Evaluate IoU / F-score; `process` is decode mode (tmp) or split tag (test_s / test_u). Wandb keys like tmp.""" + if not isinstance(dataloader, DataLoader): + raise TypeError("valid() expects a torch.utils.data.DataLoader (do not pass iter(dataloader) first).") + decode_mode, wandb_tag = _decode_mode_and_wandb_tag(process) + self.metrics['foreground_iou'].reset() + self.metrics['foreground_f-score'].reset() + dataloader_length = len(dataloader) + tbar = range(dataloader_length) + tbar = tqdm(tbar, ncols=135) if self.param.local_rank <= 0 else tbar + iou_pool = [None] * self.param.gpus + fscore_pool = [None] * self.param.gpus + + data_iter = iter(dataloader) + for _ in tbar: + items = next(data_iter) + frame, spect, label, prompt_dicts = ( + items['frame'], items['spectrogram'], items['label'], items['text'] + ) + logits = [] + labels = [] + for frame_, spect_, label_, prompt_dicts_ in zip(frame, spect, label, prompt_dicts): + frame_ = frame_.cuda(self.param.local_rank, non_blocking=True) + spect_ = spect_.cuda(self.param.local_rank, non_blocking=True) + label_ = label_.cuda(self.param.local_rank, non_blocking=True) + prompt_dicts_ = [prompt_dicts_] + with torch.autocast("cuda", dtype=torch.bfloat16): + outputs, _ = model.module(frame_, spect_, prompt_dicts_, sam_process=False) + + logits_ = torch.cat([torch.cat(i['multistep_pred_multimasks_high_res']) for i in outputs]) + ious_scores = torch.cat([torch.cat(i['multistep_pred_ious']) for i in outputs]) + occ_scores = torch.cat([torch.cat(i['multistep_object_score_logits']) for i in outputs]) + if decode_mode == 'iou_select': + ious_scores = torch.argmax(ious_scores, dim=1) + logits_ = logits_[torch.arange(0, frame_.shape[0]), ious_scores, ...] + elif decode_mode == 'iou_occ_select': + ious_scores = torch.argmax(ious_scores, dim=1) + logits_ = logits_[torch.arange(0, frame_.shape[0]), ious_scores, ...] + logits_[occ_scores.squeeze() < 0, ...] = 0. + else: + logits_ = logits_[:, 0, ...] + logits.append(logits_) + labels.append(label_) + + logits = torch.cat(logits) + labels = torch.cat(labels) + foreground_iou_rank = self.metrics['foreground_iou'].calculate_iou( + (logits > 0.).squeeze().long(), labels.squeeze().long(), get_entire_list=True, + ) + foreground_f_score_rank = self.metrics['foreground_f-score'].calculate_f_score( + logits.squeeze(), labels.squeeze().long(), get_entire_list=True, + ) + torch.distributed.all_gather_object(iou_pool, foreground_iou_rank) + torch.distributed.all_gather_object(fscore_pool, foreground_f_score_rank) + foreground_iou = sum([i['foreground_iou'][0].cpu() for i in iou_pool]) / sum( + [i['foreground_iou'][1] for i in iou_pool]) + foreground_f_score = sum([i['foreground_f-score'][0] for i in fscore_pool]) / sum( + [i['foreground_f-score'][1] for i in fscore_pool]) + + if self.param.local_rank <= 0: + tbar.set_description( + 'epoch {} | valid.f_iou {}, valid.f_f-score {}'.format( + epoch, + numpy.round(foreground_iou.cpu().numpy(), 5), + numpy.round(foreground_f_score, 5), + ), + ) + torch.cuda.empty_cache() + + final_iou = foreground_iou + final_fscore = foreground_f_score + if self.param.local_rank <= 0 and self.tensorboard is not None: + self.tensorboard.upload_wandb_info({ + "valid.f_iou/{}".format(wandb_tag): final_iou, + "valid.f_f-score/{}".format(wandb_tag): final_fscore, + }) + + def _to_float(x): + if isinstance(x, torch.Tensor): + return float(x.detach().cpu().item()) + return float(x) + + return numpy.round(_to_float(final_iou), 5), numpy.round(_to_float(final_fscore), 5) + + def train(self, epoch, dataloader, model, optimiser): + if not isinstance(dataloader, DataLoader): + raise TypeError("train() expects a torch.utils.data.DataLoader (do not pass iter(dataloader) first).") + dataloader_length = len(dataloader) + tbar = range(dataloader_length) + tbar = tqdm(tbar, ncols=135) if self.param.local_rank <= 0 else tbar + + data_iter = iter(dataloader) + for batch_index in tbar: + current_index = dataloader_length * epoch + batch_index + items = next(data_iter) + frame, spect, label, prompt_dicts = ( + items['frame'], items['spectrogram'], items['label'], items['text'], + ) + frame = torch.flatten(frame, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) + spect = torch.flatten(spect, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) + label = torch.flatten(label, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) + + with torch.autocast("cuda", dtype=torch.bfloat16): + outputs, proj_feats = model(frame, spect, prompt_dicts, sam_process=False) + loss_dict = self.loss(outputs, label.unsqueeze(1)) + cl_loss = self.cl(proj_feats, outputs, label) + + optimiser.zero_grad() + (loss_dict['core_loss'] + cl_loss).backward() + optimiser.step() + + current_lr = self.param.lr * (1 - current_index / (dataloader_length * self.param.epochs)) ** 0.9 + for params_lr in optimiser.param_groups: + names = params_lr.get("name", []) + if names and any("vgg" in n for n in names): + params_lr['lr'] = current_lr * 0.1 + else: + params_lr['lr'] = current_lr + + if self.param.local_rank <= 0 and self.tensorboard is not None: + logits = torch.cat([i['multistep_pred_multimasks_high_res'][0] for i in outputs]) + foreground_iou = self.metrics['foreground_iou'].calculate_iou( + (logits > 0)[:, 0, ...].long(), label.long(), + ) + self.tensorboard.upload_wandb_info({ + "loss": loss_dict['core_loss'].item(), "f_iou": foreground_iou.item(), + "lr": optimiser.param_groups[0]['lr'], + "loss_dice": loss_dict['loss_dice'], + "loss_focal": loss_dict['loss_mask'], + "loss_contras": cl_loss.item(), + }) + tbar.set_description( + 'epoch {} | loss {}, f_iou {}'.format( + epoch, loss_dict['core_loss'].item(), foreground_iou.item(), + ), + ) + return diff --git a/ref-avs.code/utils/data_utils.py b/ref-avs.code/utils/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7e7a98f8ec73e6e5dafd1e395b48a98575e5afb1 --- /dev/null +++ b/ref-avs.code/utils/data_utils.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch + +from PIL import Image as PILImage + + +class BatchedVideoMetaData: + """ + This class represents metadata about a batch of videos. + Attributes: + unique_objects_identifier: A tensor of shape Bx3 containing unique identifiers for each object in the batch. Index consists of (video_id, obj_id, frame_id) + frame_orig_size: A tensor of shape Bx2 containing the original size of each frame in the batch. + """ + + unique_objects_identifier: torch.LongTensor + frame_orig_size: torch.LongTensor + + +class BatchedVideoDatapoint: + """ + This class represents a batch of videos with associated annotations and metadata. + Attributes: + img_batch: A [TxBxCxHxW] tensor containing the image data for each frame in the batch, where T is the number of frames per video, and B is the number of videos in the batch. + obj_to_frame_idx: A [TxOx2] tensor containing the image_batch index which the object belongs to. O is the number of objects in the batch. + masks: A [TxOxHxW] tensor containing binary masks for each object in the batch. + metadata: An instance of BatchedVideoMetaData containing metadata about the batch. + dict_key: A string key used to identify the batch. + """ + + img_batch: torch.FloatTensor + obj_to_frame_idx: torch.IntTensor + masks: torch.BoolTensor + metadata: BatchedVideoMetaData + + dict_key: str + + def pin_memory(self, device=None): + return self.apply(torch.Tensor.pin_memory, device=device) + + @property + def num_frames(self) -> int: + """ + Returns the number of frames per video. + """ + return self.batch_size[0] + + @property + def num_videos(self) -> int: + """ + Returns the number of videos in the batch. + """ + return self.img_batch.shape[1] + + @property + def flat_obj_to_img_idx(self) -> torch.IntTensor: + """ + Returns a flattened tensor containing the object to img index. + The flat index can be used to access a flattened img_batch of shape [(T*B)xCxHxW] + """ + frame_idx, video_idx = self.obj_to_frame_idx.unbind(dim=-1) + flat_idx = video_idx * self.num_frames + frame_idx + return flat_idx + + @property + def flat_img_batch(self) -> torch.FloatTensor: + """ + Returns a flattened img_batch_tensor of shape [(B*T)xCxHxW] + """ + + return self.img_batch.transpose(0, 1).flatten(0, 1) + + +@dataclass +class Object: + # Id of the object in the media + object_id: int + # Index of the frame in the media (0 if single image) + frame_index: int + segment: Union[torch.Tensor, dict] # RLE dict or binary mask + + +@dataclass +class Frame: + data: Union[torch.Tensor, PILImage.Image] + objects: List[Object] + + +@dataclass +class VideoDatapoint: + """Refers to an image/video and all its annotations""" + + frames: List[Frame] + video_id: int + size: Tuple[int, int] + + +def collate_fn( + batch: List[VideoDatapoint], + dict_key, +) -> BatchedVideoDatapoint: + """ + Args: + batch: A list of VideoDatapoint instances. + dict_key (str): A string key used to identify the batch. + """ + img_batch = [] + for video in batch: + img_batch += [torch.stack([frame.data for frame in video.frames], dim=0)] + + img_batch = torch.stack(img_batch, dim=0).permute((1, 0, 2, 3, 4)) + T = img_batch.shape[0] + # Prepare data structures for sequential processing. Per-frame processing but batched across videos. + step_t_objects_identifier = [[] for _ in range(T)] + step_t_frame_orig_size = [[] for _ in range(T)] + + step_t_masks = [[] for _ in range(T)] + step_t_obj_to_frame_idx = [ + [] for _ in range(T) + ] # List to store frame indices for each time step + + for video_idx, video in enumerate(batch): + orig_video_id = video.video_id + orig_frame_size = video.size + for t, frame in enumerate(video.frames): + objects = frame.objects + for obj in objects: + orig_obj_id = obj.object_id + orig_frame_idx = obj.frame_index + step_t_obj_to_frame_idx[t].append( + torch.tensor([t, video_idx], dtype=torch.int) + ) + step_t_masks[t].append(obj.segment.to(torch.bool)) + step_t_objects_identifier[t].append( + torch.tensor([orig_video_id, orig_obj_id, orig_frame_idx]) + ) + step_t_frame_orig_size[t].append(torch.tensor(orig_frame_size)) + + obj_to_frame_idx = torch.stack( + [ + torch.stack(obj_to_frame_idx, dim=0) + for obj_to_frame_idx in step_t_obj_to_frame_idx + ], + dim=0, + ) + masks = torch.stack([torch.stack(masks, dim=0) for masks in step_t_masks], dim=0) + objects_identifier = torch.stack( + [torch.stack(id, dim=0) for id in step_t_objects_identifier], dim=0 + ) + frame_orig_size = torch.stack( + [torch.stack(id, dim=0) for id in step_t_frame_orig_size], dim=0 + ) + return BatchedVideoDatapoint( + img_batch=img_batch, + obj_to_frame_idx=obj_to_frame_idx, + masks=masks, + metadata=BatchedVideoMetaData( + unique_objects_identifier=objects_identifier, + frame_orig_size=frame_orig_size, + ), + dict_key=dict_key, + batch_size=[T], + ) diff --git a/ref-avs.code/utils/foreground_fscore.py b/ref-avs.code/utils/foreground_fscore.py new file mode 100644 index 0000000000000000000000000000000000000000..fd53cb0f7e8efe63e9491747f9ecefed3051ed6e --- /dev/null +++ b/ref-avs.code/utils/foreground_fscore.py @@ -0,0 +1,90 @@ +import numpy +import torch + + +class AverageMeter: + def __init__(self, *keys): + self.__data = dict() + for k in keys: + self.__data[k] = [0.0, 0] + + def add(self, dict): + for k, v in dict.items(): + self.__data[k][0] += v + self.__data[k][1] += 1 + + def get(self, *keys): + if len(keys) == 1: + return self.__data[keys[0]][0] / self.__data[keys[0]][1] + else: + v_list = [self.__data[k][0] / self.__data[k][1] for k in keys] + return tuple(v_list) + + def get_entire_dict_for_ddp_calculation(self): + return self.__data + + def pop(self, key=None): + if key is None: + for k in self.__data.keys(): + self.__data[k] = [0.0, 0] + else: + v = self.get(key) + self.__data[key] = [0.0, 0] + return v + + +class ForegroundFScore(AverageMeter): + def __init__(self, rank): + self.local_rank = rank + super(ForegroundFScore, self).__init__('foreground_f-score') + + def _eval_pr(self, y_pred, y, num, cuda_flag=True): + if cuda_flag: + prec, recall = torch.zeros(num).cuda(self.local_rank), torch.zeros(num).cuda(self.local_rank) + thlist = torch.linspace(0, 1 - 1e-10, num).cuda(self.local_rank) + else: + prec, recall = torch.zeros(num), torch.zeros(num) + thlist = torch.linspace(0, 1 - 1e-10, num) + for i in range(num): + y_temp = (y_pred >= thlist[i]).float() + tp = (y_temp * y).sum() + prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() + 1e-20) + return prec, recall + + def calculate_f_score(self, pred, gt, pr_num=255, get_entire_list=False): + + r""" + param: + pred: size [N x H x W] + gt: size [N x H x W] + output: + iou: size [1] (size_average=True) or [N] (size_average=False) + """ + # print('=> eval [FMeasure]..') + pred = torch.sigmoid(pred) # =======================================[important] + N = pred.size(0) + beta2 = 0.3 + avg_f, img_num = 0.0, 0 + score = torch.zeros(pr_num) + # fLog = open(os.path.join(measure_path, 'FMeasure.txt'), 'w') + # print("{} videos in this batch".format(N)) + + for img_id in range(N): + # examples with totally black GTs are out of consideration + if torch.mean(gt[img_id].float()) == 0.0: + continue + prec, recall = self._eval_pr(pred[img_id], gt[img_id], pr_num) + f_score = (1 + beta2) * prec * recall / (beta2 * prec + recall) + f_score[f_score != f_score] = 0 # for Nan + avg_f += f_score + img_num += 1 + score = avg_f / img_num + # print('score: ', score) + # fLog.close() + self.add({'foreground_f-score': score.max().item()}) + return self.get('foreground_iou') if not get_entire_list else self.get_entire_dict_for_ddp_calculation() + + def reset(self,): + super(ForegroundFScore, self).__init__('foreground_f-score') + + diff --git a/ref-avs.code/utils/foreground_iou.py b/ref-avs.code/utils/foreground_iou.py new file mode 100644 index 0000000000000000000000000000000000000000..e01eeb081eee8ebfa1fcb6618d05b9d57c02f817 --- /dev/null +++ b/ref-avs.code/utils/foreground_iou.py @@ -0,0 +1,69 @@ +import numpy +import torch + + +class AverageMeter: + def __init__(self, *keys): + self.__data = dict() + for k in keys: + self.__data[k] = [0.0, 0] + + def add(self, dict): + for k, v in dict.items(): + self.__data[k][0] += v + self.__data[k][1] += 1 + + def get(self, *keys): + if len(keys) == 1: + return self.__data[keys[0]][0] / self.__data[keys[0]][1] + else: + v_list = [self.__data[k][0] / self.__data[k][1] for k in keys] + return tuple(v_list) + + def get_entire_dict_for_ddp_calculation(self): + return self.__data + + def pop(self, key=None): + if key is None: + for k in self.__data.keys(): + self.__data[k] = [0.0, 0] + else: + v = self.get(key) + self.__data[key] = [0.0, 0] + return v + + +class ForegroundIoU(AverageMeter): + def __init__(self): + super(ForegroundIoU, self).__init__('foreground_iou') + + def calculate_iou(self, pred, target, eps=1e-7, get_entire_list=False): + r""" + param (both hard mask): + pred: size [N x H x W], type: int + target: size [N x H x W], type: int + output: + iou: size [1] (size_average=True) or [N] (size_average=False) + """ + assert len(pred.shape) == 3 and pred.shape == target.shape, 'shape mismatch.' + assert pred.dtype is torch.long and target.dtype is torch.long, 'type mismatch.' + + N = pred.size(0) + num_pixels = pred.size(-1) * pred.size(-2) + no_obj_flag = (target.sum(2).sum(1) == 0) + + inter = (pred * target).sum(2).sum(1) + union = torch.max(pred, target).sum(2).sum(1) + + inter_no_obj = ((1 - target) * (1 - pred)).sum(2).sum(1) + inter[no_obj_flag] = inter_no_obj[no_obj_flag] + union[no_obj_flag] = num_pixels + + iou = torch.sum(inter / (union+eps)) / N + + self.add({'foreground_iou': iou}) + return self.get('foreground_iou') if not get_entire_list else self.get_entire_dict_for_ddp_calculation() + + def reset(self,): + super(ForegroundIoU, self).__init__('foreground_iou') + diff --git a/ref-avs.code/utils/foreground_s.py b/ref-avs.code/utils/foreground_s.py new file mode 100644 index 0000000000000000000000000000000000000000..44770e26da3a87455182315345df76ad8ecda897 --- /dev/null +++ b/ref-avs.code/utils/foreground_s.py @@ -0,0 +1,62 @@ +import numpy +import torch + + +class AverageMeter: + def __init__(self, *keys): + self.__data = dict() + for k in keys: + self.__data[k] = [0.0, 0] + + def add(self, dict): + for k, v in dict.items(): + self.__data[k][0] += v + self.__data[k][1] += 1 + + def get(self, *keys): + if len(keys) == 1: + return self.__data[keys[0]][0] / self.__data[keys[0]][1] + else: + v_list = [self.__data[k][0] / self.__data[k][1] for k in keys] + return tuple(v_list) + + def get_entire_dict_for_ddp_calculation(self): + return self.__data + + def pop(self, key=None): + if key is None: + for k in self.__data.keys(): + self.__data[k] = [0.0, 0] + else: + v = self.get(key) + self.__data[key] = [0.0, 0] + return v + + +class ForegroundS(AverageMeter): + def __init__(self): + super(ForegroundS, self).__init__('foreground_p', 'foreground_n') + + def metric_s_for_null(self, pred, get_entire_list=False): + NF, bsz, H, W = pred.shape + pred = pred.view(NF * bsz, H, W) + assert len(pred.shape) == 3 + + N = pred.size(0) + num_pixels = pred.view(-1).shape[0] + + temp_pred = torch.sigmoid(pred) + pred = (temp_pred > 0.5).int() + + x = torch.sum(pred.view(-1)) + s = torch.sqrt(x / num_pixels) + + self.add({'foreground_p': x}) + self.add({'foreground_n': num_pixels}) + # self.add({'foreground_s': s}) + return self.get('foreground_p')/self.get('foreground_n') if not get_entire_list else self.get_entire_dict_for_ddp_calculation() + + def reset(self, ): + super(ForegroundS, self).__init__('foreground_p', 'foreground_n') + + diff --git a/ref-avs.code/utils/iou.py b/ref-avs.code/utils/iou.py new file mode 100644 index 0000000000000000000000000000000000000000..211488b780887a8efd84361bafc6b09bfad4c345 --- /dev/null +++ b/ref-avs.code/utils/iou.py @@ -0,0 +1,76 @@ +import torch +import numpy + + +class BinaryMIoU(object): + def __init__(self, ignore_index): + self.num_classes = 2 + self.ignore_index = ignore_index + self.inter, self.union = 0, 0 + self.correct, self.label = 0, 0 + self.iou = numpy.array([0 for _ in range(self.num_classes)]) + self.acc = 0.0 + + def get_metric_results(self, curr_correct_, curr_label_, curr_inter_, curr_union_): + # calculates the overall miou and acc + self.correct = self.correct + curr_correct_ + self.label = self.label + curr_label_ + self.inter = self.inter + curr_inter_ + self.union = self.union + curr_union_ + self.acc = 1.0 * self.correct / (numpy.spacing(1) + self.label) + self.iou = 1.0 * self.inter / (numpy.spacing(1) + self.union) + return numpy.round(self.iou, 4), numpy.round(self.acc, 4) + # if class_list is None: + # return numpy.round(self.iou.mean().item(), 4), \ + # numpy.round(self.acc, 4) + # else: + # return numpy.round(self.iou[class_list].mean().item(), 4), \ + # numpy.round(self.acc, 4) + + @staticmethod + def get_current_image_results(curr_correct_, curr_label_, curr_inter_, curr_union_): + curr_acc = 1.0 * curr_correct_ / (numpy.spacing(1) + curr_label_) + curr_iou = 1.0 * curr_inter_ / (numpy.spacing(1) + curr_union_) + return curr_iou, curr_acc + + def __call__(self, x, y): + curr_correct, curr_label, curr_inter, curr_union = self.calculate_current_sample(x, y) + return (self.get_metric_results(curr_correct, curr_label, curr_inter, curr_union), + self.get_current_image_results(curr_correct, curr_label, curr_inter, curr_union)) + + def calculate_current_sample(self, output, target): + # output => BxCxHxW (logits) + # target => Bx1xHxW + target[target == self.ignore_index] = -1 + correct, labeled = self.batch_pix_accuracy(output.data, target) + inter, union = self.batch_intersection_union(output.data, target, self.num_classes) + return [numpy.round(correct, 5), numpy.round(labeled, 5), numpy.round(inter, 5), numpy.round(union, 5)] + + @ staticmethod + def batch_pix_accuracy(predict, target): + # _, predict = torch.max(output, 1) + + predict = predict.int() + 1 + target = target.int() + 1 + + pixel_labeled = (target > 0).sum() + pixel_correct = ((predict == target) * (target > 0)).sum() + assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled" + return pixel_correct.cpu().numpy(), pixel_labeled.cpu().numpy() + + @ staticmethod + def batch_intersection_union(predict, target, num_class): + # _, predict = torch.max(output, 1) + predict = predict + 1 + target = target + 1 + + predict = predict * (target > 0).long() + intersection = predict * (predict == target).long() + + area_inter = torch.histc(intersection.float(), bins=num_class, max=num_class, min=1) + area_pred = torch.histc(predict.float(), bins=num_class, max=num_class, min=1) + area_lab = torch.histc(target.float(), bins=num_class, max=num_class, min=1) + area_union = area_pred + area_lab - area_inter + assert (area_inter <= area_union).all(), "Intersection area should be smaller than Union area" + return area_inter.cpu().numpy(), area_union.cpu().numpy() + diff --git a/ref-avs.code/utils/misc.py b/ref-avs.code/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..5eb9d66c31a4b9209b81a5b615386d29f246135c --- /dev/null +++ b/ref-avs.code/utils/misc.py @@ -0,0 +1,350 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import warnings +from threading import Thread + +import numpy as np +import torch +from PIL import Image +from tqdm import tqdm + + +def get_sdpa_settings(): + if torch.cuda.is_available(): + old_gpu = torch.cuda.get_device_properties(0).major < 7 + # only use Flash Attention on Ampere (8.0) or newer GPUs + use_flash_attn = torch.cuda.get_device_properties(0).major >= 8 + if not use_flash_attn: + warnings.warn( + "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.", + category=UserWarning, + stacklevel=2, + ) + # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only + # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases) + pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2]) + if pytorch_version < (2, 2): + warnings.warn( + f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. " + "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).", + category=UserWarning, + stacklevel=2, + ) + math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn + else: + old_gpu = True + use_flash_attn = False + math_kernel_on = True + + return old_gpu, use_flash_attn, math_kernel_on + + +def get_connected_components(mask): + """ + Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W). + + Inputs: + - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is + background. + + Outputs: + - labels: A tensor of shape (N, 1, H, W) containing the connected component labels + for foreground pixels and 0 for background pixels. + - counts: A tensor of shape (N, 1, H, W) containing the area of the connected + components for foreground pixels and 0 for background pixels. + """ + from sam2 import _C + + return _C.get_connected_componnets(mask.to(torch.uint8).contiguous()) + + +def mask_to_box(masks: torch.Tensor): + """ + compute bounding box given an input mask + + Inputs: + - masks: [B, 1, H, W] masks, dtype=torch.Tensor + + Returns: + - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor + """ + B, _, h, w = masks.shape + device = masks.device + xs = torch.arange(w, device=device, dtype=torch.int32) + ys = torch.arange(h, device=device, dtype=torch.int32) + grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy") + grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w) + grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w) + min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1) + max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1) + min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1) + max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1) + bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1) + + return bbox_coords + + +def _load_img_as_tensor(img_path, image_size): + img_pil = Image.open(img_path) + img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) + if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images + img_np = img_np / 255.0 + else: + raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}") + img = torch.from_numpy(img_np).permute(2, 0, 1) + video_width, video_height = img_pil.size # the original video size + return img, video_height, video_width + + +class AsyncVideoFrameLoader: + """ + A list of video frames to be load asynchronously without blocking session start. + """ + + def __init__( + self, + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ): + self.img_paths = img_paths + self.image_size = image_size + self.offload_video_to_cpu = offload_video_to_cpu + self.img_mean = img_mean + self.img_std = img_std + # items in `self.images` will be loaded asynchronously + self.images = [None] * len(img_paths) + # catch and raise any exceptions in the async loading thread + self.exception = None + # video_height and video_width be filled when loading the first image + self.video_height = None + self.video_width = None + self.compute_device = compute_device + + # load the first frame to fill video_height and video_width and also + # to cache it (since it's most likely where the user will click) + self.__getitem__(0) + + # load the rest of frames asynchronously without blocking the session start + def _load_frames(): + try: + for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"): + self.__getitem__(n) + except Exception as e: + self.exception = e + + self.thread = Thread(target=_load_frames, daemon=True) + self.thread.start() + + def __getitem__(self, index): + if self.exception is not None: + raise RuntimeError("Failure in frame loading thread") from self.exception + + img = self.images[index] + if img is not None: + return img + + img, video_height, video_width = _load_img_as_tensor( + self.img_paths[index], self.image_size + ) + self.video_height = video_height + self.video_width = video_width + # normalize by mean and std + img -= self.img_mean + img /= self.img_std + if not self.offload_video_to_cpu: + img = img.to(self.compute_device, non_blocking=True) + self.images[index] = img + return img + + def __len__(self): + return len(self.images) + + +def load_video_frames( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + compute_device=torch.device("cuda"), +): + """ + Load the video frames from video_path. The frames are resized to image_size as in + the model and are loaded to GPU if offload_video_to_cpu=False. This is used by the demo. + """ + is_bytes = isinstance(video_path, bytes) + is_str = isinstance(video_path, str) + is_mp4_path = is_str and os.path.splitext(video_path)[-1] in [".mp4", ".MP4"] + if is_bytes or is_mp4_path: + return load_video_frames_from_video_file( + video_path=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + compute_device=compute_device, + ) + elif is_str and os.path.isdir(video_path): + return load_video_frames_from_jpg_images( + video_path=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + async_loading_frames=async_loading_frames, + compute_device=compute_device, + ) + else: + raise NotImplementedError( + "Only MP4 video and JPEG folder are supported at this moment" + ) + + +def load_video_frames_from_jpg_images( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + compute_device=torch.device("cuda"), +): + """ + Load the video frames from a directory of JPEG files (".jpg" format). + + The frames are resized to image_size x image_size and are loaded to GPU if + `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. + + You can load a frame asynchronously by setting `async_loading_frames` to `True`. + """ + if isinstance(video_path, str) and os.path.isdir(video_path): + jpg_folder = video_path + else: + raise NotImplementedError( + "Only JPEG frames are supported at this moment. For video files, you may use " + "ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n" + "```\n" + "ffmpeg -i .mp4 -q:v 2 -start_number 0 /'%05d.jpg'\n" + "```\n" + "where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks " + "ffmpeg to start the JPEG file from 00000.jpg." + ) + + frame_names = [ + p + for p in os.listdir(jpg_folder) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + ] + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + num_frames = len(frame_names) + if num_frames == 0: + raise RuntimeError(f"no images found in {jpg_folder}") + img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + + if async_loading_frames: + lazy_images = AsyncVideoFrameLoader( + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ) + return lazy_images, lazy_images.video_height, lazy_images.video_width + + images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) + for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): + images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) + if not offload_video_to_cpu: + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def load_video_frames_from_video_file( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + compute_device=torch.device("cuda"), +): + """Load the video frames from a video file.""" + import decord + + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + # Get the original video height and width + decord.bridge.set_bridge("torch") + video_height, video_width, _ = decord.VideoReader(video_path).next().shape + # Iterate over all frames in the video + images = [] + for frame in decord.VideoReader(video_path, width=image_size, height=image_size): + images.append(frame.permute(2, 0, 1)) + + images = torch.stack(images, dim=0).float() / 255.0 + if not offload_video_to_cpu: + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def fill_holes_in_mask_scores(mask, max_area): + """ + A post processor to fill small holes in mask scores with area under `max_area`. + """ + # Holes are those connected components in background with area <= self.max_area + # (background regions are those with mask scores <= 0) + assert max_area > 0, "max_area must be positive" + + input_mask = mask + try: + labels, areas = get_connected_components(mask <= 0) + is_hole = (labels > 0) & (areas <= max_area) + # We fill holes with a small positive mask score (0.1) to change them to foreground. + mask = torch.where(is_hole, 0.1, mask) + except Exception as e: + # Skip the post-processing step on removing small holes if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. You can " + "still use SAM 2 and it's OK to ignore the error above, although some post-processing " + "functionality may be limited (which doesn't affect the results in most cases; see " + "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + mask = input_mask + + return mask + + +def concat_points(old_point_inputs, new_points, new_labels): + """Add new points and labels to previous point inputs (add at the end).""" + if old_point_inputs is None: + points, labels = new_points, new_labels + else: + points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1) + labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1) + + return {"point_coords": points, "point_labels": labels} + diff --git a/ref-avs.code/utils/tensorboard.py b/ref-avs.code/utils/tensorboard.py new file mode 100644 index 0000000000000000000000000000000000000000..e896bde094ef9a11bb684e69b63964cf7abc7407 --- /dev/null +++ b/ref-avs.code/utils/tensorboard.py @@ -0,0 +1,55 @@ +"""Optional Weights & Biases logging for Ref-AVS training.""" +import os + +import torchvision +import wandb + + +class Tensorboard: + def __init__(self, config): + key = config.get('wandb_key') or os.environ.get('WANDB_API_KEY', '') + if key: + os.environ['WANDB_API_KEY'] = key + mode = 'online' if config.get('wandb_online', False) else 'disabled' + self.tensor_board = wandb.init( + project=config['proj_name'], + name=config['experiment_name'], + config=config, + mode=mode, + settings=wandb.Settings(code_dir='.'), + ) + self.restore_transform = torchvision.transforms.Compose([ + DeNormalize(config['image_mean'], config['image_std']), + torchvision.transforms.ToPILImage(), + ]) + + def upload_wandb_info(self, info_dict): + for key, value in info_dict.items(): + self.tensor_board.log({key: value}) + + def upload_wandb_image(self, frames, pseudo_label_from_pred, pseudo_label_from_sam, img_number=4): + n = min(pseudo_label_from_pred.shape[0], img_number) + frames = frames[:n] + pseudo_label_from_sam = pseudo_label_from_sam[:n].float() + pseudo_label_from_pred = pseudo_label_from_pred[:n].float() + pseudo_label_from_sam[pseudo_label_from_sam == 255.] = 0.5 + pseudo_label_from_pred[pseudo_label_from_pred == 255.] = 0.5 + self.tensor_board.log({ + 'image': [wandb.Image(j, caption=f'id {i}') for i, j in enumerate(frames)], + 'label': [wandb.Image(j.squeeze(), caption=f'id {i}') for i, j in enumerate(pseudo_label_from_sam)], + 'logits': [wandb.Image(j.squeeze(), caption=f'id {i}') for i, j in enumerate(pseudo_label_from_pred)], + }) + + def finish(self): + self.tensor_board.finish() + + +class DeNormalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, tensor): + for t, m, s in zip(tensor, self.mean, self.std): + t.mul_(s).add_(m) + return tensor diff --git a/ref-avs.code/utils/utils.py b/ref-avs.code/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ac864d0fa1d5a64cdc3b0876b4c8f23f1ce6fc57 --- /dev/null +++ b/ref-avs.code/utils/utils.py @@ -0,0 +1,73 @@ +import torch +import copy +from typing import List, Dict, Set, Any +import itertools + +def manipulate_params(cfg, model): + weight_decay_norm = 0 + weight_decay_embed = 0 + defaults = {} + defaults["lr"] = cfg.lr + defaults["weight_decay"] = cfg.weight_decay + + norm_module_types = ( + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + torch.nn.SyncBatchNorm, + # NaiveSyncBatchNorm inherits from BatchNorm2d + torch.nn.GroupNorm, + torch.nn.InstanceNorm1d, + torch.nn.InstanceNorm2d, + torch.nn.InstanceNorm3d, + torch.nn.LayerNorm, + torch.nn.LocalResponseNorm, + ) + + params_training: List[Dict[str, Any]] = [] + params_finetuning: List[Dict[str, Any]] = [] + memo: Set[torch.nn.parameter.Parameter] = set() + + train_prefixes = ( + "patch_embeds", + "f_blocks", + "a_blocks", + "fusion_modules", + "smooth_convs", + "train_proj_v1", + "train_proj_a1", + "text_proj", + ) + + for module_name, module in model.named_modules(): + for module_param_name, value in module.named_parameters(recurse=False): + if not value.requires_grad: + continue + if value in memo: + continue + memo.add(value) + hyperparams = copy.copy(defaults) + if 'vgg' in module_name or 'vgg' in module_param_name: + hyperparams['lr'] *= 0.1 + params_finetuning.append({"params": [value], "name": [module_name], **hyperparams}) + elif ( + 'train' in module_name + or 'train' in module_param_name + or module_name.startswith(train_prefixes) + ): + if ( + "relative_position_bias_table" in module_param_name + or "pos_embed" in module_param_name + ): + hyperparams["weight_decay"] = 0.0 + if isinstance(module, norm_module_types): + hyperparams["weight_decay"] = 0.0 + if isinstance(module, torch.nn.Embedding): + hyperparams["weight_decay"] = 0.0 + params_training.append({"params": [value], "name": [module_name], **hyperparams}) + else: + print('undefined layer type.') + raise NotImplementedError + final_list = params_training + params_finetuning + assert len([p for p in model.parameters() if p.requires_grad]) == len(final_list), 'checksum confirmed not pass.' + return final_list \ No newline at end of file diff --git a/scripts/run_avs_train.sh b/scripts/run_avs_train.sh new file mode 100644 index 0000000000000000000000000000000000000000..a0478eb76f8f21aed2c7f68a8ce30bca320ff4a1 --- /dev/null +++ b/scripts/run_avs_train.sh @@ -0,0 +1,95 @@ +#!/bin/bash + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" +CODE_BASE="${REPO_ROOT}/avs.code" +cd "${SCRIPT_DIR}" + +DEFAULT_GPUS=4 +OMP_THREADS=8 + +# Reference hyper-parameter table (for quick view) +EPOCH_V1S=140 +EPOCH_V1M=140 +EPOCH_V2=90 + +WEIGHT_V1S=3.0 +WEIGHT_V1M=3.0 +WEIGHT_V2=3.0 + +print_table() { + echo "+-------------+------------+------------+------------+" + echo "| hyper-param | v1s | v1m | v2 |" + echo "+-------------+------------+------------+------------+" + printf "| %-11s | %-10s | %-10s | %-10s |\n" "epoch" "${EPOCH_V1S}" "${EPOCH_V1M}" "${EPOCH_V2}" + printf "| %-11s | %-10s | %-10s | %-10s |\n" "weight" "${WEIGHT_V1S}" "${WEIGHT_V1M}" "${WEIGHT_V2}" + printf "| %-11s | %-10s | %-10s | %-10s |\n" "gpus(def)" "${DEFAULT_GPUS}" "${DEFAULT_GPUS}" "${DEFAULT_GPUS}" + echo "+-------------+------------+------------+------------+" +} + +usage() { + echo "Usage: $0 [gpus]" + echo "Example: $0 v1s" + echo "Example: $0 v2 8" +} + +if [[ $# -lt 1 || $# -gt 2 ]]; then + usage + print_table + exit 1 +fi + +DATASET="$1" +GPUS="${2:-${DEFAULT_GPUS}}" + +case "${DATASET}" in + v1s) + CODE_DIR="v1s.code" + EPOCHS="${EPOCH_V1S}" + ;; + v1m) + CODE_DIR="v1m.code" + EPOCHS="${EPOCH_V1M}" + ;; + v2) + CODE_DIR="v2.code" + EPOCHS="${EPOCH_V2}" + ;; + *) + echo "Error: dataset must be one of v1s / v1m / v2, got: ${DATASET}" + echo + print_table + exit 1 + ;; +esac + +if ! [[ "${GPUS}" =~ ^[0-9]+$ ]] || [[ "${GPUS}" -le 0 ]]; then + echo "Error: gpus must be a positive integer, got: ${GPUS}" + exit 1 +fi + +if [[ ! -f "${CODE_BASE}/${CODE_DIR}/main.py" ]]; then + echo "Error: training entry not found: ${CODE_BASE}/${CODE_DIR}/main.py" + exit 1 +fi + +export OMP_NUM_THREADS="${OMP_THREADS}" + +LOG_FILE="train_${DATASET}.log" +CMD=(python3 "${CODE_BASE}/${CODE_DIR}/main.py" --epochs="${EPOCHS}" --gpus="${GPUS}") + +echo "Training job is about to start:" +echo " dataset: ${DATASET}" +echo " code: ${CODE_BASE}/${CODE_DIR}/main.py" +echo " epochs: ${EPOCHS}" +echo " gpus: ${GPUS}" +echo " log: ${SCRIPT_DIR}/${LOG_FILE}" +echo +print_table +echo +echo "Command: nohup ${CMD[*]} > ${LOG_FILE} 2>&1 &" + +nohup "${CMD[@]}" > "${LOG_FILE}" 2>&1 & +echo "Training started in background, PID: $!" diff --git a/scripts/run_ref_train.sh b/scripts/run_ref_train.sh new file mode 100644 index 0000000000000000000000000000000000000000..6973b2a97688f1c1820e1fdbe7d5aa171e3b5f49 --- /dev/null +++ b/scripts/run_ref_train.sh @@ -0,0 +1,72 @@ +#!/bin/bash + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" +CODE_DIR="${REPO_ROOT}/ref-avs.code" +cd "${SCRIPT_DIR}" + +DEFAULT_GPUS=4 +DEFAULT_EPOCHS=50 +DEFAULT_LR=1e-4 +OMP_THREADS=8 + +print_table() { + echo "+-------------+----------------+" + echo "| hyper-param | ref-avs |" + echo "+-------------+----------------+" + printf "| %-11s | %-14s |\n" "epoch" "${DEFAULT_EPOCHS}" + printf "| %-11s | %-14s |\n" "lr" "${DEFAULT_LR}" + printf "| %-11s | %-14s |\n" "gpus(def)" "${DEFAULT_GPUS}" + echo "+-------------+----------------+" +} + +usage() { + echo "Usage: $0 [gpus]" + echo "Example: $0" + echo "Example: $0 8" +} + +if [[ $# -gt 1 ]]; then + usage + print_table + exit 1 +fi + +GPUS="${1:-${DEFAULT_GPUS}}" + +if ! [[ "${GPUS}" =~ ^[0-9]+$ ]] || [[ "${GPUS}" -le 0 ]]; then + echo "Error: gpus must be a positive integer, got: ${GPUS}" + exit 1 +fi + +if [[ ! -f "${CODE_DIR}/main.py" ]]; then + echo "Error: training entry not found: ${CODE_DIR}/main.py" + exit 1 +fi + +export OMP_NUM_THREADS="${OMP_THREADS}" + +LOG_FILE="train_ref_avs.log" +CMD=( + python3 "${CODE_DIR}/main.py" + --epochs="${DEFAULT_EPOCHS}" + --gpus="${GPUS}" + --lr="${DEFAULT_LR}" +) + +echo "Training job is about to start:" +echo " dataset: ref-avs (REFAVS)" +echo " code: ${CODE_DIR}/main.py" +echo " epochs: ${DEFAULT_EPOCHS}" +echo " lr: ${DEFAULT_LR}" +echo " gpus: ${GPUS}" +echo " log: ${SCRIPT_DIR}/${LOG_FILE}" +echo +print_table +echo +echo "Command: nohup ${CMD[*]} > ${LOG_FILE} 2>&1 &" + +nohup "${CMD[@]}" > "${LOG_FILE}" 2>&1 & +echo "Training started in background, PID: $!"