yyliu01 commited on
Commit
c6dfc69
·
verified ·
1 Parent(s): 5b9ba72

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. LICENSE +21 -0
  3. README.md +26 -3
  4. avs.code/v1m.code/configs/__init__.py +0 -0
  5. avs.code/v1m.code/configs/auralfuser/architecture.yaml +30 -0
  6. avs.code/v1m.code/configs/config.py +85 -0
  7. avs.code/v1m.code/configs/sam2/sam2_hiera_b+.yaml +114 -0
  8. avs.code/v1m.code/configs/sam2/sam2_hiera_l.yaml +117 -0
  9. avs.code/v1m.code/configs/sam2/sam2_hiera_s.yaml +116 -0
  10. avs.code/v1m.code/configs/sam2/sam2_hiera_t.yaml +118 -0
  11. avs.code/v1m.code/configs/training/sam2_training_config.yaml +62 -0
  12. avs.code/v1m.code/dataloader/audio/audio_augmentation.py +23 -0
  13. avs.code/v1m.code/dataloader/audio/audio_dataset.py +38 -0
  14. avs.code/v1m.code/dataloader/audio/preprocess_vgg/mel_features.py +223 -0
  15. avs.code/v1m.code/dataloader/audio/preprocess_vgg/vggish_input.py +98 -0
  16. avs.code/v1m.code/dataloader/audio/preprocess_vgg/vggish_params.py +53 -0
  17. avs.code/v1m.code/dataloader/dataset.py +67 -0
  18. avs.code/v1m.code/dataloader/sam2_dataset/__init__.py +5 -0
  19. avs.code/v1m.code/dataloader/sam2_dataset/transforms.py +528 -0
  20. avs.code/v1m.code/dataloader/visual/visual_augmentation.py +140 -0
  21. avs.code/v1m.code/dataloader/visual/visual_dataset.py +127 -0
  22. avs.code/v1m.code/inference.py +193 -0
  23. avs.code/v1m.code/loss/training/__init__.py +2 -0
  24. avs.code/v1m.code/loss/training/contrastive_learning.py +201 -0
  25. avs.code/v1m.code/loss/training/sam2_training_loss.py +220 -0
  26. avs.code/v1m.code/main.py +166 -0
  27. avs.code/v1m.code/model/audio/torchvggish/mel_features.py +223 -0
  28. avs.code/v1m.code/model/audio/torchvggish/vggish.py +193 -0
  29. avs.code/v1m.code/model/audio/torchvggish/vggish_input.py +98 -0
  30. avs.code/v1m.code/model/audio/torchvggish/vggish_params.py +53 -0
  31. avs.code/v1m.code/model/aural_fuser.py +567 -0
  32. avs.code/v1m.code/model/mymodel.py +102 -0
  33. avs.code/v1m.code/model/visual/sam2/__init__.py +11 -0
  34. avs.code/v1m.code/model/visual/sam2/build_sam.py +171 -0
  35. avs.code/v1m.code/model/visual/sam2/modeling/__init__.py +5 -0
  36. avs.code/v1m.code/model/visual/sam2/modeling/backbones/__init__.py +5 -0
  37. avs.code/v1m.code/model/visual/sam2/modeling/backbones/hieradet.py +317 -0
  38. avs.code/v1m.code/model/visual/sam2/modeling/backbones/image_encoder.py +134 -0
  39. avs.code/v1m.code/model/visual/sam2/modeling/backbones/utils.py +95 -0
  40. avs.code/v1m.code/model/visual/sam2/modeling/memory_attention.py +169 -0
  41. avs.code/v1m.code/model/visual/sam2/modeling/memory_encoder.py +181 -0
  42. avs.code/v1m.code/model/visual/sam2/modeling/position_encoding.py +221 -0
  43. avs.code/v1m.code/model/visual/sam2/modeling/sam/__init__.py +5 -0
  44. avs.code/v1m.code/model/visual/sam2/modeling/sam/mask_decoder.py +300 -0
  45. avs.code/v1m.code/model/visual/sam2/modeling/sam/prompt_encoder.py +188 -0
  46. avs.code/v1m.code/model/visual/sam2/modeling/sam/transformer.py +367 -0
  47. avs.code/v1m.code/model/visual/sam2/modeling/sam2_base.py +940 -0
  48. avs.code/v1m.code/model/visual/sam2/modeling/sam2_utils.py +323 -0
  49. avs.code/v1m.code/model/visual/sam2/organised_sam2_train.py +811 -0
  50. avs.code/v1m.code/model/visual/sam2/utils/__init__.py +5 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ ckpts/avs/v1s/nohup.out filter=lfs diff=lfs merge=lfs -text
37
+ ckpts/avs/v2/nohup.out filter=lfs diff=lfs merge=lfs -text
38
+ ckpts/ref-avs/nohup.out filter=lfs diff=lfs merge=lfs -text
39
+ docs/overview.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Yuyuan Liu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,26 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AuralSAM2
2
+ > **[CVPRF'26]** [AuralSAM2: Enabling SAM2 Hear
3
+ Through Pyramid Audio-Visual Feature Prompting](#)
4
+ >
5
+ > by Yuyuan Liu, Yuanhong Chen, Chong Wang, Junlin Han, Junde Wu, Can Peng, Jingkun Chen, Yu Tian and Gustavo Carneiro
6
+ >
7
+ <img src="./docs/overview.png" width="850" height="300" />
8
+
9
+ ## Installation
10
+ please install the dependencies and dataset based on this [***installation***](./docs/installation.md) document.
11
+
12
+ ## Getting start
13
+ please follow this [***instruction***](./docs/before_start.md) document to reproduce our results.
14
+
15
+ ## Citation
16
+ please consider citing our work in your publications if it helps your research.
17
+
18
+ ```bibtex
19
+ @article{liu2025auralsam2,
20
+ title={AuralSAM2: Enabling SAM2 Hear Through Pyramid Audio-Visual Feature Prompting},
21
+ author={Liu, Yuyuan and Chen, Yuanhong and Wang, Chong and Han, Junlin and Wu, Junde and Peng, Can and Chen, Jingkun and Tian, Yu and Carneiro, Gustavo},
22
+ journal={arXiv preprint arXiv:2506.01015},
23
+ year={2025}
24
+ }
25
+ ```
26
+
avs.code/v1m.code/configs/__init__.py ADDED
File without changes
avs.code/v1m.code/configs/auralfuser/architecture.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ aural_fuser:
4
+ patch_cfgs:
5
+ - [4, 4]
6
+ - [2, 2]
7
+ - [1, 1]
8
+ f_depths: [3, 6, 12]
9
+ block_kw:
10
+ dim: 256
11
+ num_heads: 4
12
+ mlp_ratio: 4
13
+ qkv_bias: true
14
+ qk_scale: null
15
+ drop: 0.1
16
+ attn_drop: 0.1
17
+ drop_path: 0.0
18
+ sr_ratio: 4
19
+ linear: false
20
+ one_d_kw:
21
+ dim: 256
22
+ num_heads: 4
23
+ mlp_ratio: 4
24
+ qkv_bias: true
25
+ qk_scale: null
26
+ drop: 0.1
27
+ attn_drop: 0.1
28
+ drop_path: 0.0
29
+ sr_ratio: 4
30
+ linear: false
avs.code/v1m.code/configs/config.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy
3
+ from easydict import EasyDict
4
+
5
+ # v1m.code package root (parent of this `configs/` directory)
6
+ _CODE_ROOT = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
7
+ _WORKSPACE_ROOT = os.path.dirname(os.path.dirname(_CODE_ROOT))
8
+
9
+ C = EasyDict()
10
+ config = C
11
+ cfg = C
12
+
13
+ C.seed = 666
14
+
15
+ C.audio = EasyDict()
16
+ C.audio.FREEZE_AUDIO_EXTRACTOR = True
17
+ C.audio.PRETRAINED_VGGISH_MODEL_PATH = os.path.join(_WORKSPACE_ROOT, 'ckpts', 'vggish-10086976.pth')
18
+ C.audio.PREPROCESS_AUDIO_TO_LOG_MEL = False
19
+ C.audio.POSTPROCESS_LOG_MEL_WITH_PCA = False
20
+ C.train_vggish = False
21
+
22
+ """Root Directory Config"""
23
+ C.repo_name = 'AV'
24
+ C.root_dir = _CODE_ROOT
25
+
26
+ """Data Dir and Weight Dir"""
27
+ C.data_root_path = os.path.join(_WORKSPACE_ROOT, 'AVSBench')
28
+ C.data_name = 'v1m'
29
+
30
+ C.backbone_weight = os.path.join(_WORKSPACE_ROOT, 'ckpts', 'sam_ckpts', 'sam2_hiera_large.pt')
31
+ C.sam_config_path = os.path.join('sam2', 'sam2_hiera_l.yaml')
32
+
33
+ """Network Config"""
34
+ C.fix_bias = True
35
+ C.bn_eps = 1e-5
36
+ C.bn_momentum = 0.1
37
+
38
+ """Image Config"""
39
+ C.num_classes = 2
40
+
41
+ C.image_mean = numpy.array([0.485, 0.456, 0.406])
42
+ C.image_std = numpy.array([0.229, 0.224, 0.225])
43
+
44
+
45
+ C.image_size = 1024
46
+ C.image_embedding_size = int(C.image_size / 16)
47
+ C.avsbench_size = (224, 224)
48
+
49
+ C.scale_list = [.5, .75, 1., 1.25, 1.5]
50
+ C.ignore_index = 255
51
+
52
+ """Train Config"""
53
+ C.lr = 7.5e-5
54
+ C.batch_size = 8
55
+ C.energy_weight = .05
56
+
57
+ C.lr_power = 0.9
58
+ C.momentum = 0.9
59
+ C.weight_decay = 0.05
60
+
61
+ C.num_workers = 4
62
+
63
+ """Display Config"""
64
+ C.record_info_iter = 20
65
+ C.display_iter = 50
66
+
67
+ """Wandb Config"""
68
+ # Paste your W&B API key here, or set the WANDB_API_KEY environment variable instead.
69
+ C.wandb_key = ""
70
+
71
+ # Your project [work_space] name
72
+ C.proj_name = "AVS-final-report"
73
+
74
+ C.experiment_name = "v1s-hiera-l"
75
+
76
+
77
+ # False = no wandb logging (see utils/tensorboard.py)
78
+ C.wandb_online = False
79
+
80
+ """Save Config"""
81
+ C.saved_dir = os.path.join(_WORKSPACE_ROOT, 'ckpts', C.experiment_name)
82
+
83
+ import pathlib
84
+
85
+ pathlib.Path(C.saved_dir).mkdir(parents=True, exist_ok=True)
avs.code/v1m.code/configs/sam2/sam2_hiera_b+.yaml ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: model.visual.sam2.organised_sam2_train.SAM2Train
6
+ image_encoder:
7
+ _target_: model.visual.sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: model.visual.sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 112
12
+ num_heads: 2
13
+ neck:
14
+ _target_: model.visual.sam2.modeling.backbones.image_encoder.FpnNeck
15
+ position_encoding:
16
+ _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine
17
+ num_pos_feats: 256
18
+ normalize: true
19
+ scale: null
20
+ temperature: 10000
21
+ d_model: 256
22
+ backbone_channel_list: [896, 448, 224, 112]
23
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
24
+ fpn_interp_model: nearest
25
+
26
+ memory_attention:
27
+ _target_: model.visual.sam2.modeling.memory_attention.MemoryAttention
28
+ d_model: 256
29
+ pos_enc_at_input: true
30
+ layer:
31
+ _target_: model.visual.sam2.modeling.memory_attention.MemoryAttentionLayer
32
+ activation: relu
33
+ dim_feedforward: 2048
34
+ dropout: 0.1
35
+ pos_enc_at_attn: false
36
+ self_attention:
37
+ _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention
38
+ rope_theta: 10000.0
39
+ feat_sizes: [32, 32]
40
+ embedding_dim: 256
41
+ num_heads: 1
42
+ downsample_rate: 1
43
+ dropout: 0.1
44
+ d_model: 256
45
+ pos_enc_at_cross_attn_keys: true
46
+ pos_enc_at_cross_attn_queries: false
47
+ cross_attention:
48
+ _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention
49
+ rope_theta: 10000.0
50
+ feat_sizes: [32, 32]
51
+ rope_k_repeat: True
52
+ embedding_dim: 256
53
+ num_heads: 1
54
+ downsample_rate: 1
55
+ dropout: 0.1
56
+ kv_in_dim: 64
57
+ num_layers: 4
58
+
59
+ memory_encoder:
60
+ _target_: model.visual.sam2.modeling.memory_encoder.MemoryEncoder
61
+ out_dim: 64
62
+ position_encoding:
63
+ _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine
64
+ num_pos_feats: 64
65
+ normalize: true
66
+ scale: null
67
+ temperature: 10000
68
+ mask_downsampler:
69
+ _target_: model.visual.sam2.modeling.memory_encoder.MaskDownSampler
70
+ kernel_size: 3
71
+ stride: 2
72
+ padding: 1
73
+ fuser:
74
+ _target_: model.visual.sam2.modeling.memory_encoder.Fuser
75
+ layer:
76
+ _target_: model.visual.sam2.modeling.memory_encoder.CXBlock
77
+ dim: 256
78
+ kernel_size: 7
79
+ padding: 3
80
+ layer_scale_init_value: 1e-6
81
+ use_dwconv: True # depth-wise convs
82
+ num_layers: 2
83
+
84
+ num_maskmem: 7
85
+ image_size: 1024
86
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
87
+ sigmoid_scale_for_mem_enc: 20.0
88
+ sigmoid_bias_for_mem_enc: -10.0
89
+ use_mask_input_as_output_without_sam: true
90
+ # Memory
91
+ directly_add_no_mem_embed: true
92
+ # use high-resolution feature map in the SAM mask decoder
93
+ use_high_res_features_in_sam: true
94
+ # output 3 masks on the first click on initial conditioning frames
95
+ multimask_output_in_sam: true
96
+ # SAM heads
97
+ iou_prediction_use_sigmoid: True
98
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
99
+ use_obj_ptrs_in_encoder: true
100
+ add_tpos_enc_to_obj_ptrs: false
101
+ only_obj_ptrs_in_the_past_for_eval: true
102
+ # object occlusion prediction
103
+ pred_obj_scores: true
104
+ pred_obj_scores_mlp: true
105
+ fixed_no_obj_ptr: true
106
+ # multimask tracking settings
107
+ multimask_output_for_tracking: true
108
+ use_multimask_token_for_obj_ptr: true
109
+ multimask_min_pt_num: 0
110
+ multimask_max_pt_num: 1
111
+ use_mlp_for_obj_ptr_proj: true
112
+ # Compilation flag
113
+ compile_image_encoder: False
114
+
avs.code/v1m.code/configs/sam2/sam2_hiera_l.yaml ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: model.visual.sam2.organised_sam2_train.SAM2Train
6
+ image_encoder:
7
+ _target_: model.visual.sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: model.visual.sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 144
12
+ num_heads: 2
13
+ stages: [2, 6, 36, 4]
14
+ global_att_blocks: [23, 33, 43]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ window_spec: [8, 4, 16, 8]
17
+ neck:
18
+ _target_: model.visual.sam2.modeling.backbones.image_encoder.FpnNeck
19
+ position_encoding:
20
+ _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine
21
+ num_pos_feats: 256
22
+ normalize: true
23
+ scale: null
24
+ temperature: 10000
25
+ d_model: 256
26
+ backbone_channel_list: [1152, 576, 288, 144]
27
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
28
+ fpn_interp_model: nearest
29
+
30
+ memory_attention:
31
+ _target_: model.visual.sam2.modeling.memory_attention.MemoryAttention
32
+ d_model: 256
33
+ pos_enc_at_input: true
34
+ layer:
35
+ _target_: model.visual.sam2.modeling.memory_attention.MemoryAttentionLayer
36
+ activation: relu
37
+ dim_feedforward: 2048
38
+ dropout: 0.1
39
+ pos_enc_at_attn: false
40
+ self_attention:
41
+ _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention
42
+ rope_theta: 10000.0
43
+ feat_sizes: [32, 32]
44
+ embedding_dim: 256
45
+ num_heads: 1
46
+ downsample_rate: 1
47
+ dropout: 0.1
48
+ d_model: 256
49
+ pos_enc_at_cross_attn_keys: true
50
+ pos_enc_at_cross_attn_queries: false
51
+ cross_attention:
52
+ _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention
53
+ rope_theta: 10000.0
54
+ feat_sizes: [32, 32]
55
+ rope_k_repeat: True
56
+ embedding_dim: 256
57
+ num_heads: 1
58
+ downsample_rate: 1
59
+ dropout: 0.1
60
+ kv_in_dim: 64
61
+ num_layers: 4
62
+
63
+ memory_encoder:
64
+ _target_: model.visual.sam2.modeling.memory_encoder.MemoryEncoder
65
+ out_dim: 64
66
+ position_encoding:
67
+ _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine
68
+ num_pos_feats: 64
69
+ normalize: true
70
+ scale: null
71
+ temperature: 10000
72
+ mask_downsampler:
73
+ _target_: model.visual.sam2.modeling.memory_encoder.MaskDownSampler
74
+ kernel_size: 3
75
+ stride: 2
76
+ padding: 1
77
+ fuser:
78
+ _target_: model.visual.sam2.modeling.memory_encoder.Fuser
79
+ layer:
80
+ _target_: model.visual.sam2.modeling.memory_encoder.CXBlock
81
+ dim: 256
82
+ kernel_size: 7
83
+ padding: 3
84
+ layer_scale_init_value: 1e-6
85
+ use_dwconv: True # depth-wise convs
86
+ num_layers: 2
87
+
88
+ num_maskmem: 7
89
+ image_size: 1024
90
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ # use high-resolution feature map in the SAM mask decoder
97
+ use_high_res_features_in_sam: true
98
+ # output 3 masks on the first click on initial conditioning frames
99
+ multimask_output_in_sam: true
100
+ # SAM heads
101
+ iou_prediction_use_sigmoid: True
102
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103
+ use_obj_ptrs_in_encoder: true
104
+ add_tpos_enc_to_obj_ptrs: false
105
+ only_obj_ptrs_in_the_past_for_eval: true
106
+ # object occlusion prediction
107
+ pred_obj_scores: true
108
+ pred_obj_scores_mlp: true
109
+ fixed_no_obj_ptr: true
110
+ # multimask tracking settings
111
+ multimask_output_for_tracking: true
112
+ use_multimask_token_for_obj_ptr: true
113
+ multimask_min_pt_num: 0
114
+ multimask_max_pt_num: 1
115
+ use_mlp_for_obj_ptr_proj: true
116
+ # Compilation flag
117
+ compile_image_encoder: False
avs.code/v1m.code/configs/sam2/sam2_hiera_s.yaml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 11, 2]
14
+ global_att_blocks: [7, 10, 13]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [32, 32]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [32, 32]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ sigmoid_scale_for_mem_enc: 20.0
91
+ sigmoid_bias_for_mem_enc: -10.0
92
+ use_mask_input_as_output_without_sam: true
93
+ # Memory
94
+ directly_add_no_mem_embed: true
95
+ # use high-resolution feature map in the SAM mask decoder
96
+ use_high_res_features_in_sam: true
97
+ # output 3 masks on the first click on initial conditioning frames
98
+ multimask_output_in_sam: true
99
+ # SAM heads
100
+ iou_prediction_use_sigmoid: True
101
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
102
+ use_obj_ptrs_in_encoder: true
103
+ add_tpos_enc_to_obj_ptrs: false
104
+ only_obj_ptrs_in_the_past_for_eval: true
105
+ # object occlusion prediction
106
+ pred_obj_scores: true
107
+ pred_obj_scores_mlp: true
108
+ fixed_no_obj_ptr: true
109
+ # multimask tracking settings
110
+ multimask_output_for_tracking: true
111
+ use_multimask_token_for_obj_ptr: true
112
+ multimask_min_pt_num: 0
113
+ multimask_max_pt_num: 1
114
+ use_mlp_for_obj_ptr_proj: true
115
+ # Compilation flag
116
+ compile_image_encoder: False
avs.code/v1m.code/configs/sam2/sam2_hiera_t.yaml ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: model.visual.sam2.organised_sam2_train.SAM2Train
6
+ image_encoder:
7
+ _target_: model.visual.sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: model.visual.sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 7, 2]
14
+ global_att_blocks: [5, 7, 9]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: model.visual.sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: model.visual.sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: model.visual.sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [32, 32]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [32, 32]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: model.visual.sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: model.visual.sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: model.visual.sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: model.visual.sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 224 # 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ # SAM decoder
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ # use high-resolution feature map in the SAM mask decoder
97
+ use_high_res_features_in_sam: true
98
+ # output 3 masks on the first click on initial conditioning frames
99
+ multimask_output_in_sam: true
100
+ # SAM heads
101
+ iou_prediction_use_sigmoid: True
102
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103
+ use_obj_ptrs_in_encoder: true
104
+ add_tpos_enc_to_obj_ptrs: false
105
+ only_obj_ptrs_in_the_past_for_eval: false
106
+ # object occlusion prediction
107
+ pred_obj_scores: true
108
+ pred_obj_scores_mlp: true
109
+ fixed_no_obj_ptr: true
110
+ # multimask tracking settings
111
+ multimask_output_for_tracking: true
112
+ use_multimask_token_for_obj_ptr: true
113
+ multimask_min_pt_num: 0
114
+ multimask_max_pt_num: 1
115
+ use_mlp_for_obj_ptr_proj: true
116
+ # Compilation flag
117
+ # HieraT does not currently support compilation, should always be set to False
118
+ compile_image_encoder: False
avs.code/v1m.code/configs/training/sam2_training_config.yaml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Video transforms
4
+
5
+ train_transforms:
6
+ - _target_: dataloader.sam2_dataset.transforms.ComposeAPI
7
+ transforms:
8
+ - _target_: dataloader.sam2_dataset.transforms.RandomHorizontalFlip
9
+ consistent_transform: True
10
+ - _target_: dataloader.sam2_dataset.transforms.RandomAffine
11
+ degrees: 25
12
+ shear: 20
13
+ image_interpolation: bilinear
14
+ consistent_transform: True
15
+ - _target_: dataloader.sam2_dataset.transforms.RandomResizeAPI
16
+ sizes: 1024 # ${scratch.resolution}
17
+ square: true
18
+ consistent_transform: True
19
+ - _target_: dataloader.sam2_dataset.transforms.ColorJitter
20
+ consistent_transform: True
21
+ brightness: 0.1
22
+ contrast: 0.03
23
+ saturation: 0.03
24
+ hue: null
25
+ - _target_: dataloader.sam2_dataset.transforms.RandomGrayscale
26
+ p: 0.05
27
+ consistent_transform: True
28
+ - _target_: dataloader.sam2_dataset.transforms.ColorJitter
29
+ consistent_transform: False
30
+ brightness: 0.1
31
+ contrast: 0.05
32
+ saturation: 0.05
33
+ hue: null
34
+ - _target_: dataloader.sam2_dataset.transforms.ToTensorAPI
35
+ - _target_: dataloader.sam2_dataset.transforms.NormalizeAPI
36
+ mean: [0.485, 0.456, 0.406]
37
+ std: [0.229, 0.224, 0.225]
38
+
39
+ loss:
40
+ all:
41
+ _target_: loss.training.sam2_training_loss.MultiStepMultiMasksAndIous
42
+ weight_dict:
43
+ loss_mask: 20 # 20
44
+ loss_dice: 1
45
+ loss_iou: 1
46
+ loss_class: 1
47
+ supervise_all_iou: true
48
+ iou_use_l1_loss: true
49
+ pred_obj_scores: true
50
+ focal_gamma_obj_score: 0.0
51
+ focal_alpha_obj_score: -1.0
52
+ gpu_num: 4.
53
+
54
+ # Contrastive loss (ContrastLoss); loaded in main.py / inference.py → hyp_param.contrastive_learning
55
+ contrastive_learning:
56
+ temperature: 0.10
57
+ ignore_idx: 255
58
+ ood_idx: 254
59
+ max_views: 512
60
+ proj_dim: 512
61
+ sample_limits: 128
62
+ total_limits: 15240
avs.code/v1m.code/dataloader/audio/audio_augmentation.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy
2
+
3
+
4
+ class Augmentation(object):
5
+ """Audio pre-step used by training/inference: int16 waveform -> float in [-1, 1].
6
+
7
+ The previous audiomentations-based transforms were commented out and never applied;
8
+ behavior is unchanged: only scaling by 1/32768.
9
+ """
10
+
11
+ def __init__(self, mono=True):
12
+ self.mono = mono
13
+
14
+ def train_aug(self, x_, sr_):
15
+ x_ = x_ / 32768.0
16
+ return x_
17
+
18
+ def test_process(self, x_):
19
+ x_ = x_ / 32768.0
20
+ return x_
21
+
22
+ def __call__(self, x, sr, split):
23
+ return self.train_aug(x, sr) if split == "train" else self.test_process(x)
avs.code/v1m.code/dataloader/audio/audio_dataset.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy
3
+ import os
4
+ from dataloader.audio.preprocess_vgg.vggish_input import waveform_to_examples
5
+ import soundfile
6
+
7
+
8
+ class Audio(torch.utils.data.Dataset):
9
+ def __init__(self, augmentation, directory_path, split):
10
+ # temporarily set no augmentation.
11
+ self.augmentation = augmentation
12
+ self.directory_path = directory_path
13
+ self.split = split
14
+
15
+ def load_audio_wave(self, file_index, file_index_mix):
16
+ audio_path = os.path.join(file_index, 'audio.wav')
17
+ wav_data, sample_rate = soundfile.read(audio_path, dtype='int16')
18
+ assert wav_data.dtype == numpy.int16, 'Bad sample type: %r' % wav_data.dtype
19
+
20
+ if file_index_mix is not None:
21
+ audio_path2 = os.path.join(file_index_mix, 'audio.wav')
22
+ wav_data2, _ = soundfile.read(audio_path2, dtype='int16')
23
+ mix_lambda = numpy.random.beta(10, 10)
24
+ min_length = min(wav_data.shape[0], wav_data2.shape[0])
25
+ wav_data = wav_data[:min_length] * mix_lambda + wav_data2[:min_length] * (1-mix_lambda)
26
+
27
+ wav_data = self.augmentation(wav_data, sample_rate, self.split)
28
+ audio_log_mel = torch.cat([waveform_to_examples(wav_data[:, 0], sample_rate, True).detach(),
29
+ waveform_to_examples(wav_data[:, 1], sample_rate, True).detach()], dim=1)
30
+
31
+ # for the vgg preprocess, we will need 5 seconds audio log.
32
+ if audio_log_mel.shape[0] < 5:
33
+ audio_log_mel = torch.cat([audio_log_mel,
34
+ audio_log_mel[-1].unsqueeze(0).repeat(5-audio_log_mel.shape[0], 1, 1, 1)])
35
+ return audio_log_mel
36
+
37
+ def __len__(self):
38
+ return len(self.audio_list)
avs.code/v1m.code/dataloader/audio/preprocess_vgg/mel_features.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2017 The TensorFlow Authors All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Defines routines to compute mel spectrogram features from audio waveform."""
17
+
18
+ import numpy as np
19
+
20
+
21
+ def frame(data, window_length, hop_length):
22
+ """Convert array into a sequence of successive possibly overlapping frames.
23
+
24
+ An n-dimensional array of shape (num_samples, ...) is converted into an
25
+ (n+1)-D array of shape (num_frames, window_length, ...), where each frame
26
+ starts hop_length points after the preceding one.
27
+
28
+ This is accomplished using stride_tricks, so the original data is not
29
+ copied. However, there is no zero-padding, so any incomplete frames at the
30
+ end are not included.
31
+
32
+ Args:
33
+ data: np.array of dimension N >= 1.
34
+ window_length: Number of samples in each frame.
35
+ hop_length: Advance (in samples) between each window.
36
+
37
+ Returns:
38
+ (N+1)-D np.array with as many rows as there are complete frames that can be
39
+ extracted.
40
+ """
41
+ num_samples = data.shape[0]
42
+ num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length))
43
+ shape = (num_frames, window_length) + data.shape[1:]
44
+ strides = (data.strides[0] * hop_length,) + data.strides
45
+ return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides)
46
+
47
+
48
+ def periodic_hann(window_length):
49
+ """Calculate a "periodic" Hann window.
50
+
51
+ The classic Hann window is defined as a raised cosine that starts and
52
+ ends on zero, and where every value appears twice, except the middle
53
+ point for an odd-length window. Matlab calls this a "symmetric" window
54
+ and np.hanning() returns it. However, for Fourier analysis, this
55
+ actually represents just over one cycle of a period N-1 cosine, and
56
+ thus is not compactly expressed on a length-N Fourier basis. Instead,
57
+ it's better to use a raised cosine that ends just before the final
58
+ zero value - i.e. a complete cycle of a period-N cosine. Matlab
59
+ calls this a "periodic" window. This routine calculates it.
60
+
61
+ Args:
62
+ window_length: The number of points in the returned window.
63
+
64
+ Returns:
65
+ A 1D np.array containing the periodic hann window.
66
+ """
67
+ return 0.5 - (0.5 * np.cos(2 * np.pi / window_length *
68
+ np.arange(window_length)))
69
+
70
+
71
+ def stft_magnitude(signal, fft_length,
72
+ hop_length=None,
73
+ window_length=None):
74
+ """Calculate the short-time Fourier transform magnitude.
75
+
76
+ Args:
77
+ signal: 1D np.array of the input time-domain signal.
78
+ fft_length: Size of the FFT to apply.
79
+ hop_length: Advance (in samples) between each frame passed to FFT.
80
+ window_length: Length of each block of samples to pass to FFT.
81
+
82
+ Returns:
83
+ 2D np.array where each row contains the magnitudes of the fft_length/2+1
84
+ unique values of the FFT for the corresponding frame of input samples.
85
+ """
86
+ frames = frame(signal, window_length, hop_length)
87
+ # Apply frame window to each frame. We use a periodic Hann (cosine of period
88
+ # window_length) instead of the symmetric Hann of np.hanning (period
89
+ # window_length-1).
90
+ window = periodic_hann(window_length)
91
+ windowed_frames = frames * window
92
+ return np.abs(np.fft.rfft(windowed_frames, int(fft_length)))
93
+
94
+
95
+ # Mel spectrum constants and functions.
96
+ _MEL_BREAK_FREQUENCY_HERTZ = 700.0
97
+ _MEL_HIGH_FREQUENCY_Q = 1127.0
98
+
99
+
100
+ def hertz_to_mel(frequencies_hertz):
101
+ """Convert frequencies to mel scale using HTK formula.
102
+
103
+ Args:
104
+ frequencies_hertz: Scalar or np.array of frequencies in hertz.
105
+
106
+ Returns:
107
+ Object of same size as frequencies_hertz containing corresponding values
108
+ on the mel scale.
109
+ """
110
+ return _MEL_HIGH_FREQUENCY_Q * np.log(
111
+ 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
112
+
113
+
114
+ def spectrogram_to_mel_matrix(num_mel_bins=20,
115
+ num_spectrogram_bins=129,
116
+ audio_sample_rate=8000,
117
+ lower_edge_hertz=125.0,
118
+ upper_edge_hertz=3800.0):
119
+ """Return a matrix that can post-multiply spectrogram rows to make mel.
120
+
121
+ Returns a np.array matrix A that can be used to post-multiply a matrix S of
122
+ spectrogram values (STFT magnitudes) arranged as frames x bins to generate a
123
+ "mel spectrogram" M of frames x num_mel_bins. M = S A.
124
+
125
+ The classic HTK algorithm exploits the complementarity of adjacent mel bands
126
+ to multiply each FFT bin by only one mel weight, then add it, with positive
127
+ and negative signs, to the two adjacent mel bands to which that bin
128
+ contributes. Here, by expressing this operation as a matrix multiply, we go
129
+ from num_fft multiplies per frame (plus around 2*num_fft adds) to around
130
+ num_fft^2 multiplies and adds. However, because these are all presumably
131
+ accomplished in a single call to np.dot(), it's not clear which approach is
132
+ faster in Python. The matrix multiplication has the attraction of being more
133
+ general and flexible, and much easier to read.
134
+
135
+ Args:
136
+ num_mel_bins: How many bands in the resulting mel spectrum. This is
137
+ the number of columns in the output matrix.
138
+ num_spectrogram_bins: How many bins there are in the source spectrogram
139
+ data, which is understood to be fft_size/2 + 1, i.e. the spectrogram
140
+ only contains the nonredundant FFT bins.
141
+ audio_sample_rate: Samples per second of the audio at the input to the
142
+ spectrogram. We need this to figure out the actual frequencies for
143
+ each spectrogram bin, which dictates how they are mapped into mel.
144
+ lower_edge_hertz: Lower bound on the frequencies to be included in the mel
145
+ spectrum. This corresponds to the lower edge of the lowest triangular
146
+ band.
147
+ upper_edge_hertz: The desired top edge of the highest frequency band.
148
+
149
+ Returns:
150
+ An np.array with shape (num_spectrogram_bins, num_mel_bins).
151
+
152
+ Raises:
153
+ ValueError: if frequency edges are incorrectly ordered or out of range.
154
+ """
155
+ nyquist_hertz = audio_sample_rate / 2.
156
+ if lower_edge_hertz < 0.0:
157
+ raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz)
158
+ if lower_edge_hertz >= upper_edge_hertz:
159
+ raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" %
160
+ (lower_edge_hertz, upper_edge_hertz))
161
+ if upper_edge_hertz > nyquist_hertz:
162
+ raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" %
163
+ (upper_edge_hertz, nyquist_hertz))
164
+ spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins)
165
+ spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz)
166
+ # The i'th mel band (starting from i=1) has center frequency
167
+ # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge
168
+ # band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in
169
+ # the band_edges_mel arrays.
170
+ band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz),
171
+ hertz_to_mel(upper_edge_hertz), num_mel_bins + 2)
172
+ # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins
173
+ # of spectrogram values.
174
+ mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins))
175
+ for i in range(num_mel_bins):
176
+ lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3]
177
+ # Calculate lower and upper slopes for every spectrogram bin.
178
+ # Line segments are linear in the *mel* domain, not hertz.
179
+ lower_slope = ((spectrogram_bins_mel - lower_edge_mel) /
180
+ (center_mel - lower_edge_mel))
181
+ upper_slope = ((upper_edge_mel - spectrogram_bins_mel) /
182
+ (upper_edge_mel - center_mel))
183
+ # .. then intersect them with each other and zero.
184
+ mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope,
185
+ upper_slope))
186
+ # HTK excludes the spectrogram DC bin; make sure it always gets a zero
187
+ # coefficient.
188
+ mel_weights_matrix[0, :] = 0.0
189
+ return mel_weights_matrix
190
+
191
+
192
+ def log_mel_spectrogram(data,
193
+ audio_sample_rate=8000,
194
+ log_offset=0.0,
195
+ window_length_secs=0.025,
196
+ hop_length_secs=0.010,
197
+ **kwargs):
198
+ """Convert waveform to a log magnitude mel-frequency spectrogram.
199
+
200
+ Args:
201
+ data: 1D np.array of waveform data.
202
+ audio_sample_rate: The sampling rate of data.
203
+ log_offset: Add this to values when taking log to avoid -Infs.
204
+ window_length_secs: Duration of each window to analyze.
205
+ hop_length_secs: Advance between successive analysis windows.
206
+ **kwargs: Additional arguments to pass to spectrogram_to_mel_matrix.
207
+
208
+ Returns:
209
+ 2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank
210
+ magnitudes for successive frames.
211
+ """
212
+ window_length_samples = int(round(audio_sample_rate * window_length_secs))
213
+ hop_length_samples = int(round(audio_sample_rate * hop_length_secs))
214
+ fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0)))
215
+ spectrogram = stft_magnitude(
216
+ data,
217
+ fft_length=fft_length,
218
+ hop_length=hop_length_samples,
219
+ window_length=window_length_samples)
220
+ mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix(
221
+ num_spectrogram_bins=spectrogram.shape[1],
222
+ audio_sample_rate=audio_sample_rate, **kwargs))
223
+ return np.log(mel_spectrogram + log_offset)
avs.code/v1m.code/dataloader/audio/preprocess_vgg/vggish_input.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2017 The TensorFlow Authors All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Compute input examples for VGGish from audio waveform."""
17
+
18
+ # Modification: Return torch tensors rather than numpy arrays
19
+ import torch
20
+
21
+ import numpy as np
22
+ import resampy
23
+
24
+ from dataloader.audio.preprocess_vgg import mel_features
25
+ from dataloader.audio.preprocess_vgg import vggish_params
26
+
27
+ import soundfile as sf
28
+
29
+
30
+ def waveform_to_examples(data, sample_rate, return_tensor=True):
31
+ """Converts audio waveform into an array of examples for VGGish.
32
+
33
+ Args:
34
+ data: np.array of either one dimension (mono) or two dimensions
35
+ (multi-channel, with the outer dimension representing channels).
36
+ Each sample is generally expected to lie in the range [-1.0, +1.0],
37
+ although this is not required.
38
+ sample_rate: Sample rate of data.
39
+ return_tensor: Return data as a Pytorch tensor ready for VGGish
40
+
41
+ Returns:
42
+ 3-D np.array of shape [num_examples, num_frames, num_bands] which represents
43
+ a sequence of examples, each of which contains a patch of log mel
44
+ spectrogram, covering num_frames frames of audio and num_bands mel frequency
45
+ bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS.
46
+
47
+ """
48
+ # Convert to mono.
49
+ if len(data.shape) > 1:
50
+ data = np.mean(data, axis=1)
51
+ # Resample to the rate assumed by VGGish.
52
+ if sample_rate != vggish_params.SAMPLE_RATE:
53
+ data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE)
54
+
55
+ # Compute log mel spectrogram features.
56
+ log_mel = mel_features.log_mel_spectrogram(
57
+ data,
58
+ audio_sample_rate=vggish_params.SAMPLE_RATE,
59
+ log_offset=vggish_params.LOG_OFFSET,
60
+ window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS,
61
+ hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS,
62
+ num_mel_bins=vggish_params.NUM_MEL_BINS,
63
+ lower_edge_hertz=vggish_params.MEL_MIN_HZ,
64
+ upper_edge_hertz=vggish_params.MEL_MAX_HZ)
65
+
66
+ # Frame features into examples.
67
+ features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS
68
+ example_window_length = int(round(
69
+ vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate))
70
+ example_hop_length = int(round(
71
+ vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate))
72
+ log_mel_examples = mel_features.frame(
73
+ log_mel,
74
+ window_length=example_window_length,
75
+ hop_length=example_hop_length)
76
+
77
+ if return_tensor:
78
+ log_mel_examples = torch.tensor(
79
+ log_mel_examples, requires_grad=True)[:, None, :, :].float()
80
+
81
+ return log_mel_examples
82
+
83
+
84
+ def wavfile_to_examples(wav_file, return_tensor=True):
85
+ """Convenience wrapper around waveform_to_examples() for a common WAV format.
86
+
87
+ Args:
88
+ wav_file: String path to a file, or a file-like object. The file
89
+ is assumed to contain WAV audio data with signed 16-bit PCM samples.
90
+ torch: Return data as a Pytorch tensor ready for VGGish
91
+
92
+ Returns:
93
+ See waveform_to_examples.
94
+ """
95
+ wav_data, sr = sf.read(wav_file, dtype='int16')
96
+ assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype
97
+ samples = wav_data / 32768.0 # Convert to [-1.0, +1.0]
98
+ return waveform_to_examples(samples, sr, return_tensor)
avs.code/v1m.code/dataloader/audio/preprocess_vgg/vggish_params.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2017 The TensorFlow Authors All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Global parameters for the VGGish model.
17
+
18
+ See vggish_slim.py for more information.
19
+ """
20
+
21
+ # Architectural constants.
22
+ NUM_FRAMES = 96 # Frames in input mel-spectrogram patch.
23
+ NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch.
24
+ EMBEDDING_SIZE = 128 # Size of embedding layer.
25
+
26
+ # Hyperparameters used in feature and example generation.
27
+ SAMPLE_RATE = 16000
28
+ STFT_WINDOW_LENGTH_SECONDS = 0.025
29
+ STFT_HOP_LENGTH_SECONDS = 0.010
30
+ NUM_MEL_BINS = NUM_BANDS
31
+ MEL_MIN_HZ = 125
32
+ MEL_MAX_HZ = 7500
33
+ LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram.
34
+ EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames
35
+ EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap.
36
+
37
+ # Parameters used for embedding postprocessing.
38
+ PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors'
39
+ PCA_MEANS_NAME = 'pca_means'
40
+ QUANTIZE_MIN_VAL = -2.0
41
+ QUANTIZE_MAX_VAL = +2.0
42
+
43
+ # Hyperparameters used in training.
44
+ INIT_STDDEV = 0.01 # Standard deviation used to initialize weights.
45
+ LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer.
46
+ ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer.
47
+
48
+ # Names of ops, tensors, and features.
49
+ INPUT_OP_NAME = 'vggish/input_features'
50
+ INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0'
51
+ OUTPUT_OP_NAME = 'vggish/embedding'
52
+ OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0'
53
+ AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding'
avs.code/v1m.code/dataloader/dataset.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Fused audio-visual dataset for AVSBench-style indexing."""
2
+ import os
3
+ import random
4
+ import PIL.Image
5
+ import numpy
6
+ import torch
7
+ from dataloader.visual.visual_dataset import Visual
8
+ from dataloader.audio.audio_dataset import Audio
9
+ import pandas
10
+
11
+
12
+ class AV(torch.utils.data.Dataset):
13
+ """Pairs video frames + labels from `Visual` with log-mel spectrograms from `Audio` via `metadata.csv`."""
14
+
15
+ def __init__(self, split, augmentation, param, root_path='', data_name='find'):
16
+ self.visual_dataset = Visual(augmentation['visual'], os.path.join(root_path, data_name), split, param.image_size, param.image_embedding_size)
17
+ self.audio_dataset = Audio(augmentation['audio'], os.path.join(root_path, data_name), split)
18
+ self.augment = augmentation
19
+ self.split = split
20
+ self.file_path = self.organise_files(self.split, root_path, data_name, csv_name_='avss_index/metadata.csv')
21
+
22
+ def __getitem__(self, index):
23
+ mixing_prob = 0. # we omit this option.
24
+ other_index = random.randint(1, self.__len__()) - 1 if random.random() < mixing_prob and self.split == 'train' else None
25
+ frame, label, prompts = self.visual_dataset.load_data(self.file_path[index])
26
+ if other_index is not None:
27
+ other_frame, other_label, other_prompts = self.visual_dataset.load_data(self.file_path[other_index])
28
+ frame, label, prompts = self.visual_mix(frame, other_frame, label, other_label, prompts, other_prompts)
29
+ audio_mel = self.audio_dataset.load_audio_wave(self.file_path[index], self.file_path[other_index])
30
+ else:
31
+ audio_mel = self.audio_dataset.load_audio_wave(self.file_path[index], None)
32
+
33
+ assert other_index is None if self.split == 'test' else 1, print('no mix in validation.')
34
+
35
+ return {'frame': frame, 'label': label, 'spectrogram': audio_mel, 'id': self.file_path[index],
36
+ 'prompts': prompts}
37
+
38
+ def __len__(self):
39
+ return len(self.file_path)
40
+
41
+ @staticmethod
42
+ def organise_files(split_, root_path_, data_name_, csv_name_):
43
+ """Read rows from `csv_name_` under `root_path_` matching split and dataset label."""
44
+ total_files = pandas.read_csv(os.path.join(root_path_, csv_name_))
45
+ files_info = total_files[(total_files["split"] == split_) & (total_files["label"] == data_name_)]['uid']
46
+
47
+ files_path = [os.path.join(root_path_, data_name_, files_name) for files_name in files_info]
48
+ del total_files, files_info
49
+ return files_path
50
+
51
+ @staticmethod
52
+ def visual_mix(frame1, frame2, label1, label2, prompts1, prompts2):
53
+ mix_frame = frame1.clone()
54
+ mix_label = label1.clone()
55
+ bbx1, bby1, bbx2, bby2 = 0, 0, mix_label.shape[1] - 1, mix_label.shape[2] - 1
56
+
57
+ for i in range(0, mix_frame.shape[0]):
58
+ label_canvas_foreground = label2[i, bbx1:bbx2, bby1:bby2] > 0.
59
+ mix_frame[i, :, bbx1:bbx2, bby1:bby2][:, label_canvas_foreground] = (
60
+ frame2[i, :, bbx1:bbx2, bby1:bby2][:, label_canvas_foreground])
61
+ mix_label[i, bbx1:bbx2, bby1:bby2][label_canvas_foreground] = (
62
+ label2[i, bbx1:bbx2, bby1:bby2][label_canvas_foreground])
63
+
64
+ return mix_frame, mix_label, prompts1
65
+
66
+
67
+
avs.code/v1m.code/dataloader/sam2_dataset/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
avs.code/v1m.code/dataloader/sam2_dataset/transforms.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Transforms and data augmentation for both image + bbox.
9
+ """
10
+
11
+ import logging
12
+
13
+ import random
14
+ from typing import Iterable
15
+
16
+ import torch
17
+ import torchvision.transforms as T
18
+ import torchvision.transforms.functional as F
19
+ import torchvision.transforms.v2.functional as Fv2
20
+ from PIL import Image as PILImage
21
+ # from docutils.nodes import label
22
+ import numpy
23
+ from torchvision.transforms import InterpolationMode
24
+
25
+ # from utils.data_utils import VideoDatapoint
26
+
27
+
28
+ def hflip(frames, labels, index):
29
+ # print(index)
30
+ # print(len(frames), frames[index].size, type(frames[index]))
31
+ # print(len(labels), labels[index].size, type(labels[index]))
32
+ frames[index] = F.hflip(frames[index])
33
+ labels[index] = F.hflip(labels[index])
34
+ # for obj in frames[index].objects:
35
+ # if obj.segment is not None:
36
+ # obj.segment = F.hflip(obj.segment)
37
+
38
+ return frames, labels
39
+
40
+
41
+ def get_size_with_aspect_ratio(image_size, size, max_size=None):
42
+ w, h = image_size
43
+ if max_size is not None:
44
+ min_original_size = float(min((w, h)))
45
+ max_original_size = float(max((w, h)))
46
+ if max_original_size / min_original_size * size > max_size:
47
+ size = max_size * min_original_size / max_original_size
48
+
49
+ if (w <= h and w == size) or (h <= w and h == size):
50
+ return (h, w)
51
+
52
+ if w < h:
53
+ ow = int(round(size))
54
+ oh = int(round(size * h / w))
55
+ else:
56
+ oh = int(round(size))
57
+ ow = int(round(size * w / h))
58
+
59
+ return (oh, ow)
60
+
61
+
62
+ def resize(frames, labels, index, size, max_size=None, square=False, v2=False):
63
+ # size can be min_size (scalar) or (w, h) tuple
64
+ def get_size(image_size, size, max_size=None):
65
+ if isinstance(size, (list, tuple)):
66
+ return size[::-1]
67
+ else:
68
+ return get_size_with_aspect_ratio(image_size, size, max_size)
69
+
70
+ if square:
71
+ size = size, size
72
+ else:
73
+ raise NotImplementedError
74
+ # cur_size = (
75
+ # frames[index].data.size()[-2:][::-1]
76
+ # if v2
77
+ # else frames[index].data.size
78
+ # )
79
+ # size = get_size(cur_size, size, max_size)
80
+
81
+ # old_size = (
82
+ # frames[index].data.size()[-2:][::-1]
83
+ # if v2
84
+ # else frames[index].data.size
85
+ # )
86
+ if v2:
87
+ frames[index].data = Fv2.resize(
88
+ frames[index].data, size, antialias=True
89
+ )
90
+ else:
91
+ frames[index] = F.resize(frames[index], size)
92
+ labels[index] = F.resize(labels[index], size)
93
+ # new_size = (
94
+ # frames[index].data.size()[-2:][::-1]
95
+ # if v2
96
+ # else frames[index].data.size
97
+ # )
98
+
99
+ # for obj in frames[index].objects:
100
+ # if obj.segment is not None:
101
+ # obj.segment = F.resize(obj.segment[None, None], size).squeeze()
102
+
103
+ # h, w = size
104
+ # frames[index].size = (h, w)
105
+ return frames, labels
106
+
107
+
108
+ def pad(frames, index, padding, v2=False):
109
+ old_h, old_w = frames[index].size
110
+ h, w = old_h, old_w
111
+ if len(padding) == 2:
112
+ # assumes that we only pad on the bottom right corners
113
+ frames[index].data = F.pad(
114
+ frames[index].data, (0, 0, padding[0], padding[1])
115
+ )
116
+ h += padding[1]
117
+ w += padding[0]
118
+ else:
119
+ # left, top, right, bottom
120
+ frames[index].data = F.pad(
121
+ frames[index].data,
122
+ (padding[0], padding[1], padding[2], padding[3]),
123
+ )
124
+ h += padding[1] + padding[3]
125
+ w += padding[0] + padding[2]
126
+
127
+ frames[index].size = (h, w)
128
+
129
+ for obj in frames[index].objects:
130
+ if obj.segment is not None:
131
+ if v2:
132
+ if len(padding) == 2:
133
+ obj.segment = Fv2.pad(obj.segment, (0, 0, padding[0], padding[1]))
134
+ else:
135
+ obj.segment = Fv2.pad(obj.segment, tuple(padding))
136
+ else:
137
+ if len(padding) == 2:
138
+ obj.segment = F.pad(obj.segment, (0, 0, padding[0], padding[1]))
139
+ else:
140
+ obj.segment = F.pad(obj.segment, tuple(padding))
141
+ return frames
142
+
143
+
144
+ class RandomHorizontalFlip:
145
+ def __init__(self, consistent_transform, p=0.5):
146
+ self.p = p
147
+ self.consistent_transform = consistent_transform
148
+
149
+ def __call__(self, frames, labels, **kwargs):
150
+ if self.consistent_transform:
151
+ if random.random() < self.p:
152
+ for i in range(len(frames)):
153
+ frames, labels = hflip(frames, labels, i)
154
+ return frames, labels
155
+ for i in range(len(frames)):
156
+ if random.random() < self.p:
157
+ frames, labels = hflip(frames, labels, i)
158
+ return frames, labels
159
+
160
+
161
+ class RandomResizeAPI:
162
+ def __init__(
163
+ self, sizes, consistent_transform, max_size=None, square=False, v2=False
164
+ ):
165
+ if isinstance(sizes, int):
166
+ sizes = (sizes,)
167
+ assert isinstance(sizes, Iterable)
168
+ self.sizes = list(sizes)
169
+ self.max_size = max_size
170
+ self.square = square
171
+ self.consistent_transform = consistent_transform
172
+ self.v2 = v2
173
+
174
+ def __call__(self, frames, labels):
175
+ if self.consistent_transform:
176
+ size = random.choice(self.sizes)
177
+ for i in range(len(frames)):
178
+ frames, labels = resize(
179
+ frames, labels, i, size, self.max_size, square=self.square, v2=self.v2
180
+ )
181
+ return frames, labels
182
+ for i in range(len(frames)):
183
+ size = random.choice(self.sizes)
184
+ frames, labels = resize(
185
+ frames, labels, i, size, self.max_size, square=self.square, v2=self.v2
186
+ )
187
+ return frames, labels
188
+
189
+
190
+ class ToTensorAPI:
191
+ def __init__(self, v2=False):
192
+ self.v2 = v2
193
+
194
+ def __call__(self, frames, labels, **kwargs):
195
+ for img_idx in range(len(frames)):
196
+ if self.v2:
197
+ raise NotImplementedError
198
+ # frames[img_idx] = Fv2.to_tensor(frames[img_idx])
199
+ else:
200
+ frames[img_idx] = F.to_tensor(frames[img_idx])
201
+ labels[img_idx] = torch.tensor(numpy.array(labels[img_idx]), dtype=torch.float)
202
+ return frames, labels
203
+
204
+
205
+ class NormalizeAPI:
206
+ def __init__(self, mean, std, v2=False):
207
+ self.mean = mean
208
+ self.std = std
209
+ self.v2 = v2
210
+
211
+ def __call__(self, frames, labels, **kwargs):
212
+ for img_idx in range(len(frames)):
213
+ # if self.v2:
214
+ # img.data = Fv2.convert_image_dtype(img.data, torch.float32)
215
+ # img.data = Fv2.normalize(img.data, mean=self.mean, std=self.std)
216
+ # else:
217
+ frames[img_idx] = F.normalize(frames[img_idx], mean=self.mean, std=self.std)
218
+
219
+ return frames, labels
220
+
221
+ '''
222
+ <dataloader.sam2_dataset.transforms.RandomHorizontalFlip object at 0x75c815561b40>
223
+ <dataloader.sam2_dataset.transforms.RandomAffine object at 0x75c815561bd0>
224
+ <dataloader.sam2_dataset.transforms.RandomResizeAPI object at 0x75c815561c60>
225
+ <dataloader.sam2_dataset.transforms.ColorJitter object at 0x75c815561cc0>
226
+ <dataloader.sam2_dataset.transforms.RandomGrayscale object at 0x75c815561cf0>
227
+ <dataloader.sam2_dataset.transforms.ColorJitter object at 0x75c815561de0>
228
+ <dataloader.sam2_dataset.transforms.ToTensorAPI object at 0x75c815507280>
229
+ <dataloader.sam2_dataset.transforms.NormalizeAPI object at 0x75c815507490>
230
+ '''
231
+ class ComposeAPI:
232
+ def __init__(self, transforms):
233
+ self.transforms = transforms
234
+
235
+ def __call__(self, frames, labels, **kwargs):
236
+ for t in self.transforms:
237
+ frames, labels = t(frames, labels, **kwargs)
238
+ return frames, labels
239
+
240
+ def __repr__(self):
241
+ format_string = self.__class__.__name__ + "("
242
+ for t in self.transforms:
243
+ format_string += "\n"
244
+ format_string += " {0}".format(t)
245
+ format_string += "\n)"
246
+ return format_string
247
+
248
+
249
+ class RandomGrayscale:
250
+ def __init__(self, consistent_transform, p=0.5):
251
+ self.p = p
252
+ self.consistent_transform = consistent_transform
253
+ self.Grayscale = T.Grayscale(num_output_channels=3)
254
+
255
+ def __call__(self, frames, labels, **kwargs):
256
+ if self.consistent_transform:
257
+ if random.random() < self.p:
258
+ for img_idx in range(len(frames)):
259
+ frames[img_idx] = self.Grayscale(frames[img_idx])
260
+ return frames, labels
261
+ for img_idx in range(len(frames)):
262
+ if random.random() < self.p:
263
+ frames[img_idx] = self.Grayscale(frames[img_idx])
264
+ return frames, labels
265
+
266
+
267
+ class ColorJitter:
268
+ def __init__(self, consistent_transform, brightness, contrast, saturation, hue):
269
+ self.consistent_transform = consistent_transform
270
+ self.brightness = (
271
+ brightness
272
+ if isinstance(brightness, list)
273
+ else [max(0, 1 - brightness), 1 + brightness]
274
+ )
275
+ self.contrast = (
276
+ contrast
277
+ if isinstance(contrast, list)
278
+ else [max(0, 1 - contrast), 1 + contrast]
279
+ )
280
+ self.saturation = (
281
+ saturation
282
+ if isinstance(saturation, list)
283
+ else [max(0, 1 - saturation), 1 + saturation]
284
+ )
285
+ self.hue = hue if isinstance(hue, list) or hue is None else ([-hue, hue])
286
+
287
+ def __call__(self, frames, labels, **kwargs):
288
+ if self.consistent_transform:
289
+ # Create a color jitter transformation params
290
+ (
291
+ fn_idx,
292
+ brightness_factor,
293
+ contrast_factor,
294
+ saturation_factor,
295
+ hue_factor,
296
+ ) = T.ColorJitter.get_params(
297
+ self.brightness, self.contrast, self.saturation, self.hue
298
+ )
299
+ for img in frames:
300
+ if not self.consistent_transform:
301
+ (
302
+ fn_idx,
303
+ brightness_factor,
304
+ contrast_factor,
305
+ saturation_factor,
306
+ hue_factor,
307
+ ) = T.ColorJitter.get_params(
308
+ self.brightness, self.contrast, self.saturation, self.hue
309
+ )
310
+ for fn_id in fn_idx:
311
+ if fn_id == 0 and brightness_factor is not None:
312
+ img = F.adjust_brightness(img, brightness_factor)
313
+ elif fn_id == 1 and contrast_factor is not None:
314
+ img = F.adjust_contrast(img, contrast_factor)
315
+ elif fn_id == 2 and saturation_factor is not None:
316
+ img = F.adjust_saturation(img, saturation_factor)
317
+ elif fn_id == 3 and hue_factor is not None:
318
+ img = F.adjust_hue(img, hue_factor)
319
+ return frames, labels
320
+
321
+
322
+ class RandomAffine:
323
+ def __init__(
324
+ self,
325
+ degrees,
326
+ consistent_transform,
327
+ scale=None,
328
+ translate=None,
329
+ shear=None,
330
+ image_mean=(123, 116, 103),
331
+ label_fill_value=0.,
332
+ log_warning=True,
333
+ num_tentatives=1,
334
+ image_interpolation="bicubic",
335
+ ):
336
+ """
337
+ The mask is required for this transform.
338
+ if consistent_transform if True, then the same random affine is applied to all frames and masks.
339
+ """
340
+ self.degrees = degrees if isinstance(degrees, list) else ([-degrees, degrees])
341
+ self.scale = scale
342
+ self.shear = (
343
+ shear if isinstance(shear, list) else ([-shear, shear] if shear else None)
344
+ )
345
+ self.translate = translate
346
+ self.fill_img = image_mean
347
+ self.fill_label = label_fill_value
348
+ self.consistent_transform = consistent_transform
349
+ self.log_warning = log_warning
350
+ self.num_tentatives = num_tentatives
351
+ assert self.num_tentatives >= 1., 'must have at least one if we utilise the augmentation.'
352
+
353
+ if image_interpolation == "bicubic":
354
+ self.image_interpolation = InterpolationMode.BICUBIC
355
+ elif image_interpolation == "bilinear":
356
+ self.image_interpolation = InterpolationMode.BILINEAR
357
+ else:
358
+ raise NotImplementedError
359
+
360
+ def __call__(self, frames, labels, **kwargs):
361
+ for _tentative in range(self.num_tentatives):
362
+ res_img, res_labels = self.transform_frames(frames, labels)
363
+ # if res is not None:
364
+ return res_img, res_labels
365
+
366
+ # raise NotImplementedError
367
+ # if self.log_warning:
368
+ # logging.warning(
369
+ # f"Skip RandomAffine for zero-area mask in first frame after {self.num_tentatives} tentatives"
370
+ # )
371
+ # return frames
372
+
373
+ def transform_frames(self, frames, labels):
374
+ _, height, width = F.get_dimensions(frames[0])
375
+ img_size = [width, height]
376
+
377
+ if self.consistent_transform:
378
+ # Create a random affine transformation
379
+ affine_params = T.RandomAffine.get_params(
380
+ degrees=self.degrees,
381
+ translate=self.translate,
382
+ scale_ranges=self.scale,
383
+ shears=self.shear,
384
+ img_size=img_size,
385
+ )
386
+
387
+ for img_idx, img in enumerate(frames):
388
+ if not self.consistent_transform:
389
+ # if not consistent we create a new affine params for every frame&mask pair Create a random affine transformation
390
+ affine_params = T.RandomAffine.get_params(
391
+ degrees=self.degrees,
392
+ translate=self.translate,
393
+ scale_ranges=self.scale,
394
+ shears=self.shear,
395
+ img_size=img_size,
396
+ )
397
+ frames[img_idx] = F.affine(
398
+ img,
399
+ *affine_params,
400
+ interpolation=self.image_interpolation,
401
+ fill=self.fill_img,
402
+ )
403
+ labels[img_idx] = F.affine(
404
+ labels[img_idx],
405
+ *affine_params,
406
+ # default: interpolation='nearest',
407
+ fill=self.fill_label,
408
+ )
409
+ return frames, labels
410
+
411
+
412
+ '''
413
+ def random_mosaic_frame(
414
+ datapoint,
415
+ index,
416
+ grid_h,
417
+ grid_w,
418
+ target_grid_y,
419
+ target_grid_x,
420
+ should_hflip,
421
+ ):
422
+ # Step 1: downsize the images and paste them into a mosaic
423
+ image_data = datapoint.frames[index].data
424
+ is_pil = isinstance(image_data, PILImage.Image)
425
+ if is_pil:
426
+ H_im = image_data.height
427
+ W_im = image_data.width
428
+ image_data_output = PILImage.new("RGB", (W_im, H_im))
429
+ else:
430
+ H_im = image_data.size(-2)
431
+ W_im = image_data.size(-1)
432
+ image_data_output = torch.zeros_like(image_data)
433
+
434
+ downsize_cache = {}
435
+ for grid_y in range(grid_h):
436
+ for grid_x in range(grid_w):
437
+ y_offset_b = grid_y * H_im // grid_h
438
+ x_offset_b = grid_x * W_im // grid_w
439
+ y_offset_e = (grid_y + 1) * H_im // grid_h
440
+ x_offset_e = (grid_x + 1) * W_im // grid_w
441
+ H_im_downsize = y_offset_e - y_offset_b
442
+ W_im_downsize = x_offset_e - x_offset_b
443
+
444
+ if (H_im_downsize, W_im_downsize) in downsize_cache:
445
+ image_data_downsize = downsize_cache[(H_im_downsize, W_im_downsize)]
446
+ else:
447
+ image_data_downsize = F.resize(
448
+ image_data,
449
+ size=(H_im_downsize, W_im_downsize),
450
+ interpolation=InterpolationMode.BILINEAR,
451
+ antialias=True, # antialiasing for downsizing
452
+ )
453
+ downsize_cache[(H_im_downsize, W_im_downsize)] = image_data_downsize
454
+ if should_hflip[grid_y, grid_x].item():
455
+ image_data_downsize = F.hflip(image_data_downsize)
456
+
457
+ if is_pil:
458
+ image_data_output.paste(image_data_downsize, (x_offset_b, y_offset_b))
459
+ else:
460
+ image_data_output[:, y_offset_b:y_offset_e, x_offset_b:x_offset_e] = (
461
+ image_data_downsize
462
+ )
463
+
464
+ datapoint.frames[index].data = image_data_output
465
+
466
+ # Step 2: downsize the masks and paste them into the target grid of the mosaic
467
+ for obj in datapoint.frames[index].objects:
468
+ if obj.segment is None:
469
+ continue
470
+ assert obj.segment.shape == (H_im, W_im) and obj.segment.dtype == torch.uint8
471
+ segment_output = torch.zeros_like(obj.segment)
472
+
473
+ target_y_offset_b = target_grid_y * H_im // grid_h
474
+ target_x_offset_b = target_grid_x * W_im // grid_w
475
+ target_y_offset_e = (target_grid_y + 1) * H_im // grid_h
476
+ target_x_offset_e = (target_grid_x + 1) * W_im // grid_w
477
+ target_H_im_downsize = target_y_offset_e - target_y_offset_b
478
+ target_W_im_downsize = target_x_offset_e - target_x_offset_b
479
+
480
+ segment_downsize = F.resize(
481
+ obj.segment[None, None],
482
+ size=(target_H_im_downsize, target_W_im_downsize),
483
+ interpolation=InterpolationMode.BILINEAR,
484
+ antialias=True, # antialiasing for downsizing
485
+ )[0, 0]
486
+ if should_hflip[target_grid_y, target_grid_x].item():
487
+ segment_downsize = F.hflip(segment_downsize[None, None])[0, 0]
488
+
489
+ segment_output[
490
+ target_y_offset_b:target_y_offset_e, target_x_offset_b:target_x_offset_e
491
+ ] = segment_downsize
492
+ obj.segment = segment_output
493
+
494
+ return datapoint
495
+
496
+
497
+ class RandomMosaicVideoAPI:
498
+ def __init__(self, prob=0.15, grid_h=2, grid_w=2, use_random_hflip=False):
499
+ self.prob = prob
500
+ self.grid_h = grid_h
501
+ self.grid_w = grid_w
502
+ self.use_random_hflip = use_random_hflip
503
+
504
+ def __call__(self, frames, **kwargs):
505
+ if random.random() > self.prob:
506
+ return datapoint
507
+
508
+ # select a random location to place the target mask in the mosaic
509
+ target_grid_y = random.randint(0, self.grid_h - 1)
510
+ target_grid_x = random.randint(0, self.grid_w - 1)
511
+ # whether to flip each grid in the mosaic horizontally
512
+ if self.use_random_hflip:
513
+ should_hflip = torch.rand(self.grid_h, self.grid_w) < 0.5
514
+ else:
515
+ should_hflip = torch.zeros(self.grid_h, self.grid_w, dtype=torch.bool)
516
+ for i in range(len(datapoint.frames)):
517
+ datapoint = random_mosaic_frame(
518
+ datapoint,
519
+ i,
520
+ grid_h=self.grid_h,
521
+ grid_w=self.grid_w,
522
+ target_grid_y=target_grid_y,
523
+ target_grid_x=target_grid_x,
524
+ should_hflip=should_hflip,
525
+ )
526
+
527
+ return datapoint
528
+ '''
avs.code/v1m.code/dataloader/visual/visual_augmentation.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import matplotlib.pyplot as plt
4
+ import numpy
5
+ import torch
6
+ import torchvision.transforms.functional as F
7
+ import torchvision.transforms as transforms
8
+
9
+
10
+ class Augmentation(object):
11
+ def __init__(self, image_mean, image_std, image_width, image_height, scale_list, ignore_index=255):
12
+ self.image_size = (image_height, image_width)
13
+ # self.image_norm = (image_mean, image_std)
14
+ # self.get_crop_pos = transforms.RandomCrop(self.image_size)
15
+ self.color_jitter = transforms.ColorJitter(brightness=.5, contrast=.5, saturation=.5, hue=.25)
16
+ self.gaussian_blurring = transforms.GaussianBlur((3, 3))
17
+ self.scale_list = scale_list
18
+
19
+ self.normalise = transforms.Normalize(mean=image_mean, std=image_std)
20
+ self.to_tensor = transforms.ToTensor()
21
+
22
+ self.ignore_index = ignore_index
23
+
24
+ # self.normalise = transforms.Normalize(mean=image_mean, std=image_std)
25
+
26
+ # if setup == "avs" or setup == "avss" or setup == "avss_binary":
27
+ # # AVS
28
+ # self.scale_list = [.5, .75, 1.]
29
+ # self.color_jitter = None
30
+ # else:
31
+ # # COCO
32
+ # # self.scale_list = [.75, 1., 1.25, 1.5, 1.75, 2.]
33
+ # self.scale_list = [0.5,0.75,1.0,1.25,1.5,1.75,2.0]
34
+
35
+ # def normalise(self, image):
36
+ # image = image / 255.0
37
+ # image = image - self.image_norm[0]
38
+ # image = image / self.image_norm[1]
39
+ # return image
40
+
41
+ def resize(self, image_, label_, size=None):
42
+ h_, w_ = self.image_size if size is None else size
43
+ image_ = F.resize(image_, (h_, w_), transforms.InterpolationMode.BICUBIC)
44
+ label_ = F.resize(label_, (h_, w_), transforms.InterpolationMode.NEAREST)
45
+ return image_, label_
46
+
47
+ def random_crop_with_padding(self, image_, label_):
48
+ w_, h_ = image_.size
49
+ if min(h_, w_) < min(self.image_size):
50
+ res_w_ = max(self.image_size[0] - w_, 0)
51
+ res_h_ = max(self.image_size[1] - h_, 0)
52
+ image_ = F.pad(image_, [0, 0, res_w_, res_h_], fill=(numpy.array(self.image_norm[0]) * 255.).tolist())
53
+ # image_ = F.pad(image_, [0, 0, res_w_, res_h_], fill=self.ignore_index) # if error, define the padding value.
54
+ label_ = F.pad(label_, [0, 0, res_w_, res_h_], fill=self.ignore_index)
55
+
56
+ pos_ = self.get_crop_pos.get_params(image_, self.image_size)
57
+ image_ = F.crop(image_, *pos_)
58
+ label_ = F.crop(label_, *pos_)
59
+
60
+ return image_, label_
61
+
62
+ # @staticmethod
63
+ def random_scales(self, image_, label_):
64
+ w_, h_ = image_.size
65
+ chosen_scale = random.choice(self.scale_list)
66
+ w_, h_ = int(w_ * chosen_scale), int(h_ * chosen_scale)
67
+ image_ = F.resize(image_, (h_, w_), transforms.InterpolationMode.BICUBIC)
68
+ label_ = F.resize(label_, (h_, w_), transforms.InterpolationMode.NEAREST)
69
+ return image_, label_
70
+
71
+ @staticmethod
72
+ def random_flip_h(image_, label_):
73
+ chosen_flip = random.random() > 0.5
74
+ image_ = F.hflip(image_) if chosen_flip else image_
75
+ label_ = F.hflip(label_) if chosen_flip else label_
76
+ return image_, label_
77
+
78
+ def augment_entire_clip(self, x_list, y_list):
79
+ degree_ = float(torch.empty(1).uniform_(float(-25.), float(25.)).item())
80
+ shear_ = [float(torch.empty(1).uniform_(float(-20.), float(20.)).item()),
81
+ torch.empty(1).uniform_(float(-20.), float(20.)).item()]
82
+ dice = random.random()
83
+ for index, single_x in enumerate(x_list):
84
+ if dice <= 0.1:
85
+ single_x = F.rgb_to_grayscale(single_x, num_output_channels=3)
86
+
87
+ single_x = F.affine(single_x, angle=degree_, shear=shear_, translate=[0,0], scale=1.,
88
+ interpolation=transforms.InterpolationMode.BILINEAR, fill=[0., 0., 0.])
89
+ single_y = F.affine(y_list[index], angle=degree_, shear=shear_, translate=[0,0], scale=1.,
90
+ interpolation=transforms.InterpolationMode.NEAREST, fill=[0.])
91
+ x_list[index] = single_x
92
+ y_list[index] = single_y
93
+
94
+ return x_list, y_list
95
+
96
+
97
+
98
+
99
+ def train_aug(self, x_, y_):
100
+ x_, y_ = self.random_flip_h(x_, y_)
101
+ # # x, y = self.random_scales(x, y)
102
+ x_, y_ = self.resize(x_, y_)
103
+
104
+ if self.color_jitter is not None and random.random() < 0.5:
105
+ x_ = self.color_jitter(x_)
106
+ if self.gaussian_blurring is not None and random.random() < 0.5:
107
+ x_ = self.gaussian_blurring(x_)
108
+
109
+ # x, y = self.random_crop_with_padding(x, y)
110
+
111
+ x_ = self.normalise(self.to_tensor(x_)).type(torch.float32)
112
+ # receive pseudo labels.
113
+ y_ = torch.tensor(numpy.array(y_)[numpy.newaxis, ...], dtype=torch.float)
114
+ return x_, y_
115
+
116
+ def test_process(self, x_, y_):
117
+ # x = self.to_tensor(x)
118
+ # y = torch.tensor(numpy.asarray(y)).long()
119
+
120
+ # following AVSbench setup, we fix image size (224, 224)
121
+ x_, y_ = self.resize(x_, y_)
122
+
123
+ x_ = self.normalise(self.to_tensor(x_)).type(torch.float32)
124
+ y_ = torch.tensor(numpy.array(y_)[numpy.newaxis, ...], dtype=torch.float)
125
+ return x_, y_
126
+
127
+ def __call__(self, x, y, split):
128
+ return self.train_aug(x, y) if split == "train" \
129
+ else self.test_process(x, y)
130
+
131
+
132
+
133
+
134
+
135
+
136
+
137
+
138
+
139
+
140
+
avs.code/v1m.code/dataloader/visual/visual_dataset.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import PIL.Image
4
+ import matplotlib.pyplot as plt
5
+ import numpy
6
+ import torch
7
+ import pandas
8
+ import torchvision
9
+
10
+
11
+ class Visual(torch.utils.data.Dataset):
12
+ def __init__(self, augmentation, directory_path, split, image_size, image_embedding_size):
13
+ self.augment = augmentation
14
+ self.directory_path = directory_path
15
+ self.split = split
16
+ self.image_size = image_size
17
+ self.embedding_size = image_embedding_size
18
+
19
+ def load_data(self, file_prefix):
20
+ frame_path = os.path.join(file_prefix, 'frames')
21
+ frame_path = [os.path.join(frame_path, i) for i in os.listdir(frame_path)]
22
+ label_path = os.path.join(file_prefix, 'labels_rgb')
23
+ label_path = [os.path.join(label_path, i) for i in os.listdir(label_path)]
24
+
25
+ # if self.split == 'train':
26
+ # label_path += [os.path.join(file_prefix.replace('v1s', 'v1s_sam2_pseudo_labels'), i) for i in
27
+ # os.listdir(file_prefix.replace('v1s', 'v1s_sam2_pseudo_labels'))]
28
+
29
+ frame_path.sort(key=lambda x: tuple(map(int, x.split('/')[-1].split("_")[-1].split('.jpg')[0])))
30
+ label_path.sort(key=lambda x: tuple(map(int, x.split('/')[-1].split("_")[-1].split('.png')[0])))
31
+
32
+ frame = [PIL.Image.open(i) for i in frame_path]
33
+ label = [PIL.Image.open(i).convert('L') for i in label_path]
34
+
35
+ # if self.split == 'train':
36
+ # label += [PIL.Image.new('L', frame[0].size)] * (len(frame)-len(label))
37
+
38
+ label_idx = torch.tensor(list([1] + [0] * 4), dtype=torch.bool)
39
+ # fulfill the empty page.
40
+ # we utilise pseudo-labels now.
41
+ # label_idx = torch.tensor(list([1] + [0] * (len(frame) - len(label))), dtype=torch.bool)
42
+ # label += [PIL.Image.new('L', frame[0].size)] * (len(frame)-len(label))
43
+
44
+ # receive the prompts from the ground truth.
45
+ # prompts = {"point_coords": torch.nan, "point_labels": torch.nan,
46
+ # "masks": [None]*len(frame), "box_coords": [None]*len(frame)}
47
+
48
+ prompts = {}
49
+ image_batch = [None]*len(frame)
50
+ label_batch = [None]*len(frame)
51
+
52
+ if self.split == 'train':
53
+ # frame, label = self.augment.augment_entire_clip(frame, label)
54
+ frame, label = self.augment(frame, label)
55
+
56
+
57
+ for i in range(len(frame)):
58
+ if self.split == 'test':
59
+ curr_frame, curr_label = self.augment(frame[i], label[i], split=self.split)
60
+ else:
61
+ curr_frame, curr_label = frame[i], label[i]
62
+ # if self.split == 'train' and i > 0:
63
+ # curr_label = curr_label / 255.
64
+ # curr_label[curr_label > 0.5] = 1
65
+ # curr_label[curr_label < 0.5] = 0
66
+ # # curr_label[(0.05 < curr_label) & (curr_label < 0.95)] = 255
67
+ # # we temporarily make it to be hard mask;
68
+ # # curr_label = ((curr_label / 255.) - 0.5) * 2
69
+ # # curr_label[curr_label >= 0.] = 1.
70
+ # # curr_label[curr_label < 0.] = 0.
71
+ # else:
72
+ curr_label[curr_label > 0.] = 1.
73
+ image_batch[i], label_batch[i] = curr_frame, curr_label
74
+
75
+ # image_batch[i], label_batch[i] = self.augment(frame[i], label[i], split=self.split)
76
+ # note: we simply convert the code to binary mask in v1s, v1m;
77
+ # to some reason, we failed to load the label in `L' format and had to hardcoding here.
78
+ # label_batch[i][label_batch[i] > 0.] = 1.
79
+
80
+ # prompts['box_coords'][i], prompts['masks'][i] = self.receive_other_prompts(label_batch[i])
81
+
82
+ # organise the prompts
83
+ # prompts.update({'masks': torch.stack(prompts['masks'], dim=0)})
84
+ # prompts.update({'box_coords': torch.stack(prompts['box_coords'], dim=0)})
85
+ # prompts.update({'point_labels': torch.stack(prompts['point_labels'], dim=0)})
86
+ prompts.update({'label_index': label_idx})
87
+ return torch.stack(image_batch, dim=0), torch.stack(label_batch, dim=0), prompts
88
+
89
+ def receive_other_prompts(self, y_):
90
+ # y_ = torch.zeros_like(y_)
91
+ if len(torch.unique(y_)) > 1:
92
+ # foreground point
93
+ points_foreground = torch.stack(torch.where(y_ > 0)[::-1], dim=0).transpose(1, 0)
94
+
95
+ # bbox prompt (left-top corner & right-bottom corner)
96
+ bbox_one = torch.min(points_foreground[:, 0]), torch.min(points_foreground[:, 1])
97
+ bbox_fou = torch.max(points_foreground[:, 0]), torch.max(points_foreground[:, 1])
98
+ bbox_coord = torch.tensor(bbox_one + bbox_fou, dtype=torch.float)
99
+ bbox_coord = self.transform_coords(bbox_coord, orig_hw=y_.squeeze().shape)
100
+ # mask prompt
101
+ low_mask = torchvision.transforms.functional.resize(y_.clone(), [self.embedding_size*4, self.embedding_size*4],
102
+ torchvision.transforms.InterpolationMode.NEAREST)
103
+ else:
104
+ # for the pure background situation.
105
+ bbox_coord = torch.zeros([4], dtype=torch.float).fill_(float('nan'))
106
+ low_mask = torch.zeros([1, self.embedding_size*4, self.embedding_size*4], dtype=torch.float).fill_(float('nan'))
107
+
108
+ return bbox_coord, low_mask
109
+
110
+ # we transfer the coords to SAM's input resolution (1024, 1024).
111
+ def transform_coords(self, coords: torch.Tensor, orig_hw=None) -> torch.Tensor:
112
+ """
113
+ Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates,
114
+ If the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
115
+
116
+ Returns
117
+ Un-normalized coordinates in the range of [0, 1] which is expected by the sam2 model.
118
+ """
119
+ h, w = orig_hw
120
+ coords = coords.clone().reshape(-1, 2, 2)
121
+ coords[..., 0] = coords[..., 0] / w
122
+ coords[..., 1] = coords[..., 1] / h
123
+ coords = coords * self.image_size # unnormalize coords
124
+ return coords.reshape(4)
125
+
126
+
127
+
avs.code/v1m.code/inference.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Distributed inference on the test set; runs the same three `process` modes as training validation."""
2
+ import os
3
+ import pathlib
4
+ import torch
5
+ import numpy
6
+ import random
7
+ import argparse
8
+ from easydict import EasyDict
9
+
10
+ # Avoid import failure when configs.config creates saved_dir without write permission.
11
+ _real_mkdir = pathlib.Path.mkdir
12
+
13
+
14
+ def _safe_mkdir(self, mode=0o777, parents=False, exist_ok=False):
15
+ try:
16
+ return _real_mkdir(self, mode, parents=parents, exist_ok=exist_ok)
17
+ except PermissionError:
18
+ pass
19
+
20
+
21
+ pathlib.Path.mkdir = _safe_mkdir
22
+
23
+
24
+ def seed_it(seed):
25
+ random.seed(seed)
26
+ os.environ["PYTHONSEED"] = str(seed)
27
+ numpy.random.seed(seed)
28
+ torch.cuda.manual_seed(seed)
29
+ torch.cuda.manual_seed_all(seed)
30
+ torch.backends.cudnn.deterministic = True
31
+ torch.backends.cudnn.benchmark = True
32
+ torch.backends.cudnn.enabled = True
33
+ torch.manual_seed(seed)
34
+
35
+
36
+ class _DummyTensorboard:
37
+ """Minimal Tensorboard stub so Trainer.valid runs without wandb logging."""
38
+
39
+ def upload_wandb_info(self, info_dict):
40
+ pass
41
+
42
+ def upload_wandb_image(self, *args, **kwargs):
43
+ pass
44
+
45
+
46
+ def main(local_rank, ngpus_per_node, hyp_param):
47
+ hyp_param.local_rank = local_rank
48
+ torch.distributed.init_process_group(
49
+ backend='nccl',
50
+ init_method='env://',
51
+ rank=hyp_param.local_rank,
52
+ world_size=hyp_param.gpus * 1
53
+ )
54
+ seed_it(local_rank + hyp_param.seed)
55
+
56
+ import model.visual.sam2 # noqa: F401 — registers Hydra `configs`
57
+ from hydra import compose
58
+ from omegaconf import OmegaConf
59
+
60
+ arch_h = compose(config_name='auralfuser/architecture.yaml')
61
+ OmegaConf.resolve(arch_h)
62
+ hyp_param.aural_fuser = OmegaConf.to_container(arch_h.aural_fuser, resolve=True)
63
+
64
+ train_cfg = compose(config_name='training/sam2_training_config.yaml')
65
+ OmegaConf.resolve(train_cfg)
66
+ hyp_param.contrastive_learning = OmegaConf.to_container(train_cfg.contrastive_learning, resolve=True)
67
+
68
+ from model.mymodel import AVmodel
69
+ av_model = AVmodel(hyp_param).cuda()
70
+ torch.cuda.set_device(hyp_param.local_rank)
71
+ ckpt_sd = torch.load(hyp_param.inference_ckpt, map_location="cpu")
72
+ if not isinstance(ckpt_sd, dict):
73
+ raise TypeError("Checkpoint must be a state_dict dictionary.")
74
+ # Same as v1s/v2: full-model ckpt vs train-only aural_fuser ckpt (e.g. keys vgg.*, f_blocks.*).
75
+ if any(k.startswith("v_model.") or k.startswith("aural_fuser.") for k in ckpt_sd.keys()):
76
+ av_model.load_state_dict(ckpt_sd, strict=True)
77
+ else:
78
+ av_model.aural_fuser.load_state_dict(ckpt_sd, strict=True)
79
+
80
+ av_model = torch.nn.parallel.distributed.DistributedDataParallel(av_model, device_ids=[hyp_param.local_rank],
81
+ find_unused_parameters=False)
82
+ av_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(av_model)
83
+ av_model.eval()
84
+
85
+ from dataloader.dataset import AV
86
+ from dataloader.visual.visual_augmentation import Augmentation as VisualAugmentation
87
+ from dataloader.audio.audio_augmentation import Augmentation as AudioAugmentation
88
+ from torch.utils.data import DataLoader, Subset
89
+ from torch.utils.data.distributed import DistributedSampler
90
+
91
+ visual_augmentation = VisualAugmentation(hyp_param.image_mean, hyp_param.image_std,
92
+ hyp_param.image_size, hyp_param.image_size,
93
+ hyp_param.scale_list, ignore_index=hyp_param.ignore_index)
94
+ audio_augmentation = AudioAugmentation(mono=True)
95
+
96
+ dataset = AV(split='test', augmentation={"visual": visual_augmentation, "audio": audio_augmentation},
97
+ param=hyp_param, root_path=hyp_param.data_root_path, data_name=hyp_param.inference_data_name)
98
+
99
+ max_batches = getattr(hyp_param, "inference_max_batches", 0) or 0
100
+ if max_batches > 0:
101
+ n_samples = min(max_batches * hyp_param.batch_size, len(dataset))
102
+ dataset = Subset(dataset, range(n_samples))
103
+
104
+ sampler = DistributedSampler(dataset, shuffle=False)
105
+ test_dataloader = DataLoader(dataset, batch_size=hyp_param.batch_size, sampler=sampler,
106
+ num_workers=hyp_param.num_workers)
107
+
108
+ from trainer.train import Trainer
109
+ from utils.foreground_iou import ForegroundIoU
110
+ from utils.foreground_fscore import ForegroundFScore
111
+
112
+ metrics = {
113
+ "foreground_iou": ForegroundIoU(),
114
+ "foreground_f-score": ForegroundFScore(hyp_param.local_rank),
115
+ }
116
+ trainer = Trainer(hyp_param, loss=None, tensorboard=_DummyTensorboard(), metrics=metrics)
117
+
118
+ # Same three modes as main.py validation: default first mask / iou_select / iou_occ_select
119
+ runs = [
120
+ ("", "default (logits[:,0])"),
121
+ ("iou_select", "iou_select"),
122
+ ("iou_occ_select", "iou_occ_select"),
123
+ ]
124
+ results = []
125
+ for process, label in runs:
126
+ fiou, ffscore = trainer.valid(epoch=0, dataloader=test_dataloader, model=av_model, process=process)
127
+ results.append((label, fiou, ffscore))
128
+ torch.cuda.empty_cache()
129
+
130
+ if hyp_param.local_rank <= 0:
131
+ print("\n========== inference (same three process flags as training valid) ==========")
132
+ for label, fiou, ffscore in results:
133
+ print(" {:32s} f_iou={} f_f-score={}".format(label, fiou, ffscore))
134
+ print("=======================================================\n")
135
+
136
+
137
+ if __name__ == '__main__':
138
+ parser = argparse.ArgumentParser(description='Inference: full test set + three process modes')
139
+
140
+ parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N')
141
+
142
+ parser.add_argument("--local_rank", type=int, default=-1,
143
+ help='multi-process training for DDP')
144
+
145
+ parser.add_argument('-g', '--gpus', default=1, type=int,
146
+ help='number of gpus per node')
147
+
148
+ parser.add_argument('--batch_size', default=1, type=int,
149
+ help='Batch size (match training if needed)')
150
+
151
+ parser.add_argument('--epochs', default=80, type=int,
152
+ help="unused")
153
+
154
+ parser.add_argument('--lr', default=1e-5, type=float,
155
+ help="unused")
156
+
157
+ parser.add_argument('--online', action="store_true",
158
+ help='unused')
159
+
160
+ parser.add_argument(
161
+ '--inference_ckpt', type=str, default=None,
162
+ help='Trained AuralSAM2 checkpoint (.pth state_dict). '
163
+ 'SAM2 backbone is loaded from backbone_weight in configs (same path as training: repo_root/ckpts/sam_ckpts/). '
164
+ 'Default if unset: avs.code/training_details/.../hiera_l.pth',
165
+ )
166
+ parser.add_argument('--inference_data_name', type=str, default='v1m',
167
+ help='AVSBench subset folder label (v1s|v1m|v2); must match training test split')
168
+ parser.add_argument('--inference_max_batches', type=int, default=0,
169
+ help='0 = full test; >0 = first N batches only (debug)')
170
+
171
+ args = parser.parse_args()
172
+
173
+ from configs.config import C
174
+
175
+ args = EasyDict({**C, **vars(args)})
176
+
177
+ _repo = pathlib.Path(__file__).resolve().parent
178
+ # Repo root: .../AuralSAM2 (parent of avs.code)
179
+ _workspace = _repo.parent.parent
180
+ args.data_root_path = str(_workspace / 'AVSBench')
181
+ args.backbone_weight = str(_workspace / 'ckpts' / 'sam_ckpts' / 'sam2_hiera_large.pt')
182
+ args.audio.PRETRAINED_VGGISH_MODEL_PATH = str(_workspace / 'ckpts' / 'vggish-10086976.pth')
183
+ args.saved_dir = '/tmp/v1m_infer_ckpt'
184
+ pathlib.Path(args.saved_dir).mkdir(parents=True, exist_ok=True)
185
+ if args.inference_ckpt is None:
186
+ args.inference_ckpt = str(
187
+ _repo.parent / 'training_details' / 'v1m' / 'hiera_l' / 'hiera_l.pth'
188
+ )
189
+
190
+ os.environ['MASTER_ADDR'] = '127.0.0.1'
191
+ os.environ['MASTER_PORT'] = '9901'
192
+
193
+ torch.multiprocessing.spawn(main, nprocs=args.gpus, args=(args.gpus, args))
avs.code/v1m.code/loss/training/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ """Training loss modules."""
2
+
avs.code/v1m.code/loss/training/contrastive_learning.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class ContrastLoss(nn.Module, ABC):
8
+ def __init__(self, hyp_param):
9
+ super().__init__()
10
+ self.param = hyp_param
11
+ _defaults = {
12
+ "temperature": 0.10,
13
+ "ignore_idx": 255,
14
+ "ood_idx": 254,
15
+ "max_views": 512,
16
+ "proj_dim": 512,
17
+ "sample_limits": 128,
18
+ "total_limits": 15240,
19
+ }
20
+ _raw = getattr(hyp_param, "contrastive_learning", None) or {}
21
+ _cfg = {**_defaults, **_raw}
22
+ self.temperature = _cfg["temperature"]
23
+ self.ignore_idx = _cfg["ignore_idx"]
24
+ self.ood_idx = _cfg["ood_idx"]
25
+ self.max_views = _cfg["max_views"]
26
+ self.proj_dim = _cfg["proj_dim"]
27
+ self.sample_limits = _cfg["sample_limits"]
28
+ self.total_limits = _cfg["total_limits"]
29
+
30
+ def select_class_wise_samples(self, embeddings, audio_embeddings, predictions, masks, batch_idx):
31
+ embedding_sample_list = []
32
+ label_list = []
33
+ embedding_sample_list_a = []
34
+ label_list_a = []
35
+ class_index_list = torch.unique(masks)
36
+
37
+ if len(class_index_list) > 1:
38
+ for class_index in class_index_list[1:]:
39
+ embedding_sample_list_a.append(audio_embeddings.unsqueeze(0))
40
+ label_list_a.append(class_index.unsqueeze(0) + batch_idx * 1e3)
41
+ else:
42
+ embedding_sample_list_a.append(audio_embeddings.unsqueeze(0))
43
+ label_list_a.append(torch.zeros([1], device=embeddings.device) + batch_idx * 1e3)
44
+
45
+ sample_limits = self.sample_limits
46
+ embeddings = embeddings.permute(1, 0)
47
+ for class_index in class_index_list:
48
+ hard_indices = embeddings[((masks != predictions) & (masks == class_index)).nonzero()]
49
+ easy_indices = embeddings[((masks == predictions) & (masks == class_index)).nonzero()]
50
+
51
+ hard_indices_num, easy_indices_num = hard_indices.shape[0], easy_indices.shape[0]
52
+ selective_num_hard = min(sample_limits, hard_indices_num)
53
+ selective_num_easy = min(sample_limits, easy_indices_num)
54
+
55
+ if (selective_num_hard + selective_num_easy) < sample_limits * 2:
56
+ if selective_num_hard > selective_num_easy:
57
+ selective_num_hard += sample_limits * 2 - selective_num_easy
58
+ else:
59
+ selective_num_easy += sample_limits * 2 - selective_num_hard
60
+
61
+ hard_chosen_indices = torch.randperm(hard_indices_num)[:selective_num_hard]
62
+ embedding_sample_list.append(hard_indices[hard_chosen_indices])
63
+ label_list.append(masks[hard_chosen_indices] + batch_idx * 1e3)
64
+
65
+ easy_chosen_indices = torch.randperm(easy_indices_num)[:selective_num_easy]
66
+ embedding_sample_list.append(easy_indices[easy_chosen_indices])
67
+ label_list.append(masks[easy_chosen_indices] + batch_idx * 1e3)
68
+ return embedding_sample_list, label_list, embedding_sample_list_a, label_list_a
69
+
70
+ def forward_audio_visual(self, visual_embeddings, audio_embeddings, masks, predictions):
71
+ masks = masks.flatten(start_dim=1)
72
+ predictions = predictions.flatten(start_dim=1)
73
+ visual_embeddings = visual_embeddings.flatten(start_dim=-2)
74
+
75
+ visual_embedding_sample_list = []
76
+ visual_label_list = []
77
+ audio_embedding_sample_list = []
78
+ audio_label_list = []
79
+
80
+ for frame_idx in range(masks.shape[0]):
81
+ current_vision_feats = visual_embeddings[frame_idx]
82
+ current_masks = masks[frame_idx]
83
+ current_predictions = predictions[frame_idx]
84
+ current_audio_feats = audio_embeddings[frame_idx]
85
+ for layer_idx in range(3):
86
+ (
87
+ selected_vision_embeddings,
88
+ selected_vision_labels,
89
+ selected_audio_embeddings,
90
+ selected_audio_labels,
91
+ ) = self.select_class_wise_samples(
92
+ current_vision_feats[layer_idx],
93
+ current_audio_feats[layer_idx],
94
+ current_predictions,
95
+ current_masks,
96
+ 0,
97
+ )
98
+ visual_embedding_sample_list += selected_vision_embeddings
99
+ visual_label_list += selected_vision_labels
100
+ audio_embedding_sample_list += selected_audio_embeddings
101
+ audio_label_list += selected_audio_labels
102
+
103
+ if len(visual_embedding_sample_list) == 0:
104
+ return 0.0
105
+
106
+ visual_embedding_sample_list = torch.cat(visual_embedding_sample_list, dim=0).squeeze()
107
+ visual_label_list = torch.cat(visual_label_list, dim=0).unsqueeze(-1)
108
+ audio_embedding_sample_list = torch.cat(audio_embedding_sample_list, dim=0).squeeze()
109
+ audio_label_list = torch.cat(audio_label_list).unsqueeze(1)
110
+
111
+ total_limits = self.total_limits
112
+ if visual_embedding_sample_list.shape[0] > total_limits:
113
+ rand_index = torch.randperm(visual_embedding_sample_list.shape[0])[total_limits]
114
+ visual_embedding_sample_list = visual_embedding_sample_list[:rand_index]
115
+ visual_label_list = visual_label_list[:rand_index]
116
+ loss = self.info_nce(
117
+ visual_embedding_sample_list,
118
+ visual_label_list,
119
+ audio_embedding_sample_list,
120
+ audio_label_list,
121
+ )
122
+ return loss
123
+
124
+ def forward(self, embeddings, output_dicts, masks):
125
+ predictions = torch.cat([i["multistep_pred_masks"] for i in output_dicts])
126
+ predictions = torch.nn.functional.interpolate(
127
+ predictions,
128
+ size=(int(self.param.image_size / 16), int(self.param.image_size / 16)),
129
+ mode="bilinear",
130
+ align_corners=False,
131
+ ).squeeze()
132
+ masks = torch.nn.functional.interpolate(
133
+ masks.unsqueeze(1),
134
+ size=(int(self.param.image_size / 16), int(self.param.image_size / 16)),
135
+ mode="nearest",
136
+ ).squeeze()
137
+ visual_embeddings, audio_embeddings = embeddings
138
+ visual_embeddings = torch.cat(
139
+ [
140
+ torch.cat(
141
+ [
142
+ visual_embeddings[0][i].unsqueeze(0),
143
+ visual_embeddings[1][i].unsqueeze(0),
144
+ visual_embeddings[2][i].unsqueeze(0),
145
+ ]
146
+ ).unsqueeze(0)
147
+ for i in range(masks.shape[0])
148
+ ]
149
+ )
150
+ audio_embeddings = torch.cat(
151
+ [
152
+ torch.cat(
153
+ [
154
+ audio_embeddings[0][i].unsqueeze(0),
155
+ audio_embeddings[1][i].unsqueeze(0),
156
+ audio_embeddings[2][i].unsqueeze(0),
157
+ ]
158
+ ).unsqueeze(0)
159
+ for i in range(masks.shape[0])
160
+ ]
161
+ )
162
+ return self.forward_audio_visual(
163
+ visual_embeddings, audio_embeddings.squeeze(), masks, predictions
164
+ )
165
+
166
+ @staticmethod
167
+ def manipulate_cover_mask(a_label, current_mask):
168
+ a_label = a_label + 1
169
+ visual_mask = torch.matmul(a_label, torch.transpose(a_label, 0, 1))
170
+ current_mask[: visual_mask.shape[1], : visual_mask.shape[0]][visual_mask == 1.0] = 0
171
+ current_mask[: visual_mask.shape[1], : visual_mask.shape[0]][visual_mask == 4.0] = 0
172
+ return current_mask
173
+
174
+ def info_nce(self, anchors_, a_labels_, contras_, c_labels_):
175
+ c_labels_ = torch.cat([a_labels_, c_labels_])
176
+ contras_ = torch.cat([anchors_, contras_])
177
+ mask = torch.eq(a_labels_, torch.transpose(c_labels_, 0, 1)).float()
178
+
179
+ anchor_dot_contrast = torch.div(
180
+ torch.matmul(anchors_, torch.transpose(contras_, 0, 1)),
181
+ self.temperature,
182
+ )
183
+
184
+ logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
185
+ logits = anchor_dot_contrast - logits_max.detach()
186
+ neg_mask = 1 - mask
187
+
188
+ mask = self.manipulate_cover_mask(a_label=a_labels_, current_mask=mask)
189
+ mask = mask.fill_diagonal_(0.0)
190
+
191
+ neg_logits = torch.exp(logits) * neg_mask
192
+ neg_logits = neg_logits.sum(1, keepdim=True)
193
+ exp_logits = torch.exp(logits)
194
+ log_prob = logits - torch.log(exp_logits + neg_logits)
195
+
196
+ mask_pos_pairs = mask.sum(1)
197
+ mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)
198
+ mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs
199
+ assert not torch.isnan(mean_log_prob_pos).any(), print(torch.isnan(log_prob).any())
200
+ return -mean_log_prob_pos.mean()
201
+
avs.code/v1m.code/loss/training/sam2_training_loss.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from typing import Dict, List
3
+
4
+ import torch
5
+ import torch.distributed
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ CORE_LOSS_KEY = "core_loss"
10
+
11
+
12
+ def dice_loss(inputs, targets, num_objects, loss_on_multimask=False):
13
+ inputs = inputs.sigmoid()
14
+ if loss_on_multimask:
15
+ assert inputs.dim() == 4 and targets.dim() == 4
16
+ inputs = inputs.flatten(2)
17
+ targets = targets.flatten(2)
18
+ numerator = 2 * (inputs * targets).sum(-1)
19
+ else:
20
+ inputs = inputs.flatten(1)
21
+ numerator = 2 * (inputs * targets).sum(1)
22
+ denominator = inputs.sum(-1) + targets.sum(-1)
23
+ loss = 1 - (numerator + 1) / (denominator + 1)
24
+ if loss_on_multimask:
25
+ return loss / num_objects
26
+ return loss.sum() / num_objects
27
+
28
+
29
+ def sigmoid_focal_loss(
30
+ inputs,
31
+ targets,
32
+ num_objects,
33
+ alpha: float = 0.25,
34
+ gamma: float = 2,
35
+ loss_on_multimask=False,
36
+ ):
37
+ prob = inputs.sigmoid()
38
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
39
+ p_t = prob * targets + (1 - prob) * (1 - targets)
40
+ loss = ce_loss * ((1 - p_t) ** gamma)
41
+
42
+ if alpha >= 0:
43
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
44
+ loss = alpha_t * loss
45
+
46
+ if loss_on_multimask:
47
+ assert loss.dim() == 4
48
+ return loss.flatten(2).mean(-1) / num_objects
49
+ return loss.mean(1).sum() / num_objects
50
+
51
+
52
+ def iou_loss(
53
+ inputs, targets, pred_ious, num_objects, loss_on_multimask=False, use_l1_loss=False
54
+ ):
55
+ assert inputs.dim() == 4 and targets.dim() == 4
56
+ pred_mask = inputs.flatten(2) > 0
57
+ gt_mask = targets.flatten(2) > 0
58
+ area_i = torch.sum(pred_mask & gt_mask, dim=-1).float()
59
+ area_u = torch.sum(pred_mask | gt_mask, dim=-1).float()
60
+ actual_ious = area_i / torch.clamp(area_u, min=1.0)
61
+
62
+ if use_l1_loss:
63
+ loss = F.l1_loss(pred_ious, actual_ious, reduction="none")
64
+ else:
65
+ loss = F.mse_loss(pred_ious, actual_ious, reduction="none")
66
+ if loss_on_multimask:
67
+ return loss / num_objects
68
+ return loss.sum() / num_objects
69
+
70
+
71
+ class MultiStepMultiMasksAndIous(nn.Module):
72
+ def __init__(
73
+ self,
74
+ weight_dict,
75
+ focal_alpha=0.25,
76
+ focal_gamma=2,
77
+ supervise_all_iou=False,
78
+ iou_use_l1_loss=False,
79
+ pred_obj_scores=False,
80
+ focal_gamma_obj_score=0.0,
81
+ focal_alpha_obj_score=-1,
82
+ gpu_num=1,
83
+ ):
84
+ super().__init__()
85
+ self.weight_dict = weight_dict
86
+ self.focal_alpha = focal_alpha
87
+ self.focal_gamma = focal_gamma
88
+ self.world_size = gpu_num
89
+ assert "loss_mask" in self.weight_dict
90
+ assert "loss_dice" in self.weight_dict
91
+ assert "loss_iou" in self.weight_dict
92
+ if "loss_class" not in self.weight_dict:
93
+ self.weight_dict["loss_class"] = 0.0
94
+
95
+ self.focal_alpha_obj_score = focal_alpha_obj_score
96
+ self.focal_gamma_obj_score = focal_gamma_obj_score
97
+ self.supervise_all_iou = supervise_all_iou
98
+ self.iou_use_l1_loss = iou_use_l1_loss
99
+ self.pred_obj_scores = pred_obj_scores
100
+
101
+ def forward(self, outs_batch: List[Dict], targets_batch: torch.Tensor):
102
+ assert len(outs_batch) == len(targets_batch)
103
+ num_objects = torch.tensor(
104
+ targets_batch.shape[1], device=targets_batch.device, dtype=torch.float
105
+ )
106
+ torch.distributed.all_reduce(num_objects)
107
+ num_objects = torch.clamp(num_objects / self.world_size, min=1).item()
108
+
109
+ losses = defaultdict(int)
110
+ for outs, targets in zip(outs_batch, targets_batch):
111
+ cur_losses = self._forward(outs, targets, num_objects)
112
+ for k, v in cur_losses.items():
113
+ losses[k] += v
114
+ return losses
115
+
116
+ def _forward(self, outputs: Dict, targets: torch.Tensor, num_objects):
117
+ target_masks = targets.unsqueeze(1).float()
118
+ assert target_masks.dim() == 4
119
+
120
+ src_masks_list = outputs["multistep_pred_multimasks_high_res"]
121
+ ious_list = outputs["multistep_pred_ious"]
122
+ object_score_logits_list = outputs["multistep_object_score_logits"]
123
+ assert len(src_masks_list) == len(ious_list)
124
+ assert len(object_score_logits_list) == len(ious_list)
125
+
126
+ losses = {"loss_mask": 0, "loss_dice": 0, "loss_iou": 0, "loss_class": 0}
127
+ for src_masks, ious, object_score_logits in zip(
128
+ src_masks_list, ious_list, object_score_logits_list
129
+ ):
130
+ self._update_losses(
131
+ losses, src_masks, target_masks, ious, num_objects, object_score_logits
132
+ )
133
+ losses[CORE_LOSS_KEY] = self.reduce_loss(losses)
134
+ return losses
135
+
136
+ def _update_losses(
137
+ self, losses, src_masks, target_masks, ious, num_objects, object_score_logits
138
+ ):
139
+ target_masks = target_masks.expand_as(src_masks)
140
+ loss_multimask = sigmoid_focal_loss(
141
+ src_masks,
142
+ target_masks,
143
+ num_objects,
144
+ alpha=self.focal_alpha,
145
+ gamma=self.focal_gamma,
146
+ loss_on_multimask=True,
147
+ )
148
+ loss_multidice = dice_loss(
149
+ src_masks, target_masks, num_objects, loss_on_multimask=True
150
+ )
151
+ if not self.pred_obj_scores:
152
+ loss_class = torch.tensor(
153
+ 0.0, dtype=loss_multimask.dtype, device=loss_multimask.device
154
+ )
155
+ target_obj = torch.ones(
156
+ loss_multimask.shape[0],
157
+ 1,
158
+ dtype=loss_multimask.dtype,
159
+ device=loss_multimask.device,
160
+ )
161
+ else:
162
+ target_obj = torch.any((target_masks[:, 0] > 0).flatten(1), dim=-1)[
163
+ ..., None
164
+ ].float()
165
+ loss_class = sigmoid_focal_loss(
166
+ object_score_logits,
167
+ target_obj,
168
+ num_objects,
169
+ alpha=self.focal_alpha_obj_score,
170
+ gamma=self.focal_gamma_obj_score,
171
+ )
172
+
173
+ loss_multiiou = iou_loss(
174
+ src_masks,
175
+ target_masks,
176
+ ious,
177
+ num_objects,
178
+ loss_on_multimask=True,
179
+ use_l1_loss=self.iou_use_l1_loss,
180
+ )
181
+ assert loss_multimask.dim() == 2
182
+ assert loss_multidice.dim() == 2
183
+ assert loss_multiiou.dim() == 2
184
+ if loss_multimask.size(1) > 1:
185
+ loss_combo = (
186
+ loss_multimask * self.weight_dict["loss_mask"]
187
+ + loss_multidice * self.weight_dict["loss_dice"]
188
+ )
189
+ best_loss_inds = torch.argmin(loss_combo, dim=-1)
190
+ batch_inds = torch.arange(loss_combo.size(0), device=loss_combo.device)
191
+
192
+ loss_mask = loss_multimask[batch_inds, best_loss_inds].unsqueeze(1)
193
+ loss_dice = loss_multidice[batch_inds, best_loss_inds].unsqueeze(1)
194
+ if self.supervise_all_iou:
195
+ loss_iou = loss_multiiou.mean(dim=-1).unsqueeze(1)
196
+ else:
197
+ loss_iou = loss_multiiou[batch_inds, best_loss_inds].unsqueeze(1)
198
+ else:
199
+ loss_mask = loss_multimask
200
+ loss_dice = loss_multidice
201
+ loss_iou = loss_multiiou
202
+
203
+ loss_mask = loss_mask * target_obj
204
+ loss_dice = loss_dice * target_obj
205
+ loss_iou = loss_iou * target_obj
206
+
207
+ losses["loss_mask"] += loss_mask.sum()
208
+ losses["loss_dice"] += loss_dice.sum()
209
+ losses["loss_iou"] += loss_iou.sum()
210
+ losses["loss_class"] += loss_class
211
+
212
+ def reduce_loss(self, losses):
213
+ reduced_loss = 0.0
214
+ for loss_key, weight in self.weight_dict.items():
215
+ if loss_key not in losses:
216
+ raise ValueError(f"{type(self)} doesn't compute {loss_key}")
217
+ if weight != 0:
218
+ reduced_loss += losses[loss_key] * weight
219
+ return reduced_loss
220
+
avs.code/v1m.code/main.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DDP training entry: AV model with SAM2 frozen, AuralFuser trainable, Hydra transforms and loss."""
2
+ import os
3
+ import torch
4
+ import numpy
5
+ import random
6
+ import argparse
7
+ from easydict import EasyDict
8
+
9
+
10
+ def seed_it(seed):
11
+ """Fix RNGs and cuDNN for reproducible runs (rank offsets seed in DDP)."""
12
+ os.environ["PYTHONSEED"] = str(seed)
13
+ random.seed(seed)
14
+ numpy.random.seed(seed)
15
+ torch.manual_seed(seed)
16
+ torch.cuda.manual_seed(seed)
17
+ torch.cuda.manual_seed_all(seed)
18
+ torch.backends.cudnn.enabled = True
19
+ torch.backends.cudnn.deterministic = True
20
+
21
+ torch.backends.cudnn.benchmark = False
22
+
23
+
24
+ def main(local_rank, ngpus_per_node, hyp_param):
25
+ hyp_param.local_rank = local_rank
26
+ # NCCL process group; world size = GPUs on this node
27
+ torch.distributed.init_process_group(
28
+ backend='nccl',
29
+ init_method='env://',
30
+ rank=hyp_param.local_rank,
31
+ world_size=hyp_param.gpus * 1
32
+ )
33
+ seed_it(local_rank + hyp_param.seed)
34
+
35
+ torch.cuda.set_device(hyp_param.local_rank)
36
+
37
+ import model.visual.sam2 # noqa: F401 — registers Hydra `configs` (initialize_config_module)
38
+
39
+ from hydra import compose
40
+ from hydra.utils import instantiate
41
+ from omegaconf import OmegaConf
42
+
43
+ # Hydra configs under v1m.code/configs (same pattern as training/sam2_training_config.yaml)
44
+ transform_config_path = 'training/sam2_training_config.yaml'
45
+
46
+ if 'hiera_t' in hyp_param.sam_config_path:
47
+ hyp_param.image_size = 224
48
+ hyp_param.image_embedding_size = int(hyp_param.image_size / 16)
49
+ print('\n upload image size to be {}x{} \n'.format(224, 224), flush=True)
50
+
51
+ cfg = compose(config_name=transform_config_path)
52
+ OmegaConf.resolve(cfg)
53
+ hyp_param.contrastive_learning = OmegaConf.to_container(cfg.contrastive_learning, resolve=True)
54
+
55
+ arch_h = compose(config_name='auralfuser/architecture.yaml')
56
+ OmegaConf.resolve(arch_h)
57
+ hyp_param.aural_fuser = OmegaConf.to_container(arch_h.aural_fuser, resolve=True)
58
+
59
+ from model.mymodel import AVmodel
60
+ av_model = AVmodel(hyp_param).cuda(hyp_param.local_rank)
61
+
62
+ av_model = torch.nn.parallel.distributed.DistributedDataParallel(av_model, device_ids=[hyp_param.local_rank],
63
+ find_unused_parameters=True)
64
+
65
+ # Optimizer: parameter groups from AuralFuser only (train_* vs VGG backbone)
66
+ from utils.utils import manipulate_params
67
+ parameter_list = manipulate_params(hyp_param, av_model.module.aural_fuser)
68
+ optimiser = torch.optim.AdamW(parameter_list, betas=(0.9, 0.999))
69
+
70
+ from dataloader.dataset import AV
71
+ from dataloader.visual.visual_augmentation import Augmentation as VisualAugmentation
72
+ from dataloader.audio.audio_augmentation import Augmentation as AudioAugmentation
73
+ from torch.utils.data.distributed import DistributedSampler
74
+
75
+ compose_api = instantiate(cfg.train_transforms, _recursive_=True)[0]
76
+
77
+ audio_augmentation = AudioAugmentation(mono=True)
78
+ train_dataset = AV(split='train', augmentation={"visual": compose_api, "audio": audio_augmentation},
79
+ param=hyp_param, root_path=hyp_param.data_root_path, data_name=hyp_param.data_name)
80
+
81
+
82
+ visual_augmentation = VisualAugmentation(hyp_param.image_mean, hyp_param.image_std,
83
+ hyp_param.image_size, hyp_param.image_size,
84
+ hyp_param.scale_list, ignore_index=hyp_param.ignore_index)
85
+
86
+ audio_augmentation = AudioAugmentation(mono=True)
87
+
88
+ random_sampler = DistributedSampler(train_dataset, shuffle=True)
89
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=hyp_param.batch_size,
90
+ sampler=random_sampler,
91
+ num_workers=hyp_param.num_workers, drop_last=True)
92
+
93
+ test_dataset = AV(split='test', augmentation={"visual": visual_augmentation, "audio": audio_augmentation},
94
+ param=hyp_param, root_path=hyp_param.data_root_path, data_name=hyp_param.data_name)
95
+
96
+ order_sampler = DistributedSampler(test_dataset, shuffle=False)
97
+ test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, sampler=order_sampler,
98
+ num_workers=hyp_param.num_workers)
99
+
100
+
101
+ criterion = instantiate(cfg.loss, _recursive_=True)['all']
102
+ from utils.tensorboard import Tensorboard
103
+ tensorboard = Tensorboard(config=hyp_param) if hyp_param.local_rank <= 0 else None
104
+
105
+ from trainer.train import Trainer
106
+ from utils.foreground_iou import ForegroundIoU
107
+ from utils.foreground_fscore import ForegroundFScore
108
+ metrics = {"foreground_iou": ForegroundIoU(), "foreground_f-score": ForegroundFScore(0 if hyp_param.local_rank <= 0 else hyp_param.local_rank)}
109
+
110
+ trainer = Trainer(hyp_param, loss=criterion, tensorboard=tensorboard, metrics=metrics)
111
+
112
+
113
+ curr_best = 0. # checkpoint when IoU (iou_select mode) improves
114
+
115
+ for epoch in range(hyp_param.epochs):
116
+ av_model.train()
117
+ av_model.module.freeze_sam_parameters()
118
+ random_sampler.set_epoch(epoch)
119
+ trainer.train(epoch=epoch, dataloader=train_dataloader, model=av_model, optimiser=optimiser)
120
+
121
+ torch.distributed.barrier()
122
+ torch.cuda.empty_cache()
123
+
124
+ av_model.eval()
125
+ # Three validation modes: default first mask / IoU-selected mask / IoU + objectness gate
126
+ curr_results1, _ = trainer.valid(epoch=epoch, dataloader=test_dataloader, model=av_model, process='first_index')
127
+ curr_results, _ = trainer.valid(epoch=epoch, dataloader=test_dataloader, model=av_model, process='iou_select')
128
+ curr_results3, _ = trainer.valid(epoch=epoch, dataloader=test_dataloader, model=av_model, process='iou_occ_select')
129
+ if hyp_param.local_rank <= 0 and curr_results > curr_best:
130
+ curr_best = curr_results
131
+ torch.save(av_model.module.aural_fuser.state_dict(), os.path.join(hyp_param.saved_dir, str(curr_results) + ".pth"))
132
+ torch.distributed.barrier()
133
+ torch.cuda.empty_cache()
134
+
135
+
136
+ if __name__ == '__main__':
137
+ parser = argparse.ArgumentParser(description='PyTorch Training')
138
+ parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N')
139
+
140
+ parser.add_argument("--local_rank", type=int, default=-1,
141
+ help='multi-process training for DDP')
142
+
143
+ parser.add_argument('-g', '--gpus', default=1, type=int,
144
+ help='number of gpus per node')
145
+
146
+ parser.add_argument('--batch_size', default=1, type=int)
147
+
148
+ parser.add_argument('--epochs', default=80, type=int,
149
+ help="total epochs that used for the training")
150
+
151
+ parser.add_argument('--lr', default=1e-4, type=float,
152
+ help='Default HEAD Learning rate is same as others, '
153
+ '*Note: in ddp training, lr will automatically times by n_gpu')
154
+
155
+ parser.add_argument('--online', action="store_true",
156
+ help='switch on for visualization; switch off for debug')
157
+
158
+ args = parser.parse_args()
159
+
160
+ from configs.config import C
161
+
162
+ args = EasyDict({**C, **vars(args)})
163
+ os.environ['MASTER_ADDR'] = '127.0.0.1'
164
+ os.environ['MASTER_PORT'] = '9902'
165
+
166
+ torch.multiprocessing.spawn(main, nprocs=args.gpus, args=(args.gpus, args))
avs.code/v1m.code/model/audio/torchvggish/mel_features.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2017 The TensorFlow Authors All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Defines routines to compute mel spectrogram features from audio waveform."""
17
+
18
+ import numpy as np
19
+
20
+
21
+ def frame(data, window_length, hop_length):
22
+ """Convert array into a sequence of successive possibly overlapping frames.
23
+
24
+ An n-dimensional array of shape (num_samples, ...) is converted into an
25
+ (n+1)-D array of shape (num_frames, window_length, ...), where each frame
26
+ starts hop_length points after the preceding one.
27
+
28
+ This is accomplished using stride_tricks, so the original data is not
29
+ copied. However, there is no zero-padding, so any incomplete frames at the
30
+ end are not included.
31
+
32
+ Args:
33
+ data: np.array of dimension N >= 1.
34
+ window_length: Number of samples in each frame.
35
+ hop_length: Advance (in samples) between each window.
36
+
37
+ Returns:
38
+ (N+1)-D np.array with as many rows as there are complete frames that can be
39
+ extracted.
40
+ """
41
+ num_samples = data.shape[0]
42
+ num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length))
43
+ shape = (num_frames, window_length) + data.shape[1:]
44
+ strides = (data.strides[0] * hop_length,) + data.strides
45
+ return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides)
46
+
47
+
48
+ def periodic_hann(window_length):
49
+ """Calculate a "periodic" Hann window.
50
+
51
+ The classic Hann window is defined as a raised cosine that starts and
52
+ ends on zero, and where every value appears twice, except the middle
53
+ point for an odd-length window. Matlab calls this a "symmetric" window
54
+ and np.hanning() returns it. However, for Fourier analysis, this
55
+ actually represents just over one cycle of a period N-1 cosine, and
56
+ thus is not compactly expressed on a length-N Fourier basis. Instead,
57
+ it's better to use a raised cosine that ends just before the final
58
+ zero value - i.e. a complete cycle of a period-N cosine. Matlab
59
+ calls this a "periodic" window. This routine calculates it.
60
+
61
+ Args:
62
+ window_length: The number of points in the returned window.
63
+
64
+ Returns:
65
+ A 1D np.array containing the periodic hann window.
66
+ """
67
+ return 0.5 - (0.5 * np.cos(2 * np.pi / window_length *
68
+ np.arange(window_length)))
69
+
70
+
71
+ def stft_magnitude(signal, fft_length,
72
+ hop_length=None,
73
+ window_length=None):
74
+ """Calculate the short-time Fourier transform magnitude.
75
+
76
+ Args:
77
+ signal: 1D np.array of the input time-domain signal.
78
+ fft_length: Size of the FFT to apply.
79
+ hop_length: Advance (in samples) between each frame passed to FFT.
80
+ window_length: Length of each block of samples to pass to FFT.
81
+
82
+ Returns:
83
+ 2D np.array where each row contains the magnitudes of the fft_length/2+1
84
+ unique values of the FFT for the corresponding frame of input samples.
85
+ """
86
+ frames = frame(signal, window_length, hop_length)
87
+ # Apply frame window to each frame. We use a periodic Hann (cosine of period
88
+ # window_length) instead of the symmetric Hann of np.hanning (period
89
+ # window_length-1).
90
+ window = periodic_hann(window_length)
91
+ windowed_frames = frames * window
92
+ return np.abs(np.fft.rfft(windowed_frames, int(fft_length)))
93
+
94
+
95
+ # Mel spectrum constants and functions.
96
+ _MEL_BREAK_FREQUENCY_HERTZ = 700.0
97
+ _MEL_HIGH_FREQUENCY_Q = 1127.0
98
+
99
+
100
+ def hertz_to_mel(frequencies_hertz):
101
+ """Convert frequencies to mel scale using HTK formula.
102
+
103
+ Args:
104
+ frequencies_hertz: Scalar or np.array of frequencies in hertz.
105
+
106
+ Returns:
107
+ Object of same size as frequencies_hertz containing corresponding values
108
+ on the mel scale.
109
+ """
110
+ return _MEL_HIGH_FREQUENCY_Q * np.log(
111
+ 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
112
+
113
+
114
+ def spectrogram_to_mel_matrix(num_mel_bins=20,
115
+ num_spectrogram_bins=129,
116
+ audio_sample_rate=8000,
117
+ lower_edge_hertz=125.0,
118
+ upper_edge_hertz=3800.0):
119
+ """Return a matrix that can post-multiply spectrogram rows to make mel.
120
+
121
+ Returns a np.array matrix A that can be used to post-multiply a matrix S of
122
+ spectrogram values (STFT magnitudes) arranged as frames x bins to generate a
123
+ "mel spectrogram" M of frames x num_mel_bins. M = S A.
124
+
125
+ The classic HTK algorithm exploits the complementarity of adjacent mel bands
126
+ to multiply each FFT bin by only one mel weight, then add it, with positive
127
+ and negative signs, to the two adjacent mel bands to which that bin
128
+ contributes. Here, by expressing this operation as a matrix multiply, we go
129
+ from num_fft multiplies per frame (plus around 2*num_fft adds) to around
130
+ num_fft^2 multiplies and adds. However, because these are all presumably
131
+ accomplished in a single call to np.dot(), it's not clear which approach is
132
+ faster in Python. The matrix multiplication has the attraction of being more
133
+ general and flexible, and much easier to read.
134
+
135
+ Args:
136
+ num_mel_bins: How many bands in the resulting mel spectrum. This is
137
+ the number of columns in the output matrix.
138
+ num_spectrogram_bins: How many bins there are in the source spectrogram
139
+ data, which is understood to be fft_size/2 + 1, i.e. the spectrogram
140
+ only contains the nonredundant FFT bins.
141
+ audio_sample_rate: Samples per second of the audio at the input to the
142
+ spectrogram. We need this to figure out the actual frequencies for
143
+ each spectrogram bin, which dictates how they are mapped into mel.
144
+ lower_edge_hertz: Lower bound on the frequencies to be included in the mel
145
+ spectrum. This corresponds to the lower edge of the lowest triangular
146
+ band.
147
+ upper_edge_hertz: The desired top edge of the highest frequency band.
148
+
149
+ Returns:
150
+ An np.array with shape (num_spectrogram_bins, num_mel_bins).
151
+
152
+ Raises:
153
+ ValueError: if frequency edges are incorrectly ordered or out of range.
154
+ """
155
+ nyquist_hertz = audio_sample_rate / 2.
156
+ if lower_edge_hertz < 0.0:
157
+ raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz)
158
+ if lower_edge_hertz >= upper_edge_hertz:
159
+ raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" %
160
+ (lower_edge_hertz, upper_edge_hertz))
161
+ if upper_edge_hertz > nyquist_hertz:
162
+ raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" %
163
+ (upper_edge_hertz, nyquist_hertz))
164
+ spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins)
165
+ spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz)
166
+ # The i'th mel band (starting from i=1) has center frequency
167
+ # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge
168
+ # band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in
169
+ # the band_edges_mel arrays.
170
+ band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz),
171
+ hertz_to_mel(upper_edge_hertz), num_mel_bins + 2)
172
+ # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins
173
+ # of spectrogram values.
174
+ mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins))
175
+ for i in range(num_mel_bins):
176
+ lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3]
177
+ # Calculate lower and upper slopes for every spectrogram bin.
178
+ # Line segments are linear in the *mel* domain, not hertz.
179
+ lower_slope = ((spectrogram_bins_mel - lower_edge_mel) /
180
+ (center_mel - lower_edge_mel))
181
+ upper_slope = ((upper_edge_mel - spectrogram_bins_mel) /
182
+ (upper_edge_mel - center_mel))
183
+ # .. then intersect them with each other and zero.
184
+ mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope,
185
+ upper_slope))
186
+ # HTK excludes the spectrogram DC bin; make sure it always gets a zero
187
+ # coefficient.
188
+ mel_weights_matrix[0, :] = 0.0
189
+ return mel_weights_matrix
190
+
191
+
192
+ def log_mel_spectrogram(data,
193
+ audio_sample_rate=8000,
194
+ log_offset=0.0,
195
+ window_length_secs=0.025,
196
+ hop_length_secs=0.010,
197
+ **kwargs):
198
+ """Convert waveform to a log magnitude mel-frequency spectrogram.
199
+
200
+ Args:
201
+ data: 1D np.array of waveform data.
202
+ audio_sample_rate: The sampling rate of data.
203
+ log_offset: Add this to values when taking log to avoid -Infs.
204
+ window_length_secs: Duration of each window to analyze.
205
+ hop_length_secs: Advance between successive analysis windows.
206
+ **kwargs: Additional arguments to pass to spectrogram_to_mel_matrix.
207
+
208
+ Returns:
209
+ 2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank
210
+ magnitudes for successive frames.
211
+ """
212
+ window_length_samples = int(round(audio_sample_rate * window_length_secs))
213
+ hop_length_samples = int(round(audio_sample_rate * hop_length_secs))
214
+ fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0)))
215
+ spectrogram = stft_magnitude(
216
+ data,
217
+ fft_length=fft_length,
218
+ hop_length=hop_length_samples,
219
+ window_length=window_length_samples)
220
+ mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix(
221
+ num_spectrogram_bins=spectrogram.shape[1],
222
+ audio_sample_rate=audio_sample_rate, **kwargs))
223
+ return np.log(mel_spectrogram + log_offset)
avs.code/v1m.code/model/audio/torchvggish/vggish.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch import hub
5
+
6
+ from . import vggish_input, vggish_params
7
+
8
+
9
+ class VGG(nn.Module):
10
+ def __init__(self, features):
11
+ super(VGG, self).__init__()
12
+ self.features = features
13
+ self.embeddings = nn.Sequential(
14
+ nn.Linear(512 * 4 * 6, 4096),
15
+ nn.ReLU(True),
16
+ nn.Linear(4096, 4096),
17
+ nn.ReLU(True),
18
+ nn.Linear(4096, 128),
19
+ nn.ReLU(True))
20
+
21
+ def forward(self, x):
22
+ x = self.features(x)
23
+
24
+ # Transpose the output from features to
25
+ # remain compatible with vggish embeddings
26
+ x = torch.transpose(x, 1, 3)
27
+ x = torch.transpose(x, 1, 2)
28
+ x = x.contiguous()
29
+ x = x.view(x.size(0), -1)
30
+
31
+ return self.embeddings(x)
32
+
33
+
34
+ class Postprocessor(nn.Module):
35
+ """Post-processes VGGish embeddings. Returns a torch.Tensor instead of a
36
+ numpy array in order to preserve the gradient.
37
+
38
+ "The initial release of AudioSet included 128-D VGGish embeddings for each
39
+ segment of AudioSet. These released embeddings were produced by applying
40
+ a PCA transformation (technically, a whitening transform is included as well)
41
+ and 8-bit quantization to the raw embedding output from VGGish, in order to
42
+ stay compatible with the YouTube-8M project which provides visual embeddings
43
+ in the same format for a large set of YouTube videos. This class implements
44
+ the same PCA (with whitening) and quantization transformations."
45
+ """
46
+
47
+ def __init__(self):
48
+ """Constructs a postprocessor."""
49
+ super(Postprocessor, self).__init__()
50
+ # Create empty matrix, for user's state_dict to load
51
+ self.pca_eigen_vectors = torch.empty(
52
+ (vggish_params.EMBEDDING_SIZE, vggish_params.EMBEDDING_SIZE,),
53
+ dtype=torch.float,
54
+ )
55
+ self.pca_means = torch.empty(
56
+ (vggish_params.EMBEDDING_SIZE, 1), dtype=torch.float
57
+ )
58
+
59
+ self.pca_eigen_vectors = nn.Parameter(self.pca_eigen_vectors, requires_grad=False)
60
+ self.pca_means = nn.Parameter(self.pca_means, requires_grad=False)
61
+
62
+ def postprocess(self, embeddings_batch):
63
+ """Applies tensor postprocessing to a batch of embeddings.
64
+
65
+ Args:
66
+ embeddings_batch: An tensor of shape [batch_size, embedding_size]
67
+ containing output from the embedding layer of VGGish.
68
+
69
+ Returns:
70
+ A tensor of the same shape as the input, containing the PCA-transformed,
71
+ quantized, and clipped version of the input.
72
+ """
73
+ assert len(embeddings_batch.shape) == 2, "Expected 2-d batch, got %r" % (
74
+ embeddings_batch.shape,
75
+ )
76
+ assert (
77
+ embeddings_batch.shape[1] == vggish_params.EMBEDDING_SIZE
78
+ ), "Bad batch shape: %r" % (embeddings_batch.shape,)
79
+
80
+ # Apply PCA.
81
+ # - Embeddings come in as [batch_size, embedding_size].
82
+ # - Transpose to [embedding_size, batch_size].
83
+ # - Subtract pca_means column vector from each column.
84
+ # - Premultiply by PCA matrix of shape [output_dims, input_dims]
85
+ # where both are are equal to embedding_size in our case.
86
+ # - Transpose result back to [batch_size, embedding_size].
87
+ pca_applied = torch.mm(self.pca_eigen_vectors, (embeddings_batch.t() - self.pca_means)).t()
88
+
89
+ # Quantize by:
90
+ # - clipping to [min, max] range
91
+ clipped_embeddings = torch.clamp(
92
+ pca_applied, vggish_params.QUANTIZE_MIN_VAL, vggish_params.QUANTIZE_MAX_VAL
93
+ )
94
+ # - convert to 8-bit in range [0.0, 255.0]
95
+ quantized_embeddings = torch.round(
96
+ (clipped_embeddings - vggish_params.QUANTIZE_MIN_VAL)
97
+ * (
98
+ 255.0
99
+ / (vggish_params.QUANTIZE_MAX_VAL - vggish_params.QUANTIZE_MIN_VAL)
100
+ )
101
+ )
102
+ return torch.squeeze(quantized_embeddings)
103
+
104
+ def forward(self, x):
105
+ return self.postprocess(x)
106
+
107
+
108
+ def make_layers():
109
+ layers = []
110
+ in_channels = 1
111
+ for v in [64, "M", 128, "M", 256, 256, "M", 512, 512, "M"]:
112
+ if v == "M":
113
+ layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
114
+ else:
115
+ conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
116
+ layers += [conv2d, nn.ReLU(inplace=True)]
117
+ in_channels = v
118
+ return nn.Sequential(*layers)
119
+
120
+
121
+ def _vgg():
122
+ return VGG(make_layers())
123
+
124
+
125
+ # def _spectrogram():
126
+ # config = dict(
127
+ # sr=16000,
128
+ # n_fft=400,
129
+ # n_mels=64,
130
+ # hop_length=160,
131
+ # window="hann",
132
+ # center=False,
133
+ # pad_mode="reflect",
134
+ # htk=True,
135
+ # fmin=125,
136
+ # fmax=7500,
137
+ # output_format='Magnitude',
138
+ # # device=device,
139
+ # )
140
+ # return Spectrogram.MelSpectrogram(**config)
141
+
142
+
143
+ class VGGish(VGG):
144
+ def __init__(self, cfg, device=None):
145
+ super().__init__(make_layers())
146
+ if cfg.FREEZE_AUDIO_EXTRACTOR:
147
+ state_dict = torch.load(cfg.PRETRAINED_VGGISH_MODEL_PATH)
148
+ super().load_state_dict(state_dict)
149
+ print(f'==> Load pretrained VGGish parameters from {cfg.PRETRAINED_VGGISH_MODEL_PATH}')
150
+
151
+ if device is None:
152
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
153
+ print("device: ", device)
154
+ self.device = device
155
+
156
+ self.preprocess = cfg.PREPROCESS_AUDIO_TO_LOG_MEL
157
+ self.postprocess = cfg.POSTPROCESS_LOG_MEL_WITH_PCA
158
+ if self.postprocess:
159
+ self.pproc = Postprocessor()
160
+ if cfg.FREEZE_AUDIO_EXTRACTOR:
161
+ state_dict = torch.load(cfg.PRETRAINED_PCA_PARAMS_PATH)
162
+ # TODO: Convert the state_dict to torch
163
+ state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME] = torch.as_tensor(
164
+ state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME], dtype=torch.float
165
+ )
166
+ state_dict[vggish_params.PCA_MEANS_NAME] = torch.as_tensor(
167
+ state_dict[vggish_params.PCA_MEANS_NAME].reshape(-1, 1), dtype=torch.float
168
+ )
169
+ self.pproc.load_state_dict(state_dict)
170
+ self.to(self.device)
171
+
172
+ def forward(self, x):
173
+ if self.preprocess:
174
+ print(">>> pre processing...")
175
+ x = self._preprocess(x)
176
+ x = x.to(self.device)
177
+ x = VGG.forward(self, x)
178
+ if self.postprocess:
179
+ print(">>> post processing...")
180
+ x = self._postprocess(x)
181
+ return x
182
+
183
+ def _preprocess(self, x):
184
+ # if isinstance(x, np.ndarray):
185
+ # x = vggish_input.waveform_to_examples(x, fs)
186
+ if isinstance(x, str):
187
+ x = vggish_input.wavfile_to_examples(x)
188
+ else:
189
+ raise AttributeError
190
+ return x
191
+
192
+ def _postprocess(self, x):
193
+ return self.pproc(x)
avs.code/v1m.code/model/audio/torchvggish/vggish_input.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2017 The TensorFlow Authors All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Compute input examples for VGGish from audio waveform."""
17
+
18
+ # Modification: Return torch tensors rather than numpy arrays
19
+ import torch
20
+
21
+ import numpy as np
22
+ import resampy
23
+
24
+ from . import mel_features
25
+ from . import vggish_params
26
+
27
+ import soundfile as sf
28
+
29
+
30
+ def waveform_to_examples(data, sample_rate, return_tensor=True):
31
+ """Converts audio waveform into an array of examples for VGGish.
32
+
33
+ Args:
34
+ data: np.array of either one dimension (mono) or two dimensions
35
+ (multi-channel, with the outer dimension representing channels).
36
+ Each sample is generally expected to lie in the range [-1.0, +1.0],
37
+ although this is not required.
38
+ sample_rate: Sample rate of data.
39
+ return_tensor: Return data as a Pytorch tensor ready for VGGish
40
+
41
+ Returns:
42
+ 3-D np.array of shape [num_examples, num_frames, num_bands] which represents
43
+ a sequence of examples, each of which contains a patch of log mel
44
+ spectrogram, covering num_frames frames of audio and num_bands mel frequency
45
+ bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS.
46
+
47
+ """
48
+ # Convert to mono.
49
+ if len(data.shape) > 1:
50
+ data = np.mean(data, axis=1)
51
+ # Resample to the rate assumed by VGGish.
52
+ if sample_rate != vggish_params.SAMPLE_RATE:
53
+ data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE)
54
+
55
+ # Compute log mel spectrogram features.
56
+ log_mel = mel_features.log_mel_spectrogram(
57
+ data,
58
+ audio_sample_rate=vggish_params.SAMPLE_RATE,
59
+ log_offset=vggish_params.LOG_OFFSET,
60
+ window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS,
61
+ hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS,
62
+ num_mel_bins=vggish_params.NUM_MEL_BINS,
63
+ lower_edge_hertz=vggish_params.MEL_MIN_HZ,
64
+ upper_edge_hertz=vggish_params.MEL_MAX_HZ)
65
+
66
+ # Frame features into examples.
67
+ features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS
68
+ example_window_length = int(round(
69
+ vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate))
70
+ example_hop_length = int(round(
71
+ vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate))
72
+ log_mel_examples = mel_features.frame(
73
+ log_mel,
74
+ window_length=example_window_length,
75
+ hop_length=example_hop_length)
76
+
77
+ if return_tensor:
78
+ log_mel_examples = torch.tensor(
79
+ log_mel_examples, requires_grad=True)[:, None, :, :].float()
80
+
81
+ return log_mel_examples
82
+
83
+
84
+ def wavfile_to_examples(wav_file, return_tensor=True):
85
+ """Convenience wrapper around waveform_to_examples() for a common WAV format.
86
+
87
+ Args:
88
+ wav_file: String path to a file, or a file-like object. The file
89
+ is assumed to contain WAV audio data with signed 16-bit PCM samples.
90
+ torch: Return data as a Pytorch tensor ready for VGGish
91
+
92
+ Returns:
93
+ See waveform_to_examples.
94
+ """
95
+ wav_data, sr = sf.read(wav_file, dtype='int16')
96
+ assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype
97
+ samples = wav_data / 32768.0 # Convert to [-1.0, +1.0]
98
+ return waveform_to_examples(samples, sr, return_tensor)
avs.code/v1m.code/model/audio/torchvggish/vggish_params.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2017 The TensorFlow Authors All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Global parameters for the VGGish model.
17
+
18
+ See vggish_slim.py for more information.
19
+ """
20
+
21
+ # Architectural constants.
22
+ NUM_FRAMES = 96 # Frames in input mel-spectrogram patch.
23
+ NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch.
24
+ EMBEDDING_SIZE = 128 # Size of embedding layer.
25
+
26
+ # Hyperparameters used in feature and example generation.
27
+ SAMPLE_RATE = 16000
28
+ STFT_WINDOW_LENGTH_SECONDS = 0.025
29
+ STFT_HOP_LENGTH_SECONDS = 0.010
30
+ NUM_MEL_BINS = NUM_BANDS
31
+ MEL_MIN_HZ = 125
32
+ MEL_MAX_HZ = 7500
33
+ LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram.
34
+ EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames
35
+ EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap.
36
+
37
+ # Parameters used for embedding postprocessing.
38
+ PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors'
39
+ PCA_MEANS_NAME = 'pca_means'
40
+ QUANTIZE_MIN_VAL = -2.0
41
+ QUANTIZE_MAX_VAL = +2.0
42
+
43
+ # Hyperparameters used in training.
44
+ INIT_STDDEV = 0.01 # Standard deviation used to initialize weights.
45
+ LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer.
46
+ ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer.
47
+
48
+ # Names of ops, tensors, and features.
49
+ INPUT_OP_NAME = 'vggish/input_features'
50
+ INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0'
51
+ OUTPUT_OP_NAME = 'vggish/embedding'
52
+ OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0'
53
+ AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding'
avs.code/v1m.code/model/aural_fuser.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from model.audio.torchvggish import vggish
6
+ from timm.models.layers import DropPath, trunc_normal_
7
+
8
+ from model.visual.sam2.modeling.position_encoding import PositionEmbeddingSine
9
+
10
+
11
+ class ProjectionHead(nn.Module):
12
+ def __init__(self, dim_in, proj_dim=256, norm_act=nn.BatchNorm2d, conv_layer=nn.Conv2d):
13
+ super().__init__()
14
+ self.proj = nn.Sequential(
15
+ conv_layer(dim_in, proj_dim, kernel_size=1),
16
+ norm_act(proj_dim),
17
+ conv_layer(proj_dim, proj_dim, kernel_size=1),
18
+ )
19
+
20
+ def forward(self, x):
21
+ return torch.nn.functional.normalize(self.proj(x), p=2, dim=1)
22
+
23
+ class AuralFuser(torch.nn.Module):
24
+ """Fuses VGGish audio with SAM2 FPN maps via patch embeds, fusion blocks, and projection heads."""
25
+
26
+ def __init__(self, hyp_param):
27
+ self.hyp_param = hyp_param
28
+ super().__init__()
29
+ self.vgg = vggish.VGGish(self.hyp_param.audio)
30
+ if not getattr(self.hyp_param, "train_vggish", False):
31
+ for p in self.vgg.parameters():
32
+ p.requires_grad = False
33
+
34
+ self.position_encoding_func = PositionEmbeddingSine(num_pos_feats=256, normalize=True, scale=None,
35
+ temperature=10000)
36
+
37
+ # Populated in main.py / inference.py via Hydra compose('auralfuser/architecture.yaml') → hyp_param.aural_fuser
38
+ if not hasattr(self.hyp_param, "aural_fuser") or self.hyp_param.aural_fuser is None:
39
+ raise ValueError(
40
+ "hyp_param.aural_fuser is missing; load it with Hydra compose before constructing AuralFuser."
41
+ )
42
+ arch_cfg = self.hyp_param.aural_fuser
43
+
44
+ _patch_cfgs = [tuple(i) for i in arch_cfg["patch_cfgs"]]
45
+ _f_depths = arch_cfg["f_depths"]
46
+ _block_kw = dict(arch_cfg["block_kw"])
47
+ _block_kw["norm_layer"] = nn.LayerNorm
48
+ _one_d_kw = dict(arch_cfg["one_d_kw"])
49
+ _one_d_kw["norm_layer"] = nn.LayerNorm
50
+ self.patch_embeds = nn.ModuleList(
51
+ nn.Conv2d(256, 256, kernel_size=k, stride=s) for k, s in _patch_cfgs
52
+ )
53
+
54
+ self.f_blocks = nn.ModuleList(
55
+ nn.ModuleList([Block(**_block_kw) for _ in range(n)]) for n in _f_depths
56
+ )
57
+
58
+ self.a_blocks = nn.ModuleList(
59
+ nn.ModuleList([OneDBlock(**_one_d_kw) for _ in range(3)]) for _ in range(3)
60
+ )
61
+
62
+ self.fusion_modules = nn.ModuleList(
63
+ AudioVisualFusionModule(in_channels=256, mode='dot') for _ in range(3)
64
+ )
65
+ self.smooth_convs = nn.ModuleList(
66
+ nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0) for _ in range(2)
67
+ )
68
+
69
+ self.train_proj_v1 = ProjectionHead(dim_in=256, proj_dim=128)
70
+
71
+ self.train_proj_a1 = ProjectionHead(dim_in=256, norm_act=nn.BatchNorm1d, conv_layer=nn.Conv1d, proj_dim=128)
72
+
73
+ @staticmethod
74
+ def positionalencoding1d(d_model, length):
75
+ if d_model % 2 != 0:
76
+ raise ValueError("Cannot use sin/cos positional encoding with "
77
+ "odd dim (got dim={:d})".format(d_model))
78
+ pe = torch.zeros(length, d_model)
79
+ position = torch.arange(0, length).unsqueeze(1)
80
+ div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
81
+ -(math.log(10000.0) / d_model)))
82
+ pe[:, 0::2] = torch.sin(position.float() * div_term)
83
+ pe[:, 1::2] = torch.cos(position.float() * div_term)
84
+
85
+ return pe
86
+
87
+ def forward(self, feature_dicts, spect=None):
88
+ image_embed_shape = [self.hyp_param.image_embedding_size] * 2
89
+ H, W = image_embed_shape[0], image_embed_shape[1]
90
+ d = torch.cat(
91
+ [
92
+ self.vgg(spect[:, 0, ...].unsqueeze(1)),
93
+ self.vgg(spect[:, 1, ...].unsqueeze(1)),
94
+ ],
95
+ dim=-1,
96
+ )
97
+ length = d.shape[-1]
98
+ fix_audio_pos = self.positionalencoding1d(length, 1).squeeze().to(spect.device)
99
+ fpn = list(feature_dicts["backbone_fpn"])
100
+ patch_embeds = list(self.patch_embeds)
101
+ f_blocks = list(self.f_blocks)
102
+ a_blocks = list(self.a_blocks)
103
+ tpavi = list(self.fusion_modules)
104
+ smooths = [None, self.smooth_convs[0], self.smooth_convs[1]]
105
+
106
+ feats = [None, None, None]
107
+ d_outputs = []
108
+
109
+ for i in range(3):
110
+ x = fpn[i]
111
+ x = patch_embeds[i](x)
112
+ x_pos = self.position_encoding_func(x)
113
+ x = x.flatten(2).permute(0, 2, 1)
114
+ x_pos = x_pos.flatten(2).permute(0, 2, 1)
115
+
116
+ if i == 0:
117
+ x = x + x_pos
118
+ d = d + fix_audio_pos
119
+ else:
120
+ x = x + feats[i - 1]
121
+ x = smooths[i](
122
+ x.permute(0, 2, 1).reshape(x.shape[0], 256, H, W)
123
+ ).flatten(2).permute(0, 2, 1)
124
+ x = x + x_pos
125
+ d = d + fix_audio_pos
126
+
127
+ for blks in f_blocks[i]:
128
+ x = blks(x, H, W, x_pos)
129
+ for blks in a_blocks[i]:
130
+ d = blks(d, fix_audio_pos)
131
+
132
+ x = x + x_pos
133
+ d = d + fix_audio_pos
134
+ x, d_out, _, _ = tpavi[i](x, H, W, x_pos, d, length)
135
+ d = d_out
136
+ feats[i] = x
137
+ d_outputs.append(d_out)
138
+
139
+ a, b, c = feats
140
+ d1, d2, d3 = d_outputs
141
+
142
+ feature_residual = [a, b, c]
143
+ audio_out = [d1, d2, d3]
144
+
145
+ proj_feature_out = [
146
+ [
147
+ self.train_proj_v1(a.permute(0, 2, 1).reshape(-1, 256, *image_embed_shape)),
148
+ self.train_proj_v1(b.permute(0, 2, 1).reshape(-1, 256, *image_embed_shape)),
149
+ self.train_proj_v1(c.permute(0, 2, 1).reshape(-1, 256, *image_embed_shape)),
150
+ ],
151
+ [
152
+ self.train_proj_a1(d1.unsqueeze(-1)),
153
+ self.train_proj_a1(d2.unsqueeze(-1)),
154
+ self.train_proj_a1(d3.unsqueeze(-1)),
155
+ ],
156
+ ]
157
+
158
+ return feature_residual, audio_out, proj_feature_out
159
+
160
+
161
+ class AudioVisualFusionModule(nn.Module):
162
+ def __init__(self, in_channels, inter_channels=None, mode='dot',
163
+ dimension=3):
164
+ super().__init__()
165
+ assert mode == 'dot'
166
+ self.mode = mode
167
+ self.dimension = dimension
168
+
169
+ self.in_channels = in_channels
170
+ self.inter_channels = in_channels // 2
171
+
172
+ self.align_channel = nn.Conv1d(256, in_channels, kernel_size=1)
173
+ self.align_channel_back = nn.Conv1d(in_channels, 128, kernel_size=1)
174
+
175
+ self.norm_layer = nn.LayerNorm(in_channels)
176
+
177
+ if dimension == 3:
178
+ conv_nd = nn.Conv3d
179
+ bn = nn.BatchNorm3d
180
+ elif dimension == 2:
181
+ conv_nd = nn.Conv2d
182
+ bn = nn.BatchNorm2d
183
+ else:
184
+ conv_nd = nn.Conv1d
185
+ bn = nn.BatchNorm1d
186
+
187
+ self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)
188
+
189
+ self.W_z = nn.Sequential(
190
+ conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1),
191
+ bn(self.in_channels)
192
+ )
193
+ nn.init.constant_(self.W_z[1].weight, 0)
194
+ nn.init.constant_(self.W_z[1].bias, 0)
195
+
196
+ self.W_z2 = nn.Sequential(
197
+ nn.Conv1d(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1),
198
+ nn.BatchNorm1d(self.in_channels)
199
+ )
200
+ nn.init.constant_(self.W_z2[1].weight, 0)
201
+ nn.init.constant_(self.W_z2[1].bias, 0)
202
+ self.norm_layer2 = nn.LayerNorm(self.in_channels)
203
+
204
+ self.q_frame = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)
205
+ self.k_frame = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)
206
+ self.v_frame = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)
207
+
208
+ self.q_audio = nn.Conv1d(self.in_channels, self.inter_channels, kernel_size=1)
209
+ self.k_audio = nn.Conv1d(self.in_channels, self.inter_channels, kernel_size=1)
210
+ self.v_audio = nn.Conv1d(self.in_channels, self.inter_channels, kernel_size=1)
211
+
212
+ def forward(self, frame, H_x, W_x, tmp1, audio, tmp2):
213
+ frame = frame.permute(0, 2, 1)
214
+ frame = frame.reshape(frame.shape[0], frame.shape[1], H_x, W_x)
215
+ frame = frame.unsqueeze(2)
216
+ audio = self.align_channel(audio.unsqueeze(-1))
217
+
218
+ batch_size, _ = frame.size(0), frame.size(1)
219
+ q_frame = self.q_frame(frame).reshape(1, -1, self.inter_channels)
220
+ k_frame = self.k_frame(frame).reshape(1, -1, self.inter_channels)
221
+ v_frame = self.v_frame(frame).reshape(1, -1, self.inter_channels)
222
+ q_audio = self.q_audio(audio).reshape(1, -1, self.inter_channels)
223
+ k_audio = self.k_audio(audio).reshape(1, -1, self.inter_channels)
224
+ v_audio = self.v_audio(audio).reshape(1, -1, self.inter_channels)
225
+ f = torch.matmul(q_frame, k_audio.mT)
226
+ f_normalise = f / f.size(1)
227
+
228
+ frame_attn = torch.matmul(f_normalise, v_audio)
229
+
230
+ frame_attn = frame_attn.permute(0, 2, 1).contiguous()
231
+ frame_attn = frame_attn.view(batch_size, self.inter_channels, *frame.size()[2:])
232
+ frame_attn = self.W_z(frame_attn)
233
+ frame = frame_attn + frame
234
+
235
+ frame = frame.permute(0, 2, 3, 4, 1)
236
+ frame = self.norm_layer(frame)
237
+ frame = frame.permute(0, 4, 1, 2, 3)
238
+ frame = frame.squeeze().flatten(start_dim=2).permute(0, 2, 1)
239
+
240
+ a = torch.matmul(q_audio, k_frame.mT)
241
+ a_normalise = a / a.size(-1)
242
+
243
+ audio_attn = torch.matmul(a_normalise, v_frame)
244
+ audio_attn = audio_attn.permute(0, 2, 1).contiguous()
245
+
246
+ audio_attn = audio_attn.view(batch_size, self.inter_channels).unsqueeze(-1)
247
+ audio_attn = self.W_z2(audio_attn)
248
+
249
+ audio = audio_attn + audio
250
+
251
+ audio = self.norm_layer2(audio.squeeze()).squeeze()
252
+
253
+ return frame, audio, frame_attn, audio_attn
254
+
255
+
256
+ class OneDBlock(nn.Module):
257
+
258
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
259
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False):
260
+ super().__init__()
261
+ self.norm1 = norm_layer(dim)
262
+ self.attn = OneDAttention(
263
+ dim,
264
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
265
+ attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear)
266
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
267
+ self.norm2 = norm_layer(dim)
268
+ mlp_hidden_dim = int(dim * mlp_ratio)
269
+ self.mlp = OneDMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,
270
+ linear=linear)
271
+
272
+ self.apply(self._init_weights)
273
+
274
+ def _init_weights(self, m):
275
+ if isinstance(m, nn.Linear):
276
+ trunc_normal_(m.weight, std=.02)
277
+ if isinstance(m, nn.Linear) and m.bias is not None:
278
+ nn.init.constant_(m.bias, 0)
279
+ elif isinstance(m, nn.LayerNorm):
280
+ nn.init.constant_(m.bias, 0)
281
+ nn.init.constant_(m.weight, 1.0)
282
+ elif isinstance(m, nn.Conv2d):
283
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
284
+ fan_out //= m.groups
285
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
286
+ if m.bias is not None:
287
+ m.bias.data.zero_()
288
+
289
+ def forward(self, x, _pos):
290
+ x = x + self.drop_path(self.attn(self.norm1(x)))
291
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
292
+
293
+ return x
294
+
295
+
296
+ class OneDAttention(nn.Module):
297
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1,
298
+ linear=False):
299
+ super().__init__()
300
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
301
+
302
+ self.dim = dim
303
+ self.num_heads = num_heads
304
+ head_dim = dim // num_heads
305
+ self.scale = qk_scale or head_dim ** -0.5
306
+
307
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
308
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
309
+ self.attn_drop = nn.Dropout(attn_drop)
310
+ self.proj = nn.Linear(dim, dim)
311
+ self.proj_drop = nn.Dropout(proj_drop)
312
+
313
+ self.linear = linear
314
+ self.sr_ratio = sr_ratio
315
+ if not linear:
316
+ if sr_ratio > 1:
317
+ self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
318
+ self.norm = nn.LayerNorm(dim)
319
+ else:
320
+ self.pool = nn.AdaptiveAvgPool2d(7)
321
+ self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1)
322
+ self.norm = nn.LayerNorm(dim)
323
+ self.act = nn.GELU()
324
+ self.apply(self._init_weights)
325
+
326
+ def _init_weights(self, m):
327
+ if isinstance(m, nn.Linear):
328
+ trunc_normal_(m.weight, std=.02)
329
+ if isinstance(m, nn.Linear) and m.bias is not None:
330
+ nn.init.constant_(m.bias, 0)
331
+ elif isinstance(m, nn.LayerNorm):
332
+ nn.init.constant_(m.bias, 0)
333
+ nn.init.constant_(m.weight, 1.0)
334
+ elif isinstance(m, nn.Conv2d):
335
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
336
+ fan_out //= m.groups
337
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
338
+ if m.bias is not None:
339
+ m.bias.data.zero_()
340
+
341
+ def forward(self, x):
342
+ x = x.unsqueeze(0)
343
+
344
+ B, N, C = x.shape
345
+ q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
346
+ kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
347
+
348
+ k, v = kv[0], kv[1]
349
+ attn = (q @ k.transpose(-2, -1)) * self.scale
350
+ attn = attn.softmax(dim=-1)
351
+ attn = self.attn_drop(attn)
352
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
353
+ x = self.proj(x)
354
+ x = self.proj_drop(x)
355
+
356
+ x = x.squeeze()
357
+ return x
358
+
359
+
360
+ class OneDMlp(nn.Module):
361
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False):
362
+ super().__init__()
363
+ out_features = out_features or in_features
364
+ hidden_features = hidden_features or in_features
365
+ self.fc1 = nn.Linear(in_features, hidden_features)
366
+ self.dwconv = DWConv(hidden_features)
367
+ self.act = act_layer()
368
+ self.fc2 = nn.Linear(hidden_features, out_features)
369
+ self.drop = nn.Dropout(drop)
370
+ self.linear = linear
371
+
372
+ if self.linear:
373
+ self.relu = nn.ReLU(inplace=True)
374
+ self.apply(self._init_weights)
375
+
376
+ def _init_weights(self, m):
377
+ if isinstance(m, nn.Linear):
378
+ trunc_normal_(m.weight, std=.02)
379
+ if isinstance(m, nn.Linear) and m.bias is not None:
380
+ nn.init.constant_(m.bias, 0)
381
+ elif isinstance(m, nn.LayerNorm):
382
+ nn.init.constant_(m.bias, 0)
383
+ nn.init.constant_(m.weight, 1.0)
384
+ elif isinstance(m, nn.Conv2d):
385
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
386
+ fan_out //= m.groups
387
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
388
+ if m.bias is not None:
389
+ m.bias.data.zero_()
390
+
391
+ def forward(self, x):
392
+ x = self.fc1(x)
393
+ if self.linear:
394
+ x = self.relu(x)
395
+ x = self.act(x)
396
+ x = self.drop(x)
397
+ x = self.fc2(x)
398
+ x = self.drop(x)
399
+ return x
400
+
401
+
402
+ class Block(nn.Module):
403
+
404
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
405
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False):
406
+ super().__init__()
407
+ self.norm1 = norm_layer(dim)
408
+ self.attn = Attention(
409
+ dim,
410
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
411
+ attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear)
412
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
413
+ self.norm2 = norm_layer(dim)
414
+ mlp_hidden_dim = int(dim * mlp_ratio)
415
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, linear=linear)
416
+
417
+ self.apply(self._init_weights)
418
+
419
+ def _init_weights(self, m):
420
+ if isinstance(m, nn.Linear):
421
+ trunc_normal_(m.weight, std=.02)
422
+ if isinstance(m, nn.Linear) and m.bias is not None:
423
+ nn.init.constant_(m.bias, 0)
424
+ elif isinstance(m, nn.LayerNorm):
425
+ nn.init.constant_(m.bias, 0)
426
+ nn.init.constant_(m.weight, 1.0)
427
+ elif isinstance(m, nn.Conv2d):
428
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
429
+ fan_out //= m.groups
430
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
431
+ if m.bias is not None:
432
+ m.bias.data.zero_()
433
+
434
+ def forward(self, x, H, W, _pos):
435
+ x = x + self.drop_path(self.attn(self.norm1(x), H, W))
436
+ x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
437
+
438
+ return x
439
+
440
+
441
+ class Attention(nn.Module):
442
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1,
443
+ linear=False):
444
+ super().__init__()
445
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
446
+
447
+ self.dim = dim
448
+ self.num_heads = num_heads
449
+ head_dim = dim // num_heads
450
+ self.scale = qk_scale or head_dim ** -0.5
451
+
452
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
453
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
454
+ self.attn_drop = nn.Dropout(attn_drop)
455
+ self.proj = nn.Linear(dim, dim)
456
+ self.proj_drop = nn.Dropout(proj_drop)
457
+
458
+ self.linear = linear
459
+ self.sr_ratio = sr_ratio
460
+ if not linear:
461
+ if sr_ratio > 1:
462
+ self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
463
+ self.norm = nn.LayerNorm(dim)
464
+ else:
465
+ self.pool = nn.AdaptiveAvgPool2d(7)
466
+ self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1)
467
+ self.norm = nn.LayerNorm(dim)
468
+ self.act = nn.GELU()
469
+ self.apply(self._init_weights)
470
+
471
+ def _init_weights(self, m):
472
+ if isinstance(m, nn.Linear):
473
+ trunc_normal_(m.weight, std=.02)
474
+ if isinstance(m, nn.Linear) and m.bias is not None:
475
+ nn.init.constant_(m.bias, 0)
476
+ elif isinstance(m, nn.LayerNorm):
477
+ nn.init.constant_(m.bias, 0)
478
+ nn.init.constant_(m.weight, 1.0)
479
+ elif isinstance(m, nn.Conv2d):
480
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
481
+ fan_out //= m.groups
482
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
483
+ if m.bias is not None:
484
+ m.bias.data.zero_()
485
+
486
+ def forward(self, x, H, W):
487
+ B, N, C = x.shape
488
+ q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
489
+ if not self.linear:
490
+ if self.sr_ratio > 1:
491
+ x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
492
+ x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
493
+ x_ = self.norm(x_)
494
+ kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
495
+ else:
496
+ kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
497
+ else:
498
+ x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
499
+ x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1)
500
+ x_ = self.norm(x_)
501
+ x_ = self.act(x_)
502
+ kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
503
+ k, v = kv[0], kv[1]
504
+ attn = (q @ k.transpose(-2, -1)) * self.scale
505
+ attn = attn.softmax(dim=-1)
506
+ attn = self.attn_drop(attn)
507
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
508
+ x = self.proj(x)
509
+ x = self.proj_drop(x)
510
+
511
+ return x
512
+
513
+
514
+ class Mlp(nn.Module):
515
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False):
516
+ super().__init__()
517
+ out_features = out_features or in_features
518
+ hidden_features = hidden_features or in_features
519
+ self.fc1 = nn.Linear(in_features, hidden_features)
520
+ self.dwconv = DWConv(hidden_features)
521
+ self.act = act_layer()
522
+ self.fc2 = nn.Linear(hidden_features, out_features)
523
+ self.drop = nn.Dropout(drop)
524
+ self.linear = linear
525
+
526
+ if self.linear:
527
+ self.relu = nn.ReLU(inplace=True)
528
+ self.apply(self._init_weights)
529
+
530
+ def _init_weights(self, m):
531
+ if isinstance(m, nn.Linear):
532
+ trunc_normal_(m.weight, std=.02)
533
+ if isinstance(m, nn.Linear) and m.bias is not None:
534
+ nn.init.constant_(m.bias, 0)
535
+ elif isinstance(m, nn.LayerNorm):
536
+ nn.init.constant_(m.bias, 0)
537
+ nn.init.constant_(m.weight, 1.0)
538
+ elif isinstance(m, nn.Conv2d):
539
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
540
+ fan_out //= m.groups
541
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
542
+ if m.bias is not None:
543
+ m.bias.data.zero_()
544
+
545
+ def forward(self, x, H, W):
546
+ x = self.fc1(x)
547
+ if self.linear:
548
+ x = self.relu(x)
549
+ x = self.dwconv(x, H, W)
550
+ x = self.act(x)
551
+ x = self.drop(x)
552
+ x = self.fc2(x)
553
+ x = self.drop(x)
554
+ return x
555
+
556
+
557
+ class DWConv(nn.Module):
558
+ def __init__(self, dim=768):
559
+ super(DWConv, self).__init__()
560
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
561
+
562
+ def forward(self, x, H, W):
563
+ B, N, C = x.shape
564
+ x = x.transpose(1, 2).view(B, C, H, W)
565
+ x = self.dwconv(x)
566
+ x = x.flatten(2).transpose(1, 2)
567
+ return x
avs.code/v1m.code/model/mymodel.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import numpy
6
+ import numpy as np
7
+ import torch
8
+ from PIL.Image import Image
9
+
10
+ from model.visual.sam2.modeling.sam2_base import SAM2Base
11
+
12
+ from model.visual.sam2.modeling.backbones.hieradet import Hiera
13
+ from model.visual.sam2.modeling.backbones.image_encoder import FpnNeck
14
+ from model.visual.sam2.modeling.backbones.image_encoder import ImageEncoder
15
+ from model.visual.sam2.modeling.position_encoding import PositionEmbeddingSine
16
+
17
+ from model.visual.sam2.modeling.memory_attention import MemoryAttention
18
+ from model.visual.sam2.modeling.memory_attention import MemoryAttentionLayer
19
+ from model.visual.sam2.modeling.sam.transformer import RoPEAttention
20
+ from model.visual.sam2.modeling.memory_encoder import MemoryEncoder
21
+ from model.visual.sam2.modeling.memory_encoder import MaskDownSampler
22
+ from model.visual.sam2.modeling.memory_encoder import Fuser
23
+ from model.visual.sam2.modeling.memory_encoder import CXBlock
24
+
25
+ from model.visual.sam2.utils.transforms import SAM2Transforms
26
+ from model.visual.sam2.modeling.backbones.hieradet import do_pool
27
+ from model.visual.sam2.modeling.backbones.utils import (
28
+ PatchEmbed,
29
+ window_partition,
30
+ window_unpartition,
31
+ )
32
+
33
+
34
+ class AVmodel(torch.nn.Module):
35
+ """End-to-end AV segmentation: SAM2 visual backbone + AuralFuser audio-visual fusion + tracking head."""
36
+
37
+ def __init__(self, param, mask_threshold=0.0, max_hole_area=0.0, max_sprinkle_area=0.0, ):
38
+ super().__init__()
39
+ self.param = param
40
+ self.mask_threshold = mask_threshold
41
+ self._bb_feat_sizes = [(int(self.param.image_size / 4), int(self.param.image_size / 4)),
42
+ (int(self.param.image_size / 8), int(self.param.image_size / 8)),
43
+ (int(self.param.image_size / 16), int(self.param.image_size / 16))]
44
+
45
+ from model.visual.sam2.build_sam import build_sam2_visual_predictor
46
+ self.v_model = build_sam2_visual_predictor(self.param.sam_config_path, self.param.backbone_weight,
47
+ apply_postprocessing=True, mode='train')
48
+ self._transforms = SAM2Transforms(
49
+ resolution=self.v_model.image_size,
50
+ mask_threshold=mask_threshold,
51
+ max_hole_area=max_hole_area,
52
+ max_sprinkle_area=max_sprinkle_area,
53
+ )
54
+ from model.aural_fuser import AuralFuser
55
+ self.aural_fuser = AuralFuser(hyp_param=self.param)
56
+
57
+
58
+
59
+ def _prepare_backbone_features(self, backbone_out):
60
+ """Prepare and flatten visual features."""
61
+ backbone_out = backbone_out.copy()
62
+ assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
63
+ assert len(backbone_out["backbone_fpn"]) >= self.v_model.num_feature_levels
64
+
65
+ feature_maps = backbone_out["backbone_fpn"][-self.v_model.num_feature_levels:]
66
+ vision_pos_embeds = backbone_out["vision_pos_enc"][-self.v_model.num_feature_levels:]
67
+
68
+ feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
69
+ vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
70
+ vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
71
+
72
+ return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
73
+
74
+ def forward_frame(self, frame_):
75
+ frame = torch.nn.functional.interpolate(frame_, (self.param.image_size, self.param.image_size),
76
+ antialias=True, align_corners=False, mode='bilinear')
77
+ return self.v_model.image_encoder(frame)
78
+
79
+ def forward(self, frames, spect, prompts, sam_process=False):
80
+ """Fuse audio into FPN features, then run SAM2 tracking. `sam_process` is reserved for prompt path."""
81
+ backbone_feats = self.v_model.forward_image(frames, pre_compute=False)
82
+ audio_residual_feats = self.aural_fuser(backbone_feats, spect)
83
+ visual_resfeats, audio_resfeats, proj_feats = audio_residual_feats
84
+
85
+ map_res = visual_resfeats[::-1]
86
+ vec_res = audio_resfeats[::-1]
87
+
88
+ av_feats = (map_res, vec_res)
89
+ backbone_feats = self.v_model.precompute_high_res_features(backbone_feats)
90
+ backbone_feats = self.v_model.dont_prepare_prompt_inputs(backbone_feats, num_frames=frames.shape[0],
91
+ cond_frame=int(frames.shape[0]/2) if self.training else 0)
92
+ outputs = self.v_model.forward_tracking_wo_prompt(backbone_feats, audio_res=av_feats)
93
+ return outputs, proj_feats
94
+
95
+ @property
96
+ def device(self) -> torch.device:
97
+ return self.v_model.device
98
+
99
+ def freeze_sam_parameters(self):
100
+ self.v_model.eval()
101
+ for name, parameter in self.v_model.named_parameters():
102
+ parameter.requires_grad = False
avs.code/v1m.code/model/visual/sam2/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from hydra import initialize_config_module
8
+ from hydra.core.global_hydra import GlobalHydra
9
+
10
+ if not GlobalHydra.instance().is_initialized():
11
+ initialize_config_module("configs", version_base="1.2")
avs.code/v1m.code/model/visual/sam2/build_sam.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import os
9
+
10
+ import torch
11
+ from hydra import compose
12
+ from hydra.utils import instantiate
13
+ from omegaconf import OmegaConf
14
+ '''
15
+ import sam2
16
+
17
+ # Check if the user is running Python from the parent directory of the sam2 repo
18
+ # (i.e. the directory where this repo is cloned into) -- this is not supported since
19
+ # it could shadow the sam2 package and cause issues.
20
+ if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")):
21
+ # If the user has "sam2/sam2" in their path, they are likey importing the repo itself
22
+ # as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory).
23
+ # This typically happens because the user is running Python from the parent directory
24
+ # that contains the sam2 repo they cloned.
25
+ raise RuntimeError(
26
+ "You're likely running Python from the parent directory of the sam2 repository "
27
+ "(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). "
28
+ "This is not supported since the `sam2` Python package could be shadowed by the "
29
+ "repository name (the repository is also named `sam2` and contains the Python package "
30
+ "in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir "
31
+ "rather than its parent dir, or from your home directory) after installing SAM 2."
32
+ )
33
+ '''
34
+
35
+ HF_MODEL_ID_TO_FILENAMES = {
36
+ "facebook/sam2-hiera-tiny": (
37
+ "sam2/sam2_hiera_t.yaml",
38
+ "sam2_hiera_tiny.pt",
39
+ ),
40
+ "facebook/sam2-hiera-small": (
41
+ "sam2/sam2_hiera_s.yaml",
42
+ "sam2_hiera_small.pt",
43
+ ),
44
+ "facebook/sam2-hiera-base-plus": (
45
+ "sam2/sam2_hiera_b+.yaml",
46
+ "sam2_hiera_base_plus.pt",
47
+ ),
48
+ "facebook/sam2-hiera-large": (
49
+ "sam2/sam2_hiera_l.yaml",
50
+ "sam2_hiera_large.pt",
51
+ ),
52
+ "facebook/sam2.1-hiera-tiny": (
53
+ "sam2.1/sam2.1_hiera_t.yaml",
54
+ "sam2.1_hiera_tiny.pt",
55
+ ),
56
+ "facebook/sam2.1-hiera-small": (
57
+ "sam2.1/sam2.1_hiera_s.yaml",
58
+ "sam2.1_hiera_small.pt",
59
+ ),
60
+ "facebook/sam2.1-hiera-base-plus": (
61
+ "sam2.1/sam2.1_hiera_b+.yaml",
62
+ "sam2.1_hiera_base_plus.pt",
63
+ ),
64
+ "facebook/sam2.1-hiera-large": (
65
+ "sam2.1/sam2.1_hiera_l.yaml",
66
+ "sam2.1_hiera_large.pt",
67
+ ),
68
+ }
69
+
70
+
71
+ def build_sam2(
72
+ config_file,
73
+ ckpt_path=None,
74
+ device="cuda",
75
+ mode="eval",
76
+ hydra_overrides_extra=[],
77
+ apply_postprocessing=True,
78
+ **kwargs,
79
+ ):
80
+
81
+ if apply_postprocessing:
82
+ hydra_overrides_extra = hydra_overrides_extra.copy()
83
+ hydra_overrides_extra += [
84
+ # dynamically fall back to multi-mask if the single mask is not stable
85
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
86
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
87
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
88
+ ]
89
+ # Read config and init model
90
+ cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
91
+ OmegaConf.resolve(cfg)
92
+ model = instantiate(cfg.model, _recursive_=True)
93
+ _load_checkpoint(model, ckpt_path)
94
+ model = model.to(device)
95
+ if mode == "eval":
96
+ model.eval()
97
+ return model
98
+
99
+
100
+ def build_sam2_visual_predictor(
101
+ config_file,
102
+ ckpt_path=None,
103
+ mode="eval",
104
+ hydra_overrides_extra=[],
105
+ apply_postprocessing=True,
106
+ **kwargs,
107
+ ):
108
+ # visual
109
+ hydra_overrides = []
110
+ # "++model._target_=model.visual.sam2.organised_sam2_train.SAM2Train",
111
+ # ]
112
+ # hydra_overrides = [
113
+ # "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
114
+ # ]
115
+ if apply_postprocessing:
116
+ hydra_overrides_extra = hydra_overrides_extra.copy()
117
+ hydra_overrides_extra += [
118
+
119
+ # dynamically fall back to multi-mask if the single mask is not stable
120
+ # "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
121
+ # "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
122
+ # "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
123
+
124
+ # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
125
+ "++model.binarize_mask_from_pts_for_mem_enc=true",
126
+ # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
127
+ # "++model.fill_hole_area=8",
128
+ ]
129
+ hydra_overrides.extend(hydra_overrides_extra)
130
+
131
+ # Read config and init model
132
+ cfg = compose(config_name=config_file, overrides=hydra_overrides)
133
+ OmegaConf.resolve(cfg)
134
+ model = instantiate(cfg.model, _recursive_=True)
135
+ _load_checkpoint(model, ckpt_path)
136
+ if mode == "eval":
137
+ model.eval()
138
+ return model
139
+
140
+
141
+ def _hf_download(model_id):
142
+ from huggingface_hub import hf_hub_download
143
+
144
+ config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id]
145
+ ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
146
+ return config_name, ckpt_path
147
+
148
+
149
+ def build_sam2_hf(model_id, **kwargs):
150
+ config_name, ckpt_path = _hf_download(model_id)
151
+ return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
152
+
153
+
154
+ # def build_sam2_video_predictor_hf(model_id, **kwargs):
155
+ # config_name, ckpt_path = _hf_download(model_id)
156
+ # return build_sam2_video_predictor(
157
+ # config_file=config_name, ckpt_path=ckpt_path, **kwargs
158
+ # )
159
+
160
+
161
+ def _load_checkpoint(model, ckpt_path):
162
+ if ckpt_path is not None:
163
+ sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
164
+ missing_keys, unexpected_keys = model.load_state_dict(sd)
165
+ if missing_keys:
166
+ logging.error(missing_keys)
167
+ raise RuntimeError()
168
+ if unexpected_keys:
169
+ logging.error(unexpected_keys)
170
+ raise RuntimeError()
171
+ logging.info("Loaded checkpoint sucessfully")
avs.code/v1m.code/model/visual/sam2/modeling/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
avs.code/v1m.code/model/visual/sam2/modeling/backbones/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
avs.code/v1m.code/model/visual/sam2/modeling/backbones/hieradet.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ from functools import partial
9
+ from typing import List, Tuple, Union
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from iopath.common.file_io import g_pathmgr
15
+
16
+ from model.visual.sam2.modeling.backbones.utils import (
17
+ PatchEmbed,
18
+ window_partition,
19
+ window_unpartition,
20
+ )
21
+
22
+ from model.visual.sam2.modeling.sam2_utils import DropPath, MLP
23
+
24
+
25
+ def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
26
+ if pool is None:
27
+ return x
28
+ # (B, H, W, C) -> (B, C, H, W)
29
+ x = x.permute(0, 3, 1, 2)
30
+ x = pool(x)
31
+ # (B, C, H', W') -> (B, H', W', C)
32
+ x = x.permute(0, 2, 3, 1)
33
+ if norm:
34
+ x = norm(x)
35
+
36
+ return x
37
+
38
+
39
+ class MultiScaleAttention(nn.Module):
40
+ def __init__(
41
+ self,
42
+ dim: int,
43
+ dim_out: int,
44
+ num_heads: int,
45
+ q_pool: nn.Module = None,
46
+ ):
47
+ super().__init__()
48
+
49
+ self.dim = dim
50
+ self.dim_out = dim_out
51
+ self.num_heads = num_heads
52
+ self.q_pool = q_pool
53
+ self.qkv = nn.Linear(dim, dim_out * 3)
54
+ self.proj = nn.Linear(dim_out, dim_out)
55
+
56
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
57
+ B, H, W, _ = x.shape
58
+ # qkv with shape (B, H * W, 3, nHead, C)
59
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
60
+ # q, k, v with shape (B, H * W, nheads, C)
61
+ q, k, v = torch.unbind(qkv, 2)
62
+
63
+ # Q pooling (for downsample at stage changes)
64
+ if self.q_pool:
65
+ q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
66
+ H, W = q.shape[1:3] # downsampled shape
67
+ q = q.reshape(B, H * W, self.num_heads, -1)
68
+
69
+ # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
70
+ x = F.scaled_dot_product_attention(
71
+ q.transpose(1, 2),
72
+ k.transpose(1, 2),
73
+ v.transpose(1, 2),
74
+ )
75
+ # Transpose back
76
+ x = x.transpose(1, 2)
77
+ x = x.reshape(B, H, W, -1)
78
+
79
+ x = self.proj(x)
80
+
81
+ return x
82
+
83
+
84
+ class MultiScaleBlock(nn.Module):
85
+ def __init__(
86
+ self,
87
+ dim: int,
88
+ dim_out: int,
89
+ num_heads: int,
90
+ mlp_ratio: float = 4.0,
91
+ drop_path: float = 0.0,
92
+ norm_layer: Union[nn.Module, str] = "LayerNorm",
93
+ q_stride: Tuple[int, int] = None,
94
+ act_layer: nn.Module = nn.GELU,
95
+ window_size: int = 0,
96
+ ):
97
+ super().__init__()
98
+
99
+ if isinstance(norm_layer, str):
100
+ norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
101
+
102
+ self.dim = dim
103
+ self.dim_out = dim_out
104
+ self.norm1 = norm_layer(dim)
105
+
106
+ self.window_size = window_size
107
+
108
+ self.pool, self.q_stride = None, q_stride
109
+ if self.q_stride:
110
+ self.pool = nn.MaxPool2d(
111
+ kernel_size=q_stride, stride=q_stride, ceil_mode=False
112
+ )
113
+
114
+ self.attn = MultiScaleAttention(
115
+ dim,
116
+ dim_out,
117
+ num_heads=num_heads,
118
+ q_pool=self.pool,
119
+ )
120
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
121
+
122
+ self.norm2 = norm_layer(dim_out)
123
+ self.mlp = MLP(
124
+ dim_out,
125
+ int(dim_out * mlp_ratio),
126
+ dim_out,
127
+ num_layers=2,
128
+ activation=act_layer,
129
+ )
130
+
131
+ if dim != dim_out:
132
+ self.proj = nn.Linear(dim, dim_out)
133
+
134
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
135
+ shortcut = x # B, H, W, C
136
+ x = self.norm1(x)
137
+
138
+ # Skip connection
139
+ if self.dim != self.dim_out:
140
+ shortcut = do_pool(self.proj(x), self.pool)
141
+
142
+ # Window partition
143
+ window_size = self.window_size
144
+ if window_size > 0:
145
+ H, W = x.shape[1], x.shape[2]
146
+ x, pad_hw = window_partition(x, window_size)
147
+
148
+ # Window Attention + Q Pooling (if stage change)
149
+ x = self.attn(x)
150
+ if self.q_stride:
151
+ # Shapes have changed due to Q pooling
152
+ window_size = self.window_size // self.q_stride[0]
153
+ H, W = shortcut.shape[1:3]
154
+
155
+ pad_h = (window_size - H % window_size) % window_size
156
+ pad_w = (window_size - W % window_size) % window_size
157
+ pad_hw = (H + pad_h, W + pad_w)
158
+
159
+ # Reverse window partition
160
+ if self.window_size > 0:
161
+ x = window_unpartition(x, window_size, pad_hw, (H, W))
162
+
163
+ x = shortcut + self.drop_path(x)
164
+ # MLP
165
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
166
+ return x
167
+
168
+
169
+ class Hiera(nn.Module):
170
+ """
171
+ Reference: https://arxiv.org/abs/2306.00989
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ embed_dim: int = 96, # initial embed dim
177
+ num_heads: int = 1, # initial number of heads
178
+ drop_path_rate: float = 0.0, # stochastic depth
179
+ q_pool: int = 3, # number of q_pool stages
180
+ q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
181
+ stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
182
+ dim_mul: float = 2.0, # dim_mul factor at stage shift
183
+ head_mul: float = 2.0, # head_mul factor at stage shift
184
+ window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
185
+ # window size per stage, when not using global att.
186
+ window_spec: Tuple[int, ...] = (
187
+ 8,
188
+ 4,
189
+ 14,
190
+ 7,
191
+ ),
192
+ # global attn in these blocks
193
+ global_att_blocks: Tuple[int, ...] = (
194
+ 12,
195
+ 16,
196
+ 20,
197
+ ),
198
+ weights_path=None,
199
+ return_interm_layers=True, # return feats from every stage
200
+ ):
201
+ super().__init__()
202
+
203
+ assert len(stages) == len(window_spec)
204
+ self.window_spec = window_spec
205
+
206
+ depth = sum(stages)
207
+ self.q_stride = q_stride
208
+ self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
209
+ assert 0 <= q_pool <= len(self.stage_ends[:-1])
210
+ self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
211
+ self.return_interm_layers = return_interm_layers
212
+
213
+ self.patch_embed = PatchEmbed(
214
+ embed_dim=embed_dim,
215
+ )
216
+ # Which blocks have global att?
217
+ self.global_att_blocks = global_att_blocks
218
+
219
+ # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
220
+ self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
221
+ self.pos_embed = nn.Parameter(
222
+ torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
223
+ )
224
+ self.pos_embed_window = nn.Parameter(
225
+ torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
226
+ )
227
+
228
+ dpr = [
229
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
230
+ ] # stochastic depth decay rule
231
+
232
+ cur_stage = 1
233
+ self.blocks = nn.ModuleList()
234
+
235
+ for i in range(depth):
236
+ dim_out = embed_dim
237
+ # lags by a block, so first block of
238
+ # next stage uses an initial window size
239
+ # of previous stage and final window size of current stage
240
+ window_size = self.window_spec[cur_stage - 1]
241
+
242
+ if self.global_att_blocks is not None:
243
+ window_size = 0 if i in self.global_att_blocks else window_size
244
+
245
+ if i - 1 in self.stage_ends:
246
+ dim_out = int(embed_dim * dim_mul)
247
+ num_heads = int(num_heads * head_mul)
248
+ cur_stage += 1
249
+
250
+ block = MultiScaleBlock(
251
+ dim=embed_dim,
252
+ dim_out=dim_out,
253
+ num_heads=num_heads,
254
+ drop_path=dpr[i],
255
+ q_stride=self.q_stride if i in self.q_pool_blocks else None,
256
+ window_size=window_size,
257
+ )
258
+
259
+ embed_dim = dim_out
260
+ self.blocks.append(block)
261
+
262
+ self.channel_list = (
263
+ [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
264
+ if return_interm_layers
265
+ else [self.blocks[-1].dim_out]
266
+ )
267
+
268
+ if weights_path is not None:
269
+ with g_pathmgr.open(weights_path, "rb") as f:
270
+ chkpt = torch.load(f, map_location="cpu")
271
+ logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
272
+
273
+ def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
274
+ h, w = hw
275
+ window_embed = self.pos_embed_window
276
+ pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
277
+ pos_embed = pos_embed + window_embed.tile(
278
+ [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
279
+ )
280
+ pos_embed = pos_embed.permute(0, 2, 3, 1)
281
+ return pos_embed
282
+
283
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
284
+ x = self.patch_embed(x)
285
+ # x: (B, H, W, C)
286
+
287
+ # Add pos embed
288
+ x = x + self._get_pos_embed(x.shape[1:3])
289
+
290
+ outputs = []
291
+ for i, blk in enumerate(self.blocks):
292
+ x = blk(x)
293
+ if (i == self.stage_ends[-1]) or (
294
+ i in self.stage_ends and self.return_interm_layers
295
+ ):
296
+ feats = x.permute(0, 3, 1, 2)
297
+ outputs.append(feats)
298
+
299
+ return outputs
300
+
301
+ def get_layer_id(self, layer_name):
302
+ # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
303
+ num_layers = self.get_num_layers()
304
+
305
+ if layer_name.find("rel_pos") != -1:
306
+ return num_layers + 1
307
+ elif layer_name.find("pos_embed") != -1:
308
+ return 0
309
+ elif layer_name.find("patch_embed") != -1:
310
+ return 0
311
+ elif layer_name.find("blocks") != -1:
312
+ return int(layer_name.split("blocks")[1].split(".")[1]) + 1
313
+ else:
314
+ return num_layers + 1
315
+
316
+ def get_num_layers(self) -> int:
317
+ return len(self.blocks)
avs.code/v1m.code/model/visual/sam2/modeling/backbones/image_encoder.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class ImageEncoder(nn.Module):
15
+ def __init__(
16
+ self,
17
+ trunk: nn.Module,
18
+ neck: nn.Module,
19
+ scalp: int = 0,
20
+ ):
21
+ super().__init__()
22
+ self.trunk = trunk
23
+ self.neck = neck
24
+ self.scalp = scalp
25
+ assert (
26
+ self.trunk.channel_list == self.neck.backbone_channel_list
27
+ ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
28
+
29
+ def forward(self, sample: torch.Tensor):
30
+ # Forward through backbone
31
+ features, pos = self.neck(self.trunk(sample))
32
+ if self.scalp > 0:
33
+ # Discard the lowest resolution features
34
+ features, pos = features[: -self.scalp], pos[: -self.scalp]
35
+
36
+ src = features[-1]
37
+ output = {
38
+ "vision_features": src,
39
+ "vision_pos_enc": pos,
40
+ "backbone_fpn": features,
41
+ }
42
+ return output
43
+
44
+
45
+ class FpnNeck(nn.Module):
46
+ """
47
+ A modified variant of Feature Pyramid Network (FPN) neck
48
+ (we remove output conv and also do bicubic interpolation similar to ViT
49
+ pos embed interpolation)
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ position_encoding: nn.Module,
55
+ d_model: int,
56
+ backbone_channel_list: List[int],
57
+ kernel_size: int = 1,
58
+ stride: int = 1,
59
+ padding: int = 0,
60
+ fpn_interp_model: str = "bilinear",
61
+ fuse_type: str = "sum",
62
+ fpn_top_down_levels: Optional[List[int]] = None,
63
+ ):
64
+ """Initialize the neck
65
+ :param trunk: the backbone
66
+ :param position_encoding: the positional encoding to use
67
+ :param d_model: the dimension of the model
68
+ :param neck_norm: the normalization to use
69
+ """
70
+ super().__init__()
71
+ self.position_encoding = position_encoding
72
+ self.convs = nn.ModuleList()
73
+ self.backbone_channel_list = backbone_channel_list
74
+ self.d_model = d_model
75
+ for dim in backbone_channel_list:
76
+ current = nn.Sequential()
77
+ current.add_module(
78
+ "conv",
79
+ nn.Conv2d(
80
+ in_channels=dim,
81
+ out_channels=d_model,
82
+ kernel_size=kernel_size,
83
+ stride=stride,
84
+ padding=padding,
85
+ ),
86
+ )
87
+
88
+ self.convs.append(current)
89
+ self.fpn_interp_model = fpn_interp_model
90
+ assert fuse_type in ["sum", "avg"]
91
+ self.fuse_type = fuse_type
92
+
93
+ # levels to have top-down features in its outputs
94
+ # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
95
+ # have top-down propagation, while outputs of level 0 and level 1 have only
96
+ # lateral features from the same backbone level.
97
+ if fpn_top_down_levels is None:
98
+ # default is to have top-down features on all levels
99
+ fpn_top_down_levels = range(len(self.convs))
100
+ self.fpn_top_down_levels = list(fpn_top_down_levels)
101
+
102
+ def forward(self, xs: List[torch.Tensor]):
103
+
104
+ out = [None] * len(self.convs)
105
+ pos = [None] * len(self.convs)
106
+ assert len(xs) == len(self.convs)
107
+ # fpn forward pass
108
+ # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
109
+ prev_features = None
110
+ # forward in top-down order (from low to high resolution)
111
+ n = len(self.convs) - 1
112
+ for i in range(n, -1, -1):
113
+ x = xs[i]
114
+ lateral_features = self.convs[n - i](x)
115
+ if i in self.fpn_top_down_levels and prev_features is not None:
116
+ top_down_features = F.interpolate(
117
+ prev_features.to(dtype=torch.float32),
118
+ scale_factor=2.0,
119
+ mode=self.fpn_interp_model,
120
+ align_corners=(
121
+ None if self.fpn_interp_model == "nearest" else False
122
+ ),
123
+ antialias=False,
124
+ )
125
+ prev_features = lateral_features + top_down_features
126
+ if self.fuse_type == "avg":
127
+ prev_features /= 2
128
+ else:
129
+ prev_features = lateral_features
130
+ x_out = prev_features
131
+ out[i] = x_out
132
+ pos[i] = self.position_encoding(x_out).to(x_out.dtype)
133
+
134
+ return out, pos
avs.code/v1m.code/model/visual/sam2/modeling/backbones/utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Some utilities for backbones, in particular for windowing"""
8
+
9
+ from typing import Tuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+
16
+ def window_partition(x, window_size):
17
+ """
18
+ Partition into non-overlapping windows with padding if needed.
19
+ Args:
20
+ x (tensor): input tokens with [B, H, W, C].
21
+ window_size (int): window size.
22
+ Returns:
23
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
24
+ (Hp, Wp): padded height and width before partition
25
+ """
26
+ B, H, W, C = x.shape
27
+
28
+ pad_h = (window_size - H % window_size) % window_size
29
+ pad_w = (window_size - W % window_size) % window_size
30
+ if pad_h > 0 or pad_w > 0:
31
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
32
+ Hp, Wp = H + pad_h, W + pad_w
33
+
34
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
35
+ windows = (
36
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
37
+ )
38
+ return windows, (Hp, Wp)
39
+
40
+
41
+ def window_unpartition(windows, window_size, pad_hw, hw):
42
+ """
43
+ Window unpartition into original sequences and removing padding.
44
+ Args:
45
+ x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
46
+ window_size (int): window size.
47
+ pad_hw (Tuple): padded height and width (Hp, Wp).
48
+ hw (Tuple): original height and width (H, W) before padding.
49
+ Returns:
50
+ x: unpartitioned sequences with [B, H, W, C].
51
+ """
52
+ Hp, Wp = pad_hw
53
+ H, W = hw
54
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
55
+ x = windows.view(
56
+ B, Hp // window_size, Wp // window_size, window_size, window_size, -1
57
+ )
58
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
59
+
60
+ if Hp > H or Wp > W:
61
+ x = x[:, :H, :W, :].contiguous()
62
+ return x
63
+
64
+
65
+ class PatchEmbed(nn.Module):
66
+ """
67
+ Image to Patch Embedding.
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ kernel_size: Tuple[int, ...] = (7, 7),
73
+ stride: Tuple[int, ...] = (4, 4),
74
+ padding: Tuple[int, ...] = (3, 3),
75
+ in_chans: int = 3,
76
+ embed_dim: int = 768,
77
+ ):
78
+ """
79
+ Args:
80
+ kernel_size (Tuple): kernel size of the projection layer.
81
+ stride (Tuple): stride of the projection layer.
82
+ padding (Tuple): padding size of the projection layer.
83
+ in_chans (int): Number of input image channels.
84
+ embed_dim (int): embed_dim (int): Patch embedding dimension.
85
+ """
86
+ super().__init__()
87
+ self.proj = nn.Conv2d(
88
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
89
+ )
90
+
91
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
92
+ x = self.proj(x)
93
+ # B C H W -> B H W C
94
+ x = x.permute(0, 2, 3, 1)
95
+ return x
avs.code/v1m.code/model/visual/sam2/modeling/memory_attention.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from torch import nn, Tensor
11
+
12
+ from model.visual.sam2.modeling.sam.transformer import RoPEAttention
13
+
14
+ from model.visual.sam2.modeling.sam2_utils import get_activation_fn, get_clones
15
+
16
+
17
+ class MemoryAttentionLayer(nn.Module):
18
+
19
+ def __init__(
20
+ self,
21
+ activation: str,
22
+ cross_attention: nn.Module,
23
+ d_model: int,
24
+ dim_feedforward: int,
25
+ dropout: float,
26
+ pos_enc_at_attn: bool,
27
+ pos_enc_at_cross_attn_keys: bool,
28
+ pos_enc_at_cross_attn_queries: bool,
29
+ self_attention: nn.Module,
30
+ ):
31
+ super().__init__()
32
+ self.d_model = d_model
33
+ self.dim_feedforward = dim_feedforward
34
+ self.dropout_value = dropout
35
+ self.self_attn = self_attention
36
+ self.cross_attn_image = cross_attention
37
+
38
+ # Implementation of Feedforward model
39
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
40
+ self.dropout = nn.Dropout(dropout)
41
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
42
+
43
+ self.norm1 = nn.LayerNorm(d_model)
44
+ self.norm2 = nn.LayerNorm(d_model)
45
+ self.norm3 = nn.LayerNorm(d_model)
46
+ self.dropout1 = nn.Dropout(dropout)
47
+ self.dropout2 = nn.Dropout(dropout)
48
+ self.dropout3 = nn.Dropout(dropout)
49
+
50
+ self.activation_str = activation
51
+ self.activation = get_activation_fn(activation)
52
+
53
+ # Where to add pos enc
54
+ self.pos_enc_at_attn = pos_enc_at_attn
55
+ self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
56
+ self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
57
+
58
+ def _forward_sa(self, tgt, query_pos):
59
+ # Self-Attention
60
+ tgt2 = self.norm1(tgt)
61
+ q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
62
+ tgt2 = self.self_attn(q, k, v=tgt2)
63
+ tgt = tgt + self.dropout1(tgt2)
64
+ return tgt
65
+
66
+ def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
67
+ kwds = {}
68
+ if num_k_exclude_rope > 0:
69
+ assert isinstance(self.cross_attn_image, RoPEAttention)
70
+ kwds = {"num_k_exclude_rope": num_k_exclude_rope}
71
+
72
+ # Cross-Attention
73
+ tgt2 = self.norm2(tgt)
74
+ tgt2 = self.cross_attn_image(
75
+ q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
76
+ k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
77
+ v=memory,
78
+ **kwds,
79
+ )
80
+ tgt = tgt + self.dropout2(tgt2)
81
+ return tgt
82
+
83
+ def forward(
84
+ self,
85
+ tgt,
86
+ memory,
87
+ pos: Optional[Tensor] = None,
88
+ query_pos: Optional[Tensor] = None,
89
+ num_k_exclude_rope: int = 0,
90
+ ) -> torch.Tensor:
91
+
92
+ # Self-Attn, Cross-Attn
93
+ tgt = self._forward_sa(tgt, query_pos)
94
+ tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
95
+ # MLP
96
+ tgt2 = self.norm3(tgt)
97
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
98
+ tgt = tgt + self.dropout3(tgt2)
99
+ return tgt
100
+
101
+
102
+ class MemoryAttention(nn.Module):
103
+ def __init__(
104
+ self,
105
+ d_model: int,
106
+ pos_enc_at_input: bool,
107
+ layer: nn.Module,
108
+ num_layers: int,
109
+ batch_first: bool = True, # Do layers expect batch first input?
110
+ ):
111
+ super().__init__()
112
+ self.d_model = d_model
113
+ self.layers = get_clones(layer, num_layers)
114
+ self.num_layers = num_layers
115
+ self.norm = nn.LayerNorm(d_model)
116
+ self.pos_enc_at_input = pos_enc_at_input
117
+ self.batch_first = batch_first
118
+
119
+ def forward(
120
+ self,
121
+ curr: torch.Tensor, # self-attention inputs
122
+ memory: torch.Tensor, # cross-attention inputs
123
+ curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
124
+ memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
125
+ num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
126
+ ):
127
+ if isinstance(curr, list):
128
+ assert isinstance(curr_pos, list)
129
+ assert len(curr) == len(curr_pos) == 1
130
+ curr, curr_pos = (
131
+ curr[0],
132
+ curr_pos[0],
133
+ )
134
+
135
+ assert (
136
+ curr.shape[1] == memory.shape[1]
137
+ ), "Batch size must be the same for curr and memory"
138
+
139
+ output = curr
140
+ if self.pos_enc_at_input and curr_pos is not None:
141
+ output = output + 0.1 * curr_pos
142
+
143
+ if self.batch_first:
144
+ # Convert to batch first
145
+ output = output.transpose(0, 1)
146
+ curr_pos = curr_pos.transpose(0, 1)
147
+ memory = memory.transpose(0, 1)
148
+ memory_pos = memory_pos.transpose(0, 1)
149
+
150
+ for layer in self.layers:
151
+ kwds = {}
152
+ if isinstance(layer.cross_attn_image, RoPEAttention):
153
+ kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
154
+
155
+ output = layer(
156
+ tgt=output,
157
+ memory=memory,
158
+ pos=memory_pos,
159
+ query_pos=curr_pos,
160
+ **kwds,
161
+ )
162
+ normed_output = self.norm(output)
163
+
164
+ if self.batch_first:
165
+ # Convert back to seq first
166
+ normed_output = normed_output.transpose(0, 1)
167
+ curr_pos = curr_pos.transpose(0, 1)
168
+
169
+ return normed_output
avs.code/v1m.code/model/visual/sam2/modeling/memory_encoder.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from model.visual.sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d
15
+
16
+
17
+ class MaskDownSampler(nn.Module):
18
+ """
19
+ Progressively downsample a mask by total_stride, each time by stride.
20
+ Note that LayerNorm is applied per *token*, like in ViT.
21
+
22
+ With each downsample (by a factor stride**2), channel capacity increases by the same factor.
23
+ In the end, we linearly project to embed_dim channels.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ embed_dim=256,
29
+ kernel_size=4,
30
+ stride=4,
31
+ padding=0,
32
+ total_stride=16,
33
+ activation=nn.GELU,
34
+ ):
35
+ super().__init__()
36
+ num_layers = int(math.log2(total_stride) // math.log2(stride))
37
+ assert stride**num_layers == total_stride
38
+ self.encoder = nn.Sequential()
39
+ mask_in_chans, mask_out_chans = 1, 1
40
+ for _ in range(num_layers):
41
+ mask_out_chans = mask_in_chans * (stride**2)
42
+ self.encoder.append(
43
+ nn.Conv2d(
44
+ mask_in_chans,
45
+ mask_out_chans,
46
+ kernel_size=kernel_size,
47
+ stride=stride,
48
+ padding=padding,
49
+ )
50
+ )
51
+ self.encoder.append(LayerNorm2d(mask_out_chans))
52
+ self.encoder.append(activation())
53
+ mask_in_chans = mask_out_chans
54
+
55
+ self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
56
+
57
+ def forward(self, x):
58
+ return self.encoder(x)
59
+
60
+
61
+ # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
62
+ class CXBlock(nn.Module):
63
+ r"""ConvNeXt Block. There are two equivalent implementations:
64
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
65
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
66
+ We use (2) as we find it slightly faster in PyTorch
67
+
68
+ Args:
69
+ dim (int): Number of input channels.
70
+ drop_path (float): Stochastic depth rate. Default: 0.0
71
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ dim,
77
+ kernel_size=7,
78
+ padding=3,
79
+ drop_path=0.0,
80
+ layer_scale_init_value=1e-6,
81
+ use_dwconv=True,
82
+ ):
83
+ super().__init__()
84
+ self.dwconv = nn.Conv2d(
85
+ dim,
86
+ dim,
87
+ kernel_size=kernel_size,
88
+ padding=padding,
89
+ groups=dim if use_dwconv else 1,
90
+ ) # depthwise conv
91
+ self.norm = LayerNorm2d(dim, eps=1e-6)
92
+ self.pwconv1 = nn.Linear(
93
+ dim, 4 * dim
94
+ ) # pointwise/1x1 convs, implemented with linear layers
95
+ self.act = nn.GELU()
96
+ self.pwconv2 = nn.Linear(4 * dim, dim)
97
+ self.gamma = (
98
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
99
+ if layer_scale_init_value > 0
100
+ else None
101
+ )
102
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
103
+
104
+ def forward(self, x):
105
+ input = x
106
+ x = self.dwconv(x)
107
+ x = self.norm(x)
108
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
109
+ x = self.pwconv1(x)
110
+ x = self.act(x)
111
+ x = self.pwconv2(x)
112
+ if self.gamma is not None:
113
+ x = self.gamma * x
114
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
115
+
116
+ x = input + self.drop_path(x)
117
+ return x
118
+
119
+
120
+ class Fuser(nn.Module):
121
+ def __init__(self, layer, num_layers, dim=None, input_projection=False):
122
+ super().__init__()
123
+ self.proj = nn.Identity()
124
+ self.layers = get_clones(layer, num_layers)
125
+
126
+ if input_projection:
127
+ assert dim is not None
128
+ self.proj = nn.Conv2d(dim, dim, kernel_size=1)
129
+
130
+ def forward(self, x):
131
+ # normally x: (N, C, H, W)
132
+ x = self.proj(x)
133
+ for layer in self.layers:
134
+ x = layer(x)
135
+ return x
136
+
137
+
138
+ class MemoryEncoder(nn.Module):
139
+ def __init__(
140
+ self,
141
+ out_dim,
142
+ mask_downsampler,
143
+ fuser,
144
+ position_encoding,
145
+ in_dim=256, # in_dim of pix_feats
146
+ ):
147
+ super().__init__()
148
+
149
+ self.mask_downsampler = mask_downsampler
150
+
151
+ self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
152
+ self.fuser = fuser
153
+ self.position_encoding = position_encoding
154
+ self.out_proj = nn.Identity()
155
+ if out_dim != in_dim:
156
+ self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
157
+
158
+ def forward(
159
+ self,
160
+ pix_feat: torch.Tensor,
161
+ masks: torch.Tensor,
162
+ skip_mask_sigmoid: bool = False,
163
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
164
+ ## Process masks
165
+ # sigmoid, so that less domain shift from gt masks which are bool
166
+ if not skip_mask_sigmoid:
167
+ masks = F.sigmoid(masks)
168
+ masks = self.mask_downsampler(masks)
169
+
170
+ ## Fuse pix_feats and downsampled masks
171
+ # in case the visual features are on CPU, cast them to CUDA
172
+ pix_feat = pix_feat.to(masks.device)
173
+
174
+ x = self.pix_feat_proj(pix_feat)
175
+ x = x + masks
176
+ x = self.fuser(x)
177
+ x = self.out_proj(x)
178
+
179
+ pos = self.position_encoding(x).to(x.dtype)
180
+
181
+ return {"vision_features": x, "vision_pos_enc": [pos]}
avs.code/v1m.code/model/visual/sam2/modeling/position_encoding.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Any, Optional, Tuple
9
+
10
+ import numpy as np
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+
16
+ class PositionEmbeddingSine(nn.Module):
17
+ """
18
+ This is a more standard version of the position embedding, very similar to the one
19
+ used by the Attention Is All You Need paper, generalized to work on images.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ num_pos_feats,
25
+ temperature: int = 10000,
26
+ normalize: bool = True,
27
+ scale: Optional[float] = None,
28
+ ):
29
+ super().__init__()
30
+ assert num_pos_feats % 2 == 0, "Expecting even model width"
31
+ self.num_pos_feats = num_pos_feats // 2
32
+ self.temperature = temperature
33
+ self.normalize = normalize
34
+ if scale is not None and normalize is False:
35
+ raise ValueError("normalize should be True if scale is passed")
36
+ if scale is None:
37
+ scale = 2 * math.pi
38
+ self.scale = scale
39
+
40
+ self.cache = {}
41
+
42
+ def _encode_xy(self, x, y):
43
+ # The positions are expected to be normalized
44
+ assert len(x) == len(y) and x.ndim == y.ndim == 1
45
+ x_embed = x * self.scale
46
+ y_embed = y * self.scale
47
+
48
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
49
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
50
+
51
+ pos_x = x_embed[:, None] / dim_t
52
+ pos_y = y_embed[:, None] / dim_t
53
+ pos_x = torch.stack(
54
+ (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
55
+ ).flatten(1)
56
+ pos_y = torch.stack(
57
+ (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
58
+ ).flatten(1)
59
+ return pos_x, pos_y
60
+
61
+ @torch.no_grad()
62
+ def encode_boxes(self, x, y, w, h):
63
+ pos_x, pos_y = self._encode_xy(x, y)
64
+ pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
65
+ return pos
66
+
67
+ encode = encode_boxes # Backwards compatibility
68
+
69
+ @torch.no_grad()
70
+ def encode_points(self, x, y, labels):
71
+ (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
72
+ assert bx == by and nx == ny and bx == bl and nx == nl
73
+ pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
74
+ pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
75
+ pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
76
+ return pos
77
+
78
+ @torch.no_grad()
79
+ def forward(self, x: torch.Tensor):
80
+ cache_key = (x.shape[-2], x.shape[-1])
81
+ if cache_key in self.cache:
82
+ return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
83
+ y_embed = (
84
+ torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
85
+ .view(1, -1, 1)
86
+ .repeat(x.shape[0], 1, x.shape[-1])
87
+ )
88
+ x_embed = (
89
+ torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
90
+ .view(1, 1, -1)
91
+ .repeat(x.shape[0], x.shape[-2], 1)
92
+ )
93
+
94
+ if self.normalize:
95
+ eps = 1e-6
96
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
97
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
98
+
99
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
100
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
101
+
102
+ pos_x = x_embed[:, :, :, None] / dim_t
103
+ pos_y = y_embed[:, :, :, None] / dim_t
104
+ pos_x = torch.stack(
105
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
106
+ ).flatten(3)
107
+ pos_y = torch.stack(
108
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
109
+ ).flatten(3)
110
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
111
+ self.cache[cache_key] = pos[0]
112
+ return pos
113
+
114
+
115
+ class PositionEmbeddingRandom(nn.Module):
116
+ """
117
+ Positional encoding using random spatial frequencies.
118
+ """
119
+
120
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
121
+ super().__init__()
122
+ if scale is None or scale <= 0.0:
123
+ scale = 1.0
124
+ self.register_buffer(
125
+ "positional_encoding_gaussian_matrix",
126
+ scale * torch.randn((2, num_pos_feats)),
127
+ )
128
+
129
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
130
+ """Positionally encode points that are normalized to [0,1]."""
131
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
132
+ coords = 2 * coords - 1
133
+ coords = coords @ self.positional_encoding_gaussian_matrix
134
+ coords = 2 * np.pi * coords
135
+ # outputs d_1 x ... x d_n x C shape
136
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
137
+
138
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
139
+ """Generate positional encoding for a grid of the specified size."""
140
+ h, w = size
141
+ device: Any = self.positional_encoding_gaussian_matrix.device
142
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
143
+ y_embed = grid.cumsum(dim=0) - 0.5
144
+ x_embed = grid.cumsum(dim=1) - 0.5
145
+ y_embed = y_embed / h
146
+ x_embed = x_embed / w
147
+
148
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
149
+ return pe.permute(2, 0, 1) # C x H x W
150
+
151
+ def forward_with_coords(
152
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
153
+ ) -> torch.Tensor:
154
+ """Positionally encode points that are not normalized to [0,1]."""
155
+ coords = coords_input.clone()
156
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
157
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
158
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
159
+
160
+
161
+ # Rotary Positional Encoding, adapted from:
162
+ # 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
163
+ # 2. https://github.com/naver-ai/rope-vit
164
+ # 3. https://github.com/lucidrains/rotary-embedding-torch
165
+
166
+
167
+ def init_t_xy(end_x: int, end_y: int):
168
+ t = torch.arange(end_x * end_y, dtype=torch.float32)
169
+ t_x = (t % end_x).float()
170
+ t_y = torch.div(t, end_x, rounding_mode="floor").float()
171
+ return t_x, t_y
172
+
173
+
174
+ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
175
+ freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
176
+ freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
177
+
178
+ t_x, t_y = init_t_xy(end_x, end_y)
179
+ freqs_x = torch.outer(t_x, freqs_x)
180
+ freqs_y = torch.outer(t_y, freqs_y)
181
+ freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
182
+ freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
183
+ return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
184
+
185
+
186
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
187
+ ndim = x.ndim
188
+ assert 0 <= 1 < ndim
189
+ assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
190
+ shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
191
+ return freqs_cis.view(*shape)
192
+
193
+
194
+ def apply_rotary_enc(
195
+ xq: torch.Tensor,
196
+ xk: torch.Tensor,
197
+ freqs_cis: torch.Tensor,
198
+ repeat_freqs_k: bool = False,
199
+ ):
200
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
201
+ xk_ = (
202
+ torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
203
+ if xk.shape[-2] != 0
204
+ else None
205
+ )
206
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
207
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
208
+ if xk_ is None:
209
+ # no keys to rotate, due to dropout
210
+ return xq_out.type_as(xq).to(xq.device), xk
211
+ # repeat freqs along seq_len dim to match k seq_len
212
+ if repeat_freqs_k:
213
+ r = xk_.shape[-2] // xq_.shape[-2]
214
+ if freqs_cis.is_cuda:
215
+ freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
216
+ else:
217
+ # torch.repeat on complex numbers may not be supported on non-CUDA devices
218
+ # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
219
+ freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
220
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
221
+ return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
avs.code/v1m.code/model/visual/sam2/modeling/sam/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
avs.code/v1m.code/model/visual/sam2/modeling/sam/mask_decoder.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Optional, Tuple, Type
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from model.visual.sam2.modeling.sam2_utils import LayerNorm2d, MLP
13
+
14
+
15
+ class MaskDecoder(nn.Module):
16
+ def __init__(
17
+ self,
18
+ *,
19
+ transformer_dim: int,
20
+ transformer: nn.Module,
21
+ num_multimask_outputs: int = 3,
22
+ activation: Type[nn.Module] = nn.GELU,
23
+ iou_head_depth: int = 3,
24
+ iou_head_hidden_dim: int = 256,
25
+ use_high_res_features: bool = False,
26
+ iou_prediction_use_sigmoid=False,
27
+ dynamic_multimask_via_stability=False,
28
+ dynamic_multimask_stability_delta=0.05,
29
+ dynamic_multimask_stability_thresh=0.98,
30
+ pred_obj_scores: bool = False,
31
+ pred_obj_scores_mlp: bool = False,
32
+ use_multimask_token_for_obj_ptr: bool = False,
33
+ ) -> None:
34
+ """
35
+ Predicts masks given an image and prompt embeddings, using a
36
+ transformer architecture.
37
+
38
+ Arguments:
39
+ transformer_dim (int): the channel dimension of the transformer
40
+ transformer (nn.Module): the transformer used to predict masks
41
+ num_multimask_outputs (int): the number of masks to predict
42
+ when disambiguating masks
43
+ activation (nn.Module): the type of activation to use when
44
+ upscaling masks
45
+ iou_head_depth (int): the depth of the MLP used to predict
46
+ mask quality
47
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
48
+ used to predict mask quality
49
+ """
50
+ super().__init__()
51
+ self.transformer_dim = transformer_dim
52
+ self.transformer = transformer
53
+
54
+ self.num_multimask_outputs = num_multimask_outputs
55
+
56
+ self.iou_token = nn.Embedding(1, transformer_dim)
57
+ self.num_mask_tokens = num_multimask_outputs + 1
58
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
59
+
60
+ self.pred_obj_scores = pred_obj_scores
61
+ if self.pred_obj_scores:
62
+ self.obj_score_token = nn.Embedding(1, transformer_dim)
63
+ self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
64
+
65
+ self.output_upscaling = nn.Sequential(
66
+ nn.ConvTranspose2d(
67
+ transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
68
+ ),
69
+ LayerNorm2d(transformer_dim // 4),
70
+ activation(),
71
+ nn.ConvTranspose2d(
72
+ transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
73
+ ),
74
+ activation(),
75
+ )
76
+ self.use_high_res_features = use_high_res_features
77
+ if use_high_res_features:
78
+ self.conv_s0 = nn.Conv2d(
79
+ transformer_dim, transformer_dim // 8, kernel_size=1, stride=1
80
+ )
81
+ self.conv_s1 = nn.Conv2d(
82
+ transformer_dim, transformer_dim // 4, kernel_size=1, stride=1
83
+ )
84
+
85
+ self.output_hypernetworks_mlps = nn.ModuleList(
86
+ [
87
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
88
+ for i in range(self.num_mask_tokens)
89
+ ]
90
+ )
91
+
92
+ self.iou_prediction_head = MLP(
93
+ transformer_dim,
94
+ iou_head_hidden_dim,
95
+ self.num_mask_tokens,
96
+ iou_head_depth,
97
+ sigmoid_output=iou_prediction_use_sigmoid,
98
+ )
99
+ if self.pred_obj_scores:
100
+ self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
101
+ if pred_obj_scores_mlp:
102
+ self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
103
+
104
+ # When outputting a single mask, optionally we can dynamically fall back to the best
105
+ # multimask output token if the single mask output token gives low stability scores.
106
+ self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
107
+ self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
108
+ self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
109
+
110
+ def forward(
111
+ self,
112
+ image_embeddings: torch.Tensor,
113
+ image_pe: torch.Tensor,
114
+ sparse_prompt_embeddings: torch.Tensor,
115
+ dense_prompt_embeddings: torch.Tensor,
116
+ multimask_output: bool,
117
+ repeat_image: bool,
118
+ high_res_features: Optional[List[torch.Tensor]] = None,
119
+ audio_res_features: Optional[List[torch.Tensor]] = None,
120
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
121
+ """
122
+ Predict masks given image and prompt embeddings.
123
+
124
+ Arguments:
125
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
126
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
127
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
128
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
129
+ multimask_output (bool): Whether to return multiple masks or a single
130
+ mask.
131
+
132
+ Returns:
133
+ torch.Tensor: batched predicted masks
134
+ torch.Tensor: batched predictions of mask quality
135
+ torch.Tensor: batched SAM token for mask output
136
+ """
137
+ masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
138
+ image_embeddings=image_embeddings,
139
+ image_pe=image_pe,
140
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
141
+ dense_prompt_embeddings=dense_prompt_embeddings,
142
+ repeat_image=repeat_image,
143
+ high_res_features=high_res_features,
144
+ audio_res_features_=audio_res_features
145
+ )
146
+
147
+ # Select the correct mask or masks for output
148
+ if multimask_output:
149
+ masks = masks[:, 1:, :, :]
150
+ iou_pred = iou_pred[:, 1:]
151
+ elif self.dynamic_multimask_via_stability and not self.training:
152
+ masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
153
+ else:
154
+ masks = masks[:, 0:1, :, :]
155
+ iou_pred = iou_pred[:, 0:1]
156
+
157
+
158
+ if multimask_output and self.use_multimask_token_for_obj_ptr:
159
+ sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
160
+ else:
161
+ # Take the mask output token. Here we *always* use the token for single mask output.
162
+ # At test time, even if we track after 1-click (and using multimask_output=True),
163
+ # we still take the single mask token here. The rationale is that we always track
164
+ # after multiple clicks during training, so the past tokens seen during training
165
+ # are always the single mask token (and we'll let it be the object-memory token).
166
+ sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
167
+
168
+ # Prepare output
169
+ return masks, iou_pred, sam_tokens_out, object_score_logits
170
+
171
+ def predict_masks(
172
+ self,
173
+ image_embeddings: torch.Tensor,
174
+ image_pe: torch.Tensor,
175
+ sparse_prompt_embeddings: torch.Tensor,
176
+ dense_prompt_embeddings: torch.Tensor,
177
+ repeat_image: bool,
178
+ high_res_features: Optional[List[torch.Tensor]] = None,
179
+ audio_res_features_: Optional[List[torch.Tensor]] = None
180
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
181
+ """Predicts masks. See 'forward' for more details."""
182
+ # Concatenate output tokens
183
+ s = 0
184
+ if self.pred_obj_scores:
185
+ output_tokens = torch.cat(
186
+ [
187
+ self.obj_score_token.weight,
188
+ self.iou_token.weight,
189
+ self.mask_tokens.weight,
190
+ ],
191
+ dim=0,
192
+ )
193
+ s = 1
194
+ else:
195
+ output_tokens = torch.cat(
196
+ [self.iou_token.weight, self.mask_tokens.weight], dim=0
197
+ )
198
+ output_tokens = output_tokens.unsqueeze(0).expand(
199
+ sparse_prompt_embeddings.size(0), -1, -1
200
+ )
201
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
202
+
203
+ # Expand per-image data in batch direction to be per-mask
204
+ if repeat_image:
205
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
206
+ else:
207
+ assert image_embeddings.shape[0] == tokens.shape[0]
208
+ src = image_embeddings
209
+ src = src + dense_prompt_embeddings
210
+ assert (
211
+ image_pe.size(0) == 1
212
+ ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
213
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
214
+ b, c, h, w = src.shape
215
+
216
+ # Run the transformer
217
+ hs, src = self.transformer(src, pos_src, tokens, audio_res_features_)
218
+ iou_token_out = hs[:, s, :]
219
+ mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
220
+
221
+ # Upscale mask embeddings and predict masks using the mask tokens
222
+ src = src.transpose(1, 2).view(b, c, h, w)
223
+
224
+ if not self.use_high_res_features:
225
+ upscaled_embedding = self.output_upscaling(src)
226
+ else:
227
+ dc1, ln1, act1, dc2, act2 = self.output_upscaling
228
+ feat_s0, feat_s1 = high_res_features
229
+ upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
230
+ upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
231
+
232
+ hyper_in_list: List[torch.Tensor] = []
233
+ for i in range(self.num_mask_tokens):
234
+ hyper_in_list.append(
235
+ self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
236
+ )
237
+ hyper_in = torch.stack(hyper_in_list, dim=1)
238
+ b, c, h, w = upscaled_embedding.shape
239
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
240
+
241
+ # Generate mask quality predictions
242
+ iou_pred = self.iou_prediction_head(iou_token_out)
243
+ if self.pred_obj_scores:
244
+ assert s == 1
245
+ object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
246
+ else:
247
+ # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
248
+ object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
249
+
250
+ return masks, iou_pred, mask_tokens_out, object_score_logits
251
+
252
+ def _get_stability_scores(self, mask_logits):
253
+ """
254
+ Compute stability scores of the mask logits based on the IoU between upper and
255
+ lower thresholds.
256
+ """
257
+ mask_logits = mask_logits.flatten(-2)
258
+ stability_delta = self.dynamic_multimask_stability_delta
259
+ area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
260
+ area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
261
+ stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
262
+ return stability_scores
263
+
264
+ def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
265
+ """
266
+ When outputting a single mask, if the stability score from the current single-mask
267
+ output (based on output token 0) falls below a threshold, we instead select from
268
+ multi-mask outputs (based on output token 1~3) the mask with the highest predicted
269
+ IoU score. This is intended to ensure a valid mask for both clicking and tracking.
270
+ """
271
+ # The best mask from multimask output tokens (1~3)
272
+ multimask_logits = all_mask_logits[:, 1:, :, :]
273
+ multimask_iou_scores = all_iou_scores[:, 1:]
274
+ best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
275
+ batch_inds = torch.arange(
276
+ multimask_iou_scores.size(0), device=all_iou_scores.device
277
+ )
278
+ best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
279
+ best_multimask_logits = best_multimask_logits.unsqueeze(1)
280
+ best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
281
+ best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
282
+
283
+ # The mask from singlemask output token 0 and its stability score
284
+ singlemask_logits = all_mask_logits[:, 0:1, :, :]
285
+ singlemask_iou_scores = all_iou_scores[:, 0:1]
286
+ stability_scores = self._get_stability_scores(singlemask_logits)
287
+ is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
288
+
289
+ # Dynamically fall back to best multimask output upon low stability scores.
290
+ mask_logits_out = torch.where(
291
+ is_stable[..., None, None].expand_as(singlemask_logits),
292
+ singlemask_logits,
293
+ best_multimask_logits,
294
+ )
295
+ iou_scores_out = torch.where(
296
+ is_stable.expand_as(singlemask_iou_scores),
297
+ singlemask_iou_scores,
298
+ best_multimask_iou_scores,
299
+ )
300
+ return mask_logits_out, iou_scores_out
avs.code/v1m.code/model/visual/sam2/modeling/sam/prompt_encoder.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional, Tuple, Type
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from model.visual.sam2.modeling.position_encoding import PositionEmbeddingRandom
13
+
14
+ from model.visual.sam2.modeling.sam2_utils import LayerNorm2d
15
+
16
+
17
+ class PromptEncoder(nn.Module):
18
+ def __init__(
19
+ self,
20
+ embed_dim: int,
21
+ image_embedding_size: Tuple[int, int],
22
+ input_image_size: Tuple[int, int],
23
+ mask_in_chans: int,
24
+ activation: Type[nn.Module] = nn.GELU,
25
+ ) -> None:
26
+ """
27
+ Encodes prompts for input to SAM's mask decoder.
28
+
29
+ Arguments:
30
+ embed_dim (int): The prompts' embedding dimension
31
+ image_embedding_size (tuple(int, int)): The spatial size of the
32
+ image embedding, as (H, W).
33
+ input_image_size (int): The padded size of the image as input
34
+ to the image encoder, as (H, W).
35
+ mask_in_chans (int): The number of hidden channels used for
36
+ encoding input masks.
37
+ activation (nn.Module): The activation to use when encoding
38
+ input masks.
39
+ """
40
+ super().__init__()
41
+ self.embed_dim = embed_dim
42
+ self.input_image_size = input_image_size
43
+ self.image_embedding_size = image_embedding_size
44
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
45
+
46
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
47
+ point_embeddings = [
48
+ nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
49
+ ]
50
+ self.point_embeddings = nn.ModuleList(point_embeddings)
51
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
52
+
53
+ self.mask_input_size = (
54
+ 4 * image_embedding_size[0],
55
+ 4 * image_embedding_size[1],
56
+ )
57
+ self.mask_downscaling = nn.Sequential(
58
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
59
+ LayerNorm2d(mask_in_chans // 4),
60
+ activation(),
61
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
62
+ LayerNorm2d(mask_in_chans),
63
+ activation(),
64
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
65
+ )
66
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
67
+
68
+ def get_dense_pe(self) -> torch.Tensor:
69
+ """
70
+ Returns the positional encoding used to encode point prompts,
71
+ applied to a dense set of points the shape of the image encoding.
72
+
73
+ Returns:
74
+ torch.Tensor: Positional encoding with shape
75
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
76
+ """
77
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
78
+
79
+ def _embed_points(
80
+ self,
81
+ points: torch.Tensor,
82
+ labels: torch.Tensor,
83
+ pad: bool,
84
+ ) -> torch.Tensor:
85
+ """Embeds point prompts."""
86
+ points = points + 0.5 # Shift to center of pixel
87
+ if pad:
88
+ padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
89
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
90
+ points = torch.cat([points, padding_point], dim=1)
91
+ labels = torch.cat([labels, padding_label], dim=1)
92
+ point_embedding = self.pe_layer.forward_with_coords(
93
+ points, self.input_image_size
94
+ )
95
+ point_embedding[labels == -1] = 0.0
96
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
97
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
98
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
99
+ point_embedding[labels == 2] += self.point_embeddings[2].weight
100
+ point_embedding[labels == 3] += self.point_embeddings[3].weight
101
+ return point_embedding
102
+
103
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
104
+ """Embeds box prompts."""
105
+ boxes = boxes + 0.5 # Shift to center of pixel
106
+ coords = boxes.reshape(-1, 2, 2)
107
+ corner_embedding = self.pe_layer.forward_with_coords(
108
+ coords, self.input_image_size
109
+ )
110
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
111
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
112
+ return corner_embedding
113
+
114
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
115
+ """Embeds mask inputs."""
116
+ mask_embedding = self.mask_downscaling(masks)
117
+ return mask_embedding
118
+
119
+ def _get_batch_size(
120
+ self,
121
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
122
+ boxes: Optional[torch.Tensor],
123
+ masks: Optional[torch.Tensor],
124
+ ) -> int:
125
+ """
126
+ Gets the batch size of the output given the batch size of the input prompts.
127
+ """
128
+ if points is not None:
129
+ return points[0].shape[0]
130
+ elif boxes is not None:
131
+ return boxes.shape[0]
132
+ elif masks is not None:
133
+ return masks.shape[0]
134
+ else:
135
+ return 1
136
+
137
+ def _get_device(self) -> torch.device:
138
+ return self.point_embeddings[0].weight.device
139
+
140
+ def forward(
141
+ self,
142
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
143
+ boxes: Optional[torch.Tensor],
144
+ masks: Optional[torch.Tensor],
145
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
146
+ """
147
+ Embeds different types of prompts, returning both sparse and dense
148
+ embeddings.
149
+
150
+ Arguments:
151
+ points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
152
+ and labels to embed.
153
+ boxes (torch.Tensor or none): boxes to embed
154
+ masks (torch.Tensor or none): masks to embed
155
+
156
+ Returns:
157
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
158
+ BxNx(embed_dim), where N is determined by the number of input points
159
+ and boxes.
160
+ torch.Tensor: dense embeddings for the masks, in the shape
161
+ Bx(embed_dim)x(embed_H)x(embed_W)
162
+ """
163
+ # we only utilise sounding as prompt.
164
+ bs = self._get_batch_size(points, boxes, masks)
165
+ sparse_embeddings = torch.empty(
166
+ (bs, 0, self.embed_dim), device=self._get_device()
167
+ )
168
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
169
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
170
+ )
171
+ '''
172
+ if points is not None:
173
+ coords, labels = points
174
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
175
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
176
+ if boxes is not None:
177
+ box_embeddings = self._embed_boxes(boxes)
178
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
179
+
180
+ if masks is not None:
181
+ dense_embeddings = self._embed_masks(masks)
182
+ else:
183
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
184
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
185
+ )
186
+ '''
187
+ return sparse_embeddings, dense_embeddings
188
+
avs.code/v1m.code/model/visual/sam2/modeling/sam/transformer.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import contextlib
8
+ import math
9
+ import warnings
10
+ from functools import partial
11
+ from typing import Tuple, Type
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch import nn, Tensor
16
+
17
+ from model.visual.sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
18
+ from model.visual.sam2.modeling.sam2_utils import MLP
19
+ from model.visual.sam2.utils.misc import get_sdpa_settings
20
+
21
+ warnings.simplefilter(action="ignore", category=FutureWarning)
22
+ # Check whether Flash Attention is available (and use it by default)
23
+ OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
24
+ # A fallback setting to allow all available kernels if Flash Attention fails
25
+ ALLOW_ALL_KERNELS = False
26
+
27
+
28
+ def sdp_kernel_context(dropout_p):
29
+ """
30
+ Get the context for the attention scaled dot-product kernel. We use Flash Attention
31
+ by default, but fall back to all available kernels if Flash Attention fails.
32
+ """
33
+ if ALLOW_ALL_KERNELS:
34
+ return contextlib.nullcontext()
35
+
36
+ return torch.backends.cuda.sdp_kernel(
37
+ enable_flash=USE_FLASH_ATTN,
38
+ # if Flash attention kernel is off, then math kernel needs to be enabled
39
+ enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
40
+ enable_mem_efficient=OLD_GPU,
41
+ )
42
+
43
+
44
+ class TwoWayTransformer(nn.Module):
45
+ def __init__(
46
+ self,
47
+ depth: int,
48
+ embedding_dim: int,
49
+ num_heads: int,
50
+ mlp_dim: int,
51
+ activation: Type[nn.Module] = nn.ReLU,
52
+ attention_downsample_rate: int = 2,
53
+ ) -> None:
54
+ """
55
+ A transformer decoder that attends to an input image using
56
+ queries whose positional embedding is supplied.
57
+
58
+ Args:
59
+ depth (int): number of layers in the transformer
60
+ embedding_dim (int): the channel dimension for the input embeddings
61
+ num_heads (int): the number of heads for multihead attention. Must
62
+ divide embedding_dim
63
+ mlp_dim (int): the channel dimension internal to the MLP block
64
+ activation (nn.Module): the activation to use in the MLP block
65
+ """
66
+ super().__init__()
67
+ self.depth = depth
68
+ self.embedding_dim = embedding_dim
69
+ self.num_heads = num_heads
70
+ self.mlp_dim = mlp_dim
71
+ self.layers = nn.ModuleList()
72
+
73
+ for i in range(depth):
74
+ self.layers.append(
75
+ TwoWayAttentionBlock(
76
+ embedding_dim=embedding_dim,
77
+ num_heads=num_heads,
78
+ mlp_dim=mlp_dim,
79
+ activation=activation,
80
+ attention_downsample_rate=attention_downsample_rate,
81
+ skip_first_layer_pe=(i == 0),
82
+ )
83
+ )
84
+
85
+ self.final_attn_token_to_image = Attention(
86
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
87
+ )
88
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
89
+
90
+ def forward(
91
+ self,
92
+ image_embedding: Tensor,
93
+ image_pe: Tensor,
94
+ point_embedding: Tensor,
95
+ audio_res: [],
96
+ ) -> Tuple[Tensor, Tensor]:
97
+ """
98
+ Args:
99
+ image_embedding (torch.Tensor): image to attend to. Should be shape
100
+ B x embedding_dim x h x w for any h and w.
101
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
102
+ have the same shape as image_embedding.
103
+ point_embedding (torch.Tensor): the embedding to add to the query points.
104
+ Must have shape B x N_points x embedding_dim for any N_points.
105
+
106
+ Returns:
107
+ torch.Tensor: the processed point_embedding
108
+ torch.Tensor: the processed image_embedding
109
+ """
110
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
111
+ bs, c, h, w = image_embedding.shape
112
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
113
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
114
+
115
+ visual_res, audio_res = audio_res
116
+
117
+ # Prepare queries
118
+ queries = point_embedding
119
+ keys = image_embedding
120
+ # Apply transformer blocks and final layernorm
121
+ for i, layer in enumerate(self.layers):
122
+ keys = keys + visual_res[i]
123
+ queries[:, 2:6] = queries[:, 2:6] + audio_res[i]
124
+ queries, keys = layer(
125
+ queries=queries,
126
+ keys=keys,
127
+ query_pe=point_embedding,
128
+ key_pe=image_pe,
129
+ )
130
+
131
+ queries[:, 2:6] = queries[:, 2:6] + audio_res[-1]
132
+ keys = keys + visual_res[-1]
133
+
134
+ # Apply the final attention layer from the points to the image
135
+ q = queries + point_embedding
136
+ k = keys + image_pe
137
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
138
+ queries = queries + attn_out
139
+ queries = self.norm_final_attn(queries)
140
+
141
+ return queries, keys
142
+
143
+
144
+ class TwoWayAttentionBlock(nn.Module):
145
+ def __init__(
146
+ self,
147
+ embedding_dim: int,
148
+ num_heads: int,
149
+ mlp_dim: int = 2048,
150
+ activation: Type[nn.Module] = nn.ReLU,
151
+ attention_downsample_rate: int = 2,
152
+ skip_first_layer_pe: bool = False,
153
+ ) -> None:
154
+ """
155
+ A transformer block with four layers: (1) self-attention of sparse
156
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
157
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
158
+ inputs.
159
+
160
+ Arguments:
161
+ embedding_dim (int): the channel dimension of the embeddings
162
+ num_heads (int): the number of heads in the attention layers
163
+ mlp_dim (int): the hidden dimension of the mlp block
164
+ activation (nn.Module): the activation of the mlp block
165
+ skip_first_layer_pe (bool): skip the PE on the first layer
166
+ """
167
+ super().__init__()
168
+ self.self_attn = Attention(embedding_dim, num_heads)
169
+ self.norm1 = nn.LayerNorm(embedding_dim)
170
+
171
+ self.cross_attn_token_to_image = Attention(
172
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
173
+ )
174
+ self.norm2 = nn.LayerNorm(embedding_dim)
175
+
176
+ self.mlp = MLP(
177
+ embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
178
+ )
179
+ self.norm3 = nn.LayerNorm(embedding_dim)
180
+
181
+ self.norm4 = nn.LayerNorm(embedding_dim)
182
+ self.cross_attn_image_to_token = Attention(
183
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
184
+ )
185
+
186
+ self.skip_first_layer_pe = skip_first_layer_pe
187
+
188
+ def forward(
189
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
190
+ ) -> Tuple[Tensor, Tensor]:
191
+ # Self attention block
192
+ if self.skip_first_layer_pe:
193
+ queries = self.self_attn(q=queries, k=queries, v=queries)
194
+ else:
195
+ q = queries + query_pe
196
+ attn_out = self.self_attn(q=q, k=q, v=queries)
197
+ queries = queries + attn_out
198
+ queries = self.norm1(queries)
199
+
200
+ # Cross attention block, tokens attending to image embedding
201
+ q = queries + query_pe
202
+ k = keys + key_pe
203
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
204
+ queries = queries + attn_out
205
+ queries = self.norm2(queries)
206
+
207
+ # MLP block
208
+ mlp_out = self.mlp(queries)
209
+ queries = queries + mlp_out
210
+ queries = self.norm3(queries)
211
+
212
+ # Cross attention block, image embedding attending to tokens
213
+ q = queries + query_pe
214
+ k = keys + key_pe
215
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
216
+ keys = keys + attn_out
217
+ keys = self.norm4(keys)
218
+
219
+ return queries, keys
220
+
221
+
222
+ class Attention(nn.Module):
223
+ """
224
+ An attention layer that allows for downscaling the size of the embedding
225
+ after projection to queries, keys, and values.
226
+ """
227
+
228
+ def __init__(
229
+ self,
230
+ embedding_dim: int,
231
+ num_heads: int,
232
+ downsample_rate: int = 1,
233
+ dropout: float = 0.0,
234
+ kv_in_dim: int = None,
235
+ ) -> None:
236
+ super().__init__()
237
+ self.embedding_dim = embedding_dim
238
+ self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
239
+ self.internal_dim = embedding_dim // downsample_rate
240
+ self.num_heads = num_heads
241
+ assert (
242
+ self.internal_dim % num_heads == 0
243
+ ), "num_heads must divide embedding_dim."
244
+
245
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
246
+ self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
247
+ self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
248
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
249
+
250
+ self.dropout_p = dropout
251
+
252
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
253
+ b, n, c = x.shape
254
+ x = x.reshape(b, n, num_heads, c // num_heads)
255
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
256
+
257
+ def _recombine_heads(self, x: Tensor) -> Tensor:
258
+ b, n_heads, n_tokens, c_per_head = x.shape
259
+ x = x.transpose(1, 2)
260
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
261
+
262
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
263
+ # Input projections
264
+ q = self.q_proj(q)
265
+ k = self.k_proj(k)
266
+ v = self.v_proj(v)
267
+
268
+ # Separate into heads
269
+ q = self._separate_heads(q, self.num_heads)
270
+ k = self._separate_heads(k, self.num_heads)
271
+ v = self._separate_heads(v, self.num_heads)
272
+
273
+ dropout_p = self.dropout_p if self.training else 0.0
274
+ # Attention
275
+ try:
276
+ with sdp_kernel_context(dropout_p):
277
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
278
+ except Exception as e:
279
+ # Fall back to all kernels if the Flash attention kernel fails
280
+ warnings.warn(
281
+ f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
282
+ f"kernels for scaled_dot_product_attention (which may have a slower speed).",
283
+ category=UserWarning,
284
+ stacklevel=2,
285
+ )
286
+ global ALLOW_ALL_KERNELS
287
+ ALLOW_ALL_KERNELS = True
288
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
289
+
290
+ out = self._recombine_heads(out)
291
+ out = self.out_proj(out)
292
+
293
+ return out
294
+
295
+
296
+ class RoPEAttention(Attention):
297
+ """Attention with rotary position encoding."""
298
+
299
+ def __init__(
300
+ self,
301
+ *args,
302
+ rope_theta=10000.0,
303
+ # whether to repeat q rope to match k length
304
+ # this is needed for cross-attention to memories
305
+ rope_k_repeat=False,
306
+ feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
307
+ **kwargs,
308
+ ):
309
+ super().__init__(*args, **kwargs)
310
+
311
+ self.compute_cis = partial(
312
+ compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
313
+ )
314
+ freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
315
+ self.freqs_cis = freqs_cis
316
+ self.rope_k_repeat = rope_k_repeat
317
+
318
+ def forward(
319
+ self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
320
+ ) -> Tensor:
321
+ # Input projections
322
+ q = self.q_proj(q)
323
+ k = self.k_proj(k)
324
+ v = self.v_proj(v)
325
+
326
+ # Separate into heads
327
+ q = self._separate_heads(q, self.num_heads)
328
+ k = self._separate_heads(k, self.num_heads)
329
+ v = self._separate_heads(v, self.num_heads)
330
+
331
+ # Apply rotary position encoding
332
+ w = h = math.sqrt(q.shape[-2])
333
+ self.freqs_cis = self.freqs_cis.to(q.device)
334
+ if self.freqs_cis.shape[0] != q.shape[-2]:
335
+ self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
336
+ if q.shape[-2] != k.shape[-2]:
337
+ assert self.rope_k_repeat
338
+
339
+ num_k_rope = k.size(-2) - num_k_exclude_rope
340
+ q, k[:, :, :num_k_rope] = apply_rotary_enc(
341
+ q,
342
+ k[:, :, :num_k_rope],
343
+ freqs_cis=self.freqs_cis,
344
+ repeat_freqs_k=self.rope_k_repeat,
345
+ )
346
+
347
+ dropout_p = self.dropout_p if self.training else 0.0
348
+ # Attention
349
+ try:
350
+ with sdp_kernel_context(dropout_p):
351
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
352
+ except Exception as e:
353
+ # Fall back to all kernels if the Flash attention kernel fails
354
+ warnings.warn(
355
+ f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
356
+ f"kernels for scaled_dot_product_attention (which may have a slower speed).",
357
+ category=UserWarning,
358
+ stacklevel=2,
359
+ )
360
+ global ALLOW_ALL_KERNELS
361
+ ALLOW_ALL_KERNELS = True
362
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
363
+
364
+ out = self._recombine_heads(out)
365
+ out = self.out_proj(out)
366
+
367
+ return out
avs.code/v1m.code/model/visual/sam2/modeling/sam2_base.py ADDED
@@ -0,0 +1,940 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.distributed
9
+ import torch.nn.functional as F
10
+
11
+ from torch.nn.init import trunc_normal_
12
+
13
+ from model.visual.sam2.modeling.sam.mask_decoder import MaskDecoder
14
+ from model.visual.sam2.modeling.sam.prompt_encoder import PromptEncoder
15
+ from model.visual.sam2.modeling.sam.transformer import TwoWayTransformer
16
+ from model.visual.sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames
17
+
18
+ # a large negative value as a placeholder score for missing objects
19
+ NO_OBJ_SCORE = -1024.0
20
+
21
+
22
+ class SAM2Base(torch.nn.Module):
23
+ def __init__(
24
+ self,
25
+ image_encoder,
26
+ memory_attention,
27
+ memory_encoder,
28
+ num_maskmem=7, # default 1 input frame + 6 previous frames
29
+ image_size=512,
30
+ backbone_stride=16, # stride of the image backbone output
31
+ sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob
32
+ sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob
33
+ # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks
34
+ binarize_mask_from_pts_for_mem_enc=False,
35
+ use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder
36
+ # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit,
37
+ # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model
38
+ # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM.
39
+ max_cond_frames_in_attn=-1,
40
+ # on the first frame, whether to directly add the no-memory embedding to the image feature
41
+ # (instead of using the transformer encoder)
42
+ directly_add_no_mem_embed=False,
43
+ # whether to use high-resolution feature maps in the SAM mask decoder
44
+ use_high_res_features_in_sam=False,
45
+ # whether to output multiple (3) masks for the first click on initial conditioning frames
46
+ multimask_output_in_sam=False,
47
+ # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`;
48
+ # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points)
49
+ multimask_min_pt_num=1,
50
+ multimask_max_pt_num=1,
51
+ # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`)
52
+ multimask_output_for_tracking=False,
53
+ # Whether to use multimask tokens for obj ptr; Only relevant when both
54
+ # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True
55
+ use_multimask_token_for_obj_ptr: bool = False,
56
+ # whether to use sigmoid to restrict ious prediction to [0-1]
57
+ iou_prediction_use_sigmoid=False,
58
+ # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5).
59
+ # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of
60
+ # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame.
61
+ memory_temporal_stride_for_eval=1,
62
+ # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks)
63
+ non_overlap_masks_for_mem_enc=False,
64
+ # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
65
+ use_obj_ptrs_in_encoder=False,
66
+ # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`)
67
+ max_obj_ptrs_in_encoder=16,
68
+ # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`)
69
+ add_tpos_enc_to_obj_ptrs=True,
70
+ # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference
71
+ # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
72
+ proj_tpos_enc_in_obj_ptrs=False,
73
+ # whether to use signed distance (instead of unsigned absolute distance) in the temporal positional encoding in the object pointers
74
+ # (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
75
+ use_signed_tpos_enc_to_obj_ptrs=False,
76
+ # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation
77
+ # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking)
78
+ only_obj_ptrs_in_the_past_for_eval=False,
79
+ # Whether to predict if there is an object in the frame
80
+ pred_obj_scores: bool = False,
81
+ # Whether to use an MLP to predict object scores
82
+ pred_obj_scores_mlp: bool = False,
83
+ # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True;
84
+ # Whether to have a fixed no obj pointer when there is no object present
85
+ # or to use it as an additive embedding with obj_ptr produced by decoder
86
+ fixed_no_obj_ptr: bool = False,
87
+ # Soft no object, i.e. mix in no_obj_ptr softly,
88
+ # hope to make recovery easier if there is a mistake and mitigate accumulation of errors
89
+ soft_no_obj_ptr: bool = False,
90
+ use_mlp_for_obj_ptr_proj: bool = False,
91
+ # add no obj embedding to spatial frames
92
+ no_obj_embed_spatial: bool = False,
93
+ # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class.
94
+ sam_mask_decoder_extra_args=None,
95
+ compile_image_encoder: bool = False,
96
+ ):
97
+ super().__init__()
98
+
99
+ # Part 1: the image backbone
100
+ self.image_encoder = image_encoder
101
+ # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
102
+ self.use_high_res_features_in_sam = use_high_res_features_in_sam
103
+ self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
104
+ self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
105
+ self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
106
+ if use_obj_ptrs_in_encoder:
107
+ # A conv layer to downsample the mask prompt to stride 4 (the same stride as
108
+ # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
109
+ # so that it can be fed into the SAM mask decoder to generate a pointer.
110
+ self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
111
+ self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
112
+ if proj_tpos_enc_in_obj_ptrs:
113
+ assert add_tpos_enc_to_obj_ptrs # these options need to be used together
114
+ self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
115
+ self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs
116
+ self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
117
+
118
+ # Part 2: memory attention to condition current frame's visual features
119
+ # with memories (and obj ptrs) from past frames
120
+ self.memory_attention = memory_attention
121
+
122
+ #### this is for Version 2.0
123
+ # self.hidden_dim = memory_attention.d_model
124
+ #### this is for Version 2.1
125
+ # self.hidden_dim = image_encoder.neck.d_model
126
+ self.hidden_dim = 256 # well, it is always 256 anyway.
127
+
128
+ # Part 3: memory encoder for the previous frame's outputs
129
+ self.memory_encoder = memory_encoder
130
+ self.mem_dim = self.hidden_dim
131
+ if hasattr(self.memory_encoder, "out_proj") and hasattr(
132
+ self.memory_encoder.out_proj, "weight"
133
+ ):
134
+ # if there is compression of memories along channel dim
135
+ self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
136
+ self.num_maskmem = num_maskmem # Number of memories accessible
137
+ # Temporal encoding of the memories
138
+ self.maskmem_tpos_enc = torch.nn.Parameter(
139
+ torch.zeros(num_maskmem, 1, 1, self.mem_dim)
140
+ )
141
+ trunc_normal_(self.maskmem_tpos_enc, std=0.02)
142
+ # a single token to indicate no memory embedding from previous frames
143
+ self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
144
+ self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
145
+ trunc_normal_(self.no_mem_embed, std=0.02)
146
+ trunc_normal_(self.no_mem_pos_enc, std=0.02)
147
+ self.directly_add_no_mem_embed = directly_add_no_mem_embed
148
+ # Apply sigmoid to the output raw mask logits (to turn them from
149
+ # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
150
+ self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
151
+ self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
152
+ self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
153
+ self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
154
+ self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
155
+ # On frames with mask input, whether to directly output the input mask without
156
+ # using a SAM prompt encoder + mask decoder
157
+ self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
158
+ self.multimask_output_in_sam = multimask_output_in_sam
159
+ self.multimask_min_pt_num = multimask_min_pt_num
160
+ self.multimask_max_pt_num = multimask_max_pt_num
161
+ self.multimask_output_for_tracking = multimask_output_for_tracking
162
+ self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
163
+ self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid
164
+
165
+ # Part 4: SAM-style prompt encoder (for both mask and point inputs)
166
+ # and SAM-style mask decoder for the final mask output
167
+ self.image_size = image_size
168
+ self.backbone_stride = backbone_stride
169
+ self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
170
+ self.pred_obj_scores = pred_obj_scores
171
+ self.pred_obj_scores_mlp = pred_obj_scores_mlp
172
+ self.fixed_no_obj_ptr = fixed_no_obj_ptr
173
+ self.soft_no_obj_ptr = soft_no_obj_ptr
174
+ if self.fixed_no_obj_ptr:
175
+ assert self.pred_obj_scores
176
+ assert self.use_obj_ptrs_in_encoder
177
+ if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
178
+ self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
179
+ trunc_normal_(self.no_obj_ptr, std=0.02)
180
+ self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
181
+ self.no_obj_embed_spatial = None
182
+ if no_obj_embed_spatial:
183
+ self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim))
184
+ trunc_normal_(self.no_obj_embed_spatial, std=0.02)
185
+
186
+ self._build_sam_heads()
187
+ self.max_cond_frames_in_attn = max_cond_frames_in_attn
188
+
189
+ # Model compilation
190
+ if compile_image_encoder:
191
+ # Compile the forward function (not the full module) to allow loading checkpoints.
192
+ print(
193
+ "Image encoder compilation is enabled. First forward pass will be slow."
194
+ )
195
+ self.image_encoder.forward = torch.compile(
196
+ self.image_encoder.forward,
197
+ mode="max-autotune",
198
+ fullgraph=True,
199
+ dynamic=False,
200
+ )
201
+
202
+ ### we fix the use_mask_input_as_output_without_sam to be turned off.
203
+ self.use_mask_input_as_output_without_sam = False
204
+
205
+
206
+ @property
207
+ def device(self):
208
+ return next(self.parameters()).device
209
+
210
+ def forward(self, *args, **kwargs):
211
+ raise NotImplementedError(
212
+ "Please use the corresponding methods in SAM2VideoPredictor for inference or SAM2Train for training/fine-tuning"
213
+ "See notebooks/video_predictor_example.ipynb for an inference example."
214
+ )
215
+
216
+ def _build_sam_heads(self):
217
+ """Build SAM-style prompt encoder and mask decoder."""
218
+ self.sam_prompt_embed_dim = self.hidden_dim
219
+ self.sam_image_embedding_size = self.image_size // self.backbone_stride
220
+
221
+ # build PromptEncoder and MaskDecoder from SAM
222
+ # (their hyperparameters like `mask_in_chans=16` are from SAM code)
223
+ self.sam_prompt_encoder = PromptEncoder(
224
+ embed_dim=self.sam_prompt_embed_dim,
225
+ image_embedding_size=(
226
+ self.sam_image_embedding_size,
227
+ self.sam_image_embedding_size,
228
+ ),
229
+ input_image_size=(self.image_size, self.image_size),
230
+ mask_in_chans=16,
231
+ )
232
+ self.sam_mask_decoder = MaskDecoder(
233
+ num_multimask_outputs=3,
234
+ transformer=TwoWayTransformer(
235
+ depth=2,
236
+ embedding_dim=self.sam_prompt_embed_dim,
237
+ mlp_dim=2048,
238
+ num_heads=8,
239
+ ),
240
+ transformer_dim=self.sam_prompt_embed_dim,
241
+ iou_head_depth=3,
242
+ iou_head_hidden_dim=256,
243
+ use_high_res_features=self.use_high_res_features_in_sam,
244
+ iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
245
+ pred_obj_scores=self.pred_obj_scores,
246
+ pred_obj_scores_mlp=self.pred_obj_scores_mlp,
247
+ use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
248
+ **(self.sam_mask_decoder_extra_args or {}),
249
+ )
250
+ if self.use_obj_ptrs_in_encoder:
251
+ # a linear projection on SAM output tokens to turn them into object pointers
252
+ self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
253
+ if self.use_mlp_for_obj_ptr_proj:
254
+ self.obj_ptr_proj = MLP(
255
+ self.hidden_dim, self.hidden_dim, self.hidden_dim, 3
256
+ )
257
+ else:
258
+ self.obj_ptr_proj = torch.nn.Identity()
259
+ if self.proj_tpos_enc_in_obj_ptrs:
260
+ # a linear projection on temporal positional encoding in object pointers to
261
+ # avoid potential interference with spatial positional encoding
262
+ self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
263
+ else:
264
+ self.obj_ptr_tpos_proj = torch.nn.Identity()
265
+
266
+ def _forward_sam_heads(
267
+ self,
268
+ backbone_features,
269
+ point_inputs=None,
270
+ mask_inputs=None,
271
+ high_res_features=None,
272
+ multimask_output=False,
273
+ audio_res=None
274
+ ):
275
+ """
276
+ Forward SAM prompt encoders and mask heads.
277
+
278
+ Inputs:
279
+ - backbone_features: image features of [B, C, H, W] shape
280
+ - point_inputs: a dictionary with "point_coords" and "point_labels", where
281
+ 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the
282
+ absolute pixel-unit coordinate in (x, y) format of the P input points
283
+ 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means
284
+ positive clicks, 0 means negative clicks, and -1 means padding
285
+ - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the
286
+ same spatial size as the image.
287
+ - high_res_features: either 1) None or 2) or a list of length 2 containing
288
+ two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively,
289
+ which will be used as high-resolution feature maps for SAM decoder.
290
+ - multimask_output: if it's True, we output 3 candidate masks and their 3
291
+ corresponding IoU estimates, and if it's False, we output only 1 mask and
292
+ its corresponding IoU estimate.
293
+
294
+ Outputs:
295
+ - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if
296
+ `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM
297
+ output mask logits (before sigmoid) for the low-resolution masks, with 4x
298
+ the resolution (1/4 stride) of the input backbone_features.
299
+ - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3
300
+ if `multimask_output=True` and M = 1 if `multimask_output=False`),
301
+ upsampled from the low-resolution masks, with shape size as the image
302
+ (stride is 1 pixel).
303
+ - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1
304
+ if `multimask_output=False`), the estimated IoU of each output mask.
305
+ - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`.
306
+ If `multimask_output=True`, it's the mask with the highest IoU estimate.
307
+ If `multimask_output=False`, it's the same as `low_res_multimasks`.
308
+ - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`.
309
+ If `multimask_output=True`, it's the mask with the highest IoU estimate.
310
+ If `multimask_output=False`, it's the same as `high_res_multimasks`.
311
+ - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted
312
+ based on the output token from the SAM mask decoder.
313
+ """
314
+ B = backbone_features.size(0)
315
+ device = backbone_features.device
316
+ assert backbone_features.size(1) == self.sam_prompt_embed_dim
317
+ assert backbone_features.size(2) == self.sam_image_embedding_size
318
+ assert backbone_features.size(3) == self.sam_image_embedding_size
319
+
320
+ '''
321
+ # a) Handle point prompts
322
+ if point_inputs is not None:
323
+ sam_point_coords = point_inputs["point_coords"]
324
+ sam_point_labels = point_inputs["point_labels"]
325
+ assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
326
+ raise NotImplementedError
327
+ else:
328
+ # If no points are provide, pad with an empty point (with label -1)
329
+ sam_point_coords = torch.zeros(B, 1, 2, device=device)
330
+ sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
331
+
332
+ # b) Handle mask prompts
333
+ if mask_inputs is not None:
334
+ # If mask_inputs is provided, downsize it into low-res mask input if needed
335
+ # and feed it as a dense mask prompt into the SAM mask encoder
336
+ assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
337
+ if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
338
+ sam_mask_prompt = F.interpolate(
339
+ mask_inputs.float(),
340
+ size=self.sam_prompt_encoder.mask_input_size,
341
+ align_corners=False,
342
+ mode="bilinear",
343
+ antialias=True, # use antialias for downsampling
344
+ )
345
+ else:
346
+ sam_mask_prompt = mask_inputs
347
+ raise NotImplementedError
348
+ else:
349
+ # Otherwise, simply feed None (and SAM's prompt encoder will add
350
+ # a learned `no_mask_embed` to indicate no mask input in this case).
351
+ sam_mask_prompt = None
352
+ sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
353
+ points=(sam_point_coords, sam_point_labels),
354
+ boxes=None,
355
+ masks=sam_mask_prompt,
356
+ )
357
+ '''
358
+ sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
359
+ points=None,
360
+ boxes=None,
361
+ masks=None,
362
+ )
363
+
364
+ (
365
+ low_res_multimasks,
366
+ ious,
367
+ sam_output_tokens,
368
+ object_score_logits,
369
+ ) = self.sam_mask_decoder(
370
+ image_embeddings=backbone_features,
371
+ image_pe=self.sam_prompt_encoder.get_dense_pe(),
372
+ sparse_prompt_embeddings=sparse_embeddings,
373
+ dense_prompt_embeddings=dense_embeddings,
374
+ multimask_output=multimask_output,
375
+ repeat_image=False, # the image is already batched
376
+ high_res_features=high_res_features,
377
+ audio_res_features=audio_res
378
+ )
379
+ '''
380
+ if self.pred_obj_scores:
381
+ is_obj_appearing = object_score_logits > 0
382
+
383
+ # Mask used for spatial memories is always a *hard* choice between obj and no obj,
384
+ # consistent with the actual mask prediction
385
+ low_res_multimasks = torch.where(
386
+ is_obj_appearing[:, None, None],
387
+ low_res_multimasks,
388
+ NO_OBJ_SCORE,
389
+ )
390
+ '''
391
+ # convert masks from possibly bfloat16 (or float16) to float32
392
+ # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
393
+ low_res_multimasks = low_res_multimasks.float()
394
+ high_res_multimasks = F.interpolate(
395
+ low_res_multimasks,
396
+ size=(self.image_size, self.image_size),
397
+ mode="bilinear",
398
+ align_corners=False,
399
+ )
400
+ sam_output_token = sam_output_tokens[:, 0]
401
+ if multimask_output:
402
+ # comment this line temporarily.
403
+ # take the best mask prediction (with the highest IoU estimation)
404
+ best_iou_inds = torch.argmax(ious, dim=-1)
405
+ batch_inds = torch.arange(B, device=device)
406
+ low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
407
+ high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
408
+ if sam_output_tokens.size(1) > 1:
409
+ sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
410
+ else:
411
+ low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
412
+
413
+ # Extract object pointer from the SAM output token (with occlusion handling)
414
+ obj_ptr = self.obj_ptr_proj(sam_output_token)
415
+
416
+ # don't train occlusion at the moment, command temporarily.
417
+ if self.pred_obj_scores:
418
+ is_obj_appearing = object_score_logits > 0
419
+ # Allow *soft* no obj ptr, unlike for masks
420
+ if self.soft_no_obj_ptr:
421
+ lambda_is_obj_appearing = object_score_logits.sigmoid()
422
+ else:
423
+ lambda_is_obj_appearing = is_obj_appearing.float()
424
+
425
+ if self.fixed_no_obj_ptr:
426
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
427
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
428
+ return (
429
+ low_res_multimasks,
430
+ high_res_multimasks,
431
+ ious,
432
+ low_res_masks,
433
+ high_res_masks,
434
+ obj_ptr,
435
+ object_score_logits,
436
+ )
437
+
438
+ def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
439
+ """
440
+ Directly turn binary `mask_inputs` into a output mask logits without using SAM.
441
+ (same input and output shapes as in _forward_sam_heads above).
442
+ """
443
+ # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
444
+ out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
445
+ mask_inputs_float = mask_inputs.float()
446
+ high_res_masks = mask_inputs_float * out_scale + out_bias
447
+ low_res_masks = F.interpolate(
448
+ high_res_masks,
449
+ size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
450
+ align_corners=False,
451
+ mode="bilinear",
452
+ antialias=True, # use antialias for downsampling
453
+ )
454
+ # a dummy IoU prediction of all 1's under mask input
455
+ ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
456
+ if not self.use_obj_ptrs_in_encoder:
457
+ # all zeros as a dummy object pointer (of shape [B, C])
458
+ obj_ptr = torch.zeros(
459
+ mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device
460
+ )
461
+ else:
462
+ # produce an object pointer using the SAM decoder from the mask input
463
+ _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
464
+ backbone_features=backbone_features,
465
+ mask_inputs=self.mask_downsample(mask_inputs_float),
466
+ high_res_features=high_res_features,
467
+ )
468
+ # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
469
+ # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
470
+ # on the object_scores from the SAM decoder.
471
+ is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
472
+ is_obj_appearing = is_obj_appearing[..., None]
473
+ lambda_is_obj_appearing = is_obj_appearing.float()
474
+ object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
475
+ if self.pred_obj_scores:
476
+ if self.fixed_no_obj_ptr:
477
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
478
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
479
+
480
+ return (
481
+ low_res_masks,
482
+ high_res_masks,
483
+ ious,
484
+ low_res_masks,
485
+ high_res_masks,
486
+ obj_ptr,
487
+ object_score_logits,
488
+ )
489
+
490
+ def precompute_high_res_features(self, backbone_out):
491
+ if self.use_high_res_features_in_sam:
492
+ # precompute projected level 0 and level 1 features in SAM decoder
493
+ # to avoid running it again on every SAM click
494
+ backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
495
+ backbone_out["backbone_fpn"][0]
496
+ )
497
+ backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
498
+ backbone_out["backbone_fpn"][1]
499
+ )
500
+ return backbone_out
501
+
502
+ def forward_image(self, img_batch: torch.Tensor, pre_compute=True):
503
+ """Get the image feature on the input batch."""
504
+ backbone_out = self.image_encoder(img_batch)
505
+ return backbone_out if not pre_compute else self.precompute_high_res_features(backbone_out)
506
+
507
+ def _prepare_backbone_features(self, backbone_out):
508
+ """Prepare and flatten visual features."""
509
+ backbone_out = backbone_out.copy()
510
+ assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
511
+ assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
512
+
513
+ feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
514
+ vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
515
+ feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
516
+ # flatten NxCxHxW to HWxNxC
517
+ vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
518
+ vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
519
+
520
+ return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
521
+
522
+ def _prepare_memory_conditioned_features(
523
+ self,
524
+ frame_idx,
525
+ is_init_cond_frame,
526
+ current_vision_feats,
527
+ current_vision_pos_embeds,
528
+ feat_sizes,
529
+ output_dict,
530
+ num_frames,
531
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
532
+ ):
533
+ """Fuse the current frame's visual feature map with previous memory."""
534
+ B = current_vision_feats[-1].size(1) # batch size on this frame
535
+ C = self.hidden_dim
536
+ H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
537
+ device = current_vision_feats[-1].device
538
+ # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
539
+ # In this case, we skip the fusion with any memory.
540
+ if self.num_maskmem == 0: # Disable memory and skip fusion
541
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
542
+ return pix_feat
543
+
544
+ num_obj_ptr_tokens = 0
545
+ tpos_sign_mul = -1 if track_in_reverse else 1
546
+ # Step 1: condition the visual features of the current frame on previous memories
547
+ if not is_init_cond_frame:
548
+ # Retrieve the memories encoded with the maskmem backbone
549
+ to_cat_memory, to_cat_memory_pos_embed = [], []
550
+ # Add conditioning frames's output first (all cond frames have t_pos=0 for
551
+ # when getting temporal positional embedding below)
552
+ assert len(output_dict["cond_frame_outputs"]) > 0
553
+ # Select a maximum number of temporally closest cond frames for cross attention
554
+ cond_outputs = output_dict["cond_frame_outputs"]
555
+ selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
556
+ frame_idx, cond_outputs, self.max_cond_frames_in_attn
557
+ )
558
+ t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
559
+ # for t_pos in range(1, min(self.num_maskmem, frame_idx)):
560
+ # out = output_dict["non_cond_frame_outputs"].get(t_pos, None)
561
+ # t_pos_and_prevs.append((t_pos, out))
562
+ # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
563
+ # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
564
+ # We also allow taking the memory frame non-consecutively (with stride>1), in which case
565
+ # we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame.
566
+ stride = 1 if self.training else self.memory_temporal_stride_for_eval
567
+
568
+ for t_pos in range(1, self.num_maskmem):
569
+ t_rel = self.num_maskmem - t_pos # how many frames before current frame
570
+ if t_rel == 1:
571
+ # for t_rel == 1, we take the last frame (regardless of r)
572
+ if not track_in_reverse:
573
+ # the frame immediately before this frame (i.e. frame_idx - 1)
574
+ prev_frame_idx = frame_idx - t_rel
575
+ else:
576
+ # the frame immediately after this frame (i.e. frame_idx + 1)
577
+ prev_frame_idx = frame_idx + t_rel
578
+ else:
579
+ # for t_rel >= 2, we take the memory frame from every r-th frames
580
+ if not track_in_reverse:
581
+ # first find the nearest frame among every r-th frames before this frame
582
+ # for r=1, this would be (frame_idx - 2)
583
+ prev_frame_idx = ((frame_idx - 2) // stride) * stride
584
+ # then seek further among every r-th frames
585
+ prev_frame_idx = prev_frame_idx - (t_rel - 2) * stride
586
+ else:
587
+ # first find the nearest frame among every r-th frames after this frame
588
+ # for r=1, this would be (frame_idx + 2)
589
+ prev_frame_idx = -(-(frame_idx + 2) // stride) * stride
590
+ # then seek further among every r-th frames
591
+ prev_frame_idx = prev_frame_idx + (t_rel - 2) * stride
592
+ out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
593
+ if out is None:
594
+ # If an unselected conditioning frame is among the last (self.num_maskmem - 1)
595
+ # frames, we still attend to it as if it's a non-conditioning frame.
596
+ out = unselected_cond_outputs.get(prev_frame_idx, None)
597
+ t_pos_and_prevs.append((t_pos, out))
598
+
599
+ for t_pos, prev in t_pos_and_prevs:
600
+ if prev is None:
601
+ continue # skip padding frames
602
+ # "maskmem_features" might have been offloaded to CPU in demo use cases,
603
+ # so we load it back to GPU (it's a no-op if it's already on GPU).
604
+ feats = prev["maskmem_features"].to(device, non_blocking=True)
605
+ to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
606
+ # Spatial positional encoding (it might have been offloaded to CPU in eval)
607
+ maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
608
+ maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
609
+ # Temporal positional encoding
610
+ maskmem_enc = (
611
+ maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
612
+ )
613
+ to_cat_memory_pos_embed.append(maskmem_enc)
614
+ # Construct the list of past object pointers
615
+ if self.use_obj_ptrs_in_encoder:
616
+ max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
617
+ # First add those object pointers from selected conditioning frames
618
+ # (optionally, only include object pointers in the past during evaluation)
619
+ if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
620
+ ptr_cond_outputs = {
621
+ t: out
622
+ for t, out in selected_cond_outputs.items()
623
+ if (t >= frame_idx if track_in_reverse else t <= frame_idx)
624
+ }
625
+ else:
626
+ ptr_cond_outputs = selected_cond_outputs
627
+ pos_and_ptrs = [
628
+ # Temporal pos encoding contains how far away each pointer is from current frame
629
+ (
630
+ (
631
+ (frame_idx - t) * tpos_sign_mul
632
+ if self.use_signed_tpos_enc_to_obj_ptrs
633
+ else abs(frame_idx - t)
634
+ ),
635
+ out["obj_ptr"],
636
+ )
637
+ for t, out in ptr_cond_outputs.items()
638
+ ]
639
+ # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
640
+ for t_diff in range(1, max_obj_ptrs_in_encoder):
641
+ t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
642
+ if t < 0 or (num_frames is not None and t >= num_frames):
643
+ break
644
+ out = output_dict["non_cond_frame_outputs"].get(
645
+ t, unselected_cond_outputs.get(t, None)
646
+ )
647
+ if out is not None:
648
+ pos_and_ptrs.append((t_diff, out["obj_ptr"]))
649
+ # If we have at least one object pointer, add them to the across attention
650
+ if len(pos_and_ptrs) > 0:
651
+ pos_list, ptrs_list = zip(*pos_and_ptrs)
652
+ # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
653
+ obj_ptrs = torch.stack(ptrs_list, dim=0)
654
+ # a temporal positional embedding based on how far each object pointer is from
655
+ # the current frame (sine embedding normalized by the max pointer num).
656
+ # default false.
657
+ if self.add_tpos_enc_to_obj_ptrs:
658
+ t_diff_max = max_obj_ptrs_in_encoder - 1
659
+ tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
660
+ obj_pos = torch.tensor(pos_list, device=device)
661
+ obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
662
+ obj_pos = self.obj_ptr_tpos_proj(obj_pos)
663
+ obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
664
+ else:
665
+ obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
666
+ if self.mem_dim < C:
667
+ # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
668
+ obj_ptrs = obj_ptrs.reshape(
669
+ -1, B, C // self.mem_dim, self.mem_dim
670
+ )
671
+ obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
672
+ obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
673
+ to_cat_memory.append(obj_ptrs)
674
+ to_cat_memory_pos_embed.append(obj_pos)
675
+ num_obj_ptr_tokens = obj_ptrs.shape[0]
676
+ else:
677
+ num_obj_ptr_tokens = 0
678
+ else:
679
+ # for initial conditioning frames, encode them without using any previous memory
680
+ if self.directly_add_no_mem_embed:
681
+ # directly add no-mem embedding (instead of using the transformer encoder)
682
+ pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
683
+ pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
684
+ return pix_feat_with_mem
685
+ # Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder)
686
+ # the Following lines will never be triggered.
687
+ raise NotImplementedError
688
+ to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
689
+ to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
690
+
691
+ # Step 2: Concatenate the memories and forward through the transformer encoder
692
+ memory = torch.cat(to_cat_memory, dim=0)
693
+ memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
694
+
695
+ pix_feat_with_mem = self.memory_attention(
696
+ curr=current_vision_feats,
697
+ curr_pos=current_vision_pos_embeds,
698
+ memory=memory,
699
+ memory_pos=memory_pos_embed,
700
+ num_obj_ptr_tokens=num_obj_ptr_tokens,
701
+ )
702
+ # reshape the output (HW)BC => BCHW
703
+ pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
704
+ return pix_feat_with_mem
705
+
706
+ def _encode_new_memory(
707
+ self,
708
+ current_vision_feats,
709
+ feat_sizes,
710
+ pred_masks_high_res,
711
+ object_score_logits,
712
+ is_mask_from_pts,
713
+ ):
714
+ """Encode the current image and its prediction into a memory feature."""
715
+ B = current_vision_feats[-1].size(1) # batch size on this frame
716
+ C = self.hidden_dim
717
+ H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
718
+ # top-level feature, (HW)BC => BCHW
719
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
720
+ if self.non_overlap_masks_for_mem_enc and not self.training:
721
+ # optionally, apply non-overlapping constraints to the masks (it's applied
722
+ # in the batch dimension and should only be used during eval, where all
723
+ # the objects come from the same video under batch size 1).
724
+ pred_masks_high_res = self._apply_non_overlapping_constraints(
725
+ pred_masks_high_res
726
+ )
727
+ raise NotImplementedError
728
+ # scale the raw mask logits with a temperature before applying sigmoid
729
+ binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
730
+ if binarize and not self.training:
731
+ mask_for_mem = (pred_masks_high_res > 0).float()
732
+ else:
733
+ # apply sigmoid on the raw mask logits to turn them into range (0, 1)
734
+ mask_for_mem = torch.sigmoid(pred_masks_high_res)
735
+ # apply scale and bias terms to the sigmoid probabilities
736
+ if self.sigmoid_scale_for_mem_enc != 1.0:
737
+ mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
738
+ if self.sigmoid_bias_for_mem_enc != 0.0:
739
+ mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
740
+ maskmem_out = self.memory_encoder(
741
+ pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied
742
+ )
743
+ maskmem_features = maskmem_out["vision_features"]
744
+ maskmem_pos_enc = maskmem_out["vision_pos_enc"]
745
+ # add a no-object embedding to the spatial memory to indicate that the frame
746
+ # is predicted to be occluded (i.e. no object is appearing in the frame)
747
+ if self.no_obj_embed_spatial is not None:
748
+ is_obj_appearing = (object_score_logits > 0).float()
749
+ maskmem_features += (
750
+ 1 - is_obj_appearing[..., None, None]
751
+ ) * self.no_obj_embed_spatial[..., None, None].expand(
752
+ *maskmem_features.shape
753
+ )
754
+ # it will be used in sam2.1
755
+ # raise NotImplementedError
756
+
757
+ return maskmem_features, maskmem_pos_enc
758
+
759
+ def _track_step(
760
+ self,
761
+ frame_idx,
762
+ is_init_cond_frame,
763
+ current_vision_feats,
764
+ current_vision_pos_embeds,
765
+ feat_sizes,
766
+ point_inputs,
767
+ mask_inputs,
768
+ output_dict,
769
+ num_frames,
770
+ track_in_reverse,
771
+ prev_sam_mask_logits,
772
+ ):
773
+ current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
774
+ # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
775
+ if len(current_vision_feats) > 1:
776
+ high_res_features = [
777
+ x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
778
+ for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
779
+ ]
780
+ else:
781
+ high_res_features = None
782
+ if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
783
+ # When use_mask_input_as_output_without_sam=True, we directly output the mask input
784
+ # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
785
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0)
786
+ pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
787
+ sam_outputs = self._use_mask_as_output(
788
+ pix_feat, high_res_features, mask_inputs
789
+ )
790
+ else:
791
+ # fused the visual feature with previous memory features in the memory bank
792
+ pix_feat = self._prepare_memory_conditioned_features(
793
+ frame_idx=frame_idx,
794
+ is_init_cond_frame=is_init_cond_frame,
795
+ current_vision_feats=current_vision_feats[-1:],
796
+ current_vision_pos_embeds=current_vision_pos_embeds[-1:],
797
+ feat_sizes=feat_sizes[-1:],
798
+ output_dict=output_dict,
799
+ num_frames=num_frames,
800
+ track_in_reverse=track_in_reverse,
801
+ )
802
+ # apply SAM-style segmentation head
803
+ # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
804
+ # e.g. in demo where such logits come from earlier interaction instead of correction sampling
805
+ # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
806
+ if prev_sam_mask_logits is not None:
807
+ assert point_inputs is not None and mask_inputs is None
808
+ mask_inputs = prev_sam_mask_logits
809
+ multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
810
+ sam_outputs = self._forward_sam_heads(
811
+ backbone_features=pix_feat,
812
+ point_inputs=point_inputs,
813
+ mask_inputs=mask_inputs,
814
+ high_res_features=high_res_features,
815
+ multimask_output=multimask_output,
816
+ )
817
+
818
+ return current_out, sam_outputs, high_res_features, pix_feat
819
+
820
+ def _encode_memory_in_output(
821
+ self,
822
+ current_vision_feats,
823
+ feat_sizes,
824
+ point_inputs,
825
+ run_mem_encoder,
826
+ high_res_masks,
827
+ object_score_logits,
828
+ current_out,
829
+ ):
830
+ if run_mem_encoder and self.num_maskmem > 0:
831
+ high_res_masks_for_mem_enc = high_res_masks
832
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
833
+ current_vision_feats=current_vision_feats,
834
+ feat_sizes=feat_sizes,
835
+ pred_masks_high_res=high_res_masks_for_mem_enc,
836
+ object_score_logits=object_score_logits,
837
+ is_mask_from_pts=(point_inputs is not None),
838
+ )
839
+ current_out["maskmem_features"] = maskmem_features
840
+ current_out["maskmem_pos_enc"] = maskmem_pos_enc
841
+ else:
842
+ current_out["maskmem_features"] = None
843
+ current_out["maskmem_pos_enc"] = None
844
+
845
+ def track_step(
846
+ self,
847
+ frame_idx,
848
+ is_init_cond_frame,
849
+ current_vision_feats,
850
+ current_vision_pos_embeds,
851
+ feat_sizes,
852
+ point_inputs,
853
+ mask_inputs,
854
+ output_dict,
855
+ num_frames,
856
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
857
+ # Whether to run the memory encoder on the predicted masks. Sometimes we might want
858
+ # to skip the memory encoder with `run_mem_encoder=False`. For example,
859
+ # in demo we might call `track_step` multiple times for each user click,
860
+ # and only encode the memory when the user finalizes their clicks. And in ablation
861
+ # settings like SAM training on static images, we don't need the memory encoder.
862
+ run_mem_encoder=True,
863
+ # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
864
+ prev_sam_mask_logits=None,
865
+ ):
866
+ current_out, sam_outputs, _, _ = self._track_step(
867
+ frame_idx,
868
+ is_init_cond_frame,
869
+ current_vision_feats,
870
+ current_vision_pos_embeds,
871
+ feat_sizes,
872
+ point_inputs,
873
+ mask_inputs,
874
+ output_dict,
875
+ num_frames,
876
+ track_in_reverse,
877
+ prev_sam_mask_logits,
878
+ )
879
+
880
+ (
881
+ _,
882
+ _,
883
+ _,
884
+ low_res_masks,
885
+ high_res_masks,
886
+ obj_ptr,
887
+ object_score_logits,
888
+ ) = sam_outputs
889
+
890
+ current_out["pred_masks"] = low_res_masks
891
+ current_out["pred_masks_high_res"] = high_res_masks
892
+ current_out["obj_ptr"] = obj_ptr
893
+ if not self.training:
894
+ # Only add this in inference (to avoid unused param in activation checkpointing;
895
+ # it's mainly used in the demo to encode spatial memories w/ consolidated masks)
896
+ current_out["object_score_logits"] = object_score_logits
897
+
898
+ # Finally run the memory encoder on the predicted mask to encode
899
+ # it into a new memory feature (that can be used in future frames)
900
+ self._encode_memory_in_output(
901
+ current_vision_feats,
902
+ feat_sizes,
903
+ point_inputs,
904
+ run_mem_encoder,
905
+ high_res_masks,
906
+ object_score_logits,
907
+ current_out,
908
+ )
909
+
910
+ return current_out
911
+
912
+ def _use_multimask(self, is_init_cond_frame, point_inputs):
913
+ """Whether to use multimask output in the SAM head."""
914
+ num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
915
+ multimask_output = (
916
+ self.multimask_output_in_sam
917
+ and (is_init_cond_frame or self.multimask_output_for_tracking)
918
+ and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
919
+ )
920
+ return multimask_output
921
+
922
+ def _apply_non_overlapping_constraints(self, pred_masks):
923
+ """
924
+ Apply non-overlapping constraints to the object scores in pred_masks. Here we
925
+ keep only the highest scoring object at each spatial location in pred_masks.
926
+ """
927
+ batch_size = pred_masks.size(0)
928
+ if batch_size == 1:
929
+ return pred_masks
930
+
931
+ device = pred_masks.device
932
+ # "max_obj_inds": object index of the object with the highest score at each location
933
+ max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
934
+ # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
935
+ batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
936
+ keep = max_obj_inds == batch_obj_inds
937
+ # suppress overlapping regions' scores below -10.0 so that the foreground regions
938
+ # don't overlap (here sigmoid(-10.0)=4.5398e-05)
939
+ pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
940
+ return pred_masks
avs.code/v1m.code/model/visual/sam2/modeling/sam2_utils.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import copy
9
+ from typing import Tuple
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ from model.visual.sam2.utils.misc import mask_to_box
17
+
18
+
19
+ def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
20
+ """
21
+ Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs`
22
+ that are temporally closest to the current frame at `frame_idx`. Here, we take
23
+ - a) the closest conditioning frame before `frame_idx` (if any);
24
+ - b) the closest conditioning frame after `frame_idx` (if any);
25
+ - c) any other temporally closest conditioning frames until reaching a total
26
+ of `max_cond_frame_num` conditioning frames.
27
+
28
+ Outputs:
29
+ - selected_outputs: selected items (keys & values) from `cond_frame_outputs`.
30
+ - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`.
31
+ """
32
+ if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
33
+ selected_outputs = cond_frame_outputs
34
+ unselected_outputs = {}
35
+ else:
36
+ assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
37
+ selected_outputs = {}
38
+
39
+ # the closest conditioning frame before `frame_idx` (if any)
40
+ idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
41
+ if idx_before is not None:
42
+ selected_outputs[idx_before] = cond_frame_outputs[idx_before]
43
+
44
+ # the closest conditioning frame after `frame_idx` (if any)
45
+ idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
46
+ if idx_after is not None:
47
+ selected_outputs[idx_after] = cond_frame_outputs[idx_after]
48
+
49
+ # add other temporally closest conditioning frames until reaching a total
50
+ # of `max_cond_frame_num` conditioning frames.
51
+ num_remain = max_cond_frame_num - len(selected_outputs)
52
+ inds_remain = sorted(
53
+ (t for t in cond_frame_outputs if t not in selected_outputs),
54
+ key=lambda x: abs(x - frame_idx),
55
+ )[:num_remain]
56
+ selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
57
+ unselected_outputs = {
58
+ t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs
59
+ }
60
+
61
+ return selected_outputs, unselected_outputs
62
+
63
+
64
+ def get_1d_sine_pe(pos_inds, dim, temperature=10000):
65
+ """
66
+ Get 1D sine positional embedding as in the original Transformer paper.
67
+ """
68
+ pe_dim = dim // 2
69
+ dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
70
+ dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
71
+
72
+ pos_embed = pos_inds.unsqueeze(-1) / dim_t
73
+ pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
74
+ return pos_embed
75
+
76
+
77
+ def get_activation_fn(activation):
78
+ """Return an activation function given a string"""
79
+ if activation == "relu":
80
+ return F.relu
81
+ if activation == "gelu":
82
+ return F.gelu
83
+ if activation == "glu":
84
+ return F.glu
85
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
86
+
87
+
88
+ def get_clones(module, N):
89
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
90
+
91
+
92
+ class DropPath(nn.Module):
93
+ # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
94
+ def __init__(self, drop_prob=0.0, scale_by_keep=True):
95
+ super(DropPath, self).__init__()
96
+ self.drop_prob = drop_prob
97
+ self.scale_by_keep = scale_by_keep
98
+
99
+ def forward(self, x):
100
+ if self.drop_prob == 0.0 or not self.training:
101
+ return x
102
+ keep_prob = 1 - self.drop_prob
103
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
104
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
105
+ if keep_prob > 0.0 and self.scale_by_keep:
106
+ random_tensor.div_(keep_prob)
107
+ return x * random_tensor
108
+
109
+
110
+ # Lightly adapted from
111
+ # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
112
+ class MLP(nn.Module):
113
+ def __init__(
114
+ self,
115
+ input_dim: int,
116
+ hidden_dim: int,
117
+ output_dim: int,
118
+ num_layers: int,
119
+ activation: nn.Module = nn.ReLU,
120
+ sigmoid_output: bool = False,
121
+ ) -> None:
122
+ super().__init__()
123
+ self.num_layers = num_layers
124
+ h = [hidden_dim] * (num_layers - 1)
125
+ self.layers = nn.ModuleList(
126
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
127
+ )
128
+ self.sigmoid_output = sigmoid_output
129
+ self.act = activation()
130
+
131
+ def forward(self, x):
132
+ for i, layer in enumerate(self.layers):
133
+ x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
134
+ if self.sigmoid_output:
135
+ x = F.sigmoid(x)
136
+ return x
137
+
138
+
139
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
140
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
141
+ class LayerNorm2d(nn.Module):
142
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
143
+ super().__init__()
144
+ self.weight = nn.Parameter(torch.ones(num_channels))
145
+ self.bias = nn.Parameter(torch.zeros(num_channels))
146
+ self.eps = eps
147
+
148
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
149
+ u = x.mean(1, keepdim=True)
150
+ s = (x - u).pow(2).mean(1, keepdim=True)
151
+ x = (x - u) / torch.sqrt(s + self.eps)
152
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
153
+ return x
154
+
155
+
156
+ def sample_box_points(
157
+ masks: torch.Tensor,
158
+ noise: float = 0.1, # SAM default
159
+ noise_bound: int = 20, # SAM default
160
+ top_left_label: int = 2,
161
+ bottom_right_label: int = 3,
162
+ ) -> Tuple[np.array, np.array]:
163
+ """
164
+ Sample a noised version of the top left and bottom right corners of a given `bbox`
165
+
166
+ Inputs:
167
+ - masks: [B, 1, H,W] boxes, dtype=torch.Tensor
168
+ - noise: noise as a fraction of box width and height, dtype=float
169
+ - noise_bound: maximum amount of noise (in pure pixesl), dtype=int
170
+
171
+ Returns:
172
+ - box_coords: [B, num_pt, 2], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.float
173
+ - box_labels: [B, num_pt], label 2 is reserverd for top left and 3 for bottom right corners, dtype=torch.int32
174
+ """
175
+ device = masks.device
176
+ box_coords = mask_to_box(masks)
177
+ B, _, H, W = masks.shape
178
+ box_labels = torch.tensor(
179
+ [top_left_label, bottom_right_label], dtype=torch.int, device=device
180
+ ).repeat(B)
181
+ if noise > 0.0:
182
+ if not isinstance(noise_bound, torch.Tensor):
183
+ noise_bound = torch.tensor(noise_bound, device=device)
184
+ bbox_w = box_coords[..., 2] - box_coords[..., 0]
185
+ bbox_h = box_coords[..., 3] - box_coords[..., 1]
186
+ max_dx = torch.min(bbox_w * noise, noise_bound)
187
+ max_dy = torch.min(bbox_h * noise, noise_bound)
188
+ box_noise = 2 * torch.rand(B, 1, 4, device=device) - 1
189
+ box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1)
190
+
191
+ box_coords = box_coords + box_noise
192
+ img_bounds = (
193
+ torch.tensor([W, H, W, H], device=device) - 1
194
+ ) # uncentered pixel coords
195
+ box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping
196
+
197
+ box_coords = box_coords.reshape(-1, 2, 2) # always 2 points
198
+ box_labels = box_labels.reshape(-1, 2)
199
+ return box_coords, box_labels
200
+
201
+
202
+ def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1):
203
+ """
204
+ Sample `num_pt` random points (along with their labels) independently from the error regions.
205
+
206
+ Inputs:
207
+ - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
208
+ - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
209
+ - num_pt: int, number of points to sample independently for each of the B error maps
210
+
211
+ Outputs:
212
+ - points: [B, num_pt, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
213
+ - labels: [B, num_pt], dtype=torch.int32, where 1 means positive clicks and 0 means
214
+ negative clicks
215
+ """
216
+ if pred_masks is None: # if pred_masks is not provided, treat it as empty
217
+ pred_masks = torch.zeros_like(gt_masks)
218
+ assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
219
+ assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
220
+ assert num_pt >= 0
221
+
222
+ B, _, H_im, W_im = gt_masks.shape
223
+ device = gt_masks.device
224
+
225
+ # false positive region, a new point sampled in this region should have
226
+ # negative label to correct the FP error
227
+ fp_masks = ~gt_masks & pred_masks
228
+ # false negative region, a new point sampled in this region should have
229
+ # positive label to correct the FN error
230
+ fn_masks = gt_masks & ~pred_masks
231
+ # whether the prediction completely match the ground-truth on each mask
232
+ all_correct = torch.all((gt_masks == pred_masks).flatten(2), dim=2)
233
+ all_correct = all_correct[..., None, None]
234
+
235
+ # channel 0 is FP map, while channel 1 is FN map
236
+ pts_noise = torch.rand(B, num_pt, H_im, W_im, 2, device=device)
237
+ # sample a negative new click from FP region or a positive new click
238
+ # from FN region, depend on where the maximum falls,
239
+ # and in case the predictions are all correct (no FP or FN), we just
240
+ # sample a negative click from the background region
241
+ pts_noise[..., 0] *= fp_masks | (all_correct & ~gt_masks)
242
+ pts_noise[..., 1] *= fn_masks
243
+ pts_idx = pts_noise.flatten(2).argmax(dim=2)
244
+ labels = (pts_idx % 2).to(torch.int32)
245
+ pts_idx = pts_idx // 2
246
+ pts_x = pts_idx % W_im
247
+ pts_y = pts_idx // W_im
248
+ points = torch.stack([pts_x, pts_y], dim=2).to(torch.float)
249
+ return points, labels
250
+
251
+
252
+ def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True):
253
+ """
254
+ Sample 1 random point (along with its label) from the center of each error region,
255
+ that is, the point with the largest distance to the boundary of each error region.
256
+ This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py
257
+
258
+ Inputs:
259
+ - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
260
+ - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
261
+ - padding: if True, pad with boundary of 1 px for distance transform
262
+
263
+ Outputs:
264
+ - points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
265
+ - labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks
266
+ """
267
+ import cv2
268
+
269
+ if pred_masks is None:
270
+ pred_masks = torch.zeros_like(gt_masks)
271
+ assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
272
+ assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
273
+
274
+ B, _, _, W_im = gt_masks.shape
275
+ device = gt_masks.device
276
+
277
+ # false positive region, a new point sampled in this region should have
278
+ # negative label to correct the FP error
279
+ fp_masks = ~gt_masks & pred_masks
280
+ # false negative region, a new point sampled in this region should have
281
+ # positive label to correct the FN error
282
+ fn_masks = gt_masks & ~pred_masks
283
+
284
+ fp_masks = fp_masks.cpu().numpy()
285
+ fn_masks = fn_masks.cpu().numpy()
286
+ points = torch.zeros(B, 1, 2, dtype=torch.float)
287
+ labels = torch.ones(B, 1, dtype=torch.int32)
288
+ for b in range(B):
289
+ fn_mask = fn_masks[b, 0]
290
+ fp_mask = fp_masks[b, 0]
291
+ if padding:
292
+ fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant")
293
+ fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant")
294
+ # compute the distance of each point in FN/FP region to its boundary
295
+ fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
296
+ fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
297
+ if padding:
298
+ fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
299
+ fp_mask_dt = fp_mask_dt[1:-1, 1:-1]
300
+
301
+ # take the point in FN/FP region with the largest distance to its boundary
302
+ fn_mask_dt_flat = fn_mask_dt.reshape(-1)
303
+ fp_mask_dt_flat = fp_mask_dt.reshape(-1)
304
+ fn_argmax = np.argmax(fn_mask_dt_flat)
305
+ fp_argmax = np.argmax(fp_mask_dt_flat)
306
+ is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax]
307
+ pt_idx = fn_argmax if is_positive else fp_argmax
308
+ points[b, 0, 0] = pt_idx % W_im # x
309
+ points[b, 0, 1] = pt_idx // W_im # y
310
+ labels[b, 0] = int(is_positive)
311
+
312
+ points = points.to(device)
313
+ labels = labels.to(device)
314
+ return points, labels
315
+
316
+
317
+ def get_next_point(gt_masks, pred_masks, method):
318
+ if method == "uniform":
319
+ return sample_random_points_from_errors(gt_masks, pred_masks)
320
+ elif method == "center":
321
+ return sample_one_point_from_error_center(gt_masks, pred_masks)
322
+ else:
323
+ raise ValueError(f"unknown sampling method {method}")
avs.code/v1m.code/model/visual/sam2/organised_sam2_train.py ADDED
@@ -0,0 +1,811 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.distributed
12
+ from model.visual.sam2.modeling.sam2_base import SAM2Base
13
+ from model.visual.sam2.modeling.sam2_utils import (
14
+ get_1d_sine_pe,
15
+ get_next_point,
16
+ sample_box_points,
17
+ select_closest_cond_frames,
18
+ )
19
+
20
+ from utils.misc import concat_points
21
+
22
+ from utils.data_utils import BatchedVideoDatapoint
23
+
24
+
25
+ class SAM2Train(SAM2Base):
26
+ def __init__(
27
+ self,
28
+ image_encoder,
29
+ memory_attention=None,
30
+ memory_encoder=None,
31
+ prob_to_use_pt_input_for_train=0.0,
32
+ prob_to_use_pt_input_for_eval=0.0,
33
+ prob_to_use_box_input_for_train=0.0,
34
+ prob_to_use_box_input_for_eval=0.0,
35
+ # if it is greater than 1, we interactive point sampling in the 1st frame and other randomly selected frames
36
+ num_frames_to_correct_for_train=1, # default: only iteratively sample on first frame
37
+ num_frames_to_correct_for_eval=1, # default: only iteratively sample on first frame
38
+ rand_frames_to_correct_for_train=False,
39
+ rand_frames_to_correct_for_eval=False,
40
+ # how many frames to use as initial conditioning frames (for both point input and mask input; the first frame is always used as an initial conditioning frame)
41
+ # - if `rand_init_cond_frames` below is True, we randomly sample 1~num_init_cond_frames initial conditioning frames
42
+ # - otherwise we sample a fixed number of num_init_cond_frames initial conditioning frames
43
+ # note: for point input, we sample correction points on all such initial conditioning frames, and we require that `num_frames_to_correct` >= `num_init_cond_frames`;
44
+ # these are initial conditioning frames because as we track the video, more conditioning frames might be added
45
+ # when a frame receives correction clicks under point input if `add_all_frames_to_correct_as_cond=True`
46
+ num_init_cond_frames_for_train=1, # default: only use the first frame as initial conditioning frame
47
+ num_init_cond_frames_for_eval=1, # default: only use the first frame as initial conditioning frame
48
+ rand_init_cond_frames_for_train=True, # default: random 1~num_init_cond_frames_for_train cond frames (to be constent w/ previous TA data loader)
49
+ rand_init_cond_frames_for_eval=False,
50
+ # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
51
+ # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
52
+ add_all_frames_to_correct_as_cond=False,
53
+ # how many additional correction points to sample (on each frame selected to be corrected)
54
+ # note that the first frame receives an initial input click (in addition to any correction clicks)
55
+ num_correction_pt_per_frame=7,
56
+ # method for point sampling during evaluation
57
+ # "uniform" (sample uniformly from error region) or "center" (use the point with the largest distance to error region boundary)
58
+ # default to "center" to be consistent with evaluation in the SAM paper
59
+ pt_sampling_for_eval="center",
60
+ # During training, we optionally allow sampling the correction points from GT regions
61
+ # instead of the prediction error regions with a small probability. This might allow the
62
+ # model to overfit less to the error regions in training datasets
63
+ prob_to_sample_from_gt_for_train=0.0,
64
+ use_act_ckpt_iterative_pt_sampling=False,
65
+ # whether to forward image features per frame (as it's being tracked) during evaluation, instead of forwarding image features
66
+ # of all frames at once. This avoids backbone OOM errors on very long videos in evaluation, but could be slightly slower.
67
+ forward_backbone_per_frame_for_eval=False,
68
+ freeze_image_encoder=False,
69
+ **kwargs,
70
+ ):
71
+ super().__init__(image_encoder, memory_attention, memory_encoder, **kwargs)
72
+ self.use_act_ckpt_iterative_pt_sampling = use_act_ckpt_iterative_pt_sampling
73
+ self.forward_backbone_per_frame_for_eval = forward_backbone_per_frame_for_eval
74
+
75
+ # Point sampler and conditioning frames
76
+ self.prob_to_use_pt_input_for_train = prob_to_use_pt_input_for_train
77
+ self.prob_to_use_box_input_for_train = prob_to_use_box_input_for_train
78
+ self.prob_to_use_pt_input_for_eval = prob_to_use_pt_input_for_eval
79
+ self.prob_to_use_box_input_for_eval = prob_to_use_box_input_for_eval
80
+ if prob_to_use_pt_input_for_train > 0 or prob_to_use_pt_input_for_eval > 0:
81
+ logging.info(
82
+ f"Training with points (sampled from masks) as inputs with p={prob_to_use_pt_input_for_train}"
83
+ )
84
+ assert num_frames_to_correct_for_train >= num_init_cond_frames_for_train
85
+ assert num_frames_to_correct_for_eval >= num_init_cond_frames_for_eval
86
+
87
+ self.num_frames_to_correct_for_train = num_frames_to_correct_for_train
88
+ self.num_frames_to_correct_for_eval = num_frames_to_correct_for_eval
89
+ self.rand_frames_to_correct_for_train = rand_frames_to_correct_for_train
90
+ self.rand_frames_to_correct_for_eval = rand_frames_to_correct_for_eval
91
+ # Initial multi-conditioning frames
92
+ self.num_init_cond_frames_for_train = num_init_cond_frames_for_train
93
+ self.num_init_cond_frames_for_eval = num_init_cond_frames_for_eval
94
+ self.rand_init_cond_frames_for_train = rand_init_cond_frames_for_train
95
+ self.rand_init_cond_frames_for_eval = rand_init_cond_frames_for_eval
96
+ self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
97
+ self.num_correction_pt_per_frame = num_correction_pt_per_frame
98
+ self.pt_sampling_for_eval = pt_sampling_for_eval
99
+ self.prob_to_sample_from_gt_for_train = prob_to_sample_from_gt_for_train
100
+ # A random number generator with a fixed initial seed across GPUs
101
+ self.rng = np.random.default_rng(seed=42)
102
+ if freeze_image_encoder:
103
+ for p in self.image_encoder.parameters():
104
+ p.requires_grad = False
105
+
106
+
107
+ def forward(self, input: BatchedVideoDatapoint):
108
+ if self.training or not self.forward_backbone_per_frame_for_eval:
109
+ # precompute image features on all frames before tracking
110
+ backbone_out = self.forward_image(input.flat_img_batch)
111
+ else:
112
+ # defer image feature computation on a frame until it's being tracked
113
+ backbone_out = {"backbone_fpn": None, "vision_pos_enc": None}
114
+ backbone_out = self.prepare_prompt_inputs(backbone_out, input)
115
+ previous_stages_out = self.forward_tracking(backbone_out, input)
116
+
117
+ return previous_stages_out
118
+
119
+ def _prepare_backbone_features_per_frame(self, img_batch, img_ids):
120
+ """Compute the image backbone features on the fly for the given img_ids."""
121
+ # Only forward backbone on unique image ids to avoid repetitive computation
122
+ # (if `img_ids` has only one element, it's already unique so we skip this step).
123
+ if img_ids.numel() > 1:
124
+ unique_img_ids, inv_ids = torch.unique(img_ids, return_inverse=True)
125
+ else:
126
+ unique_img_ids, inv_ids = img_ids, None
127
+
128
+ # Compute the image features on those unique image ids
129
+ image = img_batch[unique_img_ids]
130
+ backbone_out = self.forward_image(image)
131
+ (
132
+ _,
133
+ vision_feats,
134
+ vision_pos_embeds,
135
+ feat_sizes,
136
+ ) = self._prepare_backbone_features(backbone_out)
137
+ '''
138
+ vision_feats
139
+ torch.Size([65536, 5, 32])
140
+ torch.Size([16384, 5, 64])
141
+ torch.Size([4096, 5, 256])
142
+ '''
143
+ # Inverse-map image features for `unique_img_ids` to the final image features
144
+ # for the original input `img_ids`.
145
+ if inv_ids is not None:
146
+ image = image[inv_ids]
147
+ vision_feats = [x[:, inv_ids] for x in vision_feats]
148
+ vision_pos_embeds = [x[:, inv_ids] for x in vision_pos_embeds]
149
+
150
+ return image, vision_feats, vision_pos_embeds, feat_sizes
151
+
152
+ @staticmethod
153
+ def dont_prepare_prompt_inputs(backbone_out, num_frames=5, cond_frame=0):
154
+ backbone_out["gt_masks_per_frame"] = {}
155
+ backbone_out["num_frames"] = num_frames
156
+ backbone_out["use_pt_input"] = False
157
+ # always start from the first frame.
158
+ backbone_out["init_cond_frames"] = [cond_frame]
159
+ backbone_out["frames_not_in_init_cond"] = [i for i in range(0, num_frames) if i != cond_frame]
160
+ # backbone_out["init_cond_frames"] = []
161
+ # backbone_out["frames_not_in_init_cond"] = [i for i in range(0, num_frames)]
162
+
163
+ backbone_out["mask_inputs_per_frame"] = {}
164
+ backbone_out["point_inputs_per_frame"] = {}
165
+ backbone_out["frames_to_add_correction_pt"] = []
166
+ return backbone_out
167
+
168
+ def prepare_prompt_inputs(self, backbone_out, input, start_frame_idx=0):
169
+ """
170
+ Prepare input mask, point or box prompts. Optionally, we allow tracking from
171
+ a custom `start_frame_idx` to the end of the video (for evaluation purposes).
172
+ """
173
+ # Load the ground-truth masks on all frames (so that we can later
174
+ # sample correction points from them)
175
+ # gt_masks_per_frame = {
176
+ # stage_id: targets.segments.unsqueeze(1) # [B, 1, H_im, W_im]
177
+ # for stage_id, targets in enumerate(input.find_targets)
178
+ # }
179
+ gt_masks_per_frame = {
180
+ stage_id: masks.unsqueeze(1) # [B, 1, H_im, W_im]
181
+ for stage_id, masks in enumerate(input.masks)
182
+ }
183
+ # gt_masks_per_frame = input.masks.unsqueeze(2) # [T,B,1,H_im,W_im] keep everything in tensor form
184
+ backbone_out["gt_masks_per_frame"] = gt_masks_per_frame
185
+ num_frames = input.num_frames
186
+ backbone_out["num_frames"] = num_frames
187
+
188
+ # Randomly decide whether to use point inputs or mask inputs
189
+ if self.training:
190
+ prob_to_use_pt_input = self.prob_to_use_pt_input_for_train
191
+ prob_to_use_box_input = self.prob_to_use_box_input_for_train
192
+ num_frames_to_correct = self.num_frames_to_correct_for_train
193
+ rand_frames_to_correct = self.rand_frames_to_correct_for_train
194
+ num_init_cond_frames = self.num_init_cond_frames_for_train
195
+ rand_init_cond_frames = self.rand_init_cond_frames_for_train
196
+ else:
197
+ prob_to_use_pt_input = self.prob_to_use_pt_input_for_eval
198
+ prob_to_use_box_input = self.prob_to_use_box_input_for_eval
199
+ num_frames_to_correct = self.num_frames_to_correct_for_eval
200
+ rand_frames_to_correct = self.rand_frames_to_correct_for_eval
201
+ num_init_cond_frames = self.num_init_cond_frames_for_eval
202
+ rand_init_cond_frames = self.rand_init_cond_frames_for_eval
203
+ if num_frames == 1:
204
+ # here we handle a special case for mixing video + SAM on image training,
205
+ # where we force using point input for the SAM task on static images
206
+ prob_to_use_pt_input = 1.0
207
+ num_frames_to_correct = 1
208
+ num_init_cond_frames = 1
209
+ assert num_init_cond_frames >= 1
210
+ # (here `self.rng.random()` returns value in range 0.0 <= X < 1.0)
211
+ use_pt_input = self.rng.random() < prob_to_use_pt_input
212
+ if rand_init_cond_frames and num_init_cond_frames > 1:
213
+ # randomly select 1 to `num_init_cond_frames` frames as initial conditioning frames
214
+ num_init_cond_frames = self.rng.integers(
215
+ 1, num_init_cond_frames, endpoint=True
216
+ )
217
+ if (
218
+ use_pt_input
219
+ and rand_frames_to_correct
220
+ and num_frames_to_correct > num_init_cond_frames
221
+ ):
222
+ # randomly select `num_init_cond_frames` to `num_frames_to_correct` frames to sample
223
+ # correction clicks (only for the case of point input)
224
+ num_frames_to_correct = self.rng.integers(
225
+ num_init_cond_frames, num_frames_to_correct, endpoint=True
226
+ )
227
+ backbone_out["use_pt_input"] = use_pt_input
228
+
229
+ # Sample initial conditioning frames
230
+ if num_init_cond_frames == 1:
231
+ init_cond_frames = [start_frame_idx] # starting frame
232
+ else:
233
+ # starting frame + randomly selected remaining frames (without replacement)
234
+ init_cond_frames = [start_frame_idx] + self.rng.choice(
235
+ range(start_frame_idx + 1, num_frames),
236
+ num_init_cond_frames - 1,
237
+ replace=False,
238
+ ).tolist()
239
+ backbone_out["init_cond_frames"] = init_cond_frames
240
+ backbone_out["frames_not_in_init_cond"] = [
241
+ t for t in range(start_frame_idx, num_frames) if t not in init_cond_frames
242
+ ]
243
+ # Prepare mask or point inputs on initial conditioning frames
244
+ backbone_out["mask_inputs_per_frame"] = {} # {frame_idx: <input_masks>}
245
+ backbone_out["point_inputs_per_frame"] = {} # {frame_idx: <input_points>}
246
+ for t in init_cond_frames:
247
+ if not use_pt_input:
248
+ backbone_out["mask_inputs_per_frame"][t] = gt_masks_per_frame[t]
249
+ else:
250
+ # During training # P(box) = prob_to_use_pt_input * prob_to_use_box_input
251
+ use_box_input = self.rng.random() < prob_to_use_box_input
252
+ if use_box_input:
253
+ points, labels = sample_box_points(
254
+ gt_masks_per_frame[t],
255
+ )
256
+ else:
257
+ # (here we only sample **one initial point** on initial conditioning frames from the
258
+ # ground-truth mask; we may sample more correction points on the fly)
259
+ points, labels = get_next_point(
260
+ gt_masks=gt_masks_per_frame[t],
261
+ pred_masks=None,
262
+ method=(
263
+ "uniform" if self.training else self.pt_sampling_for_eval
264
+ ),
265
+ )
266
+
267
+ point_inputs = {"point_coords": points, "point_labels": labels}
268
+ backbone_out["point_inputs_per_frame"][t] = point_inputs
269
+
270
+ # Sample frames where we will add correction clicks on the fly
271
+ # based on the error between prediction and ground-truth masks
272
+ if not use_pt_input:
273
+ # no correction points will be sampled when using mask inputs
274
+ frames_to_add_correction_pt = []
275
+ elif num_frames_to_correct == num_init_cond_frames:
276
+ frames_to_add_correction_pt = init_cond_frames
277
+ else:
278
+ assert num_frames_to_correct > num_init_cond_frames
279
+ # initial cond frame + randomly selected remaining frames (without replacement)
280
+ extra_num = num_frames_to_correct - num_init_cond_frames
281
+ frames_to_add_correction_pt = (
282
+ init_cond_frames
283
+ + self.rng.choice(
284
+ backbone_out["frames_not_in_init_cond"], extra_num, replace=False
285
+ ).tolist()
286
+ )
287
+ backbone_out["frames_to_add_correction_pt"] = frames_to_add_correction_pt
288
+
289
+ return backbone_out
290
+
291
+ def forward_tracking_wo_prompt(self, backbone_out, audio_res=None, return_dict=False):
292
+ # img_feats_already_computed = True.
293
+ """Forward video tracking on each frame (and sample correction clicks)."""
294
+ # Prepare the backbone features
295
+ # - vision_feats and vision_pos_embeds are in (HW)BC format
296
+ (
297
+ _,
298
+ vision_feats,
299
+ vision_pos_embeds,
300
+ feat_sizes,
301
+ ) = self._prepare_backbone_features(backbone_out)
302
+
303
+ # Starting the stage loop
304
+ num_frames = backbone_out["num_frames"]
305
+ init_cond_frames = backbone_out["init_cond_frames"]
306
+ frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"]
307
+ # first process all the initial conditioning frames to encode them as memory,
308
+ # and then conditioning on them to track the remaining frames
309
+ processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"]
310
+ output_dict = {
311
+ "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
312
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
313
+ }
314
+
315
+ av_v_feats, av_a_feats = audio_res
316
+ for stage_id in processing_order:
317
+ # Get the image features for the current frames
318
+ img_ids = stage_id
319
+ # Retrieve image features according to img_ids (if they are already computed).
320
+ current_vision_feats = [x[:, img_ids].unsqueeze(1) for x in vision_feats] # add unsqueeze to maintain single sample.
321
+ current_vision_pos_embeds = [x[:, img_ids].unsqueeze(1) for x in vision_pos_embeds] # add unsqueeze to maintain single sample.
322
+ current_av_v_feats = [x[img_ids] for x in av_v_feats]
323
+ current_av_a_feats = [x[img_ids] for x in av_a_feats]
324
+
325
+ # Get output masks based on this frame's prompts and previous memory
326
+ current_out = self.track_step_wo_prompt(
327
+ frame_idx=stage_id,
328
+ is_init_cond_frame=stage_id in init_cond_frames,
329
+ current_vision_feats=current_vision_feats,
330
+ current_vision_pos_embeds=current_vision_pos_embeds,
331
+ feat_sizes=feat_sizes,
332
+ point_inputs=None, # backbone_out["point_inputs_per_frame"].get(stage_id, None),
333
+ mask_inputs=None, # backbone_out["mask_inputs_per_frame"].get(stage_id, None),
334
+ gt_masks=None, # backbone_out["gt_masks_per_frame"].get(stage_id, None),
335
+ frames_to_add_correction_pt=None, # frames_to_add_correction_pt,
336
+ output_dict=output_dict,
337
+ num_frames=num_frames,
338
+ audio_res=(current_av_v_feats, current_av_a_feats),
339
+ )
340
+ # Append the output, depending on whether it's a conditioning frame
341
+ add_output_as_cond_frame = stage_id in init_cond_frames or (
342
+ self.add_all_frames_to_correct_as_cond
343
+ and stage_id in frames_to_add_correction_pt
344
+ )
345
+ if add_output_as_cond_frame:
346
+ output_dict["cond_frame_outputs"][stage_id] = current_out
347
+ else:
348
+ output_dict["non_cond_frame_outputs"][stage_id] = current_out
349
+
350
+ if return_dict:
351
+ return output_dict
352
+ # turn `output_dict` into a list for loss function
353
+ all_frame_outputs = {}
354
+ all_frame_outputs.update(output_dict["cond_frame_outputs"])
355
+ all_frame_outputs.update(output_dict["non_cond_frame_outputs"])
356
+ all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)]
357
+ # Make DDP happy with activation checkpointing by removing unused keys
358
+ all_frame_outputs = [
359
+ {k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs
360
+ ]
361
+
362
+
363
+ return all_frame_outputs
364
+
365
+ def track_step_wo_prompt(
366
+ self,
367
+ frame_idx,
368
+ is_init_cond_frame,
369
+ current_vision_feats,
370
+ current_vision_pos_embeds,
371
+ feat_sizes,
372
+ point_inputs,
373
+ mask_inputs,
374
+ output_dict,
375
+ num_frames,
376
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
377
+ run_mem_encoder=True, # Whether to run the memory encoder on the predicted masks.
378
+ prev_sam_mask_logits=None, # The previously predicted SAM mask logits.
379
+ frames_to_add_correction_pt=None,
380
+ gt_masks=None,
381
+ audio_res=None,
382
+ ):
383
+ if frames_to_add_correction_pt is None:
384
+ frames_to_add_correction_pt = []
385
+
386
+ current_out, sam_outputs, high_res_features, pix_feat = self._track_step_wo_prompt(
387
+ frame_idx,
388
+ is_init_cond_frame,
389
+ current_vision_feats,
390
+ current_vision_pos_embeds,
391
+ feat_sizes,
392
+ point_inputs,
393
+ mask_inputs,
394
+ output_dict,
395
+ num_frames,
396
+ track_in_reverse,
397
+ prev_sam_mask_logits,
398
+ audio_res
399
+ )
400
+
401
+ (
402
+ low_res_multimasks,
403
+ high_res_multimasks,
404
+ ious,
405
+ low_res_masks,
406
+ high_res_masks,
407
+ obj_ptr,
408
+ object_score_logits,
409
+ ) = sam_outputs
410
+ current_out["multistep_pred_masks"] = low_res_masks
411
+ current_out["multistep_pred_masks_high_res"] = high_res_masks
412
+ current_out["multistep_pred_multimasks"] = [low_res_multimasks]
413
+ current_out["multistep_pred_multimasks_high_res"] = [high_res_multimasks]
414
+ current_out["multistep_pred_ious"] = [ious]
415
+ current_out["multistep_point_inputs"] = [point_inputs]
416
+ current_out["multistep_object_score_logits"] = [object_score_logits]
417
+
418
+ '''
419
+ # Optionally, sample correction points iteratively to correct the mask
420
+ if frame_idx in frames_to_add_correction_pt:
421
+ point_inputs, final_sam_outputs = self._iter_correct_pt_sampling(
422
+ is_init_cond_frame,
423
+ point_inputs,
424
+ gt_masks,
425
+ high_res_features,
426
+ pix_feat,
427
+ low_res_multimasks,
428
+ high_res_multimasks,
429
+ ious,
430
+ low_res_masks,
431
+ high_res_masks,
432
+ object_score_logits,
433
+ current_out,
434
+ )
435
+ (
436
+ _,
437
+ _,
438
+ _,
439
+ low_res_masks,
440
+ high_res_masks,
441
+ obj_ptr,
442
+ object_score_logits,
443
+ ) = final_sam_outputs
444
+ '''
445
+ # Use the final prediction (after all correction steps for output and eval)
446
+ current_out["pred_masks"] = low_res_masks
447
+ current_out["pred_masks_high_res"] = high_res_masks
448
+ current_out["obj_ptr"] = obj_ptr
449
+
450
+ # Finally run the memory encoder on the predicted mask to encode
451
+ # it into a new memory feature (that can be used in future frames)
452
+
453
+ self._encode_memory_in_output(
454
+ current_vision_feats,
455
+ feat_sizes,
456
+ 666., # point_inputs,
457
+ run_mem_encoder,
458
+ # we follow SAM2 predictor, if we have multiple masks output, we only utilise the first one to perform
459
+ # the memory rope attention.
460
+ high_res_masks,
461
+ object_score_logits,
462
+ current_out,
463
+ )
464
+ return current_out
465
+
466
+ def _track_step_wo_prompt(
467
+ self,
468
+ frame_idx,
469
+ is_init_cond_frame,
470
+ current_vision_feats,
471
+ current_vision_pos_embeds,
472
+ feat_sizes,
473
+ point_inputs,
474
+ mask_inputs,
475
+ output_dict,
476
+ num_frames,
477
+ track_in_reverse,
478
+ prev_sam_mask_logits,
479
+ audio_res=None
480
+ ):
481
+ current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
482
+ # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
483
+ if len(current_vision_feats) > 1:
484
+ high_res_features = [
485
+ x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
486
+ for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
487
+ ]
488
+ else:
489
+ high_res_features = None
490
+ if mask_inputs is not None and self.use_mask_input_as_output_without_sam: # False
491
+ # When use_mask_input_as_output_without_sam=True, we directly output the mask input
492
+ # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
493
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0)
494
+ pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
495
+ sam_outputs = self._use_mask_as_output(
496
+ pix_feat, high_res_features, mask_inputs
497
+ )
498
+ else:
499
+ # fused the visual feature with previous memory features in the memory bank
500
+ pix_feat = self._prepare_memory_conditioned_features(
501
+ frame_idx=frame_idx,
502
+ is_init_cond_frame=is_init_cond_frame,
503
+ current_vision_feats=current_vision_feats[-1:],
504
+ current_vision_pos_embeds=current_vision_pos_embeds[-1:],
505
+ feat_sizes=feat_sizes[-1:],
506
+ output_dict=output_dict,
507
+ num_frames=num_frames,
508
+ track_in_reverse=track_in_reverse,
509
+ )
510
+ # current_vision_feats[-1] = current_vision_feats[-1] + self.no_mem_embed
511
+ # pix_feat = current_vision_feats[-1].permute(1, 2, 0)
512
+ # pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
513
+
514
+ # we do not apply any prompts except audio.
515
+ '''
516
+ # apply SAM-style segmentation head
517
+ # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
518
+ # e.g. in demo where such logits come from earlier interaction instead of correction sampling
519
+ # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
520
+ # if prev_sam_mask_logits is not None:
521
+ # assert point_inputs is not None and mask_inputs is None
522
+ # mask_inputs = prev_sam_mask_logits
523
+
524
+ ## comment this line, as we don't use points as prompts.
525
+ # multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
526
+ '''
527
+
528
+ sam_outputs = self._forward_sam_heads(
529
+ backbone_features=pix_feat,
530
+ point_inputs=point_inputs,
531
+ mask_inputs=mask_inputs,
532
+ high_res_features=high_res_features,
533
+ multimask_output=True,
534
+ audio_res=audio_res
535
+ )
536
+
537
+ return current_out, sam_outputs, high_res_features, pix_feat
538
+
539
+ def forward_tracking(
540
+ self, backbone_out, input: BatchedVideoDatapoint, return_dict=False
541
+ ):
542
+ """Forward video tracking on each frame (and sample correction clicks)."""
543
+ img_feats_already_computed = backbone_out["backbone_fpn"] is not None
544
+ if img_feats_already_computed:
545
+ # Prepare the backbone features
546
+ # - vision_feats and vision_pos_embeds are in (HW)BC format
547
+ (
548
+ _,
549
+ vision_feats,
550
+ vision_pos_embeds,
551
+ feat_sizes,
552
+ ) = self._prepare_backbone_features(backbone_out)
553
+
554
+ # Starting the stage loop
555
+ num_frames = backbone_out["num_frames"]
556
+ init_cond_frames = backbone_out["init_cond_frames"]
557
+ frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"]
558
+ # first process all the initial conditioning frames to encode them as memory,
559
+ # and then conditioning on them to track the remaining frames
560
+ processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"]
561
+ output_dict = {
562
+ "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
563
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
564
+ }
565
+ for stage_id in processing_order:
566
+ # Get the image features for the current frames
567
+ # img_ids = input.find_inputs[stage_id].img_ids
568
+ img_ids = input.flat_obj_to_img_idx[stage_id]
569
+ if img_feats_already_computed:
570
+ # Retrieve image features according to img_ids (if they are already computed).
571
+ current_vision_feats = [x[:, img_ids] for x in vision_feats]
572
+ current_vision_pos_embeds = [x[:, img_ids] for x in vision_pos_embeds]
573
+ else:
574
+ # Otherwise, compute the image features on the fly for the given img_ids
575
+ # (this might be used for evaluation on long videos to avoid backbone OOM).
576
+ (
577
+ _,
578
+ current_vision_feats,
579
+ current_vision_pos_embeds,
580
+ feat_sizes,
581
+ ) = self._prepare_backbone_features_per_frame(
582
+ input.flat_img_batch, img_ids
583
+ )
584
+
585
+ # Get output masks based on this frame's prompts and previous memory
586
+ current_out = self.track_step(
587
+ frame_idx=stage_id,
588
+ is_init_cond_frame=stage_id in init_cond_frames,
589
+ current_vision_feats=current_vision_feats,
590
+ current_vision_pos_embeds=current_vision_pos_embeds,
591
+ feat_sizes=feat_sizes,
592
+ point_inputs=backbone_out["point_inputs_per_frame"].get(stage_id, None),
593
+ mask_inputs=backbone_out["mask_inputs_per_frame"].get(stage_id, None),
594
+ gt_masks=backbone_out["gt_masks_per_frame"].get(stage_id, None),
595
+ frames_to_add_correction_pt=frames_to_add_correction_pt,
596
+ output_dict=output_dict,
597
+ num_frames=num_frames,
598
+ )
599
+ # Append the output, depending on whether it's a conditioning frame
600
+ add_output_as_cond_frame = stage_id in init_cond_frames or (
601
+ self.add_all_frames_to_correct_as_cond
602
+ and stage_id in frames_to_add_correction_pt
603
+ )
604
+ if add_output_as_cond_frame:
605
+ output_dict["cond_frame_outputs"][stage_id] = current_out
606
+ else:
607
+ output_dict["non_cond_frame_outputs"][stage_id] = current_out
608
+
609
+ if return_dict:
610
+ return output_dict
611
+ # turn `output_dict` into a list for loss function
612
+ all_frame_outputs = {}
613
+ all_frame_outputs.update(output_dict["cond_frame_outputs"])
614
+ all_frame_outputs.update(output_dict["non_cond_frame_outputs"])
615
+ all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)]
616
+ # Make DDP happy with activation checkpointing by removing unused keys
617
+ all_frame_outputs = [
618
+ {k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs
619
+ ]
620
+
621
+ return all_frame_outputs
622
+
623
+ def track_step(
624
+ self,
625
+ frame_idx,
626
+ is_init_cond_frame,
627
+ current_vision_feats,
628
+ current_vision_pos_embeds,
629
+ feat_sizes,
630
+ point_inputs,
631
+ mask_inputs,
632
+ output_dict,
633
+ num_frames,
634
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
635
+ run_mem_encoder=True, # Whether to run the memory encoder on the predicted masks.
636
+ prev_sam_mask_logits=None, # The previously predicted SAM mask logits.
637
+ frames_to_add_correction_pt=None,
638
+ gt_masks=None,
639
+ ):
640
+ if frames_to_add_correction_pt is None:
641
+ frames_to_add_correction_pt = []
642
+ current_out, sam_outputs, high_res_features, pix_feat = self._track_step(
643
+ frame_idx,
644
+ is_init_cond_frame,
645
+ current_vision_feats,
646
+ current_vision_pos_embeds,
647
+ feat_sizes,
648
+ point_inputs,
649
+ mask_inputs,
650
+ output_dict,
651
+ num_frames,
652
+ track_in_reverse,
653
+ prev_sam_mask_logits,
654
+ )
655
+
656
+ (
657
+ low_res_multimasks,
658
+ high_res_multimasks,
659
+ ious,
660
+ low_res_masks,
661
+ high_res_masks,
662
+ obj_ptr,
663
+ object_score_logits,
664
+ ) = sam_outputs
665
+
666
+ current_out["multistep_pred_masks"] = low_res_masks
667
+ current_out["multistep_pred_masks_high_res"] = high_res_masks
668
+ current_out["multistep_pred_multimasks"] = [low_res_multimasks]
669
+ current_out["multistep_pred_multimasks_high_res"] = [high_res_multimasks]
670
+ current_out["multistep_pred_ious"] = [ious]
671
+ current_out["multistep_point_inputs"] = [point_inputs]
672
+ current_out["multistep_object_score_logits"] = [object_score_logits]
673
+
674
+ # Optionally, sample correction points iteratively to correct the mask
675
+ if frame_idx in frames_to_add_correction_pt:
676
+ point_inputs, final_sam_outputs = self._iter_correct_pt_sampling(
677
+ is_init_cond_frame,
678
+ point_inputs,
679
+ gt_masks,
680
+ high_res_features,
681
+ pix_feat,
682
+ low_res_multimasks,
683
+ high_res_multimasks,
684
+ ious,
685
+ low_res_masks,
686
+ high_res_masks,
687
+ object_score_logits,
688
+ current_out,
689
+ )
690
+ (
691
+ _,
692
+ _,
693
+ _,
694
+ low_res_masks,
695
+ high_res_masks,
696
+ obj_ptr,
697
+ object_score_logits,
698
+ ) = final_sam_outputs
699
+
700
+ # Use the final prediction (after all correction steps for output and eval)
701
+ current_out["pred_masks"] = low_res_masks
702
+ current_out["pred_masks_high_res"] = high_res_masks
703
+ current_out["obj_ptr"] = obj_ptr
704
+
705
+ # Finally run the memory encoder on the predicted mask to encode
706
+ # it into a new memory feature (that can be used in future frames)
707
+ self._encode_memory_in_output(
708
+ current_vision_feats,
709
+ feat_sizes,
710
+ point_inputs,
711
+ run_mem_encoder,
712
+ high_res_masks,
713
+ object_score_logits,
714
+ current_out,
715
+ )
716
+ return current_out
717
+
718
+ def _iter_correct_pt_sampling(
719
+ self,
720
+ is_init_cond_frame,
721
+ point_inputs,
722
+ gt_masks,
723
+ high_res_features,
724
+ pix_feat_with_mem,
725
+ low_res_multimasks,
726
+ high_res_multimasks,
727
+ ious,
728
+ low_res_masks,
729
+ high_res_masks,
730
+ object_score_logits,
731
+ current_out,
732
+ ):
733
+
734
+ assert gt_masks is not None
735
+ all_pred_masks = [low_res_masks]
736
+ all_pred_high_res_masks = [high_res_masks]
737
+ all_pred_multimasks = [low_res_multimasks]
738
+ all_pred_high_res_multimasks = [high_res_multimasks]
739
+ all_pred_ious = [ious]
740
+ all_point_inputs = [point_inputs]
741
+ all_object_score_logits = [object_score_logits]
742
+ for _ in range(self.num_correction_pt_per_frame):
743
+ # sample a new point from the error between prediction and ground-truth
744
+ # (with a small probability, directly sample from GT masks instead of errors)
745
+ if self.training and self.prob_to_sample_from_gt_for_train > 0:
746
+ sample_from_gt = (
747
+ self.rng.random() < self.prob_to_sample_from_gt_for_train
748
+ )
749
+ else:
750
+ sample_from_gt = False
751
+ # if `pred_for_new_pt` is None, only GT masks will be used for point sampling
752
+ pred_for_new_pt = None if sample_from_gt else (high_res_masks > 0)
753
+ new_points, new_labels = get_next_point(
754
+ gt_masks=gt_masks,
755
+ pred_masks=pred_for_new_pt,
756
+ method="uniform" if self.training else self.pt_sampling_for_eval,
757
+ )
758
+ point_inputs = concat_points(point_inputs, new_points, new_labels)
759
+ # Feed the mask logits of the previous SAM outputs in the next SAM decoder step.
760
+ # For tracking, this means that when the user adds a correction click, we also feed
761
+ # the tracking output mask logits along with the click as input to the SAM decoder.
762
+ mask_inputs = low_res_masks
763
+ multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
764
+ if self.use_act_ckpt_iterative_pt_sampling and not multimask_output:
765
+ sam_outputs = torch.utils.checkpoint.checkpoint(
766
+ self._forward_sam_heads,
767
+ backbone_features=pix_feat_with_mem,
768
+ point_inputs=point_inputs,
769
+ mask_inputs=mask_inputs,
770
+ high_res_features=high_res_features,
771
+ multimask_output=multimask_output,
772
+ use_reentrant=False,
773
+ )
774
+ else:
775
+ sam_outputs = self._forward_sam_heads(
776
+ backbone_features=pix_feat_with_mem,
777
+ point_inputs=point_inputs,
778
+ mask_inputs=mask_inputs,
779
+ high_res_features=high_res_features,
780
+ multimask_output=multimask_output,
781
+ )
782
+ (
783
+ low_res_multimasks,
784
+ high_res_multimasks,
785
+ ious,
786
+ low_res_masks,
787
+ high_res_masks,
788
+ _,
789
+ object_score_logits,
790
+ ) = sam_outputs
791
+ all_pred_masks.append(low_res_masks)
792
+ all_pred_high_res_masks.append(high_res_masks)
793
+ all_pred_multimasks.append(low_res_multimasks)
794
+ all_pred_high_res_multimasks.append(high_res_multimasks)
795
+ all_pred_ious.append(ious)
796
+ all_point_inputs.append(point_inputs)
797
+ all_object_score_logits.append(object_score_logits)
798
+
799
+ # Concatenate the masks along channel (to compute losses on all of them,
800
+ # using `MultiStepIteractiveMasks`)
801
+ current_out["multistep_pred_masks"] = torch.cat(all_pred_masks, dim=1)
802
+ current_out["multistep_pred_masks_high_res"] = torch.cat(
803
+ all_pred_high_res_masks, dim=1
804
+ )
805
+ current_out["multistep_pred_multimasks"] = all_pred_multimasks
806
+ current_out["multistep_pred_multimasks_high_res"] = all_pred_high_res_multimasks
807
+ current_out["multistep_pred_ious"] = all_pred_ious
808
+ current_out["multistep_point_inputs"] = all_point_inputs
809
+ current_out["multistep_object_score_logits"] = all_object_score_logits
810
+
811
+ return point_inputs, sam_outputs
avs.code/v1m.code/model/visual/sam2/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.