Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +4 -0
- LICENSE +21 -0
- README.md +26 -3
- avs.code/v1m.code/configs/__init__.py +0 -0
- avs.code/v1m.code/configs/auralfuser/architecture.yaml +30 -0
- avs.code/v1m.code/configs/config.py +85 -0
- avs.code/v1m.code/configs/sam2/sam2_hiera_b+.yaml +114 -0
- avs.code/v1m.code/configs/sam2/sam2_hiera_l.yaml +117 -0
- avs.code/v1m.code/configs/sam2/sam2_hiera_s.yaml +116 -0
- avs.code/v1m.code/configs/sam2/sam2_hiera_t.yaml +118 -0
- avs.code/v1m.code/configs/training/sam2_training_config.yaml +62 -0
- avs.code/v1m.code/dataloader/audio/audio_augmentation.py +23 -0
- avs.code/v1m.code/dataloader/audio/audio_dataset.py +38 -0
- avs.code/v1m.code/dataloader/audio/preprocess_vgg/mel_features.py +223 -0
- avs.code/v1m.code/dataloader/audio/preprocess_vgg/vggish_input.py +98 -0
- avs.code/v1m.code/dataloader/audio/preprocess_vgg/vggish_params.py +53 -0
- avs.code/v1m.code/dataloader/dataset.py +67 -0
- avs.code/v1m.code/dataloader/sam2_dataset/__init__.py +5 -0
- avs.code/v1m.code/dataloader/sam2_dataset/transforms.py +528 -0
- avs.code/v1m.code/dataloader/visual/visual_augmentation.py +140 -0
- avs.code/v1m.code/dataloader/visual/visual_dataset.py +127 -0
- avs.code/v1m.code/inference.py +193 -0
- avs.code/v1m.code/loss/training/__init__.py +2 -0
- avs.code/v1m.code/loss/training/contrastive_learning.py +201 -0
- avs.code/v1m.code/loss/training/sam2_training_loss.py +220 -0
- avs.code/v1m.code/main.py +166 -0
- avs.code/v1m.code/model/audio/torchvggish/mel_features.py +223 -0
- avs.code/v1m.code/model/audio/torchvggish/vggish.py +193 -0
- avs.code/v1m.code/model/audio/torchvggish/vggish_input.py +98 -0
- avs.code/v1m.code/model/audio/torchvggish/vggish_params.py +53 -0
- avs.code/v1m.code/model/aural_fuser.py +567 -0
- avs.code/v1m.code/model/mymodel.py +102 -0
- avs.code/v1m.code/model/visual/sam2/__init__.py +11 -0
- avs.code/v1m.code/model/visual/sam2/build_sam.py +171 -0
- avs.code/v1m.code/model/visual/sam2/modeling/__init__.py +5 -0
- avs.code/v1m.code/model/visual/sam2/modeling/backbones/__init__.py +5 -0
- avs.code/v1m.code/model/visual/sam2/modeling/backbones/hieradet.py +317 -0
- avs.code/v1m.code/model/visual/sam2/modeling/backbones/image_encoder.py +134 -0
- avs.code/v1m.code/model/visual/sam2/modeling/backbones/utils.py +95 -0
- avs.code/v1m.code/model/visual/sam2/modeling/memory_attention.py +169 -0
- avs.code/v1m.code/model/visual/sam2/modeling/memory_encoder.py +181 -0
- avs.code/v1m.code/model/visual/sam2/modeling/position_encoding.py +221 -0
- avs.code/v1m.code/model/visual/sam2/modeling/sam/__init__.py +5 -0
- avs.code/v1m.code/model/visual/sam2/modeling/sam/mask_decoder.py +300 -0
- avs.code/v1m.code/model/visual/sam2/modeling/sam/prompt_encoder.py +188 -0
- avs.code/v1m.code/model/visual/sam2/modeling/sam/transformer.py +367 -0
- avs.code/v1m.code/model/visual/sam2/modeling/sam2_base.py +940 -0
- avs.code/v1m.code/model/visual/sam2/modeling/sam2_utils.py +323 -0
- avs.code/v1m.code/model/visual/sam2/organised_sam2_train.py +811 -0
- avs.code/v1m.code/model/visual/sam2/utils/__init__.py +5 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
ckpts/avs/v1s/nohup.out filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
ckpts/avs/v2/nohup.out filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
ckpts/ref-avs/nohup.out filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
docs/overview.png filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Yuyuan Liu
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,3 +1,26 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AuralSAM2
|
| 2 |
+
> **[CVPRF'26]** [AuralSAM2: Enabling SAM2 Hear
|
| 3 |
+
Through Pyramid Audio-Visual Feature Prompting](#)
|
| 4 |
+
>
|
| 5 |
+
> by Yuyuan Liu, Yuanhong Chen, Chong Wang, Junlin Han, Junde Wu, Can Peng, Jingkun Chen, Yu Tian and Gustavo Carneiro
|
| 6 |
+
>
|
| 7 |
+
<img src="./docs/overview.png" width="850" height="300" />
|
| 8 |
+
|
| 9 |
+
## Installation
|
| 10 |
+
please install the dependencies and dataset based on this [***installation***](./docs/installation.md) document.
|
| 11 |
+
|
| 12 |
+
## Getting start
|
| 13 |
+
please follow this [***instruction***](./docs/before_start.md) document to reproduce our results.
|
| 14 |
+
|
| 15 |
+
## Citation
|
| 16 |
+
please consider citing our work in your publications if it helps your research.
|
| 17 |
+
|
| 18 |
+
```bibtex
|
| 19 |
+
@article{liu2025auralsam2,
|
| 20 |
+
title={AuralSAM2: Enabling SAM2 Hear Through Pyramid Audio-Visual Feature Prompting},
|
| 21 |
+
author={Liu, Yuyuan and Chen, Yuanhong and Wang, Chong and Han, Junlin and Wu, Junde and Peng, Can and Chen, Jingkun and Tian, Yu and Carneiro, Gustavo},
|
| 22 |
+
journal={arXiv preprint arXiv:2506.01015},
|
| 23 |
+
year={2025}
|
| 24 |
+
}
|
| 25 |
+
```
|
| 26 |
+
|
avs.code/v1m.code/configs/__init__.py
ADDED
|
File without changes
|
avs.code/v1m.code/configs/auralfuser/architecture.yaml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
aural_fuser:
|
| 4 |
+
patch_cfgs:
|
| 5 |
+
- [4, 4]
|
| 6 |
+
- [2, 2]
|
| 7 |
+
- [1, 1]
|
| 8 |
+
f_depths: [3, 6, 12]
|
| 9 |
+
block_kw:
|
| 10 |
+
dim: 256
|
| 11 |
+
num_heads: 4
|
| 12 |
+
mlp_ratio: 4
|
| 13 |
+
qkv_bias: true
|
| 14 |
+
qk_scale: null
|
| 15 |
+
drop: 0.1
|
| 16 |
+
attn_drop: 0.1
|
| 17 |
+
drop_path: 0.0
|
| 18 |
+
sr_ratio: 4
|
| 19 |
+
linear: false
|
| 20 |
+
one_d_kw:
|
| 21 |
+
dim: 256
|
| 22 |
+
num_heads: 4
|
| 23 |
+
mlp_ratio: 4
|
| 24 |
+
qkv_bias: true
|
| 25 |
+
qk_scale: null
|
| 26 |
+
drop: 0.1
|
| 27 |
+
attn_drop: 0.1
|
| 28 |
+
drop_path: 0.0
|
| 29 |
+
sr_ratio: 4
|
| 30 |
+
linear: false
|
avs.code/v1m.code/configs/config.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy
|
| 3 |
+
from easydict import EasyDict
|
| 4 |
+
|
| 5 |
+
# v1m.code package root (parent of this `configs/` directory)
|
| 6 |
+
_CODE_ROOT = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
| 7 |
+
_WORKSPACE_ROOT = os.path.dirname(os.path.dirname(_CODE_ROOT))
|
| 8 |
+
|
| 9 |
+
C = EasyDict()
|
| 10 |
+
config = C
|
| 11 |
+
cfg = C
|
| 12 |
+
|
| 13 |
+
C.seed = 666
|
| 14 |
+
|
| 15 |
+
C.audio = EasyDict()
|
| 16 |
+
C.audio.FREEZE_AUDIO_EXTRACTOR = True
|
| 17 |
+
C.audio.PRETRAINED_VGGISH_MODEL_PATH = os.path.join(_WORKSPACE_ROOT, 'ckpts', 'vggish-10086976.pth')
|
| 18 |
+
C.audio.PREPROCESS_AUDIO_TO_LOG_MEL = False
|
| 19 |
+
C.audio.POSTPROCESS_LOG_MEL_WITH_PCA = False
|
| 20 |
+
C.train_vggish = False
|
| 21 |
+
|
| 22 |
+
"""Root Directory Config"""
|
| 23 |
+
C.repo_name = 'AV'
|
| 24 |
+
C.root_dir = _CODE_ROOT
|
| 25 |
+
|
| 26 |
+
"""Data Dir and Weight Dir"""
|
| 27 |
+
C.data_root_path = os.path.join(_WORKSPACE_ROOT, 'AVSBench')
|
| 28 |
+
C.data_name = 'v1m'
|
| 29 |
+
|
| 30 |
+
C.backbone_weight = os.path.join(_WORKSPACE_ROOT, 'ckpts', 'sam_ckpts', 'sam2_hiera_large.pt')
|
| 31 |
+
C.sam_config_path = os.path.join('sam2', 'sam2_hiera_l.yaml')
|
| 32 |
+
|
| 33 |
+
"""Network Config"""
|
| 34 |
+
C.fix_bias = True
|
| 35 |
+
C.bn_eps = 1e-5
|
| 36 |
+
C.bn_momentum = 0.1
|
| 37 |
+
|
| 38 |
+
"""Image Config"""
|
| 39 |
+
C.num_classes = 2
|
| 40 |
+
|
| 41 |
+
C.image_mean = numpy.array([0.485, 0.456, 0.406])
|
| 42 |
+
C.image_std = numpy.array([0.229, 0.224, 0.225])
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
C.image_size = 1024
|
| 46 |
+
C.image_embedding_size = int(C.image_size / 16)
|
| 47 |
+
C.avsbench_size = (224, 224)
|
| 48 |
+
|
| 49 |
+
C.scale_list = [.5, .75, 1., 1.25, 1.5]
|
| 50 |
+
C.ignore_index = 255
|
| 51 |
+
|
| 52 |
+
"""Train Config"""
|
| 53 |
+
C.lr = 7.5e-5
|
| 54 |
+
C.batch_size = 8
|
| 55 |
+
C.energy_weight = .05
|
| 56 |
+
|
| 57 |
+
C.lr_power = 0.9
|
| 58 |
+
C.momentum = 0.9
|
| 59 |
+
C.weight_decay = 0.05
|
| 60 |
+
|
| 61 |
+
C.num_workers = 4
|
| 62 |
+
|
| 63 |
+
"""Display Config"""
|
| 64 |
+
C.record_info_iter = 20
|
| 65 |
+
C.display_iter = 50
|
| 66 |
+
|
| 67 |
+
"""Wandb Config"""
|
| 68 |
+
# Paste your W&B API key here, or set the WANDB_API_KEY environment variable instead.
|
| 69 |
+
C.wandb_key = ""
|
| 70 |
+
|
| 71 |
+
# Your project [work_space] name
|
| 72 |
+
C.proj_name = "AVS-final-report"
|
| 73 |
+
|
| 74 |
+
C.experiment_name = "v1s-hiera-l"
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# False = no wandb logging (see utils/tensorboard.py)
|
| 78 |
+
C.wandb_online = False
|
| 79 |
+
|
| 80 |
+
"""Save Config"""
|
| 81 |
+
C.saved_dir = os.path.join(_WORKSPACE_ROOT, 'ckpts', C.experiment_name)
|
| 82 |
+
|
| 83 |
+
import pathlib
|
| 84 |
+
|
| 85 |
+
pathlib.Path(C.saved_dir).mkdir(parents=True, exist_ok=True)
|
avs.code/v1m.code/configs/sam2/sam2_hiera_b+.yaml
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: model.visual.sam2.organised_sam2_train.SAM2Train
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: model.visual.sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: model.visual.sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 112
|
| 12 |
+
num_heads: 2
|
| 13 |
+
neck:
|
| 14 |
+
_target_: model.visual.sam2.modeling.backbones.image_encoder.FpnNeck
|
| 15 |
+
position_encoding:
|
| 16 |
+
_target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 17 |
+
num_pos_feats: 256
|
| 18 |
+
normalize: true
|
| 19 |
+
scale: null
|
| 20 |
+
temperature: 10000
|
| 21 |
+
d_model: 256
|
| 22 |
+
backbone_channel_list: [896, 448, 224, 112]
|
| 23 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 24 |
+
fpn_interp_model: nearest
|
| 25 |
+
|
| 26 |
+
memory_attention:
|
| 27 |
+
_target_: model.visual.sam2.modeling.memory_attention.MemoryAttention
|
| 28 |
+
d_model: 256
|
| 29 |
+
pos_enc_at_input: true
|
| 30 |
+
layer:
|
| 31 |
+
_target_: model.visual.sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 32 |
+
activation: relu
|
| 33 |
+
dim_feedforward: 2048
|
| 34 |
+
dropout: 0.1
|
| 35 |
+
pos_enc_at_attn: false
|
| 36 |
+
self_attention:
|
| 37 |
+
_target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention
|
| 38 |
+
rope_theta: 10000.0
|
| 39 |
+
feat_sizes: [32, 32]
|
| 40 |
+
embedding_dim: 256
|
| 41 |
+
num_heads: 1
|
| 42 |
+
downsample_rate: 1
|
| 43 |
+
dropout: 0.1
|
| 44 |
+
d_model: 256
|
| 45 |
+
pos_enc_at_cross_attn_keys: true
|
| 46 |
+
pos_enc_at_cross_attn_queries: false
|
| 47 |
+
cross_attention:
|
| 48 |
+
_target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention
|
| 49 |
+
rope_theta: 10000.0
|
| 50 |
+
feat_sizes: [32, 32]
|
| 51 |
+
rope_k_repeat: True
|
| 52 |
+
embedding_dim: 256
|
| 53 |
+
num_heads: 1
|
| 54 |
+
downsample_rate: 1
|
| 55 |
+
dropout: 0.1
|
| 56 |
+
kv_in_dim: 64
|
| 57 |
+
num_layers: 4
|
| 58 |
+
|
| 59 |
+
memory_encoder:
|
| 60 |
+
_target_: model.visual.sam2.modeling.memory_encoder.MemoryEncoder
|
| 61 |
+
out_dim: 64
|
| 62 |
+
position_encoding:
|
| 63 |
+
_target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 64 |
+
num_pos_feats: 64
|
| 65 |
+
normalize: true
|
| 66 |
+
scale: null
|
| 67 |
+
temperature: 10000
|
| 68 |
+
mask_downsampler:
|
| 69 |
+
_target_: model.visual.sam2.modeling.memory_encoder.MaskDownSampler
|
| 70 |
+
kernel_size: 3
|
| 71 |
+
stride: 2
|
| 72 |
+
padding: 1
|
| 73 |
+
fuser:
|
| 74 |
+
_target_: model.visual.sam2.modeling.memory_encoder.Fuser
|
| 75 |
+
layer:
|
| 76 |
+
_target_: model.visual.sam2.modeling.memory_encoder.CXBlock
|
| 77 |
+
dim: 256
|
| 78 |
+
kernel_size: 7
|
| 79 |
+
padding: 3
|
| 80 |
+
layer_scale_init_value: 1e-6
|
| 81 |
+
use_dwconv: True # depth-wise convs
|
| 82 |
+
num_layers: 2
|
| 83 |
+
|
| 84 |
+
num_maskmem: 7
|
| 85 |
+
image_size: 1024
|
| 86 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 87 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 88 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 89 |
+
use_mask_input_as_output_without_sam: true
|
| 90 |
+
# Memory
|
| 91 |
+
directly_add_no_mem_embed: true
|
| 92 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 93 |
+
use_high_res_features_in_sam: true
|
| 94 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 95 |
+
multimask_output_in_sam: true
|
| 96 |
+
# SAM heads
|
| 97 |
+
iou_prediction_use_sigmoid: True
|
| 98 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 99 |
+
use_obj_ptrs_in_encoder: true
|
| 100 |
+
add_tpos_enc_to_obj_ptrs: false
|
| 101 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 102 |
+
# object occlusion prediction
|
| 103 |
+
pred_obj_scores: true
|
| 104 |
+
pred_obj_scores_mlp: true
|
| 105 |
+
fixed_no_obj_ptr: true
|
| 106 |
+
# multimask tracking settings
|
| 107 |
+
multimask_output_for_tracking: true
|
| 108 |
+
use_multimask_token_for_obj_ptr: true
|
| 109 |
+
multimask_min_pt_num: 0
|
| 110 |
+
multimask_max_pt_num: 1
|
| 111 |
+
use_mlp_for_obj_ptr_proj: true
|
| 112 |
+
# Compilation flag
|
| 113 |
+
compile_image_encoder: False
|
| 114 |
+
|
avs.code/v1m.code/configs/sam2/sam2_hiera_l.yaml
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: model.visual.sam2.organised_sam2_train.SAM2Train
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: model.visual.sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: model.visual.sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 144
|
| 12 |
+
num_heads: 2
|
| 13 |
+
stages: [2, 6, 36, 4]
|
| 14 |
+
global_att_blocks: [23, 33, 43]
|
| 15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
+
window_spec: [8, 4, 16, 8]
|
| 17 |
+
neck:
|
| 18 |
+
_target_: model.visual.sam2.modeling.backbones.image_encoder.FpnNeck
|
| 19 |
+
position_encoding:
|
| 20 |
+
_target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 21 |
+
num_pos_feats: 256
|
| 22 |
+
normalize: true
|
| 23 |
+
scale: null
|
| 24 |
+
temperature: 10000
|
| 25 |
+
d_model: 256
|
| 26 |
+
backbone_channel_list: [1152, 576, 288, 144]
|
| 27 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 28 |
+
fpn_interp_model: nearest
|
| 29 |
+
|
| 30 |
+
memory_attention:
|
| 31 |
+
_target_: model.visual.sam2.modeling.memory_attention.MemoryAttention
|
| 32 |
+
d_model: 256
|
| 33 |
+
pos_enc_at_input: true
|
| 34 |
+
layer:
|
| 35 |
+
_target_: model.visual.sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 36 |
+
activation: relu
|
| 37 |
+
dim_feedforward: 2048
|
| 38 |
+
dropout: 0.1
|
| 39 |
+
pos_enc_at_attn: false
|
| 40 |
+
self_attention:
|
| 41 |
+
_target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention
|
| 42 |
+
rope_theta: 10000.0
|
| 43 |
+
feat_sizes: [32, 32]
|
| 44 |
+
embedding_dim: 256
|
| 45 |
+
num_heads: 1
|
| 46 |
+
downsample_rate: 1
|
| 47 |
+
dropout: 0.1
|
| 48 |
+
d_model: 256
|
| 49 |
+
pos_enc_at_cross_attn_keys: true
|
| 50 |
+
pos_enc_at_cross_attn_queries: false
|
| 51 |
+
cross_attention:
|
| 52 |
+
_target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention
|
| 53 |
+
rope_theta: 10000.0
|
| 54 |
+
feat_sizes: [32, 32]
|
| 55 |
+
rope_k_repeat: True
|
| 56 |
+
embedding_dim: 256
|
| 57 |
+
num_heads: 1
|
| 58 |
+
downsample_rate: 1
|
| 59 |
+
dropout: 0.1
|
| 60 |
+
kv_in_dim: 64
|
| 61 |
+
num_layers: 4
|
| 62 |
+
|
| 63 |
+
memory_encoder:
|
| 64 |
+
_target_: model.visual.sam2.modeling.memory_encoder.MemoryEncoder
|
| 65 |
+
out_dim: 64
|
| 66 |
+
position_encoding:
|
| 67 |
+
_target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 68 |
+
num_pos_feats: 64
|
| 69 |
+
normalize: true
|
| 70 |
+
scale: null
|
| 71 |
+
temperature: 10000
|
| 72 |
+
mask_downsampler:
|
| 73 |
+
_target_: model.visual.sam2.modeling.memory_encoder.MaskDownSampler
|
| 74 |
+
kernel_size: 3
|
| 75 |
+
stride: 2
|
| 76 |
+
padding: 1
|
| 77 |
+
fuser:
|
| 78 |
+
_target_: model.visual.sam2.modeling.memory_encoder.Fuser
|
| 79 |
+
layer:
|
| 80 |
+
_target_: model.visual.sam2.modeling.memory_encoder.CXBlock
|
| 81 |
+
dim: 256
|
| 82 |
+
kernel_size: 7
|
| 83 |
+
padding: 3
|
| 84 |
+
layer_scale_init_value: 1e-6
|
| 85 |
+
use_dwconv: True # depth-wise convs
|
| 86 |
+
num_layers: 2
|
| 87 |
+
|
| 88 |
+
num_maskmem: 7
|
| 89 |
+
image_size: 1024
|
| 90 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 91 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 92 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 93 |
+
use_mask_input_as_output_without_sam: true
|
| 94 |
+
# Memory
|
| 95 |
+
directly_add_no_mem_embed: true
|
| 96 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 97 |
+
use_high_res_features_in_sam: true
|
| 98 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 99 |
+
multimask_output_in_sam: true
|
| 100 |
+
# SAM heads
|
| 101 |
+
iou_prediction_use_sigmoid: True
|
| 102 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 103 |
+
use_obj_ptrs_in_encoder: true
|
| 104 |
+
add_tpos_enc_to_obj_ptrs: false
|
| 105 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 106 |
+
# object occlusion prediction
|
| 107 |
+
pred_obj_scores: true
|
| 108 |
+
pred_obj_scores_mlp: true
|
| 109 |
+
fixed_no_obj_ptr: true
|
| 110 |
+
# multimask tracking settings
|
| 111 |
+
multimask_output_for_tracking: true
|
| 112 |
+
use_multimask_token_for_obj_ptr: true
|
| 113 |
+
multimask_min_pt_num: 0
|
| 114 |
+
multimask_max_pt_num: 1
|
| 115 |
+
use_mlp_for_obj_ptr_proj: true
|
| 116 |
+
# Compilation flag
|
| 117 |
+
compile_image_encoder: False
|
avs.code/v1m.code/configs/sam2/sam2_hiera_s.yaml
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 96
|
| 12 |
+
num_heads: 1
|
| 13 |
+
stages: [1, 2, 11, 2]
|
| 14 |
+
global_att_blocks: [7, 10, 13]
|
| 15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
+
neck:
|
| 17 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 18 |
+
position_encoding:
|
| 19 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 20 |
+
num_pos_feats: 256
|
| 21 |
+
normalize: true
|
| 22 |
+
scale: null
|
| 23 |
+
temperature: 10000
|
| 24 |
+
d_model: 256
|
| 25 |
+
backbone_channel_list: [768, 384, 192, 96]
|
| 26 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 27 |
+
fpn_interp_model: nearest
|
| 28 |
+
|
| 29 |
+
memory_attention:
|
| 30 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 31 |
+
d_model: 256
|
| 32 |
+
pos_enc_at_input: true
|
| 33 |
+
layer:
|
| 34 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 35 |
+
activation: relu
|
| 36 |
+
dim_feedforward: 2048
|
| 37 |
+
dropout: 0.1
|
| 38 |
+
pos_enc_at_attn: false
|
| 39 |
+
self_attention:
|
| 40 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 41 |
+
rope_theta: 10000.0
|
| 42 |
+
feat_sizes: [32, 32]
|
| 43 |
+
embedding_dim: 256
|
| 44 |
+
num_heads: 1
|
| 45 |
+
downsample_rate: 1
|
| 46 |
+
dropout: 0.1
|
| 47 |
+
d_model: 256
|
| 48 |
+
pos_enc_at_cross_attn_keys: true
|
| 49 |
+
pos_enc_at_cross_attn_queries: false
|
| 50 |
+
cross_attention:
|
| 51 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 52 |
+
rope_theta: 10000.0
|
| 53 |
+
feat_sizes: [32, 32]
|
| 54 |
+
rope_k_repeat: True
|
| 55 |
+
embedding_dim: 256
|
| 56 |
+
num_heads: 1
|
| 57 |
+
downsample_rate: 1
|
| 58 |
+
dropout: 0.1
|
| 59 |
+
kv_in_dim: 64
|
| 60 |
+
num_layers: 4
|
| 61 |
+
|
| 62 |
+
memory_encoder:
|
| 63 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 64 |
+
out_dim: 64
|
| 65 |
+
position_encoding:
|
| 66 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 67 |
+
num_pos_feats: 64
|
| 68 |
+
normalize: true
|
| 69 |
+
scale: null
|
| 70 |
+
temperature: 10000
|
| 71 |
+
mask_downsampler:
|
| 72 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 73 |
+
kernel_size: 3
|
| 74 |
+
stride: 2
|
| 75 |
+
padding: 1
|
| 76 |
+
fuser:
|
| 77 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
| 78 |
+
layer:
|
| 79 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 80 |
+
dim: 256
|
| 81 |
+
kernel_size: 7
|
| 82 |
+
padding: 3
|
| 83 |
+
layer_scale_init_value: 1e-6
|
| 84 |
+
use_dwconv: True # depth-wise convs
|
| 85 |
+
num_layers: 2
|
| 86 |
+
|
| 87 |
+
num_maskmem: 7
|
| 88 |
+
image_size: 1024
|
| 89 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 90 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 91 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 92 |
+
use_mask_input_as_output_without_sam: true
|
| 93 |
+
# Memory
|
| 94 |
+
directly_add_no_mem_embed: true
|
| 95 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 96 |
+
use_high_res_features_in_sam: true
|
| 97 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 98 |
+
multimask_output_in_sam: true
|
| 99 |
+
# SAM heads
|
| 100 |
+
iou_prediction_use_sigmoid: True
|
| 101 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 102 |
+
use_obj_ptrs_in_encoder: true
|
| 103 |
+
add_tpos_enc_to_obj_ptrs: false
|
| 104 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 105 |
+
# object occlusion prediction
|
| 106 |
+
pred_obj_scores: true
|
| 107 |
+
pred_obj_scores_mlp: true
|
| 108 |
+
fixed_no_obj_ptr: true
|
| 109 |
+
# multimask tracking settings
|
| 110 |
+
multimask_output_for_tracking: true
|
| 111 |
+
use_multimask_token_for_obj_ptr: true
|
| 112 |
+
multimask_min_pt_num: 0
|
| 113 |
+
multimask_max_pt_num: 1
|
| 114 |
+
use_mlp_for_obj_ptr_proj: true
|
| 115 |
+
# Compilation flag
|
| 116 |
+
compile_image_encoder: False
|
avs.code/v1m.code/configs/sam2/sam2_hiera_t.yaml
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: model.visual.sam2.organised_sam2_train.SAM2Train
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: model.visual.sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: model.visual.sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 96
|
| 12 |
+
num_heads: 1
|
| 13 |
+
stages: [1, 2, 7, 2]
|
| 14 |
+
global_att_blocks: [5, 7, 9]
|
| 15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
+
neck:
|
| 17 |
+
_target_: model.visual.sam2.modeling.backbones.image_encoder.FpnNeck
|
| 18 |
+
position_encoding:
|
| 19 |
+
_target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 20 |
+
num_pos_feats: 256
|
| 21 |
+
normalize: true
|
| 22 |
+
scale: null
|
| 23 |
+
temperature: 10000
|
| 24 |
+
d_model: 256
|
| 25 |
+
backbone_channel_list: [768, 384, 192, 96]
|
| 26 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 27 |
+
fpn_interp_model: nearest
|
| 28 |
+
|
| 29 |
+
memory_attention:
|
| 30 |
+
_target_: model.visual.sam2.modeling.memory_attention.MemoryAttention
|
| 31 |
+
d_model: 256
|
| 32 |
+
pos_enc_at_input: true
|
| 33 |
+
layer:
|
| 34 |
+
_target_: model.visual.sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 35 |
+
activation: relu
|
| 36 |
+
dim_feedforward: 2048
|
| 37 |
+
dropout: 0.1
|
| 38 |
+
pos_enc_at_attn: false
|
| 39 |
+
self_attention:
|
| 40 |
+
_target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention
|
| 41 |
+
rope_theta: 10000.0
|
| 42 |
+
feat_sizes: [32, 32]
|
| 43 |
+
embedding_dim: 256
|
| 44 |
+
num_heads: 1
|
| 45 |
+
downsample_rate: 1
|
| 46 |
+
dropout: 0.1
|
| 47 |
+
d_model: 256
|
| 48 |
+
pos_enc_at_cross_attn_keys: true
|
| 49 |
+
pos_enc_at_cross_attn_queries: false
|
| 50 |
+
cross_attention:
|
| 51 |
+
_target_: model.visual.sam2.modeling.sam.transformer.RoPEAttention
|
| 52 |
+
rope_theta: 10000.0
|
| 53 |
+
feat_sizes: [32, 32]
|
| 54 |
+
rope_k_repeat: True
|
| 55 |
+
embedding_dim: 256
|
| 56 |
+
num_heads: 1
|
| 57 |
+
downsample_rate: 1
|
| 58 |
+
dropout: 0.1
|
| 59 |
+
kv_in_dim: 64
|
| 60 |
+
num_layers: 4
|
| 61 |
+
|
| 62 |
+
memory_encoder:
|
| 63 |
+
_target_: model.visual.sam2.modeling.memory_encoder.MemoryEncoder
|
| 64 |
+
out_dim: 64
|
| 65 |
+
position_encoding:
|
| 66 |
+
_target_: model.visual.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 67 |
+
num_pos_feats: 64
|
| 68 |
+
normalize: true
|
| 69 |
+
scale: null
|
| 70 |
+
temperature: 10000
|
| 71 |
+
mask_downsampler:
|
| 72 |
+
_target_: model.visual.sam2.modeling.memory_encoder.MaskDownSampler
|
| 73 |
+
kernel_size: 3
|
| 74 |
+
stride: 2
|
| 75 |
+
padding: 1
|
| 76 |
+
fuser:
|
| 77 |
+
_target_: model.visual.sam2.modeling.memory_encoder.Fuser
|
| 78 |
+
layer:
|
| 79 |
+
_target_: model.visual.sam2.modeling.memory_encoder.CXBlock
|
| 80 |
+
dim: 256
|
| 81 |
+
kernel_size: 7
|
| 82 |
+
padding: 3
|
| 83 |
+
layer_scale_init_value: 1e-6
|
| 84 |
+
use_dwconv: True # depth-wise convs
|
| 85 |
+
num_layers: 2
|
| 86 |
+
|
| 87 |
+
num_maskmem: 7
|
| 88 |
+
image_size: 224 # 1024
|
| 89 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 90 |
+
# SAM decoder
|
| 91 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 92 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 93 |
+
use_mask_input_as_output_without_sam: true
|
| 94 |
+
# Memory
|
| 95 |
+
directly_add_no_mem_embed: true
|
| 96 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 97 |
+
use_high_res_features_in_sam: true
|
| 98 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 99 |
+
multimask_output_in_sam: true
|
| 100 |
+
# SAM heads
|
| 101 |
+
iou_prediction_use_sigmoid: True
|
| 102 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 103 |
+
use_obj_ptrs_in_encoder: true
|
| 104 |
+
add_tpos_enc_to_obj_ptrs: false
|
| 105 |
+
only_obj_ptrs_in_the_past_for_eval: false
|
| 106 |
+
# object occlusion prediction
|
| 107 |
+
pred_obj_scores: true
|
| 108 |
+
pred_obj_scores_mlp: true
|
| 109 |
+
fixed_no_obj_ptr: true
|
| 110 |
+
# multimask tracking settings
|
| 111 |
+
multimask_output_for_tracking: true
|
| 112 |
+
use_multimask_token_for_obj_ptr: true
|
| 113 |
+
multimask_min_pt_num: 0
|
| 114 |
+
multimask_max_pt_num: 1
|
| 115 |
+
use_mlp_for_obj_ptr_proj: true
|
| 116 |
+
# Compilation flag
|
| 117 |
+
# HieraT does not currently support compilation, should always be set to False
|
| 118 |
+
compile_image_encoder: False
|
avs.code/v1m.code/configs/training/sam2_training_config.yaml
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Video transforms
|
| 4 |
+
|
| 5 |
+
train_transforms:
|
| 6 |
+
- _target_: dataloader.sam2_dataset.transforms.ComposeAPI
|
| 7 |
+
transforms:
|
| 8 |
+
- _target_: dataloader.sam2_dataset.transforms.RandomHorizontalFlip
|
| 9 |
+
consistent_transform: True
|
| 10 |
+
- _target_: dataloader.sam2_dataset.transforms.RandomAffine
|
| 11 |
+
degrees: 25
|
| 12 |
+
shear: 20
|
| 13 |
+
image_interpolation: bilinear
|
| 14 |
+
consistent_transform: True
|
| 15 |
+
- _target_: dataloader.sam2_dataset.transforms.RandomResizeAPI
|
| 16 |
+
sizes: 1024 # ${scratch.resolution}
|
| 17 |
+
square: true
|
| 18 |
+
consistent_transform: True
|
| 19 |
+
- _target_: dataloader.sam2_dataset.transforms.ColorJitter
|
| 20 |
+
consistent_transform: True
|
| 21 |
+
brightness: 0.1
|
| 22 |
+
contrast: 0.03
|
| 23 |
+
saturation: 0.03
|
| 24 |
+
hue: null
|
| 25 |
+
- _target_: dataloader.sam2_dataset.transforms.RandomGrayscale
|
| 26 |
+
p: 0.05
|
| 27 |
+
consistent_transform: True
|
| 28 |
+
- _target_: dataloader.sam2_dataset.transforms.ColorJitter
|
| 29 |
+
consistent_transform: False
|
| 30 |
+
brightness: 0.1
|
| 31 |
+
contrast: 0.05
|
| 32 |
+
saturation: 0.05
|
| 33 |
+
hue: null
|
| 34 |
+
- _target_: dataloader.sam2_dataset.transforms.ToTensorAPI
|
| 35 |
+
- _target_: dataloader.sam2_dataset.transforms.NormalizeAPI
|
| 36 |
+
mean: [0.485, 0.456, 0.406]
|
| 37 |
+
std: [0.229, 0.224, 0.225]
|
| 38 |
+
|
| 39 |
+
loss:
|
| 40 |
+
all:
|
| 41 |
+
_target_: loss.training.sam2_training_loss.MultiStepMultiMasksAndIous
|
| 42 |
+
weight_dict:
|
| 43 |
+
loss_mask: 20 # 20
|
| 44 |
+
loss_dice: 1
|
| 45 |
+
loss_iou: 1
|
| 46 |
+
loss_class: 1
|
| 47 |
+
supervise_all_iou: true
|
| 48 |
+
iou_use_l1_loss: true
|
| 49 |
+
pred_obj_scores: true
|
| 50 |
+
focal_gamma_obj_score: 0.0
|
| 51 |
+
focal_alpha_obj_score: -1.0
|
| 52 |
+
gpu_num: 4.
|
| 53 |
+
|
| 54 |
+
# Contrastive loss (ContrastLoss); loaded in main.py / inference.py → hyp_param.contrastive_learning
|
| 55 |
+
contrastive_learning:
|
| 56 |
+
temperature: 0.10
|
| 57 |
+
ignore_idx: 255
|
| 58 |
+
ood_idx: 254
|
| 59 |
+
max_views: 512
|
| 60 |
+
proj_dim: 512
|
| 61 |
+
sample_limits: 128
|
| 62 |
+
total_limits: 15240
|
avs.code/v1m.code/dataloader/audio/audio_augmentation.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Augmentation(object):
|
| 5 |
+
"""Audio pre-step used by training/inference: int16 waveform -> float in [-1, 1].
|
| 6 |
+
|
| 7 |
+
The previous audiomentations-based transforms were commented out and never applied;
|
| 8 |
+
behavior is unchanged: only scaling by 1/32768.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, mono=True):
|
| 12 |
+
self.mono = mono
|
| 13 |
+
|
| 14 |
+
def train_aug(self, x_, sr_):
|
| 15 |
+
x_ = x_ / 32768.0
|
| 16 |
+
return x_
|
| 17 |
+
|
| 18 |
+
def test_process(self, x_):
|
| 19 |
+
x_ = x_ / 32768.0
|
| 20 |
+
return x_
|
| 21 |
+
|
| 22 |
+
def __call__(self, x, sr, split):
|
| 23 |
+
return self.train_aug(x, sr) if split == "train" else self.test_process(x)
|
avs.code/v1m.code/dataloader/audio/audio_dataset.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy
|
| 3 |
+
import os
|
| 4 |
+
from dataloader.audio.preprocess_vgg.vggish_input import waveform_to_examples
|
| 5 |
+
import soundfile
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Audio(torch.utils.data.Dataset):
|
| 9 |
+
def __init__(self, augmentation, directory_path, split):
|
| 10 |
+
# temporarily set no augmentation.
|
| 11 |
+
self.augmentation = augmentation
|
| 12 |
+
self.directory_path = directory_path
|
| 13 |
+
self.split = split
|
| 14 |
+
|
| 15 |
+
def load_audio_wave(self, file_index, file_index_mix):
|
| 16 |
+
audio_path = os.path.join(file_index, 'audio.wav')
|
| 17 |
+
wav_data, sample_rate = soundfile.read(audio_path, dtype='int16')
|
| 18 |
+
assert wav_data.dtype == numpy.int16, 'Bad sample type: %r' % wav_data.dtype
|
| 19 |
+
|
| 20 |
+
if file_index_mix is not None:
|
| 21 |
+
audio_path2 = os.path.join(file_index_mix, 'audio.wav')
|
| 22 |
+
wav_data2, _ = soundfile.read(audio_path2, dtype='int16')
|
| 23 |
+
mix_lambda = numpy.random.beta(10, 10)
|
| 24 |
+
min_length = min(wav_data.shape[0], wav_data2.shape[0])
|
| 25 |
+
wav_data = wav_data[:min_length] * mix_lambda + wav_data2[:min_length] * (1-mix_lambda)
|
| 26 |
+
|
| 27 |
+
wav_data = self.augmentation(wav_data, sample_rate, self.split)
|
| 28 |
+
audio_log_mel = torch.cat([waveform_to_examples(wav_data[:, 0], sample_rate, True).detach(),
|
| 29 |
+
waveform_to_examples(wav_data[:, 1], sample_rate, True).detach()], dim=1)
|
| 30 |
+
|
| 31 |
+
# for the vgg preprocess, we will need 5 seconds audio log.
|
| 32 |
+
if audio_log_mel.shape[0] < 5:
|
| 33 |
+
audio_log_mel = torch.cat([audio_log_mel,
|
| 34 |
+
audio_log_mel[-1].unsqueeze(0).repeat(5-audio_log_mel.shape[0], 1, 1, 1)])
|
| 35 |
+
return audio_log_mel
|
| 36 |
+
|
| 37 |
+
def __len__(self):
|
| 38 |
+
return len(self.audio_list)
|
avs.code/v1m.code/dataloader/audio/preprocess_vgg/mel_features.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Defines routines to compute mel spectrogram features from audio waveform."""
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def frame(data, window_length, hop_length):
|
| 22 |
+
"""Convert array into a sequence of successive possibly overlapping frames.
|
| 23 |
+
|
| 24 |
+
An n-dimensional array of shape (num_samples, ...) is converted into an
|
| 25 |
+
(n+1)-D array of shape (num_frames, window_length, ...), where each frame
|
| 26 |
+
starts hop_length points after the preceding one.
|
| 27 |
+
|
| 28 |
+
This is accomplished using stride_tricks, so the original data is not
|
| 29 |
+
copied. However, there is no zero-padding, so any incomplete frames at the
|
| 30 |
+
end are not included.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
data: np.array of dimension N >= 1.
|
| 34 |
+
window_length: Number of samples in each frame.
|
| 35 |
+
hop_length: Advance (in samples) between each window.
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
(N+1)-D np.array with as many rows as there are complete frames that can be
|
| 39 |
+
extracted.
|
| 40 |
+
"""
|
| 41 |
+
num_samples = data.shape[0]
|
| 42 |
+
num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length))
|
| 43 |
+
shape = (num_frames, window_length) + data.shape[1:]
|
| 44 |
+
strides = (data.strides[0] * hop_length,) + data.strides
|
| 45 |
+
return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def periodic_hann(window_length):
|
| 49 |
+
"""Calculate a "periodic" Hann window.
|
| 50 |
+
|
| 51 |
+
The classic Hann window is defined as a raised cosine that starts and
|
| 52 |
+
ends on zero, and where every value appears twice, except the middle
|
| 53 |
+
point for an odd-length window. Matlab calls this a "symmetric" window
|
| 54 |
+
and np.hanning() returns it. However, for Fourier analysis, this
|
| 55 |
+
actually represents just over one cycle of a period N-1 cosine, and
|
| 56 |
+
thus is not compactly expressed on a length-N Fourier basis. Instead,
|
| 57 |
+
it's better to use a raised cosine that ends just before the final
|
| 58 |
+
zero value - i.e. a complete cycle of a period-N cosine. Matlab
|
| 59 |
+
calls this a "periodic" window. This routine calculates it.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
window_length: The number of points in the returned window.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
A 1D np.array containing the periodic hann window.
|
| 66 |
+
"""
|
| 67 |
+
return 0.5 - (0.5 * np.cos(2 * np.pi / window_length *
|
| 68 |
+
np.arange(window_length)))
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def stft_magnitude(signal, fft_length,
|
| 72 |
+
hop_length=None,
|
| 73 |
+
window_length=None):
|
| 74 |
+
"""Calculate the short-time Fourier transform magnitude.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
signal: 1D np.array of the input time-domain signal.
|
| 78 |
+
fft_length: Size of the FFT to apply.
|
| 79 |
+
hop_length: Advance (in samples) between each frame passed to FFT.
|
| 80 |
+
window_length: Length of each block of samples to pass to FFT.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
2D np.array where each row contains the magnitudes of the fft_length/2+1
|
| 84 |
+
unique values of the FFT for the corresponding frame of input samples.
|
| 85 |
+
"""
|
| 86 |
+
frames = frame(signal, window_length, hop_length)
|
| 87 |
+
# Apply frame window to each frame. We use a periodic Hann (cosine of period
|
| 88 |
+
# window_length) instead of the symmetric Hann of np.hanning (period
|
| 89 |
+
# window_length-1).
|
| 90 |
+
window = periodic_hann(window_length)
|
| 91 |
+
windowed_frames = frames * window
|
| 92 |
+
return np.abs(np.fft.rfft(windowed_frames, int(fft_length)))
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# Mel spectrum constants and functions.
|
| 96 |
+
_MEL_BREAK_FREQUENCY_HERTZ = 700.0
|
| 97 |
+
_MEL_HIGH_FREQUENCY_Q = 1127.0
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def hertz_to_mel(frequencies_hertz):
|
| 101 |
+
"""Convert frequencies to mel scale using HTK formula.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
frequencies_hertz: Scalar or np.array of frequencies in hertz.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
Object of same size as frequencies_hertz containing corresponding values
|
| 108 |
+
on the mel scale.
|
| 109 |
+
"""
|
| 110 |
+
return _MEL_HIGH_FREQUENCY_Q * np.log(
|
| 111 |
+
1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def spectrogram_to_mel_matrix(num_mel_bins=20,
|
| 115 |
+
num_spectrogram_bins=129,
|
| 116 |
+
audio_sample_rate=8000,
|
| 117 |
+
lower_edge_hertz=125.0,
|
| 118 |
+
upper_edge_hertz=3800.0):
|
| 119 |
+
"""Return a matrix that can post-multiply spectrogram rows to make mel.
|
| 120 |
+
|
| 121 |
+
Returns a np.array matrix A that can be used to post-multiply a matrix S of
|
| 122 |
+
spectrogram values (STFT magnitudes) arranged as frames x bins to generate a
|
| 123 |
+
"mel spectrogram" M of frames x num_mel_bins. M = S A.
|
| 124 |
+
|
| 125 |
+
The classic HTK algorithm exploits the complementarity of adjacent mel bands
|
| 126 |
+
to multiply each FFT bin by only one mel weight, then add it, with positive
|
| 127 |
+
and negative signs, to the two adjacent mel bands to which that bin
|
| 128 |
+
contributes. Here, by expressing this operation as a matrix multiply, we go
|
| 129 |
+
from num_fft multiplies per frame (plus around 2*num_fft adds) to around
|
| 130 |
+
num_fft^2 multiplies and adds. However, because these are all presumably
|
| 131 |
+
accomplished in a single call to np.dot(), it's not clear which approach is
|
| 132 |
+
faster in Python. The matrix multiplication has the attraction of being more
|
| 133 |
+
general and flexible, and much easier to read.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
num_mel_bins: How many bands in the resulting mel spectrum. This is
|
| 137 |
+
the number of columns in the output matrix.
|
| 138 |
+
num_spectrogram_bins: How many bins there are in the source spectrogram
|
| 139 |
+
data, which is understood to be fft_size/2 + 1, i.e. the spectrogram
|
| 140 |
+
only contains the nonredundant FFT bins.
|
| 141 |
+
audio_sample_rate: Samples per second of the audio at the input to the
|
| 142 |
+
spectrogram. We need this to figure out the actual frequencies for
|
| 143 |
+
each spectrogram bin, which dictates how they are mapped into mel.
|
| 144 |
+
lower_edge_hertz: Lower bound on the frequencies to be included in the mel
|
| 145 |
+
spectrum. This corresponds to the lower edge of the lowest triangular
|
| 146 |
+
band.
|
| 147 |
+
upper_edge_hertz: The desired top edge of the highest frequency band.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
An np.array with shape (num_spectrogram_bins, num_mel_bins).
|
| 151 |
+
|
| 152 |
+
Raises:
|
| 153 |
+
ValueError: if frequency edges are incorrectly ordered or out of range.
|
| 154 |
+
"""
|
| 155 |
+
nyquist_hertz = audio_sample_rate / 2.
|
| 156 |
+
if lower_edge_hertz < 0.0:
|
| 157 |
+
raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz)
|
| 158 |
+
if lower_edge_hertz >= upper_edge_hertz:
|
| 159 |
+
raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" %
|
| 160 |
+
(lower_edge_hertz, upper_edge_hertz))
|
| 161 |
+
if upper_edge_hertz > nyquist_hertz:
|
| 162 |
+
raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" %
|
| 163 |
+
(upper_edge_hertz, nyquist_hertz))
|
| 164 |
+
spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins)
|
| 165 |
+
spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz)
|
| 166 |
+
# The i'th mel band (starting from i=1) has center frequency
|
| 167 |
+
# band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge
|
| 168 |
+
# band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in
|
| 169 |
+
# the band_edges_mel arrays.
|
| 170 |
+
band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz),
|
| 171 |
+
hertz_to_mel(upper_edge_hertz), num_mel_bins + 2)
|
| 172 |
+
# Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins
|
| 173 |
+
# of spectrogram values.
|
| 174 |
+
mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins))
|
| 175 |
+
for i in range(num_mel_bins):
|
| 176 |
+
lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3]
|
| 177 |
+
# Calculate lower and upper slopes for every spectrogram bin.
|
| 178 |
+
# Line segments are linear in the *mel* domain, not hertz.
|
| 179 |
+
lower_slope = ((spectrogram_bins_mel - lower_edge_mel) /
|
| 180 |
+
(center_mel - lower_edge_mel))
|
| 181 |
+
upper_slope = ((upper_edge_mel - spectrogram_bins_mel) /
|
| 182 |
+
(upper_edge_mel - center_mel))
|
| 183 |
+
# .. then intersect them with each other and zero.
|
| 184 |
+
mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope,
|
| 185 |
+
upper_slope))
|
| 186 |
+
# HTK excludes the spectrogram DC bin; make sure it always gets a zero
|
| 187 |
+
# coefficient.
|
| 188 |
+
mel_weights_matrix[0, :] = 0.0
|
| 189 |
+
return mel_weights_matrix
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def log_mel_spectrogram(data,
|
| 193 |
+
audio_sample_rate=8000,
|
| 194 |
+
log_offset=0.0,
|
| 195 |
+
window_length_secs=0.025,
|
| 196 |
+
hop_length_secs=0.010,
|
| 197 |
+
**kwargs):
|
| 198 |
+
"""Convert waveform to a log magnitude mel-frequency spectrogram.
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
data: 1D np.array of waveform data.
|
| 202 |
+
audio_sample_rate: The sampling rate of data.
|
| 203 |
+
log_offset: Add this to values when taking log to avoid -Infs.
|
| 204 |
+
window_length_secs: Duration of each window to analyze.
|
| 205 |
+
hop_length_secs: Advance between successive analysis windows.
|
| 206 |
+
**kwargs: Additional arguments to pass to spectrogram_to_mel_matrix.
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank
|
| 210 |
+
magnitudes for successive frames.
|
| 211 |
+
"""
|
| 212 |
+
window_length_samples = int(round(audio_sample_rate * window_length_secs))
|
| 213 |
+
hop_length_samples = int(round(audio_sample_rate * hop_length_secs))
|
| 214 |
+
fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0)))
|
| 215 |
+
spectrogram = stft_magnitude(
|
| 216 |
+
data,
|
| 217 |
+
fft_length=fft_length,
|
| 218 |
+
hop_length=hop_length_samples,
|
| 219 |
+
window_length=window_length_samples)
|
| 220 |
+
mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix(
|
| 221 |
+
num_spectrogram_bins=spectrogram.shape[1],
|
| 222 |
+
audio_sample_rate=audio_sample_rate, **kwargs))
|
| 223 |
+
return np.log(mel_spectrogram + log_offset)
|
avs.code/v1m.code/dataloader/audio/preprocess_vgg/vggish_input.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Compute input examples for VGGish from audio waveform."""
|
| 17 |
+
|
| 18 |
+
# Modification: Return torch tensors rather than numpy arrays
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import resampy
|
| 23 |
+
|
| 24 |
+
from dataloader.audio.preprocess_vgg import mel_features
|
| 25 |
+
from dataloader.audio.preprocess_vgg import vggish_params
|
| 26 |
+
|
| 27 |
+
import soundfile as sf
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def waveform_to_examples(data, sample_rate, return_tensor=True):
|
| 31 |
+
"""Converts audio waveform into an array of examples for VGGish.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
data: np.array of either one dimension (mono) or two dimensions
|
| 35 |
+
(multi-channel, with the outer dimension representing channels).
|
| 36 |
+
Each sample is generally expected to lie in the range [-1.0, +1.0],
|
| 37 |
+
although this is not required.
|
| 38 |
+
sample_rate: Sample rate of data.
|
| 39 |
+
return_tensor: Return data as a Pytorch tensor ready for VGGish
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
3-D np.array of shape [num_examples, num_frames, num_bands] which represents
|
| 43 |
+
a sequence of examples, each of which contains a patch of log mel
|
| 44 |
+
spectrogram, covering num_frames frames of audio and num_bands mel frequency
|
| 45 |
+
bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS.
|
| 46 |
+
|
| 47 |
+
"""
|
| 48 |
+
# Convert to mono.
|
| 49 |
+
if len(data.shape) > 1:
|
| 50 |
+
data = np.mean(data, axis=1)
|
| 51 |
+
# Resample to the rate assumed by VGGish.
|
| 52 |
+
if sample_rate != vggish_params.SAMPLE_RATE:
|
| 53 |
+
data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE)
|
| 54 |
+
|
| 55 |
+
# Compute log mel spectrogram features.
|
| 56 |
+
log_mel = mel_features.log_mel_spectrogram(
|
| 57 |
+
data,
|
| 58 |
+
audio_sample_rate=vggish_params.SAMPLE_RATE,
|
| 59 |
+
log_offset=vggish_params.LOG_OFFSET,
|
| 60 |
+
window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS,
|
| 61 |
+
hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS,
|
| 62 |
+
num_mel_bins=vggish_params.NUM_MEL_BINS,
|
| 63 |
+
lower_edge_hertz=vggish_params.MEL_MIN_HZ,
|
| 64 |
+
upper_edge_hertz=vggish_params.MEL_MAX_HZ)
|
| 65 |
+
|
| 66 |
+
# Frame features into examples.
|
| 67 |
+
features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS
|
| 68 |
+
example_window_length = int(round(
|
| 69 |
+
vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate))
|
| 70 |
+
example_hop_length = int(round(
|
| 71 |
+
vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate))
|
| 72 |
+
log_mel_examples = mel_features.frame(
|
| 73 |
+
log_mel,
|
| 74 |
+
window_length=example_window_length,
|
| 75 |
+
hop_length=example_hop_length)
|
| 76 |
+
|
| 77 |
+
if return_tensor:
|
| 78 |
+
log_mel_examples = torch.tensor(
|
| 79 |
+
log_mel_examples, requires_grad=True)[:, None, :, :].float()
|
| 80 |
+
|
| 81 |
+
return log_mel_examples
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def wavfile_to_examples(wav_file, return_tensor=True):
|
| 85 |
+
"""Convenience wrapper around waveform_to_examples() for a common WAV format.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
wav_file: String path to a file, or a file-like object. The file
|
| 89 |
+
is assumed to contain WAV audio data with signed 16-bit PCM samples.
|
| 90 |
+
torch: Return data as a Pytorch tensor ready for VGGish
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
See waveform_to_examples.
|
| 94 |
+
"""
|
| 95 |
+
wav_data, sr = sf.read(wav_file, dtype='int16')
|
| 96 |
+
assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype
|
| 97 |
+
samples = wav_data / 32768.0 # Convert to [-1.0, +1.0]
|
| 98 |
+
return waveform_to_examples(samples, sr, return_tensor)
|
avs.code/v1m.code/dataloader/audio/preprocess_vgg/vggish_params.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Global parameters for the VGGish model.
|
| 17 |
+
|
| 18 |
+
See vggish_slim.py for more information.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
# Architectural constants.
|
| 22 |
+
NUM_FRAMES = 96 # Frames in input mel-spectrogram patch.
|
| 23 |
+
NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch.
|
| 24 |
+
EMBEDDING_SIZE = 128 # Size of embedding layer.
|
| 25 |
+
|
| 26 |
+
# Hyperparameters used in feature and example generation.
|
| 27 |
+
SAMPLE_RATE = 16000
|
| 28 |
+
STFT_WINDOW_LENGTH_SECONDS = 0.025
|
| 29 |
+
STFT_HOP_LENGTH_SECONDS = 0.010
|
| 30 |
+
NUM_MEL_BINS = NUM_BANDS
|
| 31 |
+
MEL_MIN_HZ = 125
|
| 32 |
+
MEL_MAX_HZ = 7500
|
| 33 |
+
LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram.
|
| 34 |
+
EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames
|
| 35 |
+
EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap.
|
| 36 |
+
|
| 37 |
+
# Parameters used for embedding postprocessing.
|
| 38 |
+
PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors'
|
| 39 |
+
PCA_MEANS_NAME = 'pca_means'
|
| 40 |
+
QUANTIZE_MIN_VAL = -2.0
|
| 41 |
+
QUANTIZE_MAX_VAL = +2.0
|
| 42 |
+
|
| 43 |
+
# Hyperparameters used in training.
|
| 44 |
+
INIT_STDDEV = 0.01 # Standard deviation used to initialize weights.
|
| 45 |
+
LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer.
|
| 46 |
+
ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer.
|
| 47 |
+
|
| 48 |
+
# Names of ops, tensors, and features.
|
| 49 |
+
INPUT_OP_NAME = 'vggish/input_features'
|
| 50 |
+
INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0'
|
| 51 |
+
OUTPUT_OP_NAME = 'vggish/embedding'
|
| 52 |
+
OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0'
|
| 53 |
+
AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding'
|
avs.code/v1m.code/dataloader/dataset.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Fused audio-visual dataset for AVSBench-style indexing."""
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import PIL.Image
|
| 5 |
+
import numpy
|
| 6 |
+
import torch
|
| 7 |
+
from dataloader.visual.visual_dataset import Visual
|
| 8 |
+
from dataloader.audio.audio_dataset import Audio
|
| 9 |
+
import pandas
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class AV(torch.utils.data.Dataset):
|
| 13 |
+
"""Pairs video frames + labels from `Visual` with log-mel spectrograms from `Audio` via `metadata.csv`."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, split, augmentation, param, root_path='', data_name='find'):
|
| 16 |
+
self.visual_dataset = Visual(augmentation['visual'], os.path.join(root_path, data_name), split, param.image_size, param.image_embedding_size)
|
| 17 |
+
self.audio_dataset = Audio(augmentation['audio'], os.path.join(root_path, data_name), split)
|
| 18 |
+
self.augment = augmentation
|
| 19 |
+
self.split = split
|
| 20 |
+
self.file_path = self.organise_files(self.split, root_path, data_name, csv_name_='avss_index/metadata.csv')
|
| 21 |
+
|
| 22 |
+
def __getitem__(self, index):
|
| 23 |
+
mixing_prob = 0. # we omit this option.
|
| 24 |
+
other_index = random.randint(1, self.__len__()) - 1 if random.random() < mixing_prob and self.split == 'train' else None
|
| 25 |
+
frame, label, prompts = self.visual_dataset.load_data(self.file_path[index])
|
| 26 |
+
if other_index is not None:
|
| 27 |
+
other_frame, other_label, other_prompts = self.visual_dataset.load_data(self.file_path[other_index])
|
| 28 |
+
frame, label, prompts = self.visual_mix(frame, other_frame, label, other_label, prompts, other_prompts)
|
| 29 |
+
audio_mel = self.audio_dataset.load_audio_wave(self.file_path[index], self.file_path[other_index])
|
| 30 |
+
else:
|
| 31 |
+
audio_mel = self.audio_dataset.load_audio_wave(self.file_path[index], None)
|
| 32 |
+
|
| 33 |
+
assert other_index is None if self.split == 'test' else 1, print('no mix in validation.')
|
| 34 |
+
|
| 35 |
+
return {'frame': frame, 'label': label, 'spectrogram': audio_mel, 'id': self.file_path[index],
|
| 36 |
+
'prompts': prompts}
|
| 37 |
+
|
| 38 |
+
def __len__(self):
|
| 39 |
+
return len(self.file_path)
|
| 40 |
+
|
| 41 |
+
@staticmethod
|
| 42 |
+
def organise_files(split_, root_path_, data_name_, csv_name_):
|
| 43 |
+
"""Read rows from `csv_name_` under `root_path_` matching split and dataset label."""
|
| 44 |
+
total_files = pandas.read_csv(os.path.join(root_path_, csv_name_))
|
| 45 |
+
files_info = total_files[(total_files["split"] == split_) & (total_files["label"] == data_name_)]['uid']
|
| 46 |
+
|
| 47 |
+
files_path = [os.path.join(root_path_, data_name_, files_name) for files_name in files_info]
|
| 48 |
+
del total_files, files_info
|
| 49 |
+
return files_path
|
| 50 |
+
|
| 51 |
+
@staticmethod
|
| 52 |
+
def visual_mix(frame1, frame2, label1, label2, prompts1, prompts2):
|
| 53 |
+
mix_frame = frame1.clone()
|
| 54 |
+
mix_label = label1.clone()
|
| 55 |
+
bbx1, bby1, bbx2, bby2 = 0, 0, mix_label.shape[1] - 1, mix_label.shape[2] - 1
|
| 56 |
+
|
| 57 |
+
for i in range(0, mix_frame.shape[0]):
|
| 58 |
+
label_canvas_foreground = label2[i, bbx1:bbx2, bby1:bby2] > 0.
|
| 59 |
+
mix_frame[i, :, bbx1:bbx2, bby1:bby2][:, label_canvas_foreground] = (
|
| 60 |
+
frame2[i, :, bbx1:bbx2, bby1:bby2][:, label_canvas_foreground])
|
| 61 |
+
mix_label[i, bbx1:bbx2, bby1:bby2][label_canvas_foreground] = (
|
| 62 |
+
label2[i, bbx1:bbx2, bby1:bby2][label_canvas_foreground])
|
| 63 |
+
|
| 64 |
+
return mix_frame, mix_label, prompts1
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
avs.code/v1m.code/dataloader/sam2_dataset/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
avs.code/v1m.code/dataloader/sam2_dataset/transforms.py
ADDED
|
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Transforms and data augmentation for both image + bbox.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
import random
|
| 14 |
+
from typing import Iterable
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torchvision.transforms as T
|
| 18 |
+
import torchvision.transforms.functional as F
|
| 19 |
+
import torchvision.transforms.v2.functional as Fv2
|
| 20 |
+
from PIL import Image as PILImage
|
| 21 |
+
# from docutils.nodes import label
|
| 22 |
+
import numpy
|
| 23 |
+
from torchvision.transforms import InterpolationMode
|
| 24 |
+
|
| 25 |
+
# from utils.data_utils import VideoDatapoint
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def hflip(frames, labels, index):
|
| 29 |
+
# print(index)
|
| 30 |
+
# print(len(frames), frames[index].size, type(frames[index]))
|
| 31 |
+
# print(len(labels), labels[index].size, type(labels[index]))
|
| 32 |
+
frames[index] = F.hflip(frames[index])
|
| 33 |
+
labels[index] = F.hflip(labels[index])
|
| 34 |
+
# for obj in frames[index].objects:
|
| 35 |
+
# if obj.segment is not None:
|
| 36 |
+
# obj.segment = F.hflip(obj.segment)
|
| 37 |
+
|
| 38 |
+
return frames, labels
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_size_with_aspect_ratio(image_size, size, max_size=None):
|
| 42 |
+
w, h = image_size
|
| 43 |
+
if max_size is not None:
|
| 44 |
+
min_original_size = float(min((w, h)))
|
| 45 |
+
max_original_size = float(max((w, h)))
|
| 46 |
+
if max_original_size / min_original_size * size > max_size:
|
| 47 |
+
size = max_size * min_original_size / max_original_size
|
| 48 |
+
|
| 49 |
+
if (w <= h and w == size) or (h <= w and h == size):
|
| 50 |
+
return (h, w)
|
| 51 |
+
|
| 52 |
+
if w < h:
|
| 53 |
+
ow = int(round(size))
|
| 54 |
+
oh = int(round(size * h / w))
|
| 55 |
+
else:
|
| 56 |
+
oh = int(round(size))
|
| 57 |
+
ow = int(round(size * w / h))
|
| 58 |
+
|
| 59 |
+
return (oh, ow)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def resize(frames, labels, index, size, max_size=None, square=False, v2=False):
|
| 63 |
+
# size can be min_size (scalar) or (w, h) tuple
|
| 64 |
+
def get_size(image_size, size, max_size=None):
|
| 65 |
+
if isinstance(size, (list, tuple)):
|
| 66 |
+
return size[::-1]
|
| 67 |
+
else:
|
| 68 |
+
return get_size_with_aspect_ratio(image_size, size, max_size)
|
| 69 |
+
|
| 70 |
+
if square:
|
| 71 |
+
size = size, size
|
| 72 |
+
else:
|
| 73 |
+
raise NotImplementedError
|
| 74 |
+
# cur_size = (
|
| 75 |
+
# frames[index].data.size()[-2:][::-1]
|
| 76 |
+
# if v2
|
| 77 |
+
# else frames[index].data.size
|
| 78 |
+
# )
|
| 79 |
+
# size = get_size(cur_size, size, max_size)
|
| 80 |
+
|
| 81 |
+
# old_size = (
|
| 82 |
+
# frames[index].data.size()[-2:][::-1]
|
| 83 |
+
# if v2
|
| 84 |
+
# else frames[index].data.size
|
| 85 |
+
# )
|
| 86 |
+
if v2:
|
| 87 |
+
frames[index].data = Fv2.resize(
|
| 88 |
+
frames[index].data, size, antialias=True
|
| 89 |
+
)
|
| 90 |
+
else:
|
| 91 |
+
frames[index] = F.resize(frames[index], size)
|
| 92 |
+
labels[index] = F.resize(labels[index], size)
|
| 93 |
+
# new_size = (
|
| 94 |
+
# frames[index].data.size()[-2:][::-1]
|
| 95 |
+
# if v2
|
| 96 |
+
# else frames[index].data.size
|
| 97 |
+
# )
|
| 98 |
+
|
| 99 |
+
# for obj in frames[index].objects:
|
| 100 |
+
# if obj.segment is not None:
|
| 101 |
+
# obj.segment = F.resize(obj.segment[None, None], size).squeeze()
|
| 102 |
+
|
| 103 |
+
# h, w = size
|
| 104 |
+
# frames[index].size = (h, w)
|
| 105 |
+
return frames, labels
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def pad(frames, index, padding, v2=False):
|
| 109 |
+
old_h, old_w = frames[index].size
|
| 110 |
+
h, w = old_h, old_w
|
| 111 |
+
if len(padding) == 2:
|
| 112 |
+
# assumes that we only pad on the bottom right corners
|
| 113 |
+
frames[index].data = F.pad(
|
| 114 |
+
frames[index].data, (0, 0, padding[0], padding[1])
|
| 115 |
+
)
|
| 116 |
+
h += padding[1]
|
| 117 |
+
w += padding[0]
|
| 118 |
+
else:
|
| 119 |
+
# left, top, right, bottom
|
| 120 |
+
frames[index].data = F.pad(
|
| 121 |
+
frames[index].data,
|
| 122 |
+
(padding[0], padding[1], padding[2], padding[3]),
|
| 123 |
+
)
|
| 124 |
+
h += padding[1] + padding[3]
|
| 125 |
+
w += padding[0] + padding[2]
|
| 126 |
+
|
| 127 |
+
frames[index].size = (h, w)
|
| 128 |
+
|
| 129 |
+
for obj in frames[index].objects:
|
| 130 |
+
if obj.segment is not None:
|
| 131 |
+
if v2:
|
| 132 |
+
if len(padding) == 2:
|
| 133 |
+
obj.segment = Fv2.pad(obj.segment, (0, 0, padding[0], padding[1]))
|
| 134 |
+
else:
|
| 135 |
+
obj.segment = Fv2.pad(obj.segment, tuple(padding))
|
| 136 |
+
else:
|
| 137 |
+
if len(padding) == 2:
|
| 138 |
+
obj.segment = F.pad(obj.segment, (0, 0, padding[0], padding[1]))
|
| 139 |
+
else:
|
| 140 |
+
obj.segment = F.pad(obj.segment, tuple(padding))
|
| 141 |
+
return frames
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class RandomHorizontalFlip:
|
| 145 |
+
def __init__(self, consistent_transform, p=0.5):
|
| 146 |
+
self.p = p
|
| 147 |
+
self.consistent_transform = consistent_transform
|
| 148 |
+
|
| 149 |
+
def __call__(self, frames, labels, **kwargs):
|
| 150 |
+
if self.consistent_transform:
|
| 151 |
+
if random.random() < self.p:
|
| 152 |
+
for i in range(len(frames)):
|
| 153 |
+
frames, labels = hflip(frames, labels, i)
|
| 154 |
+
return frames, labels
|
| 155 |
+
for i in range(len(frames)):
|
| 156 |
+
if random.random() < self.p:
|
| 157 |
+
frames, labels = hflip(frames, labels, i)
|
| 158 |
+
return frames, labels
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class RandomResizeAPI:
|
| 162 |
+
def __init__(
|
| 163 |
+
self, sizes, consistent_transform, max_size=None, square=False, v2=False
|
| 164 |
+
):
|
| 165 |
+
if isinstance(sizes, int):
|
| 166 |
+
sizes = (sizes,)
|
| 167 |
+
assert isinstance(sizes, Iterable)
|
| 168 |
+
self.sizes = list(sizes)
|
| 169 |
+
self.max_size = max_size
|
| 170 |
+
self.square = square
|
| 171 |
+
self.consistent_transform = consistent_transform
|
| 172 |
+
self.v2 = v2
|
| 173 |
+
|
| 174 |
+
def __call__(self, frames, labels):
|
| 175 |
+
if self.consistent_transform:
|
| 176 |
+
size = random.choice(self.sizes)
|
| 177 |
+
for i in range(len(frames)):
|
| 178 |
+
frames, labels = resize(
|
| 179 |
+
frames, labels, i, size, self.max_size, square=self.square, v2=self.v2
|
| 180 |
+
)
|
| 181 |
+
return frames, labels
|
| 182 |
+
for i in range(len(frames)):
|
| 183 |
+
size = random.choice(self.sizes)
|
| 184 |
+
frames, labels = resize(
|
| 185 |
+
frames, labels, i, size, self.max_size, square=self.square, v2=self.v2
|
| 186 |
+
)
|
| 187 |
+
return frames, labels
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class ToTensorAPI:
|
| 191 |
+
def __init__(self, v2=False):
|
| 192 |
+
self.v2 = v2
|
| 193 |
+
|
| 194 |
+
def __call__(self, frames, labels, **kwargs):
|
| 195 |
+
for img_idx in range(len(frames)):
|
| 196 |
+
if self.v2:
|
| 197 |
+
raise NotImplementedError
|
| 198 |
+
# frames[img_idx] = Fv2.to_tensor(frames[img_idx])
|
| 199 |
+
else:
|
| 200 |
+
frames[img_idx] = F.to_tensor(frames[img_idx])
|
| 201 |
+
labels[img_idx] = torch.tensor(numpy.array(labels[img_idx]), dtype=torch.float)
|
| 202 |
+
return frames, labels
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class NormalizeAPI:
|
| 206 |
+
def __init__(self, mean, std, v2=False):
|
| 207 |
+
self.mean = mean
|
| 208 |
+
self.std = std
|
| 209 |
+
self.v2 = v2
|
| 210 |
+
|
| 211 |
+
def __call__(self, frames, labels, **kwargs):
|
| 212 |
+
for img_idx in range(len(frames)):
|
| 213 |
+
# if self.v2:
|
| 214 |
+
# img.data = Fv2.convert_image_dtype(img.data, torch.float32)
|
| 215 |
+
# img.data = Fv2.normalize(img.data, mean=self.mean, std=self.std)
|
| 216 |
+
# else:
|
| 217 |
+
frames[img_idx] = F.normalize(frames[img_idx], mean=self.mean, std=self.std)
|
| 218 |
+
|
| 219 |
+
return frames, labels
|
| 220 |
+
|
| 221 |
+
'''
|
| 222 |
+
<dataloader.sam2_dataset.transforms.RandomHorizontalFlip object at 0x75c815561b40>
|
| 223 |
+
<dataloader.sam2_dataset.transforms.RandomAffine object at 0x75c815561bd0>
|
| 224 |
+
<dataloader.sam2_dataset.transforms.RandomResizeAPI object at 0x75c815561c60>
|
| 225 |
+
<dataloader.sam2_dataset.transforms.ColorJitter object at 0x75c815561cc0>
|
| 226 |
+
<dataloader.sam2_dataset.transforms.RandomGrayscale object at 0x75c815561cf0>
|
| 227 |
+
<dataloader.sam2_dataset.transforms.ColorJitter object at 0x75c815561de0>
|
| 228 |
+
<dataloader.sam2_dataset.transforms.ToTensorAPI object at 0x75c815507280>
|
| 229 |
+
<dataloader.sam2_dataset.transforms.NormalizeAPI object at 0x75c815507490>
|
| 230 |
+
'''
|
| 231 |
+
class ComposeAPI:
|
| 232 |
+
def __init__(self, transforms):
|
| 233 |
+
self.transforms = transforms
|
| 234 |
+
|
| 235 |
+
def __call__(self, frames, labels, **kwargs):
|
| 236 |
+
for t in self.transforms:
|
| 237 |
+
frames, labels = t(frames, labels, **kwargs)
|
| 238 |
+
return frames, labels
|
| 239 |
+
|
| 240 |
+
def __repr__(self):
|
| 241 |
+
format_string = self.__class__.__name__ + "("
|
| 242 |
+
for t in self.transforms:
|
| 243 |
+
format_string += "\n"
|
| 244 |
+
format_string += " {0}".format(t)
|
| 245 |
+
format_string += "\n)"
|
| 246 |
+
return format_string
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class RandomGrayscale:
|
| 250 |
+
def __init__(self, consistent_transform, p=0.5):
|
| 251 |
+
self.p = p
|
| 252 |
+
self.consistent_transform = consistent_transform
|
| 253 |
+
self.Grayscale = T.Grayscale(num_output_channels=3)
|
| 254 |
+
|
| 255 |
+
def __call__(self, frames, labels, **kwargs):
|
| 256 |
+
if self.consistent_transform:
|
| 257 |
+
if random.random() < self.p:
|
| 258 |
+
for img_idx in range(len(frames)):
|
| 259 |
+
frames[img_idx] = self.Grayscale(frames[img_idx])
|
| 260 |
+
return frames, labels
|
| 261 |
+
for img_idx in range(len(frames)):
|
| 262 |
+
if random.random() < self.p:
|
| 263 |
+
frames[img_idx] = self.Grayscale(frames[img_idx])
|
| 264 |
+
return frames, labels
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class ColorJitter:
|
| 268 |
+
def __init__(self, consistent_transform, brightness, contrast, saturation, hue):
|
| 269 |
+
self.consistent_transform = consistent_transform
|
| 270 |
+
self.brightness = (
|
| 271 |
+
brightness
|
| 272 |
+
if isinstance(brightness, list)
|
| 273 |
+
else [max(0, 1 - brightness), 1 + brightness]
|
| 274 |
+
)
|
| 275 |
+
self.contrast = (
|
| 276 |
+
contrast
|
| 277 |
+
if isinstance(contrast, list)
|
| 278 |
+
else [max(0, 1 - contrast), 1 + contrast]
|
| 279 |
+
)
|
| 280 |
+
self.saturation = (
|
| 281 |
+
saturation
|
| 282 |
+
if isinstance(saturation, list)
|
| 283 |
+
else [max(0, 1 - saturation), 1 + saturation]
|
| 284 |
+
)
|
| 285 |
+
self.hue = hue if isinstance(hue, list) or hue is None else ([-hue, hue])
|
| 286 |
+
|
| 287 |
+
def __call__(self, frames, labels, **kwargs):
|
| 288 |
+
if self.consistent_transform:
|
| 289 |
+
# Create a color jitter transformation params
|
| 290 |
+
(
|
| 291 |
+
fn_idx,
|
| 292 |
+
brightness_factor,
|
| 293 |
+
contrast_factor,
|
| 294 |
+
saturation_factor,
|
| 295 |
+
hue_factor,
|
| 296 |
+
) = T.ColorJitter.get_params(
|
| 297 |
+
self.brightness, self.contrast, self.saturation, self.hue
|
| 298 |
+
)
|
| 299 |
+
for img in frames:
|
| 300 |
+
if not self.consistent_transform:
|
| 301 |
+
(
|
| 302 |
+
fn_idx,
|
| 303 |
+
brightness_factor,
|
| 304 |
+
contrast_factor,
|
| 305 |
+
saturation_factor,
|
| 306 |
+
hue_factor,
|
| 307 |
+
) = T.ColorJitter.get_params(
|
| 308 |
+
self.brightness, self.contrast, self.saturation, self.hue
|
| 309 |
+
)
|
| 310 |
+
for fn_id in fn_idx:
|
| 311 |
+
if fn_id == 0 and brightness_factor is not None:
|
| 312 |
+
img = F.adjust_brightness(img, brightness_factor)
|
| 313 |
+
elif fn_id == 1 and contrast_factor is not None:
|
| 314 |
+
img = F.adjust_contrast(img, contrast_factor)
|
| 315 |
+
elif fn_id == 2 and saturation_factor is not None:
|
| 316 |
+
img = F.adjust_saturation(img, saturation_factor)
|
| 317 |
+
elif fn_id == 3 and hue_factor is not None:
|
| 318 |
+
img = F.adjust_hue(img, hue_factor)
|
| 319 |
+
return frames, labels
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class RandomAffine:
|
| 323 |
+
def __init__(
|
| 324 |
+
self,
|
| 325 |
+
degrees,
|
| 326 |
+
consistent_transform,
|
| 327 |
+
scale=None,
|
| 328 |
+
translate=None,
|
| 329 |
+
shear=None,
|
| 330 |
+
image_mean=(123, 116, 103),
|
| 331 |
+
label_fill_value=0.,
|
| 332 |
+
log_warning=True,
|
| 333 |
+
num_tentatives=1,
|
| 334 |
+
image_interpolation="bicubic",
|
| 335 |
+
):
|
| 336 |
+
"""
|
| 337 |
+
The mask is required for this transform.
|
| 338 |
+
if consistent_transform if True, then the same random affine is applied to all frames and masks.
|
| 339 |
+
"""
|
| 340 |
+
self.degrees = degrees if isinstance(degrees, list) else ([-degrees, degrees])
|
| 341 |
+
self.scale = scale
|
| 342 |
+
self.shear = (
|
| 343 |
+
shear if isinstance(shear, list) else ([-shear, shear] if shear else None)
|
| 344 |
+
)
|
| 345 |
+
self.translate = translate
|
| 346 |
+
self.fill_img = image_mean
|
| 347 |
+
self.fill_label = label_fill_value
|
| 348 |
+
self.consistent_transform = consistent_transform
|
| 349 |
+
self.log_warning = log_warning
|
| 350 |
+
self.num_tentatives = num_tentatives
|
| 351 |
+
assert self.num_tentatives >= 1., 'must have at least one if we utilise the augmentation.'
|
| 352 |
+
|
| 353 |
+
if image_interpolation == "bicubic":
|
| 354 |
+
self.image_interpolation = InterpolationMode.BICUBIC
|
| 355 |
+
elif image_interpolation == "bilinear":
|
| 356 |
+
self.image_interpolation = InterpolationMode.BILINEAR
|
| 357 |
+
else:
|
| 358 |
+
raise NotImplementedError
|
| 359 |
+
|
| 360 |
+
def __call__(self, frames, labels, **kwargs):
|
| 361 |
+
for _tentative in range(self.num_tentatives):
|
| 362 |
+
res_img, res_labels = self.transform_frames(frames, labels)
|
| 363 |
+
# if res is not None:
|
| 364 |
+
return res_img, res_labels
|
| 365 |
+
|
| 366 |
+
# raise NotImplementedError
|
| 367 |
+
# if self.log_warning:
|
| 368 |
+
# logging.warning(
|
| 369 |
+
# f"Skip RandomAffine for zero-area mask in first frame after {self.num_tentatives} tentatives"
|
| 370 |
+
# )
|
| 371 |
+
# return frames
|
| 372 |
+
|
| 373 |
+
def transform_frames(self, frames, labels):
|
| 374 |
+
_, height, width = F.get_dimensions(frames[0])
|
| 375 |
+
img_size = [width, height]
|
| 376 |
+
|
| 377 |
+
if self.consistent_transform:
|
| 378 |
+
# Create a random affine transformation
|
| 379 |
+
affine_params = T.RandomAffine.get_params(
|
| 380 |
+
degrees=self.degrees,
|
| 381 |
+
translate=self.translate,
|
| 382 |
+
scale_ranges=self.scale,
|
| 383 |
+
shears=self.shear,
|
| 384 |
+
img_size=img_size,
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
for img_idx, img in enumerate(frames):
|
| 388 |
+
if not self.consistent_transform:
|
| 389 |
+
# if not consistent we create a new affine params for every frame&mask pair Create a random affine transformation
|
| 390 |
+
affine_params = T.RandomAffine.get_params(
|
| 391 |
+
degrees=self.degrees,
|
| 392 |
+
translate=self.translate,
|
| 393 |
+
scale_ranges=self.scale,
|
| 394 |
+
shears=self.shear,
|
| 395 |
+
img_size=img_size,
|
| 396 |
+
)
|
| 397 |
+
frames[img_idx] = F.affine(
|
| 398 |
+
img,
|
| 399 |
+
*affine_params,
|
| 400 |
+
interpolation=self.image_interpolation,
|
| 401 |
+
fill=self.fill_img,
|
| 402 |
+
)
|
| 403 |
+
labels[img_idx] = F.affine(
|
| 404 |
+
labels[img_idx],
|
| 405 |
+
*affine_params,
|
| 406 |
+
# default: interpolation='nearest',
|
| 407 |
+
fill=self.fill_label,
|
| 408 |
+
)
|
| 409 |
+
return frames, labels
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
'''
|
| 413 |
+
def random_mosaic_frame(
|
| 414 |
+
datapoint,
|
| 415 |
+
index,
|
| 416 |
+
grid_h,
|
| 417 |
+
grid_w,
|
| 418 |
+
target_grid_y,
|
| 419 |
+
target_grid_x,
|
| 420 |
+
should_hflip,
|
| 421 |
+
):
|
| 422 |
+
# Step 1: downsize the images and paste them into a mosaic
|
| 423 |
+
image_data = datapoint.frames[index].data
|
| 424 |
+
is_pil = isinstance(image_data, PILImage.Image)
|
| 425 |
+
if is_pil:
|
| 426 |
+
H_im = image_data.height
|
| 427 |
+
W_im = image_data.width
|
| 428 |
+
image_data_output = PILImage.new("RGB", (W_im, H_im))
|
| 429 |
+
else:
|
| 430 |
+
H_im = image_data.size(-2)
|
| 431 |
+
W_im = image_data.size(-1)
|
| 432 |
+
image_data_output = torch.zeros_like(image_data)
|
| 433 |
+
|
| 434 |
+
downsize_cache = {}
|
| 435 |
+
for grid_y in range(grid_h):
|
| 436 |
+
for grid_x in range(grid_w):
|
| 437 |
+
y_offset_b = grid_y * H_im // grid_h
|
| 438 |
+
x_offset_b = grid_x * W_im // grid_w
|
| 439 |
+
y_offset_e = (grid_y + 1) * H_im // grid_h
|
| 440 |
+
x_offset_e = (grid_x + 1) * W_im // grid_w
|
| 441 |
+
H_im_downsize = y_offset_e - y_offset_b
|
| 442 |
+
W_im_downsize = x_offset_e - x_offset_b
|
| 443 |
+
|
| 444 |
+
if (H_im_downsize, W_im_downsize) in downsize_cache:
|
| 445 |
+
image_data_downsize = downsize_cache[(H_im_downsize, W_im_downsize)]
|
| 446 |
+
else:
|
| 447 |
+
image_data_downsize = F.resize(
|
| 448 |
+
image_data,
|
| 449 |
+
size=(H_im_downsize, W_im_downsize),
|
| 450 |
+
interpolation=InterpolationMode.BILINEAR,
|
| 451 |
+
antialias=True, # antialiasing for downsizing
|
| 452 |
+
)
|
| 453 |
+
downsize_cache[(H_im_downsize, W_im_downsize)] = image_data_downsize
|
| 454 |
+
if should_hflip[grid_y, grid_x].item():
|
| 455 |
+
image_data_downsize = F.hflip(image_data_downsize)
|
| 456 |
+
|
| 457 |
+
if is_pil:
|
| 458 |
+
image_data_output.paste(image_data_downsize, (x_offset_b, y_offset_b))
|
| 459 |
+
else:
|
| 460 |
+
image_data_output[:, y_offset_b:y_offset_e, x_offset_b:x_offset_e] = (
|
| 461 |
+
image_data_downsize
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
datapoint.frames[index].data = image_data_output
|
| 465 |
+
|
| 466 |
+
# Step 2: downsize the masks and paste them into the target grid of the mosaic
|
| 467 |
+
for obj in datapoint.frames[index].objects:
|
| 468 |
+
if obj.segment is None:
|
| 469 |
+
continue
|
| 470 |
+
assert obj.segment.shape == (H_im, W_im) and obj.segment.dtype == torch.uint8
|
| 471 |
+
segment_output = torch.zeros_like(obj.segment)
|
| 472 |
+
|
| 473 |
+
target_y_offset_b = target_grid_y * H_im // grid_h
|
| 474 |
+
target_x_offset_b = target_grid_x * W_im // grid_w
|
| 475 |
+
target_y_offset_e = (target_grid_y + 1) * H_im // grid_h
|
| 476 |
+
target_x_offset_e = (target_grid_x + 1) * W_im // grid_w
|
| 477 |
+
target_H_im_downsize = target_y_offset_e - target_y_offset_b
|
| 478 |
+
target_W_im_downsize = target_x_offset_e - target_x_offset_b
|
| 479 |
+
|
| 480 |
+
segment_downsize = F.resize(
|
| 481 |
+
obj.segment[None, None],
|
| 482 |
+
size=(target_H_im_downsize, target_W_im_downsize),
|
| 483 |
+
interpolation=InterpolationMode.BILINEAR,
|
| 484 |
+
antialias=True, # antialiasing for downsizing
|
| 485 |
+
)[0, 0]
|
| 486 |
+
if should_hflip[target_grid_y, target_grid_x].item():
|
| 487 |
+
segment_downsize = F.hflip(segment_downsize[None, None])[0, 0]
|
| 488 |
+
|
| 489 |
+
segment_output[
|
| 490 |
+
target_y_offset_b:target_y_offset_e, target_x_offset_b:target_x_offset_e
|
| 491 |
+
] = segment_downsize
|
| 492 |
+
obj.segment = segment_output
|
| 493 |
+
|
| 494 |
+
return datapoint
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
class RandomMosaicVideoAPI:
|
| 498 |
+
def __init__(self, prob=0.15, grid_h=2, grid_w=2, use_random_hflip=False):
|
| 499 |
+
self.prob = prob
|
| 500 |
+
self.grid_h = grid_h
|
| 501 |
+
self.grid_w = grid_w
|
| 502 |
+
self.use_random_hflip = use_random_hflip
|
| 503 |
+
|
| 504 |
+
def __call__(self, frames, **kwargs):
|
| 505 |
+
if random.random() > self.prob:
|
| 506 |
+
return datapoint
|
| 507 |
+
|
| 508 |
+
# select a random location to place the target mask in the mosaic
|
| 509 |
+
target_grid_y = random.randint(0, self.grid_h - 1)
|
| 510 |
+
target_grid_x = random.randint(0, self.grid_w - 1)
|
| 511 |
+
# whether to flip each grid in the mosaic horizontally
|
| 512 |
+
if self.use_random_hflip:
|
| 513 |
+
should_hflip = torch.rand(self.grid_h, self.grid_w) < 0.5
|
| 514 |
+
else:
|
| 515 |
+
should_hflip = torch.zeros(self.grid_h, self.grid_w, dtype=torch.bool)
|
| 516 |
+
for i in range(len(datapoint.frames)):
|
| 517 |
+
datapoint = random_mosaic_frame(
|
| 518 |
+
datapoint,
|
| 519 |
+
i,
|
| 520 |
+
grid_h=self.grid_h,
|
| 521 |
+
grid_w=self.grid_w,
|
| 522 |
+
target_grid_y=target_grid_y,
|
| 523 |
+
target_grid_x=target_grid_x,
|
| 524 |
+
should_hflip=should_hflip,
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
return datapoint
|
| 528 |
+
'''
|
avs.code/v1m.code/dataloader/visual/visual_augmentation.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import numpy
|
| 5 |
+
import torch
|
| 6 |
+
import torchvision.transforms.functional as F
|
| 7 |
+
import torchvision.transforms as transforms
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Augmentation(object):
|
| 11 |
+
def __init__(self, image_mean, image_std, image_width, image_height, scale_list, ignore_index=255):
|
| 12 |
+
self.image_size = (image_height, image_width)
|
| 13 |
+
# self.image_norm = (image_mean, image_std)
|
| 14 |
+
# self.get_crop_pos = transforms.RandomCrop(self.image_size)
|
| 15 |
+
self.color_jitter = transforms.ColorJitter(brightness=.5, contrast=.5, saturation=.5, hue=.25)
|
| 16 |
+
self.gaussian_blurring = transforms.GaussianBlur((3, 3))
|
| 17 |
+
self.scale_list = scale_list
|
| 18 |
+
|
| 19 |
+
self.normalise = transforms.Normalize(mean=image_mean, std=image_std)
|
| 20 |
+
self.to_tensor = transforms.ToTensor()
|
| 21 |
+
|
| 22 |
+
self.ignore_index = ignore_index
|
| 23 |
+
|
| 24 |
+
# self.normalise = transforms.Normalize(mean=image_mean, std=image_std)
|
| 25 |
+
|
| 26 |
+
# if setup == "avs" or setup == "avss" or setup == "avss_binary":
|
| 27 |
+
# # AVS
|
| 28 |
+
# self.scale_list = [.5, .75, 1.]
|
| 29 |
+
# self.color_jitter = None
|
| 30 |
+
# else:
|
| 31 |
+
# # COCO
|
| 32 |
+
# # self.scale_list = [.75, 1., 1.25, 1.5, 1.75, 2.]
|
| 33 |
+
# self.scale_list = [0.5,0.75,1.0,1.25,1.5,1.75,2.0]
|
| 34 |
+
|
| 35 |
+
# def normalise(self, image):
|
| 36 |
+
# image = image / 255.0
|
| 37 |
+
# image = image - self.image_norm[0]
|
| 38 |
+
# image = image / self.image_norm[1]
|
| 39 |
+
# return image
|
| 40 |
+
|
| 41 |
+
def resize(self, image_, label_, size=None):
|
| 42 |
+
h_, w_ = self.image_size if size is None else size
|
| 43 |
+
image_ = F.resize(image_, (h_, w_), transforms.InterpolationMode.BICUBIC)
|
| 44 |
+
label_ = F.resize(label_, (h_, w_), transforms.InterpolationMode.NEAREST)
|
| 45 |
+
return image_, label_
|
| 46 |
+
|
| 47 |
+
def random_crop_with_padding(self, image_, label_):
|
| 48 |
+
w_, h_ = image_.size
|
| 49 |
+
if min(h_, w_) < min(self.image_size):
|
| 50 |
+
res_w_ = max(self.image_size[0] - w_, 0)
|
| 51 |
+
res_h_ = max(self.image_size[1] - h_, 0)
|
| 52 |
+
image_ = F.pad(image_, [0, 0, res_w_, res_h_], fill=(numpy.array(self.image_norm[0]) * 255.).tolist())
|
| 53 |
+
# image_ = F.pad(image_, [0, 0, res_w_, res_h_], fill=self.ignore_index) # if error, define the padding value.
|
| 54 |
+
label_ = F.pad(label_, [0, 0, res_w_, res_h_], fill=self.ignore_index)
|
| 55 |
+
|
| 56 |
+
pos_ = self.get_crop_pos.get_params(image_, self.image_size)
|
| 57 |
+
image_ = F.crop(image_, *pos_)
|
| 58 |
+
label_ = F.crop(label_, *pos_)
|
| 59 |
+
|
| 60 |
+
return image_, label_
|
| 61 |
+
|
| 62 |
+
# @staticmethod
|
| 63 |
+
def random_scales(self, image_, label_):
|
| 64 |
+
w_, h_ = image_.size
|
| 65 |
+
chosen_scale = random.choice(self.scale_list)
|
| 66 |
+
w_, h_ = int(w_ * chosen_scale), int(h_ * chosen_scale)
|
| 67 |
+
image_ = F.resize(image_, (h_, w_), transforms.InterpolationMode.BICUBIC)
|
| 68 |
+
label_ = F.resize(label_, (h_, w_), transforms.InterpolationMode.NEAREST)
|
| 69 |
+
return image_, label_
|
| 70 |
+
|
| 71 |
+
@staticmethod
|
| 72 |
+
def random_flip_h(image_, label_):
|
| 73 |
+
chosen_flip = random.random() > 0.5
|
| 74 |
+
image_ = F.hflip(image_) if chosen_flip else image_
|
| 75 |
+
label_ = F.hflip(label_) if chosen_flip else label_
|
| 76 |
+
return image_, label_
|
| 77 |
+
|
| 78 |
+
def augment_entire_clip(self, x_list, y_list):
|
| 79 |
+
degree_ = float(torch.empty(1).uniform_(float(-25.), float(25.)).item())
|
| 80 |
+
shear_ = [float(torch.empty(1).uniform_(float(-20.), float(20.)).item()),
|
| 81 |
+
torch.empty(1).uniform_(float(-20.), float(20.)).item()]
|
| 82 |
+
dice = random.random()
|
| 83 |
+
for index, single_x in enumerate(x_list):
|
| 84 |
+
if dice <= 0.1:
|
| 85 |
+
single_x = F.rgb_to_grayscale(single_x, num_output_channels=3)
|
| 86 |
+
|
| 87 |
+
single_x = F.affine(single_x, angle=degree_, shear=shear_, translate=[0,0], scale=1.,
|
| 88 |
+
interpolation=transforms.InterpolationMode.BILINEAR, fill=[0., 0., 0.])
|
| 89 |
+
single_y = F.affine(y_list[index], angle=degree_, shear=shear_, translate=[0,0], scale=1.,
|
| 90 |
+
interpolation=transforms.InterpolationMode.NEAREST, fill=[0.])
|
| 91 |
+
x_list[index] = single_x
|
| 92 |
+
y_list[index] = single_y
|
| 93 |
+
|
| 94 |
+
return x_list, y_list
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def train_aug(self, x_, y_):
|
| 100 |
+
x_, y_ = self.random_flip_h(x_, y_)
|
| 101 |
+
# # x, y = self.random_scales(x, y)
|
| 102 |
+
x_, y_ = self.resize(x_, y_)
|
| 103 |
+
|
| 104 |
+
if self.color_jitter is not None and random.random() < 0.5:
|
| 105 |
+
x_ = self.color_jitter(x_)
|
| 106 |
+
if self.gaussian_blurring is not None and random.random() < 0.5:
|
| 107 |
+
x_ = self.gaussian_blurring(x_)
|
| 108 |
+
|
| 109 |
+
# x, y = self.random_crop_with_padding(x, y)
|
| 110 |
+
|
| 111 |
+
x_ = self.normalise(self.to_tensor(x_)).type(torch.float32)
|
| 112 |
+
# receive pseudo labels.
|
| 113 |
+
y_ = torch.tensor(numpy.array(y_)[numpy.newaxis, ...], dtype=torch.float)
|
| 114 |
+
return x_, y_
|
| 115 |
+
|
| 116 |
+
def test_process(self, x_, y_):
|
| 117 |
+
# x = self.to_tensor(x)
|
| 118 |
+
# y = torch.tensor(numpy.asarray(y)).long()
|
| 119 |
+
|
| 120 |
+
# following AVSbench setup, we fix image size (224, 224)
|
| 121 |
+
x_, y_ = self.resize(x_, y_)
|
| 122 |
+
|
| 123 |
+
x_ = self.normalise(self.to_tensor(x_)).type(torch.float32)
|
| 124 |
+
y_ = torch.tensor(numpy.array(y_)[numpy.newaxis, ...], dtype=torch.float)
|
| 125 |
+
return x_, y_
|
| 126 |
+
|
| 127 |
+
def __call__(self, x, y, split):
|
| 128 |
+
return self.train_aug(x, y) if split == "train" \
|
| 129 |
+
else self.test_process(x, y)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
|
avs.code/v1m.code/dataloader/visual/visual_dataset.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import PIL.Image
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import numpy
|
| 6 |
+
import torch
|
| 7 |
+
import pandas
|
| 8 |
+
import torchvision
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Visual(torch.utils.data.Dataset):
|
| 12 |
+
def __init__(self, augmentation, directory_path, split, image_size, image_embedding_size):
|
| 13 |
+
self.augment = augmentation
|
| 14 |
+
self.directory_path = directory_path
|
| 15 |
+
self.split = split
|
| 16 |
+
self.image_size = image_size
|
| 17 |
+
self.embedding_size = image_embedding_size
|
| 18 |
+
|
| 19 |
+
def load_data(self, file_prefix):
|
| 20 |
+
frame_path = os.path.join(file_prefix, 'frames')
|
| 21 |
+
frame_path = [os.path.join(frame_path, i) for i in os.listdir(frame_path)]
|
| 22 |
+
label_path = os.path.join(file_prefix, 'labels_rgb')
|
| 23 |
+
label_path = [os.path.join(label_path, i) for i in os.listdir(label_path)]
|
| 24 |
+
|
| 25 |
+
# if self.split == 'train':
|
| 26 |
+
# label_path += [os.path.join(file_prefix.replace('v1s', 'v1s_sam2_pseudo_labels'), i) for i in
|
| 27 |
+
# os.listdir(file_prefix.replace('v1s', 'v1s_sam2_pseudo_labels'))]
|
| 28 |
+
|
| 29 |
+
frame_path.sort(key=lambda x: tuple(map(int, x.split('/')[-1].split("_")[-1].split('.jpg')[0])))
|
| 30 |
+
label_path.sort(key=lambda x: tuple(map(int, x.split('/')[-1].split("_")[-1].split('.png')[0])))
|
| 31 |
+
|
| 32 |
+
frame = [PIL.Image.open(i) for i in frame_path]
|
| 33 |
+
label = [PIL.Image.open(i).convert('L') for i in label_path]
|
| 34 |
+
|
| 35 |
+
# if self.split == 'train':
|
| 36 |
+
# label += [PIL.Image.new('L', frame[0].size)] * (len(frame)-len(label))
|
| 37 |
+
|
| 38 |
+
label_idx = torch.tensor(list([1] + [0] * 4), dtype=torch.bool)
|
| 39 |
+
# fulfill the empty page.
|
| 40 |
+
# we utilise pseudo-labels now.
|
| 41 |
+
# label_idx = torch.tensor(list([1] + [0] * (len(frame) - len(label))), dtype=torch.bool)
|
| 42 |
+
# label += [PIL.Image.new('L', frame[0].size)] * (len(frame)-len(label))
|
| 43 |
+
|
| 44 |
+
# receive the prompts from the ground truth.
|
| 45 |
+
# prompts = {"point_coords": torch.nan, "point_labels": torch.nan,
|
| 46 |
+
# "masks": [None]*len(frame), "box_coords": [None]*len(frame)}
|
| 47 |
+
|
| 48 |
+
prompts = {}
|
| 49 |
+
image_batch = [None]*len(frame)
|
| 50 |
+
label_batch = [None]*len(frame)
|
| 51 |
+
|
| 52 |
+
if self.split == 'train':
|
| 53 |
+
# frame, label = self.augment.augment_entire_clip(frame, label)
|
| 54 |
+
frame, label = self.augment(frame, label)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
for i in range(len(frame)):
|
| 58 |
+
if self.split == 'test':
|
| 59 |
+
curr_frame, curr_label = self.augment(frame[i], label[i], split=self.split)
|
| 60 |
+
else:
|
| 61 |
+
curr_frame, curr_label = frame[i], label[i]
|
| 62 |
+
# if self.split == 'train' and i > 0:
|
| 63 |
+
# curr_label = curr_label / 255.
|
| 64 |
+
# curr_label[curr_label > 0.5] = 1
|
| 65 |
+
# curr_label[curr_label < 0.5] = 0
|
| 66 |
+
# # curr_label[(0.05 < curr_label) & (curr_label < 0.95)] = 255
|
| 67 |
+
# # we temporarily make it to be hard mask;
|
| 68 |
+
# # curr_label = ((curr_label / 255.) - 0.5) * 2
|
| 69 |
+
# # curr_label[curr_label >= 0.] = 1.
|
| 70 |
+
# # curr_label[curr_label < 0.] = 0.
|
| 71 |
+
# else:
|
| 72 |
+
curr_label[curr_label > 0.] = 1.
|
| 73 |
+
image_batch[i], label_batch[i] = curr_frame, curr_label
|
| 74 |
+
|
| 75 |
+
# image_batch[i], label_batch[i] = self.augment(frame[i], label[i], split=self.split)
|
| 76 |
+
# note: we simply convert the code to binary mask in v1s, v1m;
|
| 77 |
+
# to some reason, we failed to load the label in `L' format and had to hardcoding here.
|
| 78 |
+
# label_batch[i][label_batch[i] > 0.] = 1.
|
| 79 |
+
|
| 80 |
+
# prompts['box_coords'][i], prompts['masks'][i] = self.receive_other_prompts(label_batch[i])
|
| 81 |
+
|
| 82 |
+
# organise the prompts
|
| 83 |
+
# prompts.update({'masks': torch.stack(prompts['masks'], dim=0)})
|
| 84 |
+
# prompts.update({'box_coords': torch.stack(prompts['box_coords'], dim=0)})
|
| 85 |
+
# prompts.update({'point_labels': torch.stack(prompts['point_labels'], dim=0)})
|
| 86 |
+
prompts.update({'label_index': label_idx})
|
| 87 |
+
return torch.stack(image_batch, dim=0), torch.stack(label_batch, dim=0), prompts
|
| 88 |
+
|
| 89 |
+
def receive_other_prompts(self, y_):
|
| 90 |
+
# y_ = torch.zeros_like(y_)
|
| 91 |
+
if len(torch.unique(y_)) > 1:
|
| 92 |
+
# foreground point
|
| 93 |
+
points_foreground = torch.stack(torch.where(y_ > 0)[::-1], dim=0).transpose(1, 0)
|
| 94 |
+
|
| 95 |
+
# bbox prompt (left-top corner & right-bottom corner)
|
| 96 |
+
bbox_one = torch.min(points_foreground[:, 0]), torch.min(points_foreground[:, 1])
|
| 97 |
+
bbox_fou = torch.max(points_foreground[:, 0]), torch.max(points_foreground[:, 1])
|
| 98 |
+
bbox_coord = torch.tensor(bbox_one + bbox_fou, dtype=torch.float)
|
| 99 |
+
bbox_coord = self.transform_coords(bbox_coord, orig_hw=y_.squeeze().shape)
|
| 100 |
+
# mask prompt
|
| 101 |
+
low_mask = torchvision.transforms.functional.resize(y_.clone(), [self.embedding_size*4, self.embedding_size*4],
|
| 102 |
+
torchvision.transforms.InterpolationMode.NEAREST)
|
| 103 |
+
else:
|
| 104 |
+
# for the pure background situation.
|
| 105 |
+
bbox_coord = torch.zeros([4], dtype=torch.float).fill_(float('nan'))
|
| 106 |
+
low_mask = torch.zeros([1, self.embedding_size*4, self.embedding_size*4], dtype=torch.float).fill_(float('nan'))
|
| 107 |
+
|
| 108 |
+
return bbox_coord, low_mask
|
| 109 |
+
|
| 110 |
+
# we transfer the coords to SAM's input resolution (1024, 1024).
|
| 111 |
+
def transform_coords(self, coords: torch.Tensor, orig_hw=None) -> torch.Tensor:
|
| 112 |
+
"""
|
| 113 |
+
Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates,
|
| 114 |
+
If the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
|
| 115 |
+
|
| 116 |
+
Returns
|
| 117 |
+
Un-normalized coordinates in the range of [0, 1] which is expected by the sam2 model.
|
| 118 |
+
"""
|
| 119 |
+
h, w = orig_hw
|
| 120 |
+
coords = coords.clone().reshape(-1, 2, 2)
|
| 121 |
+
coords[..., 0] = coords[..., 0] / w
|
| 122 |
+
coords[..., 1] = coords[..., 1] / h
|
| 123 |
+
coords = coords * self.image_size # unnormalize coords
|
| 124 |
+
return coords.reshape(4)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
|
avs.code/v1m.code/inference.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Distributed inference on the test set; runs the same three `process` modes as training validation."""
|
| 2 |
+
import os
|
| 3 |
+
import pathlib
|
| 4 |
+
import torch
|
| 5 |
+
import numpy
|
| 6 |
+
import random
|
| 7 |
+
import argparse
|
| 8 |
+
from easydict import EasyDict
|
| 9 |
+
|
| 10 |
+
# Avoid import failure when configs.config creates saved_dir without write permission.
|
| 11 |
+
_real_mkdir = pathlib.Path.mkdir
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _safe_mkdir(self, mode=0o777, parents=False, exist_ok=False):
|
| 15 |
+
try:
|
| 16 |
+
return _real_mkdir(self, mode, parents=parents, exist_ok=exist_ok)
|
| 17 |
+
except PermissionError:
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
pathlib.Path.mkdir = _safe_mkdir
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def seed_it(seed):
|
| 25 |
+
random.seed(seed)
|
| 26 |
+
os.environ["PYTHONSEED"] = str(seed)
|
| 27 |
+
numpy.random.seed(seed)
|
| 28 |
+
torch.cuda.manual_seed(seed)
|
| 29 |
+
torch.cuda.manual_seed_all(seed)
|
| 30 |
+
torch.backends.cudnn.deterministic = True
|
| 31 |
+
torch.backends.cudnn.benchmark = True
|
| 32 |
+
torch.backends.cudnn.enabled = True
|
| 33 |
+
torch.manual_seed(seed)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class _DummyTensorboard:
|
| 37 |
+
"""Minimal Tensorboard stub so Trainer.valid runs without wandb logging."""
|
| 38 |
+
|
| 39 |
+
def upload_wandb_info(self, info_dict):
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
def upload_wandb_image(self, *args, **kwargs):
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def main(local_rank, ngpus_per_node, hyp_param):
|
| 47 |
+
hyp_param.local_rank = local_rank
|
| 48 |
+
torch.distributed.init_process_group(
|
| 49 |
+
backend='nccl',
|
| 50 |
+
init_method='env://',
|
| 51 |
+
rank=hyp_param.local_rank,
|
| 52 |
+
world_size=hyp_param.gpus * 1
|
| 53 |
+
)
|
| 54 |
+
seed_it(local_rank + hyp_param.seed)
|
| 55 |
+
|
| 56 |
+
import model.visual.sam2 # noqa: F401 — registers Hydra `configs`
|
| 57 |
+
from hydra import compose
|
| 58 |
+
from omegaconf import OmegaConf
|
| 59 |
+
|
| 60 |
+
arch_h = compose(config_name='auralfuser/architecture.yaml')
|
| 61 |
+
OmegaConf.resolve(arch_h)
|
| 62 |
+
hyp_param.aural_fuser = OmegaConf.to_container(arch_h.aural_fuser, resolve=True)
|
| 63 |
+
|
| 64 |
+
train_cfg = compose(config_name='training/sam2_training_config.yaml')
|
| 65 |
+
OmegaConf.resolve(train_cfg)
|
| 66 |
+
hyp_param.contrastive_learning = OmegaConf.to_container(train_cfg.contrastive_learning, resolve=True)
|
| 67 |
+
|
| 68 |
+
from model.mymodel import AVmodel
|
| 69 |
+
av_model = AVmodel(hyp_param).cuda()
|
| 70 |
+
torch.cuda.set_device(hyp_param.local_rank)
|
| 71 |
+
ckpt_sd = torch.load(hyp_param.inference_ckpt, map_location="cpu")
|
| 72 |
+
if not isinstance(ckpt_sd, dict):
|
| 73 |
+
raise TypeError("Checkpoint must be a state_dict dictionary.")
|
| 74 |
+
# Same as v1s/v2: full-model ckpt vs train-only aural_fuser ckpt (e.g. keys vgg.*, f_blocks.*).
|
| 75 |
+
if any(k.startswith("v_model.") or k.startswith("aural_fuser.") for k in ckpt_sd.keys()):
|
| 76 |
+
av_model.load_state_dict(ckpt_sd, strict=True)
|
| 77 |
+
else:
|
| 78 |
+
av_model.aural_fuser.load_state_dict(ckpt_sd, strict=True)
|
| 79 |
+
|
| 80 |
+
av_model = torch.nn.parallel.distributed.DistributedDataParallel(av_model, device_ids=[hyp_param.local_rank],
|
| 81 |
+
find_unused_parameters=False)
|
| 82 |
+
av_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(av_model)
|
| 83 |
+
av_model.eval()
|
| 84 |
+
|
| 85 |
+
from dataloader.dataset import AV
|
| 86 |
+
from dataloader.visual.visual_augmentation import Augmentation as VisualAugmentation
|
| 87 |
+
from dataloader.audio.audio_augmentation import Augmentation as AudioAugmentation
|
| 88 |
+
from torch.utils.data import DataLoader, Subset
|
| 89 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 90 |
+
|
| 91 |
+
visual_augmentation = VisualAugmentation(hyp_param.image_mean, hyp_param.image_std,
|
| 92 |
+
hyp_param.image_size, hyp_param.image_size,
|
| 93 |
+
hyp_param.scale_list, ignore_index=hyp_param.ignore_index)
|
| 94 |
+
audio_augmentation = AudioAugmentation(mono=True)
|
| 95 |
+
|
| 96 |
+
dataset = AV(split='test', augmentation={"visual": visual_augmentation, "audio": audio_augmentation},
|
| 97 |
+
param=hyp_param, root_path=hyp_param.data_root_path, data_name=hyp_param.inference_data_name)
|
| 98 |
+
|
| 99 |
+
max_batches = getattr(hyp_param, "inference_max_batches", 0) or 0
|
| 100 |
+
if max_batches > 0:
|
| 101 |
+
n_samples = min(max_batches * hyp_param.batch_size, len(dataset))
|
| 102 |
+
dataset = Subset(dataset, range(n_samples))
|
| 103 |
+
|
| 104 |
+
sampler = DistributedSampler(dataset, shuffle=False)
|
| 105 |
+
test_dataloader = DataLoader(dataset, batch_size=hyp_param.batch_size, sampler=sampler,
|
| 106 |
+
num_workers=hyp_param.num_workers)
|
| 107 |
+
|
| 108 |
+
from trainer.train import Trainer
|
| 109 |
+
from utils.foreground_iou import ForegroundIoU
|
| 110 |
+
from utils.foreground_fscore import ForegroundFScore
|
| 111 |
+
|
| 112 |
+
metrics = {
|
| 113 |
+
"foreground_iou": ForegroundIoU(),
|
| 114 |
+
"foreground_f-score": ForegroundFScore(hyp_param.local_rank),
|
| 115 |
+
}
|
| 116 |
+
trainer = Trainer(hyp_param, loss=None, tensorboard=_DummyTensorboard(), metrics=metrics)
|
| 117 |
+
|
| 118 |
+
# Same three modes as main.py validation: default first mask / iou_select / iou_occ_select
|
| 119 |
+
runs = [
|
| 120 |
+
("", "default (logits[:,0])"),
|
| 121 |
+
("iou_select", "iou_select"),
|
| 122 |
+
("iou_occ_select", "iou_occ_select"),
|
| 123 |
+
]
|
| 124 |
+
results = []
|
| 125 |
+
for process, label in runs:
|
| 126 |
+
fiou, ffscore = trainer.valid(epoch=0, dataloader=test_dataloader, model=av_model, process=process)
|
| 127 |
+
results.append((label, fiou, ffscore))
|
| 128 |
+
torch.cuda.empty_cache()
|
| 129 |
+
|
| 130 |
+
if hyp_param.local_rank <= 0:
|
| 131 |
+
print("\n========== inference (same three process flags as training valid) ==========")
|
| 132 |
+
for label, fiou, ffscore in results:
|
| 133 |
+
print(" {:32s} f_iou={} f_f-score={}".format(label, fiou, ffscore))
|
| 134 |
+
print("=======================================================\n")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if __name__ == '__main__':
|
| 138 |
+
parser = argparse.ArgumentParser(description='Inference: full test set + three process modes')
|
| 139 |
+
|
| 140 |
+
parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N')
|
| 141 |
+
|
| 142 |
+
parser.add_argument("--local_rank", type=int, default=-1,
|
| 143 |
+
help='multi-process training for DDP')
|
| 144 |
+
|
| 145 |
+
parser.add_argument('-g', '--gpus', default=1, type=int,
|
| 146 |
+
help='number of gpus per node')
|
| 147 |
+
|
| 148 |
+
parser.add_argument('--batch_size', default=1, type=int,
|
| 149 |
+
help='Batch size (match training if needed)')
|
| 150 |
+
|
| 151 |
+
parser.add_argument('--epochs', default=80, type=int,
|
| 152 |
+
help="unused")
|
| 153 |
+
|
| 154 |
+
parser.add_argument('--lr', default=1e-5, type=float,
|
| 155 |
+
help="unused")
|
| 156 |
+
|
| 157 |
+
parser.add_argument('--online', action="store_true",
|
| 158 |
+
help='unused')
|
| 159 |
+
|
| 160 |
+
parser.add_argument(
|
| 161 |
+
'--inference_ckpt', type=str, default=None,
|
| 162 |
+
help='Trained AuralSAM2 checkpoint (.pth state_dict). '
|
| 163 |
+
'SAM2 backbone is loaded from backbone_weight in configs (same path as training: repo_root/ckpts/sam_ckpts/). '
|
| 164 |
+
'Default if unset: avs.code/training_details/.../hiera_l.pth',
|
| 165 |
+
)
|
| 166 |
+
parser.add_argument('--inference_data_name', type=str, default='v1m',
|
| 167 |
+
help='AVSBench subset folder label (v1s|v1m|v2); must match training test split')
|
| 168 |
+
parser.add_argument('--inference_max_batches', type=int, default=0,
|
| 169 |
+
help='0 = full test; >0 = first N batches only (debug)')
|
| 170 |
+
|
| 171 |
+
args = parser.parse_args()
|
| 172 |
+
|
| 173 |
+
from configs.config import C
|
| 174 |
+
|
| 175 |
+
args = EasyDict({**C, **vars(args)})
|
| 176 |
+
|
| 177 |
+
_repo = pathlib.Path(__file__).resolve().parent
|
| 178 |
+
# Repo root: .../AuralSAM2 (parent of avs.code)
|
| 179 |
+
_workspace = _repo.parent.parent
|
| 180 |
+
args.data_root_path = str(_workspace / 'AVSBench')
|
| 181 |
+
args.backbone_weight = str(_workspace / 'ckpts' / 'sam_ckpts' / 'sam2_hiera_large.pt')
|
| 182 |
+
args.audio.PRETRAINED_VGGISH_MODEL_PATH = str(_workspace / 'ckpts' / 'vggish-10086976.pth')
|
| 183 |
+
args.saved_dir = '/tmp/v1m_infer_ckpt'
|
| 184 |
+
pathlib.Path(args.saved_dir).mkdir(parents=True, exist_ok=True)
|
| 185 |
+
if args.inference_ckpt is None:
|
| 186 |
+
args.inference_ckpt = str(
|
| 187 |
+
_repo.parent / 'training_details' / 'v1m' / 'hiera_l' / 'hiera_l.pth'
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
| 191 |
+
os.environ['MASTER_PORT'] = '9901'
|
| 192 |
+
|
| 193 |
+
torch.multiprocessing.spawn(main, nprocs=args.gpus, args=(args.gpus, args))
|
avs.code/v1m.code/loss/training/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training loss modules."""
|
| 2 |
+
|
avs.code/v1m.code/loss/training/contrastive_learning.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ContrastLoss(nn.Module, ABC):
|
| 8 |
+
def __init__(self, hyp_param):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.param = hyp_param
|
| 11 |
+
_defaults = {
|
| 12 |
+
"temperature": 0.10,
|
| 13 |
+
"ignore_idx": 255,
|
| 14 |
+
"ood_idx": 254,
|
| 15 |
+
"max_views": 512,
|
| 16 |
+
"proj_dim": 512,
|
| 17 |
+
"sample_limits": 128,
|
| 18 |
+
"total_limits": 15240,
|
| 19 |
+
}
|
| 20 |
+
_raw = getattr(hyp_param, "contrastive_learning", None) or {}
|
| 21 |
+
_cfg = {**_defaults, **_raw}
|
| 22 |
+
self.temperature = _cfg["temperature"]
|
| 23 |
+
self.ignore_idx = _cfg["ignore_idx"]
|
| 24 |
+
self.ood_idx = _cfg["ood_idx"]
|
| 25 |
+
self.max_views = _cfg["max_views"]
|
| 26 |
+
self.proj_dim = _cfg["proj_dim"]
|
| 27 |
+
self.sample_limits = _cfg["sample_limits"]
|
| 28 |
+
self.total_limits = _cfg["total_limits"]
|
| 29 |
+
|
| 30 |
+
def select_class_wise_samples(self, embeddings, audio_embeddings, predictions, masks, batch_idx):
|
| 31 |
+
embedding_sample_list = []
|
| 32 |
+
label_list = []
|
| 33 |
+
embedding_sample_list_a = []
|
| 34 |
+
label_list_a = []
|
| 35 |
+
class_index_list = torch.unique(masks)
|
| 36 |
+
|
| 37 |
+
if len(class_index_list) > 1:
|
| 38 |
+
for class_index in class_index_list[1:]:
|
| 39 |
+
embedding_sample_list_a.append(audio_embeddings.unsqueeze(0))
|
| 40 |
+
label_list_a.append(class_index.unsqueeze(0) + batch_idx * 1e3)
|
| 41 |
+
else:
|
| 42 |
+
embedding_sample_list_a.append(audio_embeddings.unsqueeze(0))
|
| 43 |
+
label_list_a.append(torch.zeros([1], device=embeddings.device) + batch_idx * 1e3)
|
| 44 |
+
|
| 45 |
+
sample_limits = self.sample_limits
|
| 46 |
+
embeddings = embeddings.permute(1, 0)
|
| 47 |
+
for class_index in class_index_list:
|
| 48 |
+
hard_indices = embeddings[((masks != predictions) & (masks == class_index)).nonzero()]
|
| 49 |
+
easy_indices = embeddings[((masks == predictions) & (masks == class_index)).nonzero()]
|
| 50 |
+
|
| 51 |
+
hard_indices_num, easy_indices_num = hard_indices.shape[0], easy_indices.shape[0]
|
| 52 |
+
selective_num_hard = min(sample_limits, hard_indices_num)
|
| 53 |
+
selective_num_easy = min(sample_limits, easy_indices_num)
|
| 54 |
+
|
| 55 |
+
if (selective_num_hard + selective_num_easy) < sample_limits * 2:
|
| 56 |
+
if selective_num_hard > selective_num_easy:
|
| 57 |
+
selective_num_hard += sample_limits * 2 - selective_num_easy
|
| 58 |
+
else:
|
| 59 |
+
selective_num_easy += sample_limits * 2 - selective_num_hard
|
| 60 |
+
|
| 61 |
+
hard_chosen_indices = torch.randperm(hard_indices_num)[:selective_num_hard]
|
| 62 |
+
embedding_sample_list.append(hard_indices[hard_chosen_indices])
|
| 63 |
+
label_list.append(masks[hard_chosen_indices] + batch_idx * 1e3)
|
| 64 |
+
|
| 65 |
+
easy_chosen_indices = torch.randperm(easy_indices_num)[:selective_num_easy]
|
| 66 |
+
embedding_sample_list.append(easy_indices[easy_chosen_indices])
|
| 67 |
+
label_list.append(masks[easy_chosen_indices] + batch_idx * 1e3)
|
| 68 |
+
return embedding_sample_list, label_list, embedding_sample_list_a, label_list_a
|
| 69 |
+
|
| 70 |
+
def forward_audio_visual(self, visual_embeddings, audio_embeddings, masks, predictions):
|
| 71 |
+
masks = masks.flatten(start_dim=1)
|
| 72 |
+
predictions = predictions.flatten(start_dim=1)
|
| 73 |
+
visual_embeddings = visual_embeddings.flatten(start_dim=-2)
|
| 74 |
+
|
| 75 |
+
visual_embedding_sample_list = []
|
| 76 |
+
visual_label_list = []
|
| 77 |
+
audio_embedding_sample_list = []
|
| 78 |
+
audio_label_list = []
|
| 79 |
+
|
| 80 |
+
for frame_idx in range(masks.shape[0]):
|
| 81 |
+
current_vision_feats = visual_embeddings[frame_idx]
|
| 82 |
+
current_masks = masks[frame_idx]
|
| 83 |
+
current_predictions = predictions[frame_idx]
|
| 84 |
+
current_audio_feats = audio_embeddings[frame_idx]
|
| 85 |
+
for layer_idx in range(3):
|
| 86 |
+
(
|
| 87 |
+
selected_vision_embeddings,
|
| 88 |
+
selected_vision_labels,
|
| 89 |
+
selected_audio_embeddings,
|
| 90 |
+
selected_audio_labels,
|
| 91 |
+
) = self.select_class_wise_samples(
|
| 92 |
+
current_vision_feats[layer_idx],
|
| 93 |
+
current_audio_feats[layer_idx],
|
| 94 |
+
current_predictions,
|
| 95 |
+
current_masks,
|
| 96 |
+
0,
|
| 97 |
+
)
|
| 98 |
+
visual_embedding_sample_list += selected_vision_embeddings
|
| 99 |
+
visual_label_list += selected_vision_labels
|
| 100 |
+
audio_embedding_sample_list += selected_audio_embeddings
|
| 101 |
+
audio_label_list += selected_audio_labels
|
| 102 |
+
|
| 103 |
+
if len(visual_embedding_sample_list) == 0:
|
| 104 |
+
return 0.0
|
| 105 |
+
|
| 106 |
+
visual_embedding_sample_list = torch.cat(visual_embedding_sample_list, dim=0).squeeze()
|
| 107 |
+
visual_label_list = torch.cat(visual_label_list, dim=0).unsqueeze(-1)
|
| 108 |
+
audio_embedding_sample_list = torch.cat(audio_embedding_sample_list, dim=0).squeeze()
|
| 109 |
+
audio_label_list = torch.cat(audio_label_list).unsqueeze(1)
|
| 110 |
+
|
| 111 |
+
total_limits = self.total_limits
|
| 112 |
+
if visual_embedding_sample_list.shape[0] > total_limits:
|
| 113 |
+
rand_index = torch.randperm(visual_embedding_sample_list.shape[0])[total_limits]
|
| 114 |
+
visual_embedding_sample_list = visual_embedding_sample_list[:rand_index]
|
| 115 |
+
visual_label_list = visual_label_list[:rand_index]
|
| 116 |
+
loss = self.info_nce(
|
| 117 |
+
visual_embedding_sample_list,
|
| 118 |
+
visual_label_list,
|
| 119 |
+
audio_embedding_sample_list,
|
| 120 |
+
audio_label_list,
|
| 121 |
+
)
|
| 122 |
+
return loss
|
| 123 |
+
|
| 124 |
+
def forward(self, embeddings, output_dicts, masks):
|
| 125 |
+
predictions = torch.cat([i["multistep_pred_masks"] for i in output_dicts])
|
| 126 |
+
predictions = torch.nn.functional.interpolate(
|
| 127 |
+
predictions,
|
| 128 |
+
size=(int(self.param.image_size / 16), int(self.param.image_size / 16)),
|
| 129 |
+
mode="bilinear",
|
| 130 |
+
align_corners=False,
|
| 131 |
+
).squeeze()
|
| 132 |
+
masks = torch.nn.functional.interpolate(
|
| 133 |
+
masks.unsqueeze(1),
|
| 134 |
+
size=(int(self.param.image_size / 16), int(self.param.image_size / 16)),
|
| 135 |
+
mode="nearest",
|
| 136 |
+
).squeeze()
|
| 137 |
+
visual_embeddings, audio_embeddings = embeddings
|
| 138 |
+
visual_embeddings = torch.cat(
|
| 139 |
+
[
|
| 140 |
+
torch.cat(
|
| 141 |
+
[
|
| 142 |
+
visual_embeddings[0][i].unsqueeze(0),
|
| 143 |
+
visual_embeddings[1][i].unsqueeze(0),
|
| 144 |
+
visual_embeddings[2][i].unsqueeze(0),
|
| 145 |
+
]
|
| 146 |
+
).unsqueeze(0)
|
| 147 |
+
for i in range(masks.shape[0])
|
| 148 |
+
]
|
| 149 |
+
)
|
| 150 |
+
audio_embeddings = torch.cat(
|
| 151 |
+
[
|
| 152 |
+
torch.cat(
|
| 153 |
+
[
|
| 154 |
+
audio_embeddings[0][i].unsqueeze(0),
|
| 155 |
+
audio_embeddings[1][i].unsqueeze(0),
|
| 156 |
+
audio_embeddings[2][i].unsqueeze(0),
|
| 157 |
+
]
|
| 158 |
+
).unsqueeze(0)
|
| 159 |
+
for i in range(masks.shape[0])
|
| 160 |
+
]
|
| 161 |
+
)
|
| 162 |
+
return self.forward_audio_visual(
|
| 163 |
+
visual_embeddings, audio_embeddings.squeeze(), masks, predictions
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
@staticmethod
|
| 167 |
+
def manipulate_cover_mask(a_label, current_mask):
|
| 168 |
+
a_label = a_label + 1
|
| 169 |
+
visual_mask = torch.matmul(a_label, torch.transpose(a_label, 0, 1))
|
| 170 |
+
current_mask[: visual_mask.shape[1], : visual_mask.shape[0]][visual_mask == 1.0] = 0
|
| 171 |
+
current_mask[: visual_mask.shape[1], : visual_mask.shape[0]][visual_mask == 4.0] = 0
|
| 172 |
+
return current_mask
|
| 173 |
+
|
| 174 |
+
def info_nce(self, anchors_, a_labels_, contras_, c_labels_):
|
| 175 |
+
c_labels_ = torch.cat([a_labels_, c_labels_])
|
| 176 |
+
contras_ = torch.cat([anchors_, contras_])
|
| 177 |
+
mask = torch.eq(a_labels_, torch.transpose(c_labels_, 0, 1)).float()
|
| 178 |
+
|
| 179 |
+
anchor_dot_contrast = torch.div(
|
| 180 |
+
torch.matmul(anchors_, torch.transpose(contras_, 0, 1)),
|
| 181 |
+
self.temperature,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
|
| 185 |
+
logits = anchor_dot_contrast - logits_max.detach()
|
| 186 |
+
neg_mask = 1 - mask
|
| 187 |
+
|
| 188 |
+
mask = self.manipulate_cover_mask(a_label=a_labels_, current_mask=mask)
|
| 189 |
+
mask = mask.fill_diagonal_(0.0)
|
| 190 |
+
|
| 191 |
+
neg_logits = torch.exp(logits) * neg_mask
|
| 192 |
+
neg_logits = neg_logits.sum(1, keepdim=True)
|
| 193 |
+
exp_logits = torch.exp(logits)
|
| 194 |
+
log_prob = logits - torch.log(exp_logits + neg_logits)
|
| 195 |
+
|
| 196 |
+
mask_pos_pairs = mask.sum(1)
|
| 197 |
+
mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)
|
| 198 |
+
mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs
|
| 199 |
+
assert not torch.isnan(mean_log_prob_pos).any(), print(torch.isnan(log_prob).any())
|
| 200 |
+
return -mean_log_prob_pos.mean()
|
| 201 |
+
|
avs.code/v1m.code/loss/training/sam2_training_loss.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
from typing import Dict, List
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.distributed
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
CORE_LOSS_KEY = "core_loss"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def dice_loss(inputs, targets, num_objects, loss_on_multimask=False):
|
| 13 |
+
inputs = inputs.sigmoid()
|
| 14 |
+
if loss_on_multimask:
|
| 15 |
+
assert inputs.dim() == 4 and targets.dim() == 4
|
| 16 |
+
inputs = inputs.flatten(2)
|
| 17 |
+
targets = targets.flatten(2)
|
| 18 |
+
numerator = 2 * (inputs * targets).sum(-1)
|
| 19 |
+
else:
|
| 20 |
+
inputs = inputs.flatten(1)
|
| 21 |
+
numerator = 2 * (inputs * targets).sum(1)
|
| 22 |
+
denominator = inputs.sum(-1) + targets.sum(-1)
|
| 23 |
+
loss = 1 - (numerator + 1) / (denominator + 1)
|
| 24 |
+
if loss_on_multimask:
|
| 25 |
+
return loss / num_objects
|
| 26 |
+
return loss.sum() / num_objects
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def sigmoid_focal_loss(
|
| 30 |
+
inputs,
|
| 31 |
+
targets,
|
| 32 |
+
num_objects,
|
| 33 |
+
alpha: float = 0.25,
|
| 34 |
+
gamma: float = 2,
|
| 35 |
+
loss_on_multimask=False,
|
| 36 |
+
):
|
| 37 |
+
prob = inputs.sigmoid()
|
| 38 |
+
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
| 39 |
+
p_t = prob * targets + (1 - prob) * (1 - targets)
|
| 40 |
+
loss = ce_loss * ((1 - p_t) ** gamma)
|
| 41 |
+
|
| 42 |
+
if alpha >= 0:
|
| 43 |
+
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
| 44 |
+
loss = alpha_t * loss
|
| 45 |
+
|
| 46 |
+
if loss_on_multimask:
|
| 47 |
+
assert loss.dim() == 4
|
| 48 |
+
return loss.flatten(2).mean(-1) / num_objects
|
| 49 |
+
return loss.mean(1).sum() / num_objects
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def iou_loss(
|
| 53 |
+
inputs, targets, pred_ious, num_objects, loss_on_multimask=False, use_l1_loss=False
|
| 54 |
+
):
|
| 55 |
+
assert inputs.dim() == 4 and targets.dim() == 4
|
| 56 |
+
pred_mask = inputs.flatten(2) > 0
|
| 57 |
+
gt_mask = targets.flatten(2) > 0
|
| 58 |
+
area_i = torch.sum(pred_mask & gt_mask, dim=-1).float()
|
| 59 |
+
area_u = torch.sum(pred_mask | gt_mask, dim=-1).float()
|
| 60 |
+
actual_ious = area_i / torch.clamp(area_u, min=1.0)
|
| 61 |
+
|
| 62 |
+
if use_l1_loss:
|
| 63 |
+
loss = F.l1_loss(pred_ious, actual_ious, reduction="none")
|
| 64 |
+
else:
|
| 65 |
+
loss = F.mse_loss(pred_ious, actual_ious, reduction="none")
|
| 66 |
+
if loss_on_multimask:
|
| 67 |
+
return loss / num_objects
|
| 68 |
+
return loss.sum() / num_objects
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class MultiStepMultiMasksAndIous(nn.Module):
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
weight_dict,
|
| 75 |
+
focal_alpha=0.25,
|
| 76 |
+
focal_gamma=2,
|
| 77 |
+
supervise_all_iou=False,
|
| 78 |
+
iou_use_l1_loss=False,
|
| 79 |
+
pred_obj_scores=False,
|
| 80 |
+
focal_gamma_obj_score=0.0,
|
| 81 |
+
focal_alpha_obj_score=-1,
|
| 82 |
+
gpu_num=1,
|
| 83 |
+
):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.weight_dict = weight_dict
|
| 86 |
+
self.focal_alpha = focal_alpha
|
| 87 |
+
self.focal_gamma = focal_gamma
|
| 88 |
+
self.world_size = gpu_num
|
| 89 |
+
assert "loss_mask" in self.weight_dict
|
| 90 |
+
assert "loss_dice" in self.weight_dict
|
| 91 |
+
assert "loss_iou" in self.weight_dict
|
| 92 |
+
if "loss_class" not in self.weight_dict:
|
| 93 |
+
self.weight_dict["loss_class"] = 0.0
|
| 94 |
+
|
| 95 |
+
self.focal_alpha_obj_score = focal_alpha_obj_score
|
| 96 |
+
self.focal_gamma_obj_score = focal_gamma_obj_score
|
| 97 |
+
self.supervise_all_iou = supervise_all_iou
|
| 98 |
+
self.iou_use_l1_loss = iou_use_l1_loss
|
| 99 |
+
self.pred_obj_scores = pred_obj_scores
|
| 100 |
+
|
| 101 |
+
def forward(self, outs_batch: List[Dict], targets_batch: torch.Tensor):
|
| 102 |
+
assert len(outs_batch) == len(targets_batch)
|
| 103 |
+
num_objects = torch.tensor(
|
| 104 |
+
targets_batch.shape[1], device=targets_batch.device, dtype=torch.float
|
| 105 |
+
)
|
| 106 |
+
torch.distributed.all_reduce(num_objects)
|
| 107 |
+
num_objects = torch.clamp(num_objects / self.world_size, min=1).item()
|
| 108 |
+
|
| 109 |
+
losses = defaultdict(int)
|
| 110 |
+
for outs, targets in zip(outs_batch, targets_batch):
|
| 111 |
+
cur_losses = self._forward(outs, targets, num_objects)
|
| 112 |
+
for k, v in cur_losses.items():
|
| 113 |
+
losses[k] += v
|
| 114 |
+
return losses
|
| 115 |
+
|
| 116 |
+
def _forward(self, outputs: Dict, targets: torch.Tensor, num_objects):
|
| 117 |
+
target_masks = targets.unsqueeze(1).float()
|
| 118 |
+
assert target_masks.dim() == 4
|
| 119 |
+
|
| 120 |
+
src_masks_list = outputs["multistep_pred_multimasks_high_res"]
|
| 121 |
+
ious_list = outputs["multistep_pred_ious"]
|
| 122 |
+
object_score_logits_list = outputs["multistep_object_score_logits"]
|
| 123 |
+
assert len(src_masks_list) == len(ious_list)
|
| 124 |
+
assert len(object_score_logits_list) == len(ious_list)
|
| 125 |
+
|
| 126 |
+
losses = {"loss_mask": 0, "loss_dice": 0, "loss_iou": 0, "loss_class": 0}
|
| 127 |
+
for src_masks, ious, object_score_logits in zip(
|
| 128 |
+
src_masks_list, ious_list, object_score_logits_list
|
| 129 |
+
):
|
| 130 |
+
self._update_losses(
|
| 131 |
+
losses, src_masks, target_masks, ious, num_objects, object_score_logits
|
| 132 |
+
)
|
| 133 |
+
losses[CORE_LOSS_KEY] = self.reduce_loss(losses)
|
| 134 |
+
return losses
|
| 135 |
+
|
| 136 |
+
def _update_losses(
|
| 137 |
+
self, losses, src_masks, target_masks, ious, num_objects, object_score_logits
|
| 138 |
+
):
|
| 139 |
+
target_masks = target_masks.expand_as(src_masks)
|
| 140 |
+
loss_multimask = sigmoid_focal_loss(
|
| 141 |
+
src_masks,
|
| 142 |
+
target_masks,
|
| 143 |
+
num_objects,
|
| 144 |
+
alpha=self.focal_alpha,
|
| 145 |
+
gamma=self.focal_gamma,
|
| 146 |
+
loss_on_multimask=True,
|
| 147 |
+
)
|
| 148 |
+
loss_multidice = dice_loss(
|
| 149 |
+
src_masks, target_masks, num_objects, loss_on_multimask=True
|
| 150 |
+
)
|
| 151 |
+
if not self.pred_obj_scores:
|
| 152 |
+
loss_class = torch.tensor(
|
| 153 |
+
0.0, dtype=loss_multimask.dtype, device=loss_multimask.device
|
| 154 |
+
)
|
| 155 |
+
target_obj = torch.ones(
|
| 156 |
+
loss_multimask.shape[0],
|
| 157 |
+
1,
|
| 158 |
+
dtype=loss_multimask.dtype,
|
| 159 |
+
device=loss_multimask.device,
|
| 160 |
+
)
|
| 161 |
+
else:
|
| 162 |
+
target_obj = torch.any((target_masks[:, 0] > 0).flatten(1), dim=-1)[
|
| 163 |
+
..., None
|
| 164 |
+
].float()
|
| 165 |
+
loss_class = sigmoid_focal_loss(
|
| 166 |
+
object_score_logits,
|
| 167 |
+
target_obj,
|
| 168 |
+
num_objects,
|
| 169 |
+
alpha=self.focal_alpha_obj_score,
|
| 170 |
+
gamma=self.focal_gamma_obj_score,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
loss_multiiou = iou_loss(
|
| 174 |
+
src_masks,
|
| 175 |
+
target_masks,
|
| 176 |
+
ious,
|
| 177 |
+
num_objects,
|
| 178 |
+
loss_on_multimask=True,
|
| 179 |
+
use_l1_loss=self.iou_use_l1_loss,
|
| 180 |
+
)
|
| 181 |
+
assert loss_multimask.dim() == 2
|
| 182 |
+
assert loss_multidice.dim() == 2
|
| 183 |
+
assert loss_multiiou.dim() == 2
|
| 184 |
+
if loss_multimask.size(1) > 1:
|
| 185 |
+
loss_combo = (
|
| 186 |
+
loss_multimask * self.weight_dict["loss_mask"]
|
| 187 |
+
+ loss_multidice * self.weight_dict["loss_dice"]
|
| 188 |
+
)
|
| 189 |
+
best_loss_inds = torch.argmin(loss_combo, dim=-1)
|
| 190 |
+
batch_inds = torch.arange(loss_combo.size(0), device=loss_combo.device)
|
| 191 |
+
|
| 192 |
+
loss_mask = loss_multimask[batch_inds, best_loss_inds].unsqueeze(1)
|
| 193 |
+
loss_dice = loss_multidice[batch_inds, best_loss_inds].unsqueeze(1)
|
| 194 |
+
if self.supervise_all_iou:
|
| 195 |
+
loss_iou = loss_multiiou.mean(dim=-1).unsqueeze(1)
|
| 196 |
+
else:
|
| 197 |
+
loss_iou = loss_multiiou[batch_inds, best_loss_inds].unsqueeze(1)
|
| 198 |
+
else:
|
| 199 |
+
loss_mask = loss_multimask
|
| 200 |
+
loss_dice = loss_multidice
|
| 201 |
+
loss_iou = loss_multiiou
|
| 202 |
+
|
| 203 |
+
loss_mask = loss_mask * target_obj
|
| 204 |
+
loss_dice = loss_dice * target_obj
|
| 205 |
+
loss_iou = loss_iou * target_obj
|
| 206 |
+
|
| 207 |
+
losses["loss_mask"] += loss_mask.sum()
|
| 208 |
+
losses["loss_dice"] += loss_dice.sum()
|
| 209 |
+
losses["loss_iou"] += loss_iou.sum()
|
| 210 |
+
losses["loss_class"] += loss_class
|
| 211 |
+
|
| 212 |
+
def reduce_loss(self, losses):
|
| 213 |
+
reduced_loss = 0.0
|
| 214 |
+
for loss_key, weight in self.weight_dict.items():
|
| 215 |
+
if loss_key not in losses:
|
| 216 |
+
raise ValueError(f"{type(self)} doesn't compute {loss_key}")
|
| 217 |
+
if weight != 0:
|
| 218 |
+
reduced_loss += losses[loss_key] * weight
|
| 219 |
+
return reduced_loss
|
| 220 |
+
|
avs.code/v1m.code/main.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DDP training entry: AV model with SAM2 frozen, AuralFuser trainable, Hydra transforms and loss."""
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import numpy
|
| 5 |
+
import random
|
| 6 |
+
import argparse
|
| 7 |
+
from easydict import EasyDict
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def seed_it(seed):
|
| 11 |
+
"""Fix RNGs and cuDNN for reproducible runs (rank offsets seed in DDP)."""
|
| 12 |
+
os.environ["PYTHONSEED"] = str(seed)
|
| 13 |
+
random.seed(seed)
|
| 14 |
+
numpy.random.seed(seed)
|
| 15 |
+
torch.manual_seed(seed)
|
| 16 |
+
torch.cuda.manual_seed(seed)
|
| 17 |
+
torch.cuda.manual_seed_all(seed)
|
| 18 |
+
torch.backends.cudnn.enabled = True
|
| 19 |
+
torch.backends.cudnn.deterministic = True
|
| 20 |
+
|
| 21 |
+
torch.backends.cudnn.benchmark = False
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def main(local_rank, ngpus_per_node, hyp_param):
|
| 25 |
+
hyp_param.local_rank = local_rank
|
| 26 |
+
# NCCL process group; world size = GPUs on this node
|
| 27 |
+
torch.distributed.init_process_group(
|
| 28 |
+
backend='nccl',
|
| 29 |
+
init_method='env://',
|
| 30 |
+
rank=hyp_param.local_rank,
|
| 31 |
+
world_size=hyp_param.gpus * 1
|
| 32 |
+
)
|
| 33 |
+
seed_it(local_rank + hyp_param.seed)
|
| 34 |
+
|
| 35 |
+
torch.cuda.set_device(hyp_param.local_rank)
|
| 36 |
+
|
| 37 |
+
import model.visual.sam2 # noqa: F401 — registers Hydra `configs` (initialize_config_module)
|
| 38 |
+
|
| 39 |
+
from hydra import compose
|
| 40 |
+
from hydra.utils import instantiate
|
| 41 |
+
from omegaconf import OmegaConf
|
| 42 |
+
|
| 43 |
+
# Hydra configs under v1m.code/configs (same pattern as training/sam2_training_config.yaml)
|
| 44 |
+
transform_config_path = 'training/sam2_training_config.yaml'
|
| 45 |
+
|
| 46 |
+
if 'hiera_t' in hyp_param.sam_config_path:
|
| 47 |
+
hyp_param.image_size = 224
|
| 48 |
+
hyp_param.image_embedding_size = int(hyp_param.image_size / 16)
|
| 49 |
+
print('\n upload image size to be {}x{} \n'.format(224, 224), flush=True)
|
| 50 |
+
|
| 51 |
+
cfg = compose(config_name=transform_config_path)
|
| 52 |
+
OmegaConf.resolve(cfg)
|
| 53 |
+
hyp_param.contrastive_learning = OmegaConf.to_container(cfg.contrastive_learning, resolve=True)
|
| 54 |
+
|
| 55 |
+
arch_h = compose(config_name='auralfuser/architecture.yaml')
|
| 56 |
+
OmegaConf.resolve(arch_h)
|
| 57 |
+
hyp_param.aural_fuser = OmegaConf.to_container(arch_h.aural_fuser, resolve=True)
|
| 58 |
+
|
| 59 |
+
from model.mymodel import AVmodel
|
| 60 |
+
av_model = AVmodel(hyp_param).cuda(hyp_param.local_rank)
|
| 61 |
+
|
| 62 |
+
av_model = torch.nn.parallel.distributed.DistributedDataParallel(av_model, device_ids=[hyp_param.local_rank],
|
| 63 |
+
find_unused_parameters=True)
|
| 64 |
+
|
| 65 |
+
# Optimizer: parameter groups from AuralFuser only (train_* vs VGG backbone)
|
| 66 |
+
from utils.utils import manipulate_params
|
| 67 |
+
parameter_list = manipulate_params(hyp_param, av_model.module.aural_fuser)
|
| 68 |
+
optimiser = torch.optim.AdamW(parameter_list, betas=(0.9, 0.999))
|
| 69 |
+
|
| 70 |
+
from dataloader.dataset import AV
|
| 71 |
+
from dataloader.visual.visual_augmentation import Augmentation as VisualAugmentation
|
| 72 |
+
from dataloader.audio.audio_augmentation import Augmentation as AudioAugmentation
|
| 73 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 74 |
+
|
| 75 |
+
compose_api = instantiate(cfg.train_transforms, _recursive_=True)[0]
|
| 76 |
+
|
| 77 |
+
audio_augmentation = AudioAugmentation(mono=True)
|
| 78 |
+
train_dataset = AV(split='train', augmentation={"visual": compose_api, "audio": audio_augmentation},
|
| 79 |
+
param=hyp_param, root_path=hyp_param.data_root_path, data_name=hyp_param.data_name)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
visual_augmentation = VisualAugmentation(hyp_param.image_mean, hyp_param.image_std,
|
| 83 |
+
hyp_param.image_size, hyp_param.image_size,
|
| 84 |
+
hyp_param.scale_list, ignore_index=hyp_param.ignore_index)
|
| 85 |
+
|
| 86 |
+
audio_augmentation = AudioAugmentation(mono=True)
|
| 87 |
+
|
| 88 |
+
random_sampler = DistributedSampler(train_dataset, shuffle=True)
|
| 89 |
+
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=hyp_param.batch_size,
|
| 90 |
+
sampler=random_sampler,
|
| 91 |
+
num_workers=hyp_param.num_workers, drop_last=True)
|
| 92 |
+
|
| 93 |
+
test_dataset = AV(split='test', augmentation={"visual": visual_augmentation, "audio": audio_augmentation},
|
| 94 |
+
param=hyp_param, root_path=hyp_param.data_root_path, data_name=hyp_param.data_name)
|
| 95 |
+
|
| 96 |
+
order_sampler = DistributedSampler(test_dataset, shuffle=False)
|
| 97 |
+
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, sampler=order_sampler,
|
| 98 |
+
num_workers=hyp_param.num_workers)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
criterion = instantiate(cfg.loss, _recursive_=True)['all']
|
| 102 |
+
from utils.tensorboard import Tensorboard
|
| 103 |
+
tensorboard = Tensorboard(config=hyp_param) if hyp_param.local_rank <= 0 else None
|
| 104 |
+
|
| 105 |
+
from trainer.train import Trainer
|
| 106 |
+
from utils.foreground_iou import ForegroundIoU
|
| 107 |
+
from utils.foreground_fscore import ForegroundFScore
|
| 108 |
+
metrics = {"foreground_iou": ForegroundIoU(), "foreground_f-score": ForegroundFScore(0 if hyp_param.local_rank <= 0 else hyp_param.local_rank)}
|
| 109 |
+
|
| 110 |
+
trainer = Trainer(hyp_param, loss=criterion, tensorboard=tensorboard, metrics=metrics)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
curr_best = 0. # checkpoint when IoU (iou_select mode) improves
|
| 114 |
+
|
| 115 |
+
for epoch in range(hyp_param.epochs):
|
| 116 |
+
av_model.train()
|
| 117 |
+
av_model.module.freeze_sam_parameters()
|
| 118 |
+
random_sampler.set_epoch(epoch)
|
| 119 |
+
trainer.train(epoch=epoch, dataloader=train_dataloader, model=av_model, optimiser=optimiser)
|
| 120 |
+
|
| 121 |
+
torch.distributed.barrier()
|
| 122 |
+
torch.cuda.empty_cache()
|
| 123 |
+
|
| 124 |
+
av_model.eval()
|
| 125 |
+
# Three validation modes: default first mask / IoU-selected mask / IoU + objectness gate
|
| 126 |
+
curr_results1, _ = trainer.valid(epoch=epoch, dataloader=test_dataloader, model=av_model, process='first_index')
|
| 127 |
+
curr_results, _ = trainer.valid(epoch=epoch, dataloader=test_dataloader, model=av_model, process='iou_select')
|
| 128 |
+
curr_results3, _ = trainer.valid(epoch=epoch, dataloader=test_dataloader, model=av_model, process='iou_occ_select')
|
| 129 |
+
if hyp_param.local_rank <= 0 and curr_results > curr_best:
|
| 130 |
+
curr_best = curr_results
|
| 131 |
+
torch.save(av_model.module.aural_fuser.state_dict(), os.path.join(hyp_param.saved_dir, str(curr_results) + ".pth"))
|
| 132 |
+
torch.distributed.barrier()
|
| 133 |
+
torch.cuda.empty_cache()
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
if __name__ == '__main__':
|
| 137 |
+
parser = argparse.ArgumentParser(description='PyTorch Training')
|
| 138 |
+
parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N')
|
| 139 |
+
|
| 140 |
+
parser.add_argument("--local_rank", type=int, default=-1,
|
| 141 |
+
help='multi-process training for DDP')
|
| 142 |
+
|
| 143 |
+
parser.add_argument('-g', '--gpus', default=1, type=int,
|
| 144 |
+
help='number of gpus per node')
|
| 145 |
+
|
| 146 |
+
parser.add_argument('--batch_size', default=1, type=int)
|
| 147 |
+
|
| 148 |
+
parser.add_argument('--epochs', default=80, type=int,
|
| 149 |
+
help="total epochs that used for the training")
|
| 150 |
+
|
| 151 |
+
parser.add_argument('--lr', default=1e-4, type=float,
|
| 152 |
+
help='Default HEAD Learning rate is same as others, '
|
| 153 |
+
'*Note: in ddp training, lr will automatically times by n_gpu')
|
| 154 |
+
|
| 155 |
+
parser.add_argument('--online', action="store_true",
|
| 156 |
+
help='switch on for visualization; switch off for debug')
|
| 157 |
+
|
| 158 |
+
args = parser.parse_args()
|
| 159 |
+
|
| 160 |
+
from configs.config import C
|
| 161 |
+
|
| 162 |
+
args = EasyDict({**C, **vars(args)})
|
| 163 |
+
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
| 164 |
+
os.environ['MASTER_PORT'] = '9902'
|
| 165 |
+
|
| 166 |
+
torch.multiprocessing.spawn(main, nprocs=args.gpus, args=(args.gpus, args))
|
avs.code/v1m.code/model/audio/torchvggish/mel_features.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Defines routines to compute mel spectrogram features from audio waveform."""
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def frame(data, window_length, hop_length):
|
| 22 |
+
"""Convert array into a sequence of successive possibly overlapping frames.
|
| 23 |
+
|
| 24 |
+
An n-dimensional array of shape (num_samples, ...) is converted into an
|
| 25 |
+
(n+1)-D array of shape (num_frames, window_length, ...), where each frame
|
| 26 |
+
starts hop_length points after the preceding one.
|
| 27 |
+
|
| 28 |
+
This is accomplished using stride_tricks, so the original data is not
|
| 29 |
+
copied. However, there is no zero-padding, so any incomplete frames at the
|
| 30 |
+
end are not included.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
data: np.array of dimension N >= 1.
|
| 34 |
+
window_length: Number of samples in each frame.
|
| 35 |
+
hop_length: Advance (in samples) between each window.
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
(N+1)-D np.array with as many rows as there are complete frames that can be
|
| 39 |
+
extracted.
|
| 40 |
+
"""
|
| 41 |
+
num_samples = data.shape[0]
|
| 42 |
+
num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length))
|
| 43 |
+
shape = (num_frames, window_length) + data.shape[1:]
|
| 44 |
+
strides = (data.strides[0] * hop_length,) + data.strides
|
| 45 |
+
return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def periodic_hann(window_length):
|
| 49 |
+
"""Calculate a "periodic" Hann window.
|
| 50 |
+
|
| 51 |
+
The classic Hann window is defined as a raised cosine that starts and
|
| 52 |
+
ends on zero, and where every value appears twice, except the middle
|
| 53 |
+
point for an odd-length window. Matlab calls this a "symmetric" window
|
| 54 |
+
and np.hanning() returns it. However, for Fourier analysis, this
|
| 55 |
+
actually represents just over one cycle of a period N-1 cosine, and
|
| 56 |
+
thus is not compactly expressed on a length-N Fourier basis. Instead,
|
| 57 |
+
it's better to use a raised cosine that ends just before the final
|
| 58 |
+
zero value - i.e. a complete cycle of a period-N cosine. Matlab
|
| 59 |
+
calls this a "periodic" window. This routine calculates it.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
window_length: The number of points in the returned window.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
A 1D np.array containing the periodic hann window.
|
| 66 |
+
"""
|
| 67 |
+
return 0.5 - (0.5 * np.cos(2 * np.pi / window_length *
|
| 68 |
+
np.arange(window_length)))
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def stft_magnitude(signal, fft_length,
|
| 72 |
+
hop_length=None,
|
| 73 |
+
window_length=None):
|
| 74 |
+
"""Calculate the short-time Fourier transform magnitude.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
signal: 1D np.array of the input time-domain signal.
|
| 78 |
+
fft_length: Size of the FFT to apply.
|
| 79 |
+
hop_length: Advance (in samples) between each frame passed to FFT.
|
| 80 |
+
window_length: Length of each block of samples to pass to FFT.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
2D np.array where each row contains the magnitudes of the fft_length/2+1
|
| 84 |
+
unique values of the FFT for the corresponding frame of input samples.
|
| 85 |
+
"""
|
| 86 |
+
frames = frame(signal, window_length, hop_length)
|
| 87 |
+
# Apply frame window to each frame. We use a periodic Hann (cosine of period
|
| 88 |
+
# window_length) instead of the symmetric Hann of np.hanning (period
|
| 89 |
+
# window_length-1).
|
| 90 |
+
window = periodic_hann(window_length)
|
| 91 |
+
windowed_frames = frames * window
|
| 92 |
+
return np.abs(np.fft.rfft(windowed_frames, int(fft_length)))
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# Mel spectrum constants and functions.
|
| 96 |
+
_MEL_BREAK_FREQUENCY_HERTZ = 700.0
|
| 97 |
+
_MEL_HIGH_FREQUENCY_Q = 1127.0
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def hertz_to_mel(frequencies_hertz):
|
| 101 |
+
"""Convert frequencies to mel scale using HTK formula.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
frequencies_hertz: Scalar or np.array of frequencies in hertz.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
Object of same size as frequencies_hertz containing corresponding values
|
| 108 |
+
on the mel scale.
|
| 109 |
+
"""
|
| 110 |
+
return _MEL_HIGH_FREQUENCY_Q * np.log(
|
| 111 |
+
1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def spectrogram_to_mel_matrix(num_mel_bins=20,
|
| 115 |
+
num_spectrogram_bins=129,
|
| 116 |
+
audio_sample_rate=8000,
|
| 117 |
+
lower_edge_hertz=125.0,
|
| 118 |
+
upper_edge_hertz=3800.0):
|
| 119 |
+
"""Return a matrix that can post-multiply spectrogram rows to make mel.
|
| 120 |
+
|
| 121 |
+
Returns a np.array matrix A that can be used to post-multiply a matrix S of
|
| 122 |
+
spectrogram values (STFT magnitudes) arranged as frames x bins to generate a
|
| 123 |
+
"mel spectrogram" M of frames x num_mel_bins. M = S A.
|
| 124 |
+
|
| 125 |
+
The classic HTK algorithm exploits the complementarity of adjacent mel bands
|
| 126 |
+
to multiply each FFT bin by only one mel weight, then add it, with positive
|
| 127 |
+
and negative signs, to the two adjacent mel bands to which that bin
|
| 128 |
+
contributes. Here, by expressing this operation as a matrix multiply, we go
|
| 129 |
+
from num_fft multiplies per frame (plus around 2*num_fft adds) to around
|
| 130 |
+
num_fft^2 multiplies and adds. However, because these are all presumably
|
| 131 |
+
accomplished in a single call to np.dot(), it's not clear which approach is
|
| 132 |
+
faster in Python. The matrix multiplication has the attraction of being more
|
| 133 |
+
general and flexible, and much easier to read.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
num_mel_bins: How many bands in the resulting mel spectrum. This is
|
| 137 |
+
the number of columns in the output matrix.
|
| 138 |
+
num_spectrogram_bins: How many bins there are in the source spectrogram
|
| 139 |
+
data, which is understood to be fft_size/2 + 1, i.e. the spectrogram
|
| 140 |
+
only contains the nonredundant FFT bins.
|
| 141 |
+
audio_sample_rate: Samples per second of the audio at the input to the
|
| 142 |
+
spectrogram. We need this to figure out the actual frequencies for
|
| 143 |
+
each spectrogram bin, which dictates how they are mapped into mel.
|
| 144 |
+
lower_edge_hertz: Lower bound on the frequencies to be included in the mel
|
| 145 |
+
spectrum. This corresponds to the lower edge of the lowest triangular
|
| 146 |
+
band.
|
| 147 |
+
upper_edge_hertz: The desired top edge of the highest frequency band.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
An np.array with shape (num_spectrogram_bins, num_mel_bins).
|
| 151 |
+
|
| 152 |
+
Raises:
|
| 153 |
+
ValueError: if frequency edges are incorrectly ordered or out of range.
|
| 154 |
+
"""
|
| 155 |
+
nyquist_hertz = audio_sample_rate / 2.
|
| 156 |
+
if lower_edge_hertz < 0.0:
|
| 157 |
+
raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz)
|
| 158 |
+
if lower_edge_hertz >= upper_edge_hertz:
|
| 159 |
+
raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" %
|
| 160 |
+
(lower_edge_hertz, upper_edge_hertz))
|
| 161 |
+
if upper_edge_hertz > nyquist_hertz:
|
| 162 |
+
raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" %
|
| 163 |
+
(upper_edge_hertz, nyquist_hertz))
|
| 164 |
+
spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins)
|
| 165 |
+
spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz)
|
| 166 |
+
# The i'th mel band (starting from i=1) has center frequency
|
| 167 |
+
# band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge
|
| 168 |
+
# band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in
|
| 169 |
+
# the band_edges_mel arrays.
|
| 170 |
+
band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz),
|
| 171 |
+
hertz_to_mel(upper_edge_hertz), num_mel_bins + 2)
|
| 172 |
+
# Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins
|
| 173 |
+
# of spectrogram values.
|
| 174 |
+
mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins))
|
| 175 |
+
for i in range(num_mel_bins):
|
| 176 |
+
lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3]
|
| 177 |
+
# Calculate lower and upper slopes for every spectrogram bin.
|
| 178 |
+
# Line segments are linear in the *mel* domain, not hertz.
|
| 179 |
+
lower_slope = ((spectrogram_bins_mel - lower_edge_mel) /
|
| 180 |
+
(center_mel - lower_edge_mel))
|
| 181 |
+
upper_slope = ((upper_edge_mel - spectrogram_bins_mel) /
|
| 182 |
+
(upper_edge_mel - center_mel))
|
| 183 |
+
# .. then intersect them with each other and zero.
|
| 184 |
+
mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope,
|
| 185 |
+
upper_slope))
|
| 186 |
+
# HTK excludes the spectrogram DC bin; make sure it always gets a zero
|
| 187 |
+
# coefficient.
|
| 188 |
+
mel_weights_matrix[0, :] = 0.0
|
| 189 |
+
return mel_weights_matrix
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def log_mel_spectrogram(data,
|
| 193 |
+
audio_sample_rate=8000,
|
| 194 |
+
log_offset=0.0,
|
| 195 |
+
window_length_secs=0.025,
|
| 196 |
+
hop_length_secs=0.010,
|
| 197 |
+
**kwargs):
|
| 198 |
+
"""Convert waveform to a log magnitude mel-frequency spectrogram.
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
data: 1D np.array of waveform data.
|
| 202 |
+
audio_sample_rate: The sampling rate of data.
|
| 203 |
+
log_offset: Add this to values when taking log to avoid -Infs.
|
| 204 |
+
window_length_secs: Duration of each window to analyze.
|
| 205 |
+
hop_length_secs: Advance between successive analysis windows.
|
| 206 |
+
**kwargs: Additional arguments to pass to spectrogram_to_mel_matrix.
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank
|
| 210 |
+
magnitudes for successive frames.
|
| 211 |
+
"""
|
| 212 |
+
window_length_samples = int(round(audio_sample_rate * window_length_secs))
|
| 213 |
+
hop_length_samples = int(round(audio_sample_rate * hop_length_secs))
|
| 214 |
+
fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0)))
|
| 215 |
+
spectrogram = stft_magnitude(
|
| 216 |
+
data,
|
| 217 |
+
fft_length=fft_length,
|
| 218 |
+
hop_length=hop_length_samples,
|
| 219 |
+
window_length=window_length_samples)
|
| 220 |
+
mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix(
|
| 221 |
+
num_spectrogram_bins=spectrogram.shape[1],
|
| 222 |
+
audio_sample_rate=audio_sample_rate, **kwargs))
|
| 223 |
+
return np.log(mel_spectrogram + log_offset)
|
avs.code/v1m.code/model/audio/torchvggish/vggish.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch import hub
|
| 5 |
+
|
| 6 |
+
from . import vggish_input, vggish_params
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class VGG(nn.Module):
|
| 10 |
+
def __init__(self, features):
|
| 11 |
+
super(VGG, self).__init__()
|
| 12 |
+
self.features = features
|
| 13 |
+
self.embeddings = nn.Sequential(
|
| 14 |
+
nn.Linear(512 * 4 * 6, 4096),
|
| 15 |
+
nn.ReLU(True),
|
| 16 |
+
nn.Linear(4096, 4096),
|
| 17 |
+
nn.ReLU(True),
|
| 18 |
+
nn.Linear(4096, 128),
|
| 19 |
+
nn.ReLU(True))
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
x = self.features(x)
|
| 23 |
+
|
| 24 |
+
# Transpose the output from features to
|
| 25 |
+
# remain compatible with vggish embeddings
|
| 26 |
+
x = torch.transpose(x, 1, 3)
|
| 27 |
+
x = torch.transpose(x, 1, 2)
|
| 28 |
+
x = x.contiguous()
|
| 29 |
+
x = x.view(x.size(0), -1)
|
| 30 |
+
|
| 31 |
+
return self.embeddings(x)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Postprocessor(nn.Module):
|
| 35 |
+
"""Post-processes VGGish embeddings. Returns a torch.Tensor instead of a
|
| 36 |
+
numpy array in order to preserve the gradient.
|
| 37 |
+
|
| 38 |
+
"The initial release of AudioSet included 128-D VGGish embeddings for each
|
| 39 |
+
segment of AudioSet. These released embeddings were produced by applying
|
| 40 |
+
a PCA transformation (technically, a whitening transform is included as well)
|
| 41 |
+
and 8-bit quantization to the raw embedding output from VGGish, in order to
|
| 42 |
+
stay compatible with the YouTube-8M project which provides visual embeddings
|
| 43 |
+
in the same format for a large set of YouTube videos. This class implements
|
| 44 |
+
the same PCA (with whitening) and quantization transformations."
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(self):
|
| 48 |
+
"""Constructs a postprocessor."""
|
| 49 |
+
super(Postprocessor, self).__init__()
|
| 50 |
+
# Create empty matrix, for user's state_dict to load
|
| 51 |
+
self.pca_eigen_vectors = torch.empty(
|
| 52 |
+
(vggish_params.EMBEDDING_SIZE, vggish_params.EMBEDDING_SIZE,),
|
| 53 |
+
dtype=torch.float,
|
| 54 |
+
)
|
| 55 |
+
self.pca_means = torch.empty(
|
| 56 |
+
(vggish_params.EMBEDDING_SIZE, 1), dtype=torch.float
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
self.pca_eigen_vectors = nn.Parameter(self.pca_eigen_vectors, requires_grad=False)
|
| 60 |
+
self.pca_means = nn.Parameter(self.pca_means, requires_grad=False)
|
| 61 |
+
|
| 62 |
+
def postprocess(self, embeddings_batch):
|
| 63 |
+
"""Applies tensor postprocessing to a batch of embeddings.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
embeddings_batch: An tensor of shape [batch_size, embedding_size]
|
| 67 |
+
containing output from the embedding layer of VGGish.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
A tensor of the same shape as the input, containing the PCA-transformed,
|
| 71 |
+
quantized, and clipped version of the input.
|
| 72 |
+
"""
|
| 73 |
+
assert len(embeddings_batch.shape) == 2, "Expected 2-d batch, got %r" % (
|
| 74 |
+
embeddings_batch.shape,
|
| 75 |
+
)
|
| 76 |
+
assert (
|
| 77 |
+
embeddings_batch.shape[1] == vggish_params.EMBEDDING_SIZE
|
| 78 |
+
), "Bad batch shape: %r" % (embeddings_batch.shape,)
|
| 79 |
+
|
| 80 |
+
# Apply PCA.
|
| 81 |
+
# - Embeddings come in as [batch_size, embedding_size].
|
| 82 |
+
# - Transpose to [embedding_size, batch_size].
|
| 83 |
+
# - Subtract pca_means column vector from each column.
|
| 84 |
+
# - Premultiply by PCA matrix of shape [output_dims, input_dims]
|
| 85 |
+
# where both are are equal to embedding_size in our case.
|
| 86 |
+
# - Transpose result back to [batch_size, embedding_size].
|
| 87 |
+
pca_applied = torch.mm(self.pca_eigen_vectors, (embeddings_batch.t() - self.pca_means)).t()
|
| 88 |
+
|
| 89 |
+
# Quantize by:
|
| 90 |
+
# - clipping to [min, max] range
|
| 91 |
+
clipped_embeddings = torch.clamp(
|
| 92 |
+
pca_applied, vggish_params.QUANTIZE_MIN_VAL, vggish_params.QUANTIZE_MAX_VAL
|
| 93 |
+
)
|
| 94 |
+
# - convert to 8-bit in range [0.0, 255.0]
|
| 95 |
+
quantized_embeddings = torch.round(
|
| 96 |
+
(clipped_embeddings - vggish_params.QUANTIZE_MIN_VAL)
|
| 97 |
+
* (
|
| 98 |
+
255.0
|
| 99 |
+
/ (vggish_params.QUANTIZE_MAX_VAL - vggish_params.QUANTIZE_MIN_VAL)
|
| 100 |
+
)
|
| 101 |
+
)
|
| 102 |
+
return torch.squeeze(quantized_embeddings)
|
| 103 |
+
|
| 104 |
+
def forward(self, x):
|
| 105 |
+
return self.postprocess(x)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def make_layers():
|
| 109 |
+
layers = []
|
| 110 |
+
in_channels = 1
|
| 111 |
+
for v in [64, "M", 128, "M", 256, 256, "M", 512, 512, "M"]:
|
| 112 |
+
if v == "M":
|
| 113 |
+
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
| 114 |
+
else:
|
| 115 |
+
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
|
| 116 |
+
layers += [conv2d, nn.ReLU(inplace=True)]
|
| 117 |
+
in_channels = v
|
| 118 |
+
return nn.Sequential(*layers)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def _vgg():
|
| 122 |
+
return VGG(make_layers())
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# def _spectrogram():
|
| 126 |
+
# config = dict(
|
| 127 |
+
# sr=16000,
|
| 128 |
+
# n_fft=400,
|
| 129 |
+
# n_mels=64,
|
| 130 |
+
# hop_length=160,
|
| 131 |
+
# window="hann",
|
| 132 |
+
# center=False,
|
| 133 |
+
# pad_mode="reflect",
|
| 134 |
+
# htk=True,
|
| 135 |
+
# fmin=125,
|
| 136 |
+
# fmax=7500,
|
| 137 |
+
# output_format='Magnitude',
|
| 138 |
+
# # device=device,
|
| 139 |
+
# )
|
| 140 |
+
# return Spectrogram.MelSpectrogram(**config)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class VGGish(VGG):
|
| 144 |
+
def __init__(self, cfg, device=None):
|
| 145 |
+
super().__init__(make_layers())
|
| 146 |
+
if cfg.FREEZE_AUDIO_EXTRACTOR:
|
| 147 |
+
state_dict = torch.load(cfg.PRETRAINED_VGGISH_MODEL_PATH)
|
| 148 |
+
super().load_state_dict(state_dict)
|
| 149 |
+
print(f'==> Load pretrained VGGish parameters from {cfg.PRETRAINED_VGGISH_MODEL_PATH}')
|
| 150 |
+
|
| 151 |
+
if device is None:
|
| 152 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 153 |
+
print("device: ", device)
|
| 154 |
+
self.device = device
|
| 155 |
+
|
| 156 |
+
self.preprocess = cfg.PREPROCESS_AUDIO_TO_LOG_MEL
|
| 157 |
+
self.postprocess = cfg.POSTPROCESS_LOG_MEL_WITH_PCA
|
| 158 |
+
if self.postprocess:
|
| 159 |
+
self.pproc = Postprocessor()
|
| 160 |
+
if cfg.FREEZE_AUDIO_EXTRACTOR:
|
| 161 |
+
state_dict = torch.load(cfg.PRETRAINED_PCA_PARAMS_PATH)
|
| 162 |
+
# TODO: Convert the state_dict to torch
|
| 163 |
+
state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME] = torch.as_tensor(
|
| 164 |
+
state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME], dtype=torch.float
|
| 165 |
+
)
|
| 166 |
+
state_dict[vggish_params.PCA_MEANS_NAME] = torch.as_tensor(
|
| 167 |
+
state_dict[vggish_params.PCA_MEANS_NAME].reshape(-1, 1), dtype=torch.float
|
| 168 |
+
)
|
| 169 |
+
self.pproc.load_state_dict(state_dict)
|
| 170 |
+
self.to(self.device)
|
| 171 |
+
|
| 172 |
+
def forward(self, x):
|
| 173 |
+
if self.preprocess:
|
| 174 |
+
print(">>> pre processing...")
|
| 175 |
+
x = self._preprocess(x)
|
| 176 |
+
x = x.to(self.device)
|
| 177 |
+
x = VGG.forward(self, x)
|
| 178 |
+
if self.postprocess:
|
| 179 |
+
print(">>> post processing...")
|
| 180 |
+
x = self._postprocess(x)
|
| 181 |
+
return x
|
| 182 |
+
|
| 183 |
+
def _preprocess(self, x):
|
| 184 |
+
# if isinstance(x, np.ndarray):
|
| 185 |
+
# x = vggish_input.waveform_to_examples(x, fs)
|
| 186 |
+
if isinstance(x, str):
|
| 187 |
+
x = vggish_input.wavfile_to_examples(x)
|
| 188 |
+
else:
|
| 189 |
+
raise AttributeError
|
| 190 |
+
return x
|
| 191 |
+
|
| 192 |
+
def _postprocess(self, x):
|
| 193 |
+
return self.pproc(x)
|
avs.code/v1m.code/model/audio/torchvggish/vggish_input.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Compute input examples for VGGish from audio waveform."""
|
| 17 |
+
|
| 18 |
+
# Modification: Return torch tensors rather than numpy arrays
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import resampy
|
| 23 |
+
|
| 24 |
+
from . import mel_features
|
| 25 |
+
from . import vggish_params
|
| 26 |
+
|
| 27 |
+
import soundfile as sf
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def waveform_to_examples(data, sample_rate, return_tensor=True):
|
| 31 |
+
"""Converts audio waveform into an array of examples for VGGish.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
data: np.array of either one dimension (mono) or two dimensions
|
| 35 |
+
(multi-channel, with the outer dimension representing channels).
|
| 36 |
+
Each sample is generally expected to lie in the range [-1.0, +1.0],
|
| 37 |
+
although this is not required.
|
| 38 |
+
sample_rate: Sample rate of data.
|
| 39 |
+
return_tensor: Return data as a Pytorch tensor ready for VGGish
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
3-D np.array of shape [num_examples, num_frames, num_bands] which represents
|
| 43 |
+
a sequence of examples, each of which contains a patch of log mel
|
| 44 |
+
spectrogram, covering num_frames frames of audio and num_bands mel frequency
|
| 45 |
+
bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS.
|
| 46 |
+
|
| 47 |
+
"""
|
| 48 |
+
# Convert to mono.
|
| 49 |
+
if len(data.shape) > 1:
|
| 50 |
+
data = np.mean(data, axis=1)
|
| 51 |
+
# Resample to the rate assumed by VGGish.
|
| 52 |
+
if sample_rate != vggish_params.SAMPLE_RATE:
|
| 53 |
+
data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE)
|
| 54 |
+
|
| 55 |
+
# Compute log mel spectrogram features.
|
| 56 |
+
log_mel = mel_features.log_mel_spectrogram(
|
| 57 |
+
data,
|
| 58 |
+
audio_sample_rate=vggish_params.SAMPLE_RATE,
|
| 59 |
+
log_offset=vggish_params.LOG_OFFSET,
|
| 60 |
+
window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS,
|
| 61 |
+
hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS,
|
| 62 |
+
num_mel_bins=vggish_params.NUM_MEL_BINS,
|
| 63 |
+
lower_edge_hertz=vggish_params.MEL_MIN_HZ,
|
| 64 |
+
upper_edge_hertz=vggish_params.MEL_MAX_HZ)
|
| 65 |
+
|
| 66 |
+
# Frame features into examples.
|
| 67 |
+
features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS
|
| 68 |
+
example_window_length = int(round(
|
| 69 |
+
vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate))
|
| 70 |
+
example_hop_length = int(round(
|
| 71 |
+
vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate))
|
| 72 |
+
log_mel_examples = mel_features.frame(
|
| 73 |
+
log_mel,
|
| 74 |
+
window_length=example_window_length,
|
| 75 |
+
hop_length=example_hop_length)
|
| 76 |
+
|
| 77 |
+
if return_tensor:
|
| 78 |
+
log_mel_examples = torch.tensor(
|
| 79 |
+
log_mel_examples, requires_grad=True)[:, None, :, :].float()
|
| 80 |
+
|
| 81 |
+
return log_mel_examples
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def wavfile_to_examples(wav_file, return_tensor=True):
|
| 85 |
+
"""Convenience wrapper around waveform_to_examples() for a common WAV format.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
wav_file: String path to a file, or a file-like object. The file
|
| 89 |
+
is assumed to contain WAV audio data with signed 16-bit PCM samples.
|
| 90 |
+
torch: Return data as a Pytorch tensor ready for VGGish
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
See waveform_to_examples.
|
| 94 |
+
"""
|
| 95 |
+
wav_data, sr = sf.read(wav_file, dtype='int16')
|
| 96 |
+
assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype
|
| 97 |
+
samples = wav_data / 32768.0 # Convert to [-1.0, +1.0]
|
| 98 |
+
return waveform_to_examples(samples, sr, return_tensor)
|
avs.code/v1m.code/model/audio/torchvggish/vggish_params.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Global parameters for the VGGish model.
|
| 17 |
+
|
| 18 |
+
See vggish_slim.py for more information.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
# Architectural constants.
|
| 22 |
+
NUM_FRAMES = 96 # Frames in input mel-spectrogram patch.
|
| 23 |
+
NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch.
|
| 24 |
+
EMBEDDING_SIZE = 128 # Size of embedding layer.
|
| 25 |
+
|
| 26 |
+
# Hyperparameters used in feature and example generation.
|
| 27 |
+
SAMPLE_RATE = 16000
|
| 28 |
+
STFT_WINDOW_LENGTH_SECONDS = 0.025
|
| 29 |
+
STFT_HOP_LENGTH_SECONDS = 0.010
|
| 30 |
+
NUM_MEL_BINS = NUM_BANDS
|
| 31 |
+
MEL_MIN_HZ = 125
|
| 32 |
+
MEL_MAX_HZ = 7500
|
| 33 |
+
LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram.
|
| 34 |
+
EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames
|
| 35 |
+
EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap.
|
| 36 |
+
|
| 37 |
+
# Parameters used for embedding postprocessing.
|
| 38 |
+
PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors'
|
| 39 |
+
PCA_MEANS_NAME = 'pca_means'
|
| 40 |
+
QUANTIZE_MIN_VAL = -2.0
|
| 41 |
+
QUANTIZE_MAX_VAL = +2.0
|
| 42 |
+
|
| 43 |
+
# Hyperparameters used in training.
|
| 44 |
+
INIT_STDDEV = 0.01 # Standard deviation used to initialize weights.
|
| 45 |
+
LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer.
|
| 46 |
+
ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer.
|
| 47 |
+
|
| 48 |
+
# Names of ops, tensors, and features.
|
| 49 |
+
INPUT_OP_NAME = 'vggish/input_features'
|
| 50 |
+
INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0'
|
| 51 |
+
OUTPUT_OP_NAME = 'vggish/embedding'
|
| 52 |
+
OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0'
|
| 53 |
+
AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding'
|
avs.code/v1m.code/model/aural_fuser.py
ADDED
|
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from model.audio.torchvggish import vggish
|
| 6 |
+
from timm.models.layers import DropPath, trunc_normal_
|
| 7 |
+
|
| 8 |
+
from model.visual.sam2.modeling.position_encoding import PositionEmbeddingSine
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ProjectionHead(nn.Module):
|
| 12 |
+
def __init__(self, dim_in, proj_dim=256, norm_act=nn.BatchNorm2d, conv_layer=nn.Conv2d):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.proj = nn.Sequential(
|
| 15 |
+
conv_layer(dim_in, proj_dim, kernel_size=1),
|
| 16 |
+
norm_act(proj_dim),
|
| 17 |
+
conv_layer(proj_dim, proj_dim, kernel_size=1),
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
return torch.nn.functional.normalize(self.proj(x), p=2, dim=1)
|
| 22 |
+
|
| 23 |
+
class AuralFuser(torch.nn.Module):
|
| 24 |
+
"""Fuses VGGish audio with SAM2 FPN maps via patch embeds, fusion blocks, and projection heads."""
|
| 25 |
+
|
| 26 |
+
def __init__(self, hyp_param):
|
| 27 |
+
self.hyp_param = hyp_param
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.vgg = vggish.VGGish(self.hyp_param.audio)
|
| 30 |
+
if not getattr(self.hyp_param, "train_vggish", False):
|
| 31 |
+
for p in self.vgg.parameters():
|
| 32 |
+
p.requires_grad = False
|
| 33 |
+
|
| 34 |
+
self.position_encoding_func = PositionEmbeddingSine(num_pos_feats=256, normalize=True, scale=None,
|
| 35 |
+
temperature=10000)
|
| 36 |
+
|
| 37 |
+
# Populated in main.py / inference.py via Hydra compose('auralfuser/architecture.yaml') → hyp_param.aural_fuser
|
| 38 |
+
if not hasattr(self.hyp_param, "aural_fuser") or self.hyp_param.aural_fuser is None:
|
| 39 |
+
raise ValueError(
|
| 40 |
+
"hyp_param.aural_fuser is missing; load it with Hydra compose before constructing AuralFuser."
|
| 41 |
+
)
|
| 42 |
+
arch_cfg = self.hyp_param.aural_fuser
|
| 43 |
+
|
| 44 |
+
_patch_cfgs = [tuple(i) for i in arch_cfg["patch_cfgs"]]
|
| 45 |
+
_f_depths = arch_cfg["f_depths"]
|
| 46 |
+
_block_kw = dict(arch_cfg["block_kw"])
|
| 47 |
+
_block_kw["norm_layer"] = nn.LayerNorm
|
| 48 |
+
_one_d_kw = dict(arch_cfg["one_d_kw"])
|
| 49 |
+
_one_d_kw["norm_layer"] = nn.LayerNorm
|
| 50 |
+
self.patch_embeds = nn.ModuleList(
|
| 51 |
+
nn.Conv2d(256, 256, kernel_size=k, stride=s) for k, s in _patch_cfgs
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
self.f_blocks = nn.ModuleList(
|
| 55 |
+
nn.ModuleList([Block(**_block_kw) for _ in range(n)]) for n in _f_depths
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
self.a_blocks = nn.ModuleList(
|
| 59 |
+
nn.ModuleList([OneDBlock(**_one_d_kw) for _ in range(3)]) for _ in range(3)
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
self.fusion_modules = nn.ModuleList(
|
| 63 |
+
AudioVisualFusionModule(in_channels=256, mode='dot') for _ in range(3)
|
| 64 |
+
)
|
| 65 |
+
self.smooth_convs = nn.ModuleList(
|
| 66 |
+
nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0) for _ in range(2)
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
self.train_proj_v1 = ProjectionHead(dim_in=256, proj_dim=128)
|
| 70 |
+
|
| 71 |
+
self.train_proj_a1 = ProjectionHead(dim_in=256, norm_act=nn.BatchNorm1d, conv_layer=nn.Conv1d, proj_dim=128)
|
| 72 |
+
|
| 73 |
+
@staticmethod
|
| 74 |
+
def positionalencoding1d(d_model, length):
|
| 75 |
+
if d_model % 2 != 0:
|
| 76 |
+
raise ValueError("Cannot use sin/cos positional encoding with "
|
| 77 |
+
"odd dim (got dim={:d})".format(d_model))
|
| 78 |
+
pe = torch.zeros(length, d_model)
|
| 79 |
+
position = torch.arange(0, length).unsqueeze(1)
|
| 80 |
+
div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
|
| 81 |
+
-(math.log(10000.0) / d_model)))
|
| 82 |
+
pe[:, 0::2] = torch.sin(position.float() * div_term)
|
| 83 |
+
pe[:, 1::2] = torch.cos(position.float() * div_term)
|
| 84 |
+
|
| 85 |
+
return pe
|
| 86 |
+
|
| 87 |
+
def forward(self, feature_dicts, spect=None):
|
| 88 |
+
image_embed_shape = [self.hyp_param.image_embedding_size] * 2
|
| 89 |
+
H, W = image_embed_shape[0], image_embed_shape[1]
|
| 90 |
+
d = torch.cat(
|
| 91 |
+
[
|
| 92 |
+
self.vgg(spect[:, 0, ...].unsqueeze(1)),
|
| 93 |
+
self.vgg(spect[:, 1, ...].unsqueeze(1)),
|
| 94 |
+
],
|
| 95 |
+
dim=-1,
|
| 96 |
+
)
|
| 97 |
+
length = d.shape[-1]
|
| 98 |
+
fix_audio_pos = self.positionalencoding1d(length, 1).squeeze().to(spect.device)
|
| 99 |
+
fpn = list(feature_dicts["backbone_fpn"])
|
| 100 |
+
patch_embeds = list(self.patch_embeds)
|
| 101 |
+
f_blocks = list(self.f_blocks)
|
| 102 |
+
a_blocks = list(self.a_blocks)
|
| 103 |
+
tpavi = list(self.fusion_modules)
|
| 104 |
+
smooths = [None, self.smooth_convs[0], self.smooth_convs[1]]
|
| 105 |
+
|
| 106 |
+
feats = [None, None, None]
|
| 107 |
+
d_outputs = []
|
| 108 |
+
|
| 109 |
+
for i in range(3):
|
| 110 |
+
x = fpn[i]
|
| 111 |
+
x = patch_embeds[i](x)
|
| 112 |
+
x_pos = self.position_encoding_func(x)
|
| 113 |
+
x = x.flatten(2).permute(0, 2, 1)
|
| 114 |
+
x_pos = x_pos.flatten(2).permute(0, 2, 1)
|
| 115 |
+
|
| 116 |
+
if i == 0:
|
| 117 |
+
x = x + x_pos
|
| 118 |
+
d = d + fix_audio_pos
|
| 119 |
+
else:
|
| 120 |
+
x = x + feats[i - 1]
|
| 121 |
+
x = smooths[i](
|
| 122 |
+
x.permute(0, 2, 1).reshape(x.shape[0], 256, H, W)
|
| 123 |
+
).flatten(2).permute(0, 2, 1)
|
| 124 |
+
x = x + x_pos
|
| 125 |
+
d = d + fix_audio_pos
|
| 126 |
+
|
| 127 |
+
for blks in f_blocks[i]:
|
| 128 |
+
x = blks(x, H, W, x_pos)
|
| 129 |
+
for blks in a_blocks[i]:
|
| 130 |
+
d = blks(d, fix_audio_pos)
|
| 131 |
+
|
| 132 |
+
x = x + x_pos
|
| 133 |
+
d = d + fix_audio_pos
|
| 134 |
+
x, d_out, _, _ = tpavi[i](x, H, W, x_pos, d, length)
|
| 135 |
+
d = d_out
|
| 136 |
+
feats[i] = x
|
| 137 |
+
d_outputs.append(d_out)
|
| 138 |
+
|
| 139 |
+
a, b, c = feats
|
| 140 |
+
d1, d2, d3 = d_outputs
|
| 141 |
+
|
| 142 |
+
feature_residual = [a, b, c]
|
| 143 |
+
audio_out = [d1, d2, d3]
|
| 144 |
+
|
| 145 |
+
proj_feature_out = [
|
| 146 |
+
[
|
| 147 |
+
self.train_proj_v1(a.permute(0, 2, 1).reshape(-1, 256, *image_embed_shape)),
|
| 148 |
+
self.train_proj_v1(b.permute(0, 2, 1).reshape(-1, 256, *image_embed_shape)),
|
| 149 |
+
self.train_proj_v1(c.permute(0, 2, 1).reshape(-1, 256, *image_embed_shape)),
|
| 150 |
+
],
|
| 151 |
+
[
|
| 152 |
+
self.train_proj_a1(d1.unsqueeze(-1)),
|
| 153 |
+
self.train_proj_a1(d2.unsqueeze(-1)),
|
| 154 |
+
self.train_proj_a1(d3.unsqueeze(-1)),
|
| 155 |
+
],
|
| 156 |
+
]
|
| 157 |
+
|
| 158 |
+
return feature_residual, audio_out, proj_feature_out
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class AudioVisualFusionModule(nn.Module):
|
| 162 |
+
def __init__(self, in_channels, inter_channels=None, mode='dot',
|
| 163 |
+
dimension=3):
|
| 164 |
+
super().__init__()
|
| 165 |
+
assert mode == 'dot'
|
| 166 |
+
self.mode = mode
|
| 167 |
+
self.dimension = dimension
|
| 168 |
+
|
| 169 |
+
self.in_channels = in_channels
|
| 170 |
+
self.inter_channels = in_channels // 2
|
| 171 |
+
|
| 172 |
+
self.align_channel = nn.Conv1d(256, in_channels, kernel_size=1)
|
| 173 |
+
self.align_channel_back = nn.Conv1d(in_channels, 128, kernel_size=1)
|
| 174 |
+
|
| 175 |
+
self.norm_layer = nn.LayerNorm(in_channels)
|
| 176 |
+
|
| 177 |
+
if dimension == 3:
|
| 178 |
+
conv_nd = nn.Conv3d
|
| 179 |
+
bn = nn.BatchNorm3d
|
| 180 |
+
elif dimension == 2:
|
| 181 |
+
conv_nd = nn.Conv2d
|
| 182 |
+
bn = nn.BatchNorm2d
|
| 183 |
+
else:
|
| 184 |
+
conv_nd = nn.Conv1d
|
| 185 |
+
bn = nn.BatchNorm1d
|
| 186 |
+
|
| 187 |
+
self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)
|
| 188 |
+
|
| 189 |
+
self.W_z = nn.Sequential(
|
| 190 |
+
conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1),
|
| 191 |
+
bn(self.in_channels)
|
| 192 |
+
)
|
| 193 |
+
nn.init.constant_(self.W_z[1].weight, 0)
|
| 194 |
+
nn.init.constant_(self.W_z[1].bias, 0)
|
| 195 |
+
|
| 196 |
+
self.W_z2 = nn.Sequential(
|
| 197 |
+
nn.Conv1d(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1),
|
| 198 |
+
nn.BatchNorm1d(self.in_channels)
|
| 199 |
+
)
|
| 200 |
+
nn.init.constant_(self.W_z2[1].weight, 0)
|
| 201 |
+
nn.init.constant_(self.W_z2[1].bias, 0)
|
| 202 |
+
self.norm_layer2 = nn.LayerNorm(self.in_channels)
|
| 203 |
+
|
| 204 |
+
self.q_frame = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)
|
| 205 |
+
self.k_frame = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)
|
| 206 |
+
self.v_frame = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)
|
| 207 |
+
|
| 208 |
+
self.q_audio = nn.Conv1d(self.in_channels, self.inter_channels, kernel_size=1)
|
| 209 |
+
self.k_audio = nn.Conv1d(self.in_channels, self.inter_channels, kernel_size=1)
|
| 210 |
+
self.v_audio = nn.Conv1d(self.in_channels, self.inter_channels, kernel_size=1)
|
| 211 |
+
|
| 212 |
+
def forward(self, frame, H_x, W_x, tmp1, audio, tmp2):
|
| 213 |
+
frame = frame.permute(0, 2, 1)
|
| 214 |
+
frame = frame.reshape(frame.shape[0], frame.shape[1], H_x, W_x)
|
| 215 |
+
frame = frame.unsqueeze(2)
|
| 216 |
+
audio = self.align_channel(audio.unsqueeze(-1))
|
| 217 |
+
|
| 218 |
+
batch_size, _ = frame.size(0), frame.size(1)
|
| 219 |
+
q_frame = self.q_frame(frame).reshape(1, -1, self.inter_channels)
|
| 220 |
+
k_frame = self.k_frame(frame).reshape(1, -1, self.inter_channels)
|
| 221 |
+
v_frame = self.v_frame(frame).reshape(1, -1, self.inter_channels)
|
| 222 |
+
q_audio = self.q_audio(audio).reshape(1, -1, self.inter_channels)
|
| 223 |
+
k_audio = self.k_audio(audio).reshape(1, -1, self.inter_channels)
|
| 224 |
+
v_audio = self.v_audio(audio).reshape(1, -1, self.inter_channels)
|
| 225 |
+
f = torch.matmul(q_frame, k_audio.mT)
|
| 226 |
+
f_normalise = f / f.size(1)
|
| 227 |
+
|
| 228 |
+
frame_attn = torch.matmul(f_normalise, v_audio)
|
| 229 |
+
|
| 230 |
+
frame_attn = frame_attn.permute(0, 2, 1).contiguous()
|
| 231 |
+
frame_attn = frame_attn.view(batch_size, self.inter_channels, *frame.size()[2:])
|
| 232 |
+
frame_attn = self.W_z(frame_attn)
|
| 233 |
+
frame = frame_attn + frame
|
| 234 |
+
|
| 235 |
+
frame = frame.permute(0, 2, 3, 4, 1)
|
| 236 |
+
frame = self.norm_layer(frame)
|
| 237 |
+
frame = frame.permute(0, 4, 1, 2, 3)
|
| 238 |
+
frame = frame.squeeze().flatten(start_dim=2).permute(0, 2, 1)
|
| 239 |
+
|
| 240 |
+
a = torch.matmul(q_audio, k_frame.mT)
|
| 241 |
+
a_normalise = a / a.size(-1)
|
| 242 |
+
|
| 243 |
+
audio_attn = torch.matmul(a_normalise, v_frame)
|
| 244 |
+
audio_attn = audio_attn.permute(0, 2, 1).contiguous()
|
| 245 |
+
|
| 246 |
+
audio_attn = audio_attn.view(batch_size, self.inter_channels).unsqueeze(-1)
|
| 247 |
+
audio_attn = self.W_z2(audio_attn)
|
| 248 |
+
|
| 249 |
+
audio = audio_attn + audio
|
| 250 |
+
|
| 251 |
+
audio = self.norm_layer2(audio.squeeze()).squeeze()
|
| 252 |
+
|
| 253 |
+
return frame, audio, frame_attn, audio_attn
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class OneDBlock(nn.Module):
|
| 257 |
+
|
| 258 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 259 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False):
|
| 260 |
+
super().__init__()
|
| 261 |
+
self.norm1 = norm_layer(dim)
|
| 262 |
+
self.attn = OneDAttention(
|
| 263 |
+
dim,
|
| 264 |
+
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 265 |
+
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear)
|
| 266 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 267 |
+
self.norm2 = norm_layer(dim)
|
| 268 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 269 |
+
self.mlp = OneDMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,
|
| 270 |
+
linear=linear)
|
| 271 |
+
|
| 272 |
+
self.apply(self._init_weights)
|
| 273 |
+
|
| 274 |
+
def _init_weights(self, m):
|
| 275 |
+
if isinstance(m, nn.Linear):
|
| 276 |
+
trunc_normal_(m.weight, std=.02)
|
| 277 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 278 |
+
nn.init.constant_(m.bias, 0)
|
| 279 |
+
elif isinstance(m, nn.LayerNorm):
|
| 280 |
+
nn.init.constant_(m.bias, 0)
|
| 281 |
+
nn.init.constant_(m.weight, 1.0)
|
| 282 |
+
elif isinstance(m, nn.Conv2d):
|
| 283 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 284 |
+
fan_out //= m.groups
|
| 285 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 286 |
+
if m.bias is not None:
|
| 287 |
+
m.bias.data.zero_()
|
| 288 |
+
|
| 289 |
+
def forward(self, x, _pos):
|
| 290 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
| 291 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 292 |
+
|
| 293 |
+
return x
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class OneDAttention(nn.Module):
|
| 297 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1,
|
| 298 |
+
linear=False):
|
| 299 |
+
super().__init__()
|
| 300 |
+
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
|
| 301 |
+
|
| 302 |
+
self.dim = dim
|
| 303 |
+
self.num_heads = num_heads
|
| 304 |
+
head_dim = dim // num_heads
|
| 305 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 306 |
+
|
| 307 |
+
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
| 308 |
+
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
| 309 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 310 |
+
self.proj = nn.Linear(dim, dim)
|
| 311 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 312 |
+
|
| 313 |
+
self.linear = linear
|
| 314 |
+
self.sr_ratio = sr_ratio
|
| 315 |
+
if not linear:
|
| 316 |
+
if sr_ratio > 1:
|
| 317 |
+
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
|
| 318 |
+
self.norm = nn.LayerNorm(dim)
|
| 319 |
+
else:
|
| 320 |
+
self.pool = nn.AdaptiveAvgPool2d(7)
|
| 321 |
+
self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1)
|
| 322 |
+
self.norm = nn.LayerNorm(dim)
|
| 323 |
+
self.act = nn.GELU()
|
| 324 |
+
self.apply(self._init_weights)
|
| 325 |
+
|
| 326 |
+
def _init_weights(self, m):
|
| 327 |
+
if isinstance(m, nn.Linear):
|
| 328 |
+
trunc_normal_(m.weight, std=.02)
|
| 329 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 330 |
+
nn.init.constant_(m.bias, 0)
|
| 331 |
+
elif isinstance(m, nn.LayerNorm):
|
| 332 |
+
nn.init.constant_(m.bias, 0)
|
| 333 |
+
nn.init.constant_(m.weight, 1.0)
|
| 334 |
+
elif isinstance(m, nn.Conv2d):
|
| 335 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 336 |
+
fan_out //= m.groups
|
| 337 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 338 |
+
if m.bias is not None:
|
| 339 |
+
m.bias.data.zero_()
|
| 340 |
+
|
| 341 |
+
def forward(self, x):
|
| 342 |
+
x = x.unsqueeze(0)
|
| 343 |
+
|
| 344 |
+
B, N, C = x.shape
|
| 345 |
+
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
| 346 |
+
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 347 |
+
|
| 348 |
+
k, v = kv[0], kv[1]
|
| 349 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 350 |
+
attn = attn.softmax(dim=-1)
|
| 351 |
+
attn = self.attn_drop(attn)
|
| 352 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 353 |
+
x = self.proj(x)
|
| 354 |
+
x = self.proj_drop(x)
|
| 355 |
+
|
| 356 |
+
x = x.squeeze()
|
| 357 |
+
return x
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class OneDMlp(nn.Module):
|
| 361 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False):
|
| 362 |
+
super().__init__()
|
| 363 |
+
out_features = out_features or in_features
|
| 364 |
+
hidden_features = hidden_features or in_features
|
| 365 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 366 |
+
self.dwconv = DWConv(hidden_features)
|
| 367 |
+
self.act = act_layer()
|
| 368 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 369 |
+
self.drop = nn.Dropout(drop)
|
| 370 |
+
self.linear = linear
|
| 371 |
+
|
| 372 |
+
if self.linear:
|
| 373 |
+
self.relu = nn.ReLU(inplace=True)
|
| 374 |
+
self.apply(self._init_weights)
|
| 375 |
+
|
| 376 |
+
def _init_weights(self, m):
|
| 377 |
+
if isinstance(m, nn.Linear):
|
| 378 |
+
trunc_normal_(m.weight, std=.02)
|
| 379 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 380 |
+
nn.init.constant_(m.bias, 0)
|
| 381 |
+
elif isinstance(m, nn.LayerNorm):
|
| 382 |
+
nn.init.constant_(m.bias, 0)
|
| 383 |
+
nn.init.constant_(m.weight, 1.0)
|
| 384 |
+
elif isinstance(m, nn.Conv2d):
|
| 385 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 386 |
+
fan_out //= m.groups
|
| 387 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 388 |
+
if m.bias is not None:
|
| 389 |
+
m.bias.data.zero_()
|
| 390 |
+
|
| 391 |
+
def forward(self, x):
|
| 392 |
+
x = self.fc1(x)
|
| 393 |
+
if self.linear:
|
| 394 |
+
x = self.relu(x)
|
| 395 |
+
x = self.act(x)
|
| 396 |
+
x = self.drop(x)
|
| 397 |
+
x = self.fc2(x)
|
| 398 |
+
x = self.drop(x)
|
| 399 |
+
return x
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
class Block(nn.Module):
|
| 403 |
+
|
| 404 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 405 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False):
|
| 406 |
+
super().__init__()
|
| 407 |
+
self.norm1 = norm_layer(dim)
|
| 408 |
+
self.attn = Attention(
|
| 409 |
+
dim,
|
| 410 |
+
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 411 |
+
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear)
|
| 412 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 413 |
+
self.norm2 = norm_layer(dim)
|
| 414 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 415 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, linear=linear)
|
| 416 |
+
|
| 417 |
+
self.apply(self._init_weights)
|
| 418 |
+
|
| 419 |
+
def _init_weights(self, m):
|
| 420 |
+
if isinstance(m, nn.Linear):
|
| 421 |
+
trunc_normal_(m.weight, std=.02)
|
| 422 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 423 |
+
nn.init.constant_(m.bias, 0)
|
| 424 |
+
elif isinstance(m, nn.LayerNorm):
|
| 425 |
+
nn.init.constant_(m.bias, 0)
|
| 426 |
+
nn.init.constant_(m.weight, 1.0)
|
| 427 |
+
elif isinstance(m, nn.Conv2d):
|
| 428 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 429 |
+
fan_out //= m.groups
|
| 430 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 431 |
+
if m.bias is not None:
|
| 432 |
+
m.bias.data.zero_()
|
| 433 |
+
|
| 434 |
+
def forward(self, x, H, W, _pos):
|
| 435 |
+
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
|
| 436 |
+
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
|
| 437 |
+
|
| 438 |
+
return x
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
class Attention(nn.Module):
|
| 442 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1,
|
| 443 |
+
linear=False):
|
| 444 |
+
super().__init__()
|
| 445 |
+
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
|
| 446 |
+
|
| 447 |
+
self.dim = dim
|
| 448 |
+
self.num_heads = num_heads
|
| 449 |
+
head_dim = dim // num_heads
|
| 450 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 451 |
+
|
| 452 |
+
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
| 453 |
+
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
| 454 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 455 |
+
self.proj = nn.Linear(dim, dim)
|
| 456 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 457 |
+
|
| 458 |
+
self.linear = linear
|
| 459 |
+
self.sr_ratio = sr_ratio
|
| 460 |
+
if not linear:
|
| 461 |
+
if sr_ratio > 1:
|
| 462 |
+
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
|
| 463 |
+
self.norm = nn.LayerNorm(dim)
|
| 464 |
+
else:
|
| 465 |
+
self.pool = nn.AdaptiveAvgPool2d(7)
|
| 466 |
+
self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1)
|
| 467 |
+
self.norm = nn.LayerNorm(dim)
|
| 468 |
+
self.act = nn.GELU()
|
| 469 |
+
self.apply(self._init_weights)
|
| 470 |
+
|
| 471 |
+
def _init_weights(self, m):
|
| 472 |
+
if isinstance(m, nn.Linear):
|
| 473 |
+
trunc_normal_(m.weight, std=.02)
|
| 474 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 475 |
+
nn.init.constant_(m.bias, 0)
|
| 476 |
+
elif isinstance(m, nn.LayerNorm):
|
| 477 |
+
nn.init.constant_(m.bias, 0)
|
| 478 |
+
nn.init.constant_(m.weight, 1.0)
|
| 479 |
+
elif isinstance(m, nn.Conv2d):
|
| 480 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 481 |
+
fan_out //= m.groups
|
| 482 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 483 |
+
if m.bias is not None:
|
| 484 |
+
m.bias.data.zero_()
|
| 485 |
+
|
| 486 |
+
def forward(self, x, H, W):
|
| 487 |
+
B, N, C = x.shape
|
| 488 |
+
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
| 489 |
+
if not self.linear:
|
| 490 |
+
if self.sr_ratio > 1:
|
| 491 |
+
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
|
| 492 |
+
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
|
| 493 |
+
x_ = self.norm(x_)
|
| 494 |
+
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 495 |
+
else:
|
| 496 |
+
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 497 |
+
else:
|
| 498 |
+
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
|
| 499 |
+
x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1)
|
| 500 |
+
x_ = self.norm(x_)
|
| 501 |
+
x_ = self.act(x_)
|
| 502 |
+
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 503 |
+
k, v = kv[0], kv[1]
|
| 504 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 505 |
+
attn = attn.softmax(dim=-1)
|
| 506 |
+
attn = self.attn_drop(attn)
|
| 507 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 508 |
+
x = self.proj(x)
|
| 509 |
+
x = self.proj_drop(x)
|
| 510 |
+
|
| 511 |
+
return x
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
class Mlp(nn.Module):
|
| 515 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False):
|
| 516 |
+
super().__init__()
|
| 517 |
+
out_features = out_features or in_features
|
| 518 |
+
hidden_features = hidden_features or in_features
|
| 519 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 520 |
+
self.dwconv = DWConv(hidden_features)
|
| 521 |
+
self.act = act_layer()
|
| 522 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 523 |
+
self.drop = nn.Dropout(drop)
|
| 524 |
+
self.linear = linear
|
| 525 |
+
|
| 526 |
+
if self.linear:
|
| 527 |
+
self.relu = nn.ReLU(inplace=True)
|
| 528 |
+
self.apply(self._init_weights)
|
| 529 |
+
|
| 530 |
+
def _init_weights(self, m):
|
| 531 |
+
if isinstance(m, nn.Linear):
|
| 532 |
+
trunc_normal_(m.weight, std=.02)
|
| 533 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 534 |
+
nn.init.constant_(m.bias, 0)
|
| 535 |
+
elif isinstance(m, nn.LayerNorm):
|
| 536 |
+
nn.init.constant_(m.bias, 0)
|
| 537 |
+
nn.init.constant_(m.weight, 1.0)
|
| 538 |
+
elif isinstance(m, nn.Conv2d):
|
| 539 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 540 |
+
fan_out //= m.groups
|
| 541 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 542 |
+
if m.bias is not None:
|
| 543 |
+
m.bias.data.zero_()
|
| 544 |
+
|
| 545 |
+
def forward(self, x, H, W):
|
| 546 |
+
x = self.fc1(x)
|
| 547 |
+
if self.linear:
|
| 548 |
+
x = self.relu(x)
|
| 549 |
+
x = self.dwconv(x, H, W)
|
| 550 |
+
x = self.act(x)
|
| 551 |
+
x = self.drop(x)
|
| 552 |
+
x = self.fc2(x)
|
| 553 |
+
x = self.drop(x)
|
| 554 |
+
return x
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
class DWConv(nn.Module):
|
| 558 |
+
def __init__(self, dim=768):
|
| 559 |
+
super(DWConv, self).__init__()
|
| 560 |
+
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
| 561 |
+
|
| 562 |
+
def forward(self, x, H, W):
|
| 563 |
+
B, N, C = x.shape
|
| 564 |
+
x = x.transpose(1, 2).view(B, C, H, W)
|
| 565 |
+
x = self.dwconv(x)
|
| 566 |
+
x = x.flatten(2).transpose(1, 2)
|
| 567 |
+
return x
|
avs.code/v1m.code/model/mymodel.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
from typing import List, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import numpy
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from PIL.Image import Image
|
| 9 |
+
|
| 10 |
+
from model.visual.sam2.modeling.sam2_base import SAM2Base
|
| 11 |
+
|
| 12 |
+
from model.visual.sam2.modeling.backbones.hieradet import Hiera
|
| 13 |
+
from model.visual.sam2.modeling.backbones.image_encoder import FpnNeck
|
| 14 |
+
from model.visual.sam2.modeling.backbones.image_encoder import ImageEncoder
|
| 15 |
+
from model.visual.sam2.modeling.position_encoding import PositionEmbeddingSine
|
| 16 |
+
|
| 17 |
+
from model.visual.sam2.modeling.memory_attention import MemoryAttention
|
| 18 |
+
from model.visual.sam2.modeling.memory_attention import MemoryAttentionLayer
|
| 19 |
+
from model.visual.sam2.modeling.sam.transformer import RoPEAttention
|
| 20 |
+
from model.visual.sam2.modeling.memory_encoder import MemoryEncoder
|
| 21 |
+
from model.visual.sam2.modeling.memory_encoder import MaskDownSampler
|
| 22 |
+
from model.visual.sam2.modeling.memory_encoder import Fuser
|
| 23 |
+
from model.visual.sam2.modeling.memory_encoder import CXBlock
|
| 24 |
+
|
| 25 |
+
from model.visual.sam2.utils.transforms import SAM2Transforms
|
| 26 |
+
from model.visual.sam2.modeling.backbones.hieradet import do_pool
|
| 27 |
+
from model.visual.sam2.modeling.backbones.utils import (
|
| 28 |
+
PatchEmbed,
|
| 29 |
+
window_partition,
|
| 30 |
+
window_unpartition,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class AVmodel(torch.nn.Module):
|
| 35 |
+
"""End-to-end AV segmentation: SAM2 visual backbone + AuralFuser audio-visual fusion + tracking head."""
|
| 36 |
+
|
| 37 |
+
def __init__(self, param, mask_threshold=0.0, max_hole_area=0.0, max_sprinkle_area=0.0, ):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.param = param
|
| 40 |
+
self.mask_threshold = mask_threshold
|
| 41 |
+
self._bb_feat_sizes = [(int(self.param.image_size / 4), int(self.param.image_size / 4)),
|
| 42 |
+
(int(self.param.image_size / 8), int(self.param.image_size / 8)),
|
| 43 |
+
(int(self.param.image_size / 16), int(self.param.image_size / 16))]
|
| 44 |
+
|
| 45 |
+
from model.visual.sam2.build_sam import build_sam2_visual_predictor
|
| 46 |
+
self.v_model = build_sam2_visual_predictor(self.param.sam_config_path, self.param.backbone_weight,
|
| 47 |
+
apply_postprocessing=True, mode='train')
|
| 48 |
+
self._transforms = SAM2Transforms(
|
| 49 |
+
resolution=self.v_model.image_size,
|
| 50 |
+
mask_threshold=mask_threshold,
|
| 51 |
+
max_hole_area=max_hole_area,
|
| 52 |
+
max_sprinkle_area=max_sprinkle_area,
|
| 53 |
+
)
|
| 54 |
+
from model.aural_fuser import AuralFuser
|
| 55 |
+
self.aural_fuser = AuralFuser(hyp_param=self.param)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _prepare_backbone_features(self, backbone_out):
|
| 60 |
+
"""Prepare and flatten visual features."""
|
| 61 |
+
backbone_out = backbone_out.copy()
|
| 62 |
+
assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
|
| 63 |
+
assert len(backbone_out["backbone_fpn"]) >= self.v_model.num_feature_levels
|
| 64 |
+
|
| 65 |
+
feature_maps = backbone_out["backbone_fpn"][-self.v_model.num_feature_levels:]
|
| 66 |
+
vision_pos_embeds = backbone_out["vision_pos_enc"][-self.v_model.num_feature_levels:]
|
| 67 |
+
|
| 68 |
+
feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
|
| 69 |
+
vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
|
| 70 |
+
vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
|
| 71 |
+
|
| 72 |
+
return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
|
| 73 |
+
|
| 74 |
+
def forward_frame(self, frame_):
|
| 75 |
+
frame = torch.nn.functional.interpolate(frame_, (self.param.image_size, self.param.image_size),
|
| 76 |
+
antialias=True, align_corners=False, mode='bilinear')
|
| 77 |
+
return self.v_model.image_encoder(frame)
|
| 78 |
+
|
| 79 |
+
def forward(self, frames, spect, prompts, sam_process=False):
|
| 80 |
+
"""Fuse audio into FPN features, then run SAM2 tracking. `sam_process` is reserved for prompt path."""
|
| 81 |
+
backbone_feats = self.v_model.forward_image(frames, pre_compute=False)
|
| 82 |
+
audio_residual_feats = self.aural_fuser(backbone_feats, spect)
|
| 83 |
+
visual_resfeats, audio_resfeats, proj_feats = audio_residual_feats
|
| 84 |
+
|
| 85 |
+
map_res = visual_resfeats[::-1]
|
| 86 |
+
vec_res = audio_resfeats[::-1]
|
| 87 |
+
|
| 88 |
+
av_feats = (map_res, vec_res)
|
| 89 |
+
backbone_feats = self.v_model.precompute_high_res_features(backbone_feats)
|
| 90 |
+
backbone_feats = self.v_model.dont_prepare_prompt_inputs(backbone_feats, num_frames=frames.shape[0],
|
| 91 |
+
cond_frame=int(frames.shape[0]/2) if self.training else 0)
|
| 92 |
+
outputs = self.v_model.forward_tracking_wo_prompt(backbone_feats, audio_res=av_feats)
|
| 93 |
+
return outputs, proj_feats
|
| 94 |
+
|
| 95 |
+
@property
|
| 96 |
+
def device(self) -> torch.device:
|
| 97 |
+
return self.v_model.device
|
| 98 |
+
|
| 99 |
+
def freeze_sam_parameters(self):
|
| 100 |
+
self.v_model.eval()
|
| 101 |
+
for name, parameter in self.v_model.named_parameters():
|
| 102 |
+
parameter.requires_grad = False
|
avs.code/v1m.code/model/visual/sam2/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from hydra import initialize_config_module
|
| 8 |
+
from hydra.core.global_hydra import GlobalHydra
|
| 9 |
+
|
| 10 |
+
if not GlobalHydra.instance().is_initialized():
|
| 11 |
+
initialize_config_module("configs", version_base="1.2")
|
avs.code/v1m.code/model/visual/sam2/build_sam.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from hydra import compose
|
| 12 |
+
from hydra.utils import instantiate
|
| 13 |
+
from omegaconf import OmegaConf
|
| 14 |
+
'''
|
| 15 |
+
import sam2
|
| 16 |
+
|
| 17 |
+
# Check if the user is running Python from the parent directory of the sam2 repo
|
| 18 |
+
# (i.e. the directory where this repo is cloned into) -- this is not supported since
|
| 19 |
+
# it could shadow the sam2 package and cause issues.
|
| 20 |
+
if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")):
|
| 21 |
+
# If the user has "sam2/sam2" in their path, they are likey importing the repo itself
|
| 22 |
+
# as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory).
|
| 23 |
+
# This typically happens because the user is running Python from the parent directory
|
| 24 |
+
# that contains the sam2 repo they cloned.
|
| 25 |
+
raise RuntimeError(
|
| 26 |
+
"You're likely running Python from the parent directory of the sam2 repository "
|
| 27 |
+
"(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). "
|
| 28 |
+
"This is not supported since the `sam2` Python package could be shadowed by the "
|
| 29 |
+
"repository name (the repository is also named `sam2` and contains the Python package "
|
| 30 |
+
"in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir "
|
| 31 |
+
"rather than its parent dir, or from your home directory) after installing SAM 2."
|
| 32 |
+
)
|
| 33 |
+
'''
|
| 34 |
+
|
| 35 |
+
HF_MODEL_ID_TO_FILENAMES = {
|
| 36 |
+
"facebook/sam2-hiera-tiny": (
|
| 37 |
+
"sam2/sam2_hiera_t.yaml",
|
| 38 |
+
"sam2_hiera_tiny.pt",
|
| 39 |
+
),
|
| 40 |
+
"facebook/sam2-hiera-small": (
|
| 41 |
+
"sam2/sam2_hiera_s.yaml",
|
| 42 |
+
"sam2_hiera_small.pt",
|
| 43 |
+
),
|
| 44 |
+
"facebook/sam2-hiera-base-plus": (
|
| 45 |
+
"sam2/sam2_hiera_b+.yaml",
|
| 46 |
+
"sam2_hiera_base_plus.pt",
|
| 47 |
+
),
|
| 48 |
+
"facebook/sam2-hiera-large": (
|
| 49 |
+
"sam2/sam2_hiera_l.yaml",
|
| 50 |
+
"sam2_hiera_large.pt",
|
| 51 |
+
),
|
| 52 |
+
"facebook/sam2.1-hiera-tiny": (
|
| 53 |
+
"sam2.1/sam2.1_hiera_t.yaml",
|
| 54 |
+
"sam2.1_hiera_tiny.pt",
|
| 55 |
+
),
|
| 56 |
+
"facebook/sam2.1-hiera-small": (
|
| 57 |
+
"sam2.1/sam2.1_hiera_s.yaml",
|
| 58 |
+
"sam2.1_hiera_small.pt",
|
| 59 |
+
),
|
| 60 |
+
"facebook/sam2.1-hiera-base-plus": (
|
| 61 |
+
"sam2.1/sam2.1_hiera_b+.yaml",
|
| 62 |
+
"sam2.1_hiera_base_plus.pt",
|
| 63 |
+
),
|
| 64 |
+
"facebook/sam2.1-hiera-large": (
|
| 65 |
+
"sam2.1/sam2.1_hiera_l.yaml",
|
| 66 |
+
"sam2.1_hiera_large.pt",
|
| 67 |
+
),
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def build_sam2(
|
| 72 |
+
config_file,
|
| 73 |
+
ckpt_path=None,
|
| 74 |
+
device="cuda",
|
| 75 |
+
mode="eval",
|
| 76 |
+
hydra_overrides_extra=[],
|
| 77 |
+
apply_postprocessing=True,
|
| 78 |
+
**kwargs,
|
| 79 |
+
):
|
| 80 |
+
|
| 81 |
+
if apply_postprocessing:
|
| 82 |
+
hydra_overrides_extra = hydra_overrides_extra.copy()
|
| 83 |
+
hydra_overrides_extra += [
|
| 84 |
+
# dynamically fall back to multi-mask if the single mask is not stable
|
| 85 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
|
| 86 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
|
| 87 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
|
| 88 |
+
]
|
| 89 |
+
# Read config and init model
|
| 90 |
+
cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
|
| 91 |
+
OmegaConf.resolve(cfg)
|
| 92 |
+
model = instantiate(cfg.model, _recursive_=True)
|
| 93 |
+
_load_checkpoint(model, ckpt_path)
|
| 94 |
+
model = model.to(device)
|
| 95 |
+
if mode == "eval":
|
| 96 |
+
model.eval()
|
| 97 |
+
return model
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def build_sam2_visual_predictor(
|
| 101 |
+
config_file,
|
| 102 |
+
ckpt_path=None,
|
| 103 |
+
mode="eval",
|
| 104 |
+
hydra_overrides_extra=[],
|
| 105 |
+
apply_postprocessing=True,
|
| 106 |
+
**kwargs,
|
| 107 |
+
):
|
| 108 |
+
# visual
|
| 109 |
+
hydra_overrides = []
|
| 110 |
+
# "++model._target_=model.visual.sam2.organised_sam2_train.SAM2Train",
|
| 111 |
+
# ]
|
| 112 |
+
# hydra_overrides = [
|
| 113 |
+
# "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
|
| 114 |
+
# ]
|
| 115 |
+
if apply_postprocessing:
|
| 116 |
+
hydra_overrides_extra = hydra_overrides_extra.copy()
|
| 117 |
+
hydra_overrides_extra += [
|
| 118 |
+
|
| 119 |
+
# dynamically fall back to multi-mask if the single mask is not stable
|
| 120 |
+
# "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
|
| 121 |
+
# "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
|
| 122 |
+
# "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
|
| 123 |
+
|
| 124 |
+
# the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
|
| 125 |
+
"++model.binarize_mask_from_pts_for_mem_enc=true",
|
| 126 |
+
# fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
|
| 127 |
+
# "++model.fill_hole_area=8",
|
| 128 |
+
]
|
| 129 |
+
hydra_overrides.extend(hydra_overrides_extra)
|
| 130 |
+
|
| 131 |
+
# Read config and init model
|
| 132 |
+
cfg = compose(config_name=config_file, overrides=hydra_overrides)
|
| 133 |
+
OmegaConf.resolve(cfg)
|
| 134 |
+
model = instantiate(cfg.model, _recursive_=True)
|
| 135 |
+
_load_checkpoint(model, ckpt_path)
|
| 136 |
+
if mode == "eval":
|
| 137 |
+
model.eval()
|
| 138 |
+
return model
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def _hf_download(model_id):
|
| 142 |
+
from huggingface_hub import hf_hub_download
|
| 143 |
+
|
| 144 |
+
config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id]
|
| 145 |
+
ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
|
| 146 |
+
return config_name, ckpt_path
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def build_sam2_hf(model_id, **kwargs):
|
| 150 |
+
config_name, ckpt_path = _hf_download(model_id)
|
| 151 |
+
return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# def build_sam2_video_predictor_hf(model_id, **kwargs):
|
| 155 |
+
# config_name, ckpt_path = _hf_download(model_id)
|
| 156 |
+
# return build_sam2_video_predictor(
|
| 157 |
+
# config_file=config_name, ckpt_path=ckpt_path, **kwargs
|
| 158 |
+
# )
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def _load_checkpoint(model, ckpt_path):
|
| 162 |
+
if ckpt_path is not None:
|
| 163 |
+
sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
|
| 164 |
+
missing_keys, unexpected_keys = model.load_state_dict(sd)
|
| 165 |
+
if missing_keys:
|
| 166 |
+
logging.error(missing_keys)
|
| 167 |
+
raise RuntimeError()
|
| 168 |
+
if unexpected_keys:
|
| 169 |
+
logging.error(unexpected_keys)
|
| 170 |
+
raise RuntimeError()
|
| 171 |
+
logging.info("Loaded checkpoint sucessfully")
|
avs.code/v1m.code/model/visual/sam2/modeling/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
avs.code/v1m.code/model/visual/sam2/modeling/backbones/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
avs.code/v1m.code/model/visual/sam2/modeling/backbones/hieradet.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from functools import partial
|
| 9 |
+
from typing import List, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from iopath.common.file_io import g_pathmgr
|
| 15 |
+
|
| 16 |
+
from model.visual.sam2.modeling.backbones.utils import (
|
| 17 |
+
PatchEmbed,
|
| 18 |
+
window_partition,
|
| 19 |
+
window_unpartition,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
from model.visual.sam2.modeling.sam2_utils import DropPath, MLP
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
|
| 26 |
+
if pool is None:
|
| 27 |
+
return x
|
| 28 |
+
# (B, H, W, C) -> (B, C, H, W)
|
| 29 |
+
x = x.permute(0, 3, 1, 2)
|
| 30 |
+
x = pool(x)
|
| 31 |
+
# (B, C, H', W') -> (B, H', W', C)
|
| 32 |
+
x = x.permute(0, 2, 3, 1)
|
| 33 |
+
if norm:
|
| 34 |
+
x = norm(x)
|
| 35 |
+
|
| 36 |
+
return x
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class MultiScaleAttention(nn.Module):
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
dim: int,
|
| 43 |
+
dim_out: int,
|
| 44 |
+
num_heads: int,
|
| 45 |
+
q_pool: nn.Module = None,
|
| 46 |
+
):
|
| 47 |
+
super().__init__()
|
| 48 |
+
|
| 49 |
+
self.dim = dim
|
| 50 |
+
self.dim_out = dim_out
|
| 51 |
+
self.num_heads = num_heads
|
| 52 |
+
self.q_pool = q_pool
|
| 53 |
+
self.qkv = nn.Linear(dim, dim_out * 3)
|
| 54 |
+
self.proj = nn.Linear(dim_out, dim_out)
|
| 55 |
+
|
| 56 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 57 |
+
B, H, W, _ = x.shape
|
| 58 |
+
# qkv with shape (B, H * W, 3, nHead, C)
|
| 59 |
+
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
|
| 60 |
+
# q, k, v with shape (B, H * W, nheads, C)
|
| 61 |
+
q, k, v = torch.unbind(qkv, 2)
|
| 62 |
+
|
| 63 |
+
# Q pooling (for downsample at stage changes)
|
| 64 |
+
if self.q_pool:
|
| 65 |
+
q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
|
| 66 |
+
H, W = q.shape[1:3] # downsampled shape
|
| 67 |
+
q = q.reshape(B, H * W, self.num_heads, -1)
|
| 68 |
+
|
| 69 |
+
# Torch's SDPA expects [B, nheads, H*W, C] so we transpose
|
| 70 |
+
x = F.scaled_dot_product_attention(
|
| 71 |
+
q.transpose(1, 2),
|
| 72 |
+
k.transpose(1, 2),
|
| 73 |
+
v.transpose(1, 2),
|
| 74 |
+
)
|
| 75 |
+
# Transpose back
|
| 76 |
+
x = x.transpose(1, 2)
|
| 77 |
+
x = x.reshape(B, H, W, -1)
|
| 78 |
+
|
| 79 |
+
x = self.proj(x)
|
| 80 |
+
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class MultiScaleBlock(nn.Module):
|
| 85 |
+
def __init__(
|
| 86 |
+
self,
|
| 87 |
+
dim: int,
|
| 88 |
+
dim_out: int,
|
| 89 |
+
num_heads: int,
|
| 90 |
+
mlp_ratio: float = 4.0,
|
| 91 |
+
drop_path: float = 0.0,
|
| 92 |
+
norm_layer: Union[nn.Module, str] = "LayerNorm",
|
| 93 |
+
q_stride: Tuple[int, int] = None,
|
| 94 |
+
act_layer: nn.Module = nn.GELU,
|
| 95 |
+
window_size: int = 0,
|
| 96 |
+
):
|
| 97 |
+
super().__init__()
|
| 98 |
+
|
| 99 |
+
if isinstance(norm_layer, str):
|
| 100 |
+
norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
|
| 101 |
+
|
| 102 |
+
self.dim = dim
|
| 103 |
+
self.dim_out = dim_out
|
| 104 |
+
self.norm1 = norm_layer(dim)
|
| 105 |
+
|
| 106 |
+
self.window_size = window_size
|
| 107 |
+
|
| 108 |
+
self.pool, self.q_stride = None, q_stride
|
| 109 |
+
if self.q_stride:
|
| 110 |
+
self.pool = nn.MaxPool2d(
|
| 111 |
+
kernel_size=q_stride, stride=q_stride, ceil_mode=False
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
self.attn = MultiScaleAttention(
|
| 115 |
+
dim,
|
| 116 |
+
dim_out,
|
| 117 |
+
num_heads=num_heads,
|
| 118 |
+
q_pool=self.pool,
|
| 119 |
+
)
|
| 120 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 121 |
+
|
| 122 |
+
self.norm2 = norm_layer(dim_out)
|
| 123 |
+
self.mlp = MLP(
|
| 124 |
+
dim_out,
|
| 125 |
+
int(dim_out * mlp_ratio),
|
| 126 |
+
dim_out,
|
| 127 |
+
num_layers=2,
|
| 128 |
+
activation=act_layer,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
if dim != dim_out:
|
| 132 |
+
self.proj = nn.Linear(dim, dim_out)
|
| 133 |
+
|
| 134 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 135 |
+
shortcut = x # B, H, W, C
|
| 136 |
+
x = self.norm1(x)
|
| 137 |
+
|
| 138 |
+
# Skip connection
|
| 139 |
+
if self.dim != self.dim_out:
|
| 140 |
+
shortcut = do_pool(self.proj(x), self.pool)
|
| 141 |
+
|
| 142 |
+
# Window partition
|
| 143 |
+
window_size = self.window_size
|
| 144 |
+
if window_size > 0:
|
| 145 |
+
H, W = x.shape[1], x.shape[2]
|
| 146 |
+
x, pad_hw = window_partition(x, window_size)
|
| 147 |
+
|
| 148 |
+
# Window Attention + Q Pooling (if stage change)
|
| 149 |
+
x = self.attn(x)
|
| 150 |
+
if self.q_stride:
|
| 151 |
+
# Shapes have changed due to Q pooling
|
| 152 |
+
window_size = self.window_size // self.q_stride[0]
|
| 153 |
+
H, W = shortcut.shape[1:3]
|
| 154 |
+
|
| 155 |
+
pad_h = (window_size - H % window_size) % window_size
|
| 156 |
+
pad_w = (window_size - W % window_size) % window_size
|
| 157 |
+
pad_hw = (H + pad_h, W + pad_w)
|
| 158 |
+
|
| 159 |
+
# Reverse window partition
|
| 160 |
+
if self.window_size > 0:
|
| 161 |
+
x = window_unpartition(x, window_size, pad_hw, (H, W))
|
| 162 |
+
|
| 163 |
+
x = shortcut + self.drop_path(x)
|
| 164 |
+
# MLP
|
| 165 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 166 |
+
return x
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class Hiera(nn.Module):
|
| 170 |
+
"""
|
| 171 |
+
Reference: https://arxiv.org/abs/2306.00989
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
def __init__(
|
| 175 |
+
self,
|
| 176 |
+
embed_dim: int = 96, # initial embed dim
|
| 177 |
+
num_heads: int = 1, # initial number of heads
|
| 178 |
+
drop_path_rate: float = 0.0, # stochastic depth
|
| 179 |
+
q_pool: int = 3, # number of q_pool stages
|
| 180 |
+
q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
|
| 181 |
+
stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
|
| 182 |
+
dim_mul: float = 2.0, # dim_mul factor at stage shift
|
| 183 |
+
head_mul: float = 2.0, # head_mul factor at stage shift
|
| 184 |
+
window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
|
| 185 |
+
# window size per stage, when not using global att.
|
| 186 |
+
window_spec: Tuple[int, ...] = (
|
| 187 |
+
8,
|
| 188 |
+
4,
|
| 189 |
+
14,
|
| 190 |
+
7,
|
| 191 |
+
),
|
| 192 |
+
# global attn in these blocks
|
| 193 |
+
global_att_blocks: Tuple[int, ...] = (
|
| 194 |
+
12,
|
| 195 |
+
16,
|
| 196 |
+
20,
|
| 197 |
+
),
|
| 198 |
+
weights_path=None,
|
| 199 |
+
return_interm_layers=True, # return feats from every stage
|
| 200 |
+
):
|
| 201 |
+
super().__init__()
|
| 202 |
+
|
| 203 |
+
assert len(stages) == len(window_spec)
|
| 204 |
+
self.window_spec = window_spec
|
| 205 |
+
|
| 206 |
+
depth = sum(stages)
|
| 207 |
+
self.q_stride = q_stride
|
| 208 |
+
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
|
| 209 |
+
assert 0 <= q_pool <= len(self.stage_ends[:-1])
|
| 210 |
+
self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
|
| 211 |
+
self.return_interm_layers = return_interm_layers
|
| 212 |
+
|
| 213 |
+
self.patch_embed = PatchEmbed(
|
| 214 |
+
embed_dim=embed_dim,
|
| 215 |
+
)
|
| 216 |
+
# Which blocks have global att?
|
| 217 |
+
self.global_att_blocks = global_att_blocks
|
| 218 |
+
|
| 219 |
+
# Windowed positional embedding (https://arxiv.org/abs/2311.05613)
|
| 220 |
+
self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
|
| 221 |
+
self.pos_embed = nn.Parameter(
|
| 222 |
+
torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
|
| 223 |
+
)
|
| 224 |
+
self.pos_embed_window = nn.Parameter(
|
| 225 |
+
torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
dpr = [
|
| 229 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
| 230 |
+
] # stochastic depth decay rule
|
| 231 |
+
|
| 232 |
+
cur_stage = 1
|
| 233 |
+
self.blocks = nn.ModuleList()
|
| 234 |
+
|
| 235 |
+
for i in range(depth):
|
| 236 |
+
dim_out = embed_dim
|
| 237 |
+
# lags by a block, so first block of
|
| 238 |
+
# next stage uses an initial window size
|
| 239 |
+
# of previous stage and final window size of current stage
|
| 240 |
+
window_size = self.window_spec[cur_stage - 1]
|
| 241 |
+
|
| 242 |
+
if self.global_att_blocks is not None:
|
| 243 |
+
window_size = 0 if i in self.global_att_blocks else window_size
|
| 244 |
+
|
| 245 |
+
if i - 1 in self.stage_ends:
|
| 246 |
+
dim_out = int(embed_dim * dim_mul)
|
| 247 |
+
num_heads = int(num_heads * head_mul)
|
| 248 |
+
cur_stage += 1
|
| 249 |
+
|
| 250 |
+
block = MultiScaleBlock(
|
| 251 |
+
dim=embed_dim,
|
| 252 |
+
dim_out=dim_out,
|
| 253 |
+
num_heads=num_heads,
|
| 254 |
+
drop_path=dpr[i],
|
| 255 |
+
q_stride=self.q_stride if i in self.q_pool_blocks else None,
|
| 256 |
+
window_size=window_size,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
embed_dim = dim_out
|
| 260 |
+
self.blocks.append(block)
|
| 261 |
+
|
| 262 |
+
self.channel_list = (
|
| 263 |
+
[self.blocks[i].dim_out for i in self.stage_ends[::-1]]
|
| 264 |
+
if return_interm_layers
|
| 265 |
+
else [self.blocks[-1].dim_out]
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
if weights_path is not None:
|
| 269 |
+
with g_pathmgr.open(weights_path, "rb") as f:
|
| 270 |
+
chkpt = torch.load(f, map_location="cpu")
|
| 271 |
+
logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
|
| 272 |
+
|
| 273 |
+
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
|
| 274 |
+
h, w = hw
|
| 275 |
+
window_embed = self.pos_embed_window
|
| 276 |
+
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
|
| 277 |
+
pos_embed = pos_embed + window_embed.tile(
|
| 278 |
+
[x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
|
| 279 |
+
)
|
| 280 |
+
pos_embed = pos_embed.permute(0, 2, 3, 1)
|
| 281 |
+
return pos_embed
|
| 282 |
+
|
| 283 |
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
| 284 |
+
x = self.patch_embed(x)
|
| 285 |
+
# x: (B, H, W, C)
|
| 286 |
+
|
| 287 |
+
# Add pos embed
|
| 288 |
+
x = x + self._get_pos_embed(x.shape[1:3])
|
| 289 |
+
|
| 290 |
+
outputs = []
|
| 291 |
+
for i, blk in enumerate(self.blocks):
|
| 292 |
+
x = blk(x)
|
| 293 |
+
if (i == self.stage_ends[-1]) or (
|
| 294 |
+
i in self.stage_ends and self.return_interm_layers
|
| 295 |
+
):
|
| 296 |
+
feats = x.permute(0, 3, 1, 2)
|
| 297 |
+
outputs.append(feats)
|
| 298 |
+
|
| 299 |
+
return outputs
|
| 300 |
+
|
| 301 |
+
def get_layer_id(self, layer_name):
|
| 302 |
+
# https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
|
| 303 |
+
num_layers = self.get_num_layers()
|
| 304 |
+
|
| 305 |
+
if layer_name.find("rel_pos") != -1:
|
| 306 |
+
return num_layers + 1
|
| 307 |
+
elif layer_name.find("pos_embed") != -1:
|
| 308 |
+
return 0
|
| 309 |
+
elif layer_name.find("patch_embed") != -1:
|
| 310 |
+
return 0
|
| 311 |
+
elif layer_name.find("blocks") != -1:
|
| 312 |
+
return int(layer_name.split("blocks")[1].split(".")[1]) + 1
|
| 313 |
+
else:
|
| 314 |
+
return num_layers + 1
|
| 315 |
+
|
| 316 |
+
def get_num_layers(self) -> int:
|
| 317 |
+
return len(self.blocks)
|
avs.code/v1m.code/model/visual/sam2/modeling/backbones/image_encoder.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import List, Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ImageEncoder(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
trunk: nn.Module,
|
| 18 |
+
neck: nn.Module,
|
| 19 |
+
scalp: int = 0,
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.trunk = trunk
|
| 23 |
+
self.neck = neck
|
| 24 |
+
self.scalp = scalp
|
| 25 |
+
assert (
|
| 26 |
+
self.trunk.channel_list == self.neck.backbone_channel_list
|
| 27 |
+
), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
|
| 28 |
+
|
| 29 |
+
def forward(self, sample: torch.Tensor):
|
| 30 |
+
# Forward through backbone
|
| 31 |
+
features, pos = self.neck(self.trunk(sample))
|
| 32 |
+
if self.scalp > 0:
|
| 33 |
+
# Discard the lowest resolution features
|
| 34 |
+
features, pos = features[: -self.scalp], pos[: -self.scalp]
|
| 35 |
+
|
| 36 |
+
src = features[-1]
|
| 37 |
+
output = {
|
| 38 |
+
"vision_features": src,
|
| 39 |
+
"vision_pos_enc": pos,
|
| 40 |
+
"backbone_fpn": features,
|
| 41 |
+
}
|
| 42 |
+
return output
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class FpnNeck(nn.Module):
|
| 46 |
+
"""
|
| 47 |
+
A modified variant of Feature Pyramid Network (FPN) neck
|
| 48 |
+
(we remove output conv and also do bicubic interpolation similar to ViT
|
| 49 |
+
pos embed interpolation)
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
position_encoding: nn.Module,
|
| 55 |
+
d_model: int,
|
| 56 |
+
backbone_channel_list: List[int],
|
| 57 |
+
kernel_size: int = 1,
|
| 58 |
+
stride: int = 1,
|
| 59 |
+
padding: int = 0,
|
| 60 |
+
fpn_interp_model: str = "bilinear",
|
| 61 |
+
fuse_type: str = "sum",
|
| 62 |
+
fpn_top_down_levels: Optional[List[int]] = None,
|
| 63 |
+
):
|
| 64 |
+
"""Initialize the neck
|
| 65 |
+
:param trunk: the backbone
|
| 66 |
+
:param position_encoding: the positional encoding to use
|
| 67 |
+
:param d_model: the dimension of the model
|
| 68 |
+
:param neck_norm: the normalization to use
|
| 69 |
+
"""
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.position_encoding = position_encoding
|
| 72 |
+
self.convs = nn.ModuleList()
|
| 73 |
+
self.backbone_channel_list = backbone_channel_list
|
| 74 |
+
self.d_model = d_model
|
| 75 |
+
for dim in backbone_channel_list:
|
| 76 |
+
current = nn.Sequential()
|
| 77 |
+
current.add_module(
|
| 78 |
+
"conv",
|
| 79 |
+
nn.Conv2d(
|
| 80 |
+
in_channels=dim,
|
| 81 |
+
out_channels=d_model,
|
| 82 |
+
kernel_size=kernel_size,
|
| 83 |
+
stride=stride,
|
| 84 |
+
padding=padding,
|
| 85 |
+
),
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
self.convs.append(current)
|
| 89 |
+
self.fpn_interp_model = fpn_interp_model
|
| 90 |
+
assert fuse_type in ["sum", "avg"]
|
| 91 |
+
self.fuse_type = fuse_type
|
| 92 |
+
|
| 93 |
+
# levels to have top-down features in its outputs
|
| 94 |
+
# e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
|
| 95 |
+
# have top-down propagation, while outputs of level 0 and level 1 have only
|
| 96 |
+
# lateral features from the same backbone level.
|
| 97 |
+
if fpn_top_down_levels is None:
|
| 98 |
+
# default is to have top-down features on all levels
|
| 99 |
+
fpn_top_down_levels = range(len(self.convs))
|
| 100 |
+
self.fpn_top_down_levels = list(fpn_top_down_levels)
|
| 101 |
+
|
| 102 |
+
def forward(self, xs: List[torch.Tensor]):
|
| 103 |
+
|
| 104 |
+
out = [None] * len(self.convs)
|
| 105 |
+
pos = [None] * len(self.convs)
|
| 106 |
+
assert len(xs) == len(self.convs)
|
| 107 |
+
# fpn forward pass
|
| 108 |
+
# see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
|
| 109 |
+
prev_features = None
|
| 110 |
+
# forward in top-down order (from low to high resolution)
|
| 111 |
+
n = len(self.convs) - 1
|
| 112 |
+
for i in range(n, -1, -1):
|
| 113 |
+
x = xs[i]
|
| 114 |
+
lateral_features = self.convs[n - i](x)
|
| 115 |
+
if i in self.fpn_top_down_levels and prev_features is not None:
|
| 116 |
+
top_down_features = F.interpolate(
|
| 117 |
+
prev_features.to(dtype=torch.float32),
|
| 118 |
+
scale_factor=2.0,
|
| 119 |
+
mode=self.fpn_interp_model,
|
| 120 |
+
align_corners=(
|
| 121 |
+
None if self.fpn_interp_model == "nearest" else False
|
| 122 |
+
),
|
| 123 |
+
antialias=False,
|
| 124 |
+
)
|
| 125 |
+
prev_features = lateral_features + top_down_features
|
| 126 |
+
if self.fuse_type == "avg":
|
| 127 |
+
prev_features /= 2
|
| 128 |
+
else:
|
| 129 |
+
prev_features = lateral_features
|
| 130 |
+
x_out = prev_features
|
| 131 |
+
out[i] = x_out
|
| 132 |
+
pos[i] = self.position_encoding(x_out).to(x_out.dtype)
|
| 133 |
+
|
| 134 |
+
return out, pos
|
avs.code/v1m.code/model/visual/sam2/modeling/backbones/utils.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Some utilities for backbones, in particular for windowing"""
|
| 8 |
+
|
| 9 |
+
from typing import Tuple
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def window_partition(x, window_size):
|
| 17 |
+
"""
|
| 18 |
+
Partition into non-overlapping windows with padding if needed.
|
| 19 |
+
Args:
|
| 20 |
+
x (tensor): input tokens with [B, H, W, C].
|
| 21 |
+
window_size (int): window size.
|
| 22 |
+
Returns:
|
| 23 |
+
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
| 24 |
+
(Hp, Wp): padded height and width before partition
|
| 25 |
+
"""
|
| 26 |
+
B, H, W, C = x.shape
|
| 27 |
+
|
| 28 |
+
pad_h = (window_size - H % window_size) % window_size
|
| 29 |
+
pad_w = (window_size - W % window_size) % window_size
|
| 30 |
+
if pad_h > 0 or pad_w > 0:
|
| 31 |
+
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
| 32 |
+
Hp, Wp = H + pad_h, W + pad_w
|
| 33 |
+
|
| 34 |
+
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
| 35 |
+
windows = (
|
| 36 |
+
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
| 37 |
+
)
|
| 38 |
+
return windows, (Hp, Wp)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def window_unpartition(windows, window_size, pad_hw, hw):
|
| 42 |
+
"""
|
| 43 |
+
Window unpartition into original sequences and removing padding.
|
| 44 |
+
Args:
|
| 45 |
+
x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
| 46 |
+
window_size (int): window size.
|
| 47 |
+
pad_hw (Tuple): padded height and width (Hp, Wp).
|
| 48 |
+
hw (Tuple): original height and width (H, W) before padding.
|
| 49 |
+
Returns:
|
| 50 |
+
x: unpartitioned sequences with [B, H, W, C].
|
| 51 |
+
"""
|
| 52 |
+
Hp, Wp = pad_hw
|
| 53 |
+
H, W = hw
|
| 54 |
+
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
| 55 |
+
x = windows.view(
|
| 56 |
+
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
|
| 57 |
+
)
|
| 58 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
| 59 |
+
|
| 60 |
+
if Hp > H or Wp > W:
|
| 61 |
+
x = x[:, :H, :W, :].contiguous()
|
| 62 |
+
return x
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class PatchEmbed(nn.Module):
|
| 66 |
+
"""
|
| 67 |
+
Image to Patch Embedding.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__(
|
| 71 |
+
self,
|
| 72 |
+
kernel_size: Tuple[int, ...] = (7, 7),
|
| 73 |
+
stride: Tuple[int, ...] = (4, 4),
|
| 74 |
+
padding: Tuple[int, ...] = (3, 3),
|
| 75 |
+
in_chans: int = 3,
|
| 76 |
+
embed_dim: int = 768,
|
| 77 |
+
):
|
| 78 |
+
"""
|
| 79 |
+
Args:
|
| 80 |
+
kernel_size (Tuple): kernel size of the projection layer.
|
| 81 |
+
stride (Tuple): stride of the projection layer.
|
| 82 |
+
padding (Tuple): padding size of the projection layer.
|
| 83 |
+
in_chans (int): Number of input image channels.
|
| 84 |
+
embed_dim (int): embed_dim (int): Patch embedding dimension.
|
| 85 |
+
"""
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.proj = nn.Conv2d(
|
| 88 |
+
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 92 |
+
x = self.proj(x)
|
| 93 |
+
# B C H W -> B H W C
|
| 94 |
+
x = x.permute(0, 2, 3, 1)
|
| 95 |
+
return x
|
avs.code/v1m.code/model/visual/sam2/modeling/memory_attention.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn, Tensor
|
| 11 |
+
|
| 12 |
+
from model.visual.sam2.modeling.sam.transformer import RoPEAttention
|
| 13 |
+
|
| 14 |
+
from model.visual.sam2.modeling.sam2_utils import get_activation_fn, get_clones
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MemoryAttentionLayer(nn.Module):
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
activation: str,
|
| 22 |
+
cross_attention: nn.Module,
|
| 23 |
+
d_model: int,
|
| 24 |
+
dim_feedforward: int,
|
| 25 |
+
dropout: float,
|
| 26 |
+
pos_enc_at_attn: bool,
|
| 27 |
+
pos_enc_at_cross_attn_keys: bool,
|
| 28 |
+
pos_enc_at_cross_attn_queries: bool,
|
| 29 |
+
self_attention: nn.Module,
|
| 30 |
+
):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.d_model = d_model
|
| 33 |
+
self.dim_feedforward = dim_feedforward
|
| 34 |
+
self.dropout_value = dropout
|
| 35 |
+
self.self_attn = self_attention
|
| 36 |
+
self.cross_attn_image = cross_attention
|
| 37 |
+
|
| 38 |
+
# Implementation of Feedforward model
|
| 39 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 40 |
+
self.dropout = nn.Dropout(dropout)
|
| 41 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 42 |
+
|
| 43 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 44 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 45 |
+
self.norm3 = nn.LayerNorm(d_model)
|
| 46 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 47 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 48 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 49 |
+
|
| 50 |
+
self.activation_str = activation
|
| 51 |
+
self.activation = get_activation_fn(activation)
|
| 52 |
+
|
| 53 |
+
# Where to add pos enc
|
| 54 |
+
self.pos_enc_at_attn = pos_enc_at_attn
|
| 55 |
+
self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
|
| 56 |
+
self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
|
| 57 |
+
|
| 58 |
+
def _forward_sa(self, tgt, query_pos):
|
| 59 |
+
# Self-Attention
|
| 60 |
+
tgt2 = self.norm1(tgt)
|
| 61 |
+
q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
|
| 62 |
+
tgt2 = self.self_attn(q, k, v=tgt2)
|
| 63 |
+
tgt = tgt + self.dropout1(tgt2)
|
| 64 |
+
return tgt
|
| 65 |
+
|
| 66 |
+
def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
|
| 67 |
+
kwds = {}
|
| 68 |
+
if num_k_exclude_rope > 0:
|
| 69 |
+
assert isinstance(self.cross_attn_image, RoPEAttention)
|
| 70 |
+
kwds = {"num_k_exclude_rope": num_k_exclude_rope}
|
| 71 |
+
|
| 72 |
+
# Cross-Attention
|
| 73 |
+
tgt2 = self.norm2(tgt)
|
| 74 |
+
tgt2 = self.cross_attn_image(
|
| 75 |
+
q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
|
| 76 |
+
k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
|
| 77 |
+
v=memory,
|
| 78 |
+
**kwds,
|
| 79 |
+
)
|
| 80 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 81 |
+
return tgt
|
| 82 |
+
|
| 83 |
+
def forward(
|
| 84 |
+
self,
|
| 85 |
+
tgt,
|
| 86 |
+
memory,
|
| 87 |
+
pos: Optional[Tensor] = None,
|
| 88 |
+
query_pos: Optional[Tensor] = None,
|
| 89 |
+
num_k_exclude_rope: int = 0,
|
| 90 |
+
) -> torch.Tensor:
|
| 91 |
+
|
| 92 |
+
# Self-Attn, Cross-Attn
|
| 93 |
+
tgt = self._forward_sa(tgt, query_pos)
|
| 94 |
+
tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
|
| 95 |
+
# MLP
|
| 96 |
+
tgt2 = self.norm3(tgt)
|
| 97 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
| 98 |
+
tgt = tgt + self.dropout3(tgt2)
|
| 99 |
+
return tgt
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class MemoryAttention(nn.Module):
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
d_model: int,
|
| 106 |
+
pos_enc_at_input: bool,
|
| 107 |
+
layer: nn.Module,
|
| 108 |
+
num_layers: int,
|
| 109 |
+
batch_first: bool = True, # Do layers expect batch first input?
|
| 110 |
+
):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.d_model = d_model
|
| 113 |
+
self.layers = get_clones(layer, num_layers)
|
| 114 |
+
self.num_layers = num_layers
|
| 115 |
+
self.norm = nn.LayerNorm(d_model)
|
| 116 |
+
self.pos_enc_at_input = pos_enc_at_input
|
| 117 |
+
self.batch_first = batch_first
|
| 118 |
+
|
| 119 |
+
def forward(
|
| 120 |
+
self,
|
| 121 |
+
curr: torch.Tensor, # self-attention inputs
|
| 122 |
+
memory: torch.Tensor, # cross-attention inputs
|
| 123 |
+
curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
|
| 124 |
+
memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
|
| 125 |
+
num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
|
| 126 |
+
):
|
| 127 |
+
if isinstance(curr, list):
|
| 128 |
+
assert isinstance(curr_pos, list)
|
| 129 |
+
assert len(curr) == len(curr_pos) == 1
|
| 130 |
+
curr, curr_pos = (
|
| 131 |
+
curr[0],
|
| 132 |
+
curr_pos[0],
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
assert (
|
| 136 |
+
curr.shape[1] == memory.shape[1]
|
| 137 |
+
), "Batch size must be the same for curr and memory"
|
| 138 |
+
|
| 139 |
+
output = curr
|
| 140 |
+
if self.pos_enc_at_input and curr_pos is not None:
|
| 141 |
+
output = output + 0.1 * curr_pos
|
| 142 |
+
|
| 143 |
+
if self.batch_first:
|
| 144 |
+
# Convert to batch first
|
| 145 |
+
output = output.transpose(0, 1)
|
| 146 |
+
curr_pos = curr_pos.transpose(0, 1)
|
| 147 |
+
memory = memory.transpose(0, 1)
|
| 148 |
+
memory_pos = memory_pos.transpose(0, 1)
|
| 149 |
+
|
| 150 |
+
for layer in self.layers:
|
| 151 |
+
kwds = {}
|
| 152 |
+
if isinstance(layer.cross_attn_image, RoPEAttention):
|
| 153 |
+
kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
|
| 154 |
+
|
| 155 |
+
output = layer(
|
| 156 |
+
tgt=output,
|
| 157 |
+
memory=memory,
|
| 158 |
+
pos=memory_pos,
|
| 159 |
+
query_pos=curr_pos,
|
| 160 |
+
**kwds,
|
| 161 |
+
)
|
| 162 |
+
normed_output = self.norm(output)
|
| 163 |
+
|
| 164 |
+
if self.batch_first:
|
| 165 |
+
# Convert back to seq first
|
| 166 |
+
normed_output = normed_output.transpose(0, 1)
|
| 167 |
+
curr_pos = curr_pos.transpose(0, 1)
|
| 168 |
+
|
| 169 |
+
return normed_output
|
avs.code/v1m.code/model/visual/sam2/modeling/memory_encoder.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
from model.visual.sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MaskDownSampler(nn.Module):
|
| 18 |
+
"""
|
| 19 |
+
Progressively downsample a mask by total_stride, each time by stride.
|
| 20 |
+
Note that LayerNorm is applied per *token*, like in ViT.
|
| 21 |
+
|
| 22 |
+
With each downsample (by a factor stride**2), channel capacity increases by the same factor.
|
| 23 |
+
In the end, we linearly project to embed_dim channels.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
embed_dim=256,
|
| 29 |
+
kernel_size=4,
|
| 30 |
+
stride=4,
|
| 31 |
+
padding=0,
|
| 32 |
+
total_stride=16,
|
| 33 |
+
activation=nn.GELU,
|
| 34 |
+
):
|
| 35 |
+
super().__init__()
|
| 36 |
+
num_layers = int(math.log2(total_stride) // math.log2(stride))
|
| 37 |
+
assert stride**num_layers == total_stride
|
| 38 |
+
self.encoder = nn.Sequential()
|
| 39 |
+
mask_in_chans, mask_out_chans = 1, 1
|
| 40 |
+
for _ in range(num_layers):
|
| 41 |
+
mask_out_chans = mask_in_chans * (stride**2)
|
| 42 |
+
self.encoder.append(
|
| 43 |
+
nn.Conv2d(
|
| 44 |
+
mask_in_chans,
|
| 45 |
+
mask_out_chans,
|
| 46 |
+
kernel_size=kernel_size,
|
| 47 |
+
stride=stride,
|
| 48 |
+
padding=padding,
|
| 49 |
+
)
|
| 50 |
+
)
|
| 51 |
+
self.encoder.append(LayerNorm2d(mask_out_chans))
|
| 52 |
+
self.encoder.append(activation())
|
| 53 |
+
mask_in_chans = mask_out_chans
|
| 54 |
+
|
| 55 |
+
self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
|
| 56 |
+
|
| 57 |
+
def forward(self, x):
|
| 58 |
+
return self.encoder(x)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
|
| 62 |
+
class CXBlock(nn.Module):
|
| 63 |
+
r"""ConvNeXt Block. There are two equivalent implementations:
|
| 64 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
| 65 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
| 66 |
+
We use (2) as we find it slightly faster in PyTorch
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
dim (int): Number of input channels.
|
| 70 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
| 71 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(
|
| 75 |
+
self,
|
| 76 |
+
dim,
|
| 77 |
+
kernel_size=7,
|
| 78 |
+
padding=3,
|
| 79 |
+
drop_path=0.0,
|
| 80 |
+
layer_scale_init_value=1e-6,
|
| 81 |
+
use_dwconv=True,
|
| 82 |
+
):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.dwconv = nn.Conv2d(
|
| 85 |
+
dim,
|
| 86 |
+
dim,
|
| 87 |
+
kernel_size=kernel_size,
|
| 88 |
+
padding=padding,
|
| 89 |
+
groups=dim if use_dwconv else 1,
|
| 90 |
+
) # depthwise conv
|
| 91 |
+
self.norm = LayerNorm2d(dim, eps=1e-6)
|
| 92 |
+
self.pwconv1 = nn.Linear(
|
| 93 |
+
dim, 4 * dim
|
| 94 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
| 95 |
+
self.act = nn.GELU()
|
| 96 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
| 97 |
+
self.gamma = (
|
| 98 |
+
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
| 99 |
+
if layer_scale_init_value > 0
|
| 100 |
+
else None
|
| 101 |
+
)
|
| 102 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 103 |
+
|
| 104 |
+
def forward(self, x):
|
| 105 |
+
input = x
|
| 106 |
+
x = self.dwconv(x)
|
| 107 |
+
x = self.norm(x)
|
| 108 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
| 109 |
+
x = self.pwconv1(x)
|
| 110 |
+
x = self.act(x)
|
| 111 |
+
x = self.pwconv2(x)
|
| 112 |
+
if self.gamma is not None:
|
| 113 |
+
x = self.gamma * x
|
| 114 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
| 115 |
+
|
| 116 |
+
x = input + self.drop_path(x)
|
| 117 |
+
return x
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class Fuser(nn.Module):
|
| 121 |
+
def __init__(self, layer, num_layers, dim=None, input_projection=False):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.proj = nn.Identity()
|
| 124 |
+
self.layers = get_clones(layer, num_layers)
|
| 125 |
+
|
| 126 |
+
if input_projection:
|
| 127 |
+
assert dim is not None
|
| 128 |
+
self.proj = nn.Conv2d(dim, dim, kernel_size=1)
|
| 129 |
+
|
| 130 |
+
def forward(self, x):
|
| 131 |
+
# normally x: (N, C, H, W)
|
| 132 |
+
x = self.proj(x)
|
| 133 |
+
for layer in self.layers:
|
| 134 |
+
x = layer(x)
|
| 135 |
+
return x
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class MemoryEncoder(nn.Module):
|
| 139 |
+
def __init__(
|
| 140 |
+
self,
|
| 141 |
+
out_dim,
|
| 142 |
+
mask_downsampler,
|
| 143 |
+
fuser,
|
| 144 |
+
position_encoding,
|
| 145 |
+
in_dim=256, # in_dim of pix_feats
|
| 146 |
+
):
|
| 147 |
+
super().__init__()
|
| 148 |
+
|
| 149 |
+
self.mask_downsampler = mask_downsampler
|
| 150 |
+
|
| 151 |
+
self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
|
| 152 |
+
self.fuser = fuser
|
| 153 |
+
self.position_encoding = position_encoding
|
| 154 |
+
self.out_proj = nn.Identity()
|
| 155 |
+
if out_dim != in_dim:
|
| 156 |
+
self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
|
| 157 |
+
|
| 158 |
+
def forward(
|
| 159 |
+
self,
|
| 160 |
+
pix_feat: torch.Tensor,
|
| 161 |
+
masks: torch.Tensor,
|
| 162 |
+
skip_mask_sigmoid: bool = False,
|
| 163 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 164 |
+
## Process masks
|
| 165 |
+
# sigmoid, so that less domain shift from gt masks which are bool
|
| 166 |
+
if not skip_mask_sigmoid:
|
| 167 |
+
masks = F.sigmoid(masks)
|
| 168 |
+
masks = self.mask_downsampler(masks)
|
| 169 |
+
|
| 170 |
+
## Fuse pix_feats and downsampled masks
|
| 171 |
+
# in case the visual features are on CPU, cast them to CUDA
|
| 172 |
+
pix_feat = pix_feat.to(masks.device)
|
| 173 |
+
|
| 174 |
+
x = self.pix_feat_proj(pix_feat)
|
| 175 |
+
x = x + masks
|
| 176 |
+
x = self.fuser(x)
|
| 177 |
+
x = self.out_proj(x)
|
| 178 |
+
|
| 179 |
+
pos = self.position_encoding(x).to(x.dtype)
|
| 180 |
+
|
| 181 |
+
return {"vision_features": x, "vision_pos_enc": [pos]}
|
avs.code/v1m.code/model/visual/sam2/modeling/position_encoding.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
from typing import Any, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch import nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class PositionEmbeddingSine(nn.Module):
|
| 17 |
+
"""
|
| 18 |
+
This is a more standard version of the position embedding, very similar to the one
|
| 19 |
+
used by the Attention Is All You Need paper, generalized to work on images.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
num_pos_feats,
|
| 25 |
+
temperature: int = 10000,
|
| 26 |
+
normalize: bool = True,
|
| 27 |
+
scale: Optional[float] = None,
|
| 28 |
+
):
|
| 29 |
+
super().__init__()
|
| 30 |
+
assert num_pos_feats % 2 == 0, "Expecting even model width"
|
| 31 |
+
self.num_pos_feats = num_pos_feats // 2
|
| 32 |
+
self.temperature = temperature
|
| 33 |
+
self.normalize = normalize
|
| 34 |
+
if scale is not None and normalize is False:
|
| 35 |
+
raise ValueError("normalize should be True if scale is passed")
|
| 36 |
+
if scale is None:
|
| 37 |
+
scale = 2 * math.pi
|
| 38 |
+
self.scale = scale
|
| 39 |
+
|
| 40 |
+
self.cache = {}
|
| 41 |
+
|
| 42 |
+
def _encode_xy(self, x, y):
|
| 43 |
+
# The positions are expected to be normalized
|
| 44 |
+
assert len(x) == len(y) and x.ndim == y.ndim == 1
|
| 45 |
+
x_embed = x * self.scale
|
| 46 |
+
y_embed = y * self.scale
|
| 47 |
+
|
| 48 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
| 49 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
| 50 |
+
|
| 51 |
+
pos_x = x_embed[:, None] / dim_t
|
| 52 |
+
pos_y = y_embed[:, None] / dim_t
|
| 53 |
+
pos_x = torch.stack(
|
| 54 |
+
(pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
|
| 55 |
+
).flatten(1)
|
| 56 |
+
pos_y = torch.stack(
|
| 57 |
+
(pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
|
| 58 |
+
).flatten(1)
|
| 59 |
+
return pos_x, pos_y
|
| 60 |
+
|
| 61 |
+
@torch.no_grad()
|
| 62 |
+
def encode_boxes(self, x, y, w, h):
|
| 63 |
+
pos_x, pos_y = self._encode_xy(x, y)
|
| 64 |
+
pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
|
| 65 |
+
return pos
|
| 66 |
+
|
| 67 |
+
encode = encode_boxes # Backwards compatibility
|
| 68 |
+
|
| 69 |
+
@torch.no_grad()
|
| 70 |
+
def encode_points(self, x, y, labels):
|
| 71 |
+
(bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
|
| 72 |
+
assert bx == by and nx == ny and bx == bl and nx == nl
|
| 73 |
+
pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
|
| 74 |
+
pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
|
| 75 |
+
pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
|
| 76 |
+
return pos
|
| 77 |
+
|
| 78 |
+
@torch.no_grad()
|
| 79 |
+
def forward(self, x: torch.Tensor):
|
| 80 |
+
cache_key = (x.shape[-2], x.shape[-1])
|
| 81 |
+
if cache_key in self.cache:
|
| 82 |
+
return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
|
| 83 |
+
y_embed = (
|
| 84 |
+
torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
|
| 85 |
+
.view(1, -1, 1)
|
| 86 |
+
.repeat(x.shape[0], 1, x.shape[-1])
|
| 87 |
+
)
|
| 88 |
+
x_embed = (
|
| 89 |
+
torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
|
| 90 |
+
.view(1, 1, -1)
|
| 91 |
+
.repeat(x.shape[0], x.shape[-2], 1)
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
if self.normalize:
|
| 95 |
+
eps = 1e-6
|
| 96 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
| 97 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
| 98 |
+
|
| 99 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
| 100 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
| 101 |
+
|
| 102 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
| 103 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
| 104 |
+
pos_x = torch.stack(
|
| 105 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
| 106 |
+
).flatten(3)
|
| 107 |
+
pos_y = torch.stack(
|
| 108 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
| 109 |
+
).flatten(3)
|
| 110 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 111 |
+
self.cache[cache_key] = pos[0]
|
| 112 |
+
return pos
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class PositionEmbeddingRandom(nn.Module):
|
| 116 |
+
"""
|
| 117 |
+
Positional encoding using random spatial frequencies.
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
|
| 121 |
+
super().__init__()
|
| 122 |
+
if scale is None or scale <= 0.0:
|
| 123 |
+
scale = 1.0
|
| 124 |
+
self.register_buffer(
|
| 125 |
+
"positional_encoding_gaussian_matrix",
|
| 126 |
+
scale * torch.randn((2, num_pos_feats)),
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
|
| 130 |
+
"""Positionally encode points that are normalized to [0,1]."""
|
| 131 |
+
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
| 132 |
+
coords = 2 * coords - 1
|
| 133 |
+
coords = coords @ self.positional_encoding_gaussian_matrix
|
| 134 |
+
coords = 2 * np.pi * coords
|
| 135 |
+
# outputs d_1 x ... x d_n x C shape
|
| 136 |
+
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
|
| 137 |
+
|
| 138 |
+
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
|
| 139 |
+
"""Generate positional encoding for a grid of the specified size."""
|
| 140 |
+
h, w = size
|
| 141 |
+
device: Any = self.positional_encoding_gaussian_matrix.device
|
| 142 |
+
grid = torch.ones((h, w), device=device, dtype=torch.float32)
|
| 143 |
+
y_embed = grid.cumsum(dim=0) - 0.5
|
| 144 |
+
x_embed = grid.cumsum(dim=1) - 0.5
|
| 145 |
+
y_embed = y_embed / h
|
| 146 |
+
x_embed = x_embed / w
|
| 147 |
+
|
| 148 |
+
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
|
| 149 |
+
return pe.permute(2, 0, 1) # C x H x W
|
| 150 |
+
|
| 151 |
+
def forward_with_coords(
|
| 152 |
+
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
|
| 153 |
+
) -> torch.Tensor:
|
| 154 |
+
"""Positionally encode points that are not normalized to [0,1]."""
|
| 155 |
+
coords = coords_input.clone()
|
| 156 |
+
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
|
| 157 |
+
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
|
| 158 |
+
return self._pe_encoding(coords.to(torch.float)) # B x N x C
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# Rotary Positional Encoding, adapted from:
|
| 162 |
+
# 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
|
| 163 |
+
# 2. https://github.com/naver-ai/rope-vit
|
| 164 |
+
# 3. https://github.com/lucidrains/rotary-embedding-torch
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def init_t_xy(end_x: int, end_y: int):
|
| 168 |
+
t = torch.arange(end_x * end_y, dtype=torch.float32)
|
| 169 |
+
t_x = (t % end_x).float()
|
| 170 |
+
t_y = torch.div(t, end_x, rounding_mode="floor").float()
|
| 171 |
+
return t_x, t_y
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
| 175 |
+
freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
| 176 |
+
freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
| 177 |
+
|
| 178 |
+
t_x, t_y = init_t_xy(end_x, end_y)
|
| 179 |
+
freqs_x = torch.outer(t_x, freqs_x)
|
| 180 |
+
freqs_y = torch.outer(t_y, freqs_y)
|
| 181 |
+
freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
|
| 182 |
+
freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
|
| 183 |
+
return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
| 187 |
+
ndim = x.ndim
|
| 188 |
+
assert 0 <= 1 < ndim
|
| 189 |
+
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
|
| 190 |
+
shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
|
| 191 |
+
return freqs_cis.view(*shape)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def apply_rotary_enc(
|
| 195 |
+
xq: torch.Tensor,
|
| 196 |
+
xk: torch.Tensor,
|
| 197 |
+
freqs_cis: torch.Tensor,
|
| 198 |
+
repeat_freqs_k: bool = False,
|
| 199 |
+
):
|
| 200 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
| 201 |
+
xk_ = (
|
| 202 |
+
torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
| 203 |
+
if xk.shape[-2] != 0
|
| 204 |
+
else None
|
| 205 |
+
)
|
| 206 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
| 207 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
| 208 |
+
if xk_ is None:
|
| 209 |
+
# no keys to rotate, due to dropout
|
| 210 |
+
return xq_out.type_as(xq).to(xq.device), xk
|
| 211 |
+
# repeat freqs along seq_len dim to match k seq_len
|
| 212 |
+
if repeat_freqs_k:
|
| 213 |
+
r = xk_.shape[-2] // xq_.shape[-2]
|
| 214 |
+
if freqs_cis.is_cuda:
|
| 215 |
+
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
|
| 216 |
+
else:
|
| 217 |
+
# torch.repeat on complex numbers may not be supported on non-CUDA devices
|
| 218 |
+
# (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
|
| 219 |
+
freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
|
| 220 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
| 221 |
+
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
|
avs.code/v1m.code/model/visual/sam2/modeling/sam/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
avs.code/v1m.code/model/visual/sam2/modeling/sam/mask_decoder.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import List, Optional, Tuple, Type
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
from model.visual.sam2.modeling.sam2_utils import LayerNorm2d, MLP
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class MaskDecoder(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
*,
|
| 19 |
+
transformer_dim: int,
|
| 20 |
+
transformer: nn.Module,
|
| 21 |
+
num_multimask_outputs: int = 3,
|
| 22 |
+
activation: Type[nn.Module] = nn.GELU,
|
| 23 |
+
iou_head_depth: int = 3,
|
| 24 |
+
iou_head_hidden_dim: int = 256,
|
| 25 |
+
use_high_res_features: bool = False,
|
| 26 |
+
iou_prediction_use_sigmoid=False,
|
| 27 |
+
dynamic_multimask_via_stability=False,
|
| 28 |
+
dynamic_multimask_stability_delta=0.05,
|
| 29 |
+
dynamic_multimask_stability_thresh=0.98,
|
| 30 |
+
pred_obj_scores: bool = False,
|
| 31 |
+
pred_obj_scores_mlp: bool = False,
|
| 32 |
+
use_multimask_token_for_obj_ptr: bool = False,
|
| 33 |
+
) -> None:
|
| 34 |
+
"""
|
| 35 |
+
Predicts masks given an image and prompt embeddings, using a
|
| 36 |
+
transformer architecture.
|
| 37 |
+
|
| 38 |
+
Arguments:
|
| 39 |
+
transformer_dim (int): the channel dimension of the transformer
|
| 40 |
+
transformer (nn.Module): the transformer used to predict masks
|
| 41 |
+
num_multimask_outputs (int): the number of masks to predict
|
| 42 |
+
when disambiguating masks
|
| 43 |
+
activation (nn.Module): the type of activation to use when
|
| 44 |
+
upscaling masks
|
| 45 |
+
iou_head_depth (int): the depth of the MLP used to predict
|
| 46 |
+
mask quality
|
| 47 |
+
iou_head_hidden_dim (int): the hidden dimension of the MLP
|
| 48 |
+
used to predict mask quality
|
| 49 |
+
"""
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.transformer_dim = transformer_dim
|
| 52 |
+
self.transformer = transformer
|
| 53 |
+
|
| 54 |
+
self.num_multimask_outputs = num_multimask_outputs
|
| 55 |
+
|
| 56 |
+
self.iou_token = nn.Embedding(1, transformer_dim)
|
| 57 |
+
self.num_mask_tokens = num_multimask_outputs + 1
|
| 58 |
+
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
| 59 |
+
|
| 60 |
+
self.pred_obj_scores = pred_obj_scores
|
| 61 |
+
if self.pred_obj_scores:
|
| 62 |
+
self.obj_score_token = nn.Embedding(1, transformer_dim)
|
| 63 |
+
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
|
| 64 |
+
|
| 65 |
+
self.output_upscaling = nn.Sequential(
|
| 66 |
+
nn.ConvTranspose2d(
|
| 67 |
+
transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
|
| 68 |
+
),
|
| 69 |
+
LayerNorm2d(transformer_dim // 4),
|
| 70 |
+
activation(),
|
| 71 |
+
nn.ConvTranspose2d(
|
| 72 |
+
transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
|
| 73 |
+
),
|
| 74 |
+
activation(),
|
| 75 |
+
)
|
| 76 |
+
self.use_high_res_features = use_high_res_features
|
| 77 |
+
if use_high_res_features:
|
| 78 |
+
self.conv_s0 = nn.Conv2d(
|
| 79 |
+
transformer_dim, transformer_dim // 8, kernel_size=1, stride=1
|
| 80 |
+
)
|
| 81 |
+
self.conv_s1 = nn.Conv2d(
|
| 82 |
+
transformer_dim, transformer_dim // 4, kernel_size=1, stride=1
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
self.output_hypernetworks_mlps = nn.ModuleList(
|
| 86 |
+
[
|
| 87 |
+
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
|
| 88 |
+
for i in range(self.num_mask_tokens)
|
| 89 |
+
]
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
self.iou_prediction_head = MLP(
|
| 93 |
+
transformer_dim,
|
| 94 |
+
iou_head_hidden_dim,
|
| 95 |
+
self.num_mask_tokens,
|
| 96 |
+
iou_head_depth,
|
| 97 |
+
sigmoid_output=iou_prediction_use_sigmoid,
|
| 98 |
+
)
|
| 99 |
+
if self.pred_obj_scores:
|
| 100 |
+
self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
|
| 101 |
+
if pred_obj_scores_mlp:
|
| 102 |
+
self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
|
| 103 |
+
|
| 104 |
+
# When outputting a single mask, optionally we can dynamically fall back to the best
|
| 105 |
+
# multimask output token if the single mask output token gives low stability scores.
|
| 106 |
+
self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
|
| 107 |
+
self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
|
| 108 |
+
self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
|
| 109 |
+
|
| 110 |
+
def forward(
|
| 111 |
+
self,
|
| 112 |
+
image_embeddings: torch.Tensor,
|
| 113 |
+
image_pe: torch.Tensor,
|
| 114 |
+
sparse_prompt_embeddings: torch.Tensor,
|
| 115 |
+
dense_prompt_embeddings: torch.Tensor,
|
| 116 |
+
multimask_output: bool,
|
| 117 |
+
repeat_image: bool,
|
| 118 |
+
high_res_features: Optional[List[torch.Tensor]] = None,
|
| 119 |
+
audio_res_features: Optional[List[torch.Tensor]] = None,
|
| 120 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 121 |
+
"""
|
| 122 |
+
Predict masks given image and prompt embeddings.
|
| 123 |
+
|
| 124 |
+
Arguments:
|
| 125 |
+
image_embeddings (torch.Tensor): the embeddings from the image encoder
|
| 126 |
+
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
|
| 127 |
+
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
|
| 128 |
+
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
|
| 129 |
+
multimask_output (bool): Whether to return multiple masks or a single
|
| 130 |
+
mask.
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
torch.Tensor: batched predicted masks
|
| 134 |
+
torch.Tensor: batched predictions of mask quality
|
| 135 |
+
torch.Tensor: batched SAM token for mask output
|
| 136 |
+
"""
|
| 137 |
+
masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
|
| 138 |
+
image_embeddings=image_embeddings,
|
| 139 |
+
image_pe=image_pe,
|
| 140 |
+
sparse_prompt_embeddings=sparse_prompt_embeddings,
|
| 141 |
+
dense_prompt_embeddings=dense_prompt_embeddings,
|
| 142 |
+
repeat_image=repeat_image,
|
| 143 |
+
high_res_features=high_res_features,
|
| 144 |
+
audio_res_features_=audio_res_features
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Select the correct mask or masks for output
|
| 148 |
+
if multimask_output:
|
| 149 |
+
masks = masks[:, 1:, :, :]
|
| 150 |
+
iou_pred = iou_pred[:, 1:]
|
| 151 |
+
elif self.dynamic_multimask_via_stability and not self.training:
|
| 152 |
+
masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
|
| 153 |
+
else:
|
| 154 |
+
masks = masks[:, 0:1, :, :]
|
| 155 |
+
iou_pred = iou_pred[:, 0:1]
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
if multimask_output and self.use_multimask_token_for_obj_ptr:
|
| 159 |
+
sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
|
| 160 |
+
else:
|
| 161 |
+
# Take the mask output token. Here we *always* use the token for single mask output.
|
| 162 |
+
# At test time, even if we track after 1-click (and using multimask_output=True),
|
| 163 |
+
# we still take the single mask token here. The rationale is that we always track
|
| 164 |
+
# after multiple clicks during training, so the past tokens seen during training
|
| 165 |
+
# are always the single mask token (and we'll let it be the object-memory token).
|
| 166 |
+
sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
|
| 167 |
+
|
| 168 |
+
# Prepare output
|
| 169 |
+
return masks, iou_pred, sam_tokens_out, object_score_logits
|
| 170 |
+
|
| 171 |
+
def predict_masks(
|
| 172 |
+
self,
|
| 173 |
+
image_embeddings: torch.Tensor,
|
| 174 |
+
image_pe: torch.Tensor,
|
| 175 |
+
sparse_prompt_embeddings: torch.Tensor,
|
| 176 |
+
dense_prompt_embeddings: torch.Tensor,
|
| 177 |
+
repeat_image: bool,
|
| 178 |
+
high_res_features: Optional[List[torch.Tensor]] = None,
|
| 179 |
+
audio_res_features_: Optional[List[torch.Tensor]] = None
|
| 180 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 181 |
+
"""Predicts masks. See 'forward' for more details."""
|
| 182 |
+
# Concatenate output tokens
|
| 183 |
+
s = 0
|
| 184 |
+
if self.pred_obj_scores:
|
| 185 |
+
output_tokens = torch.cat(
|
| 186 |
+
[
|
| 187 |
+
self.obj_score_token.weight,
|
| 188 |
+
self.iou_token.weight,
|
| 189 |
+
self.mask_tokens.weight,
|
| 190 |
+
],
|
| 191 |
+
dim=0,
|
| 192 |
+
)
|
| 193 |
+
s = 1
|
| 194 |
+
else:
|
| 195 |
+
output_tokens = torch.cat(
|
| 196 |
+
[self.iou_token.weight, self.mask_tokens.weight], dim=0
|
| 197 |
+
)
|
| 198 |
+
output_tokens = output_tokens.unsqueeze(0).expand(
|
| 199 |
+
sparse_prompt_embeddings.size(0), -1, -1
|
| 200 |
+
)
|
| 201 |
+
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
|
| 202 |
+
|
| 203 |
+
# Expand per-image data in batch direction to be per-mask
|
| 204 |
+
if repeat_image:
|
| 205 |
+
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
|
| 206 |
+
else:
|
| 207 |
+
assert image_embeddings.shape[0] == tokens.shape[0]
|
| 208 |
+
src = image_embeddings
|
| 209 |
+
src = src + dense_prompt_embeddings
|
| 210 |
+
assert (
|
| 211 |
+
image_pe.size(0) == 1
|
| 212 |
+
), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
|
| 213 |
+
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
| 214 |
+
b, c, h, w = src.shape
|
| 215 |
+
|
| 216 |
+
# Run the transformer
|
| 217 |
+
hs, src = self.transformer(src, pos_src, tokens, audio_res_features_)
|
| 218 |
+
iou_token_out = hs[:, s, :]
|
| 219 |
+
mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
|
| 220 |
+
|
| 221 |
+
# Upscale mask embeddings and predict masks using the mask tokens
|
| 222 |
+
src = src.transpose(1, 2).view(b, c, h, w)
|
| 223 |
+
|
| 224 |
+
if not self.use_high_res_features:
|
| 225 |
+
upscaled_embedding = self.output_upscaling(src)
|
| 226 |
+
else:
|
| 227 |
+
dc1, ln1, act1, dc2, act2 = self.output_upscaling
|
| 228 |
+
feat_s0, feat_s1 = high_res_features
|
| 229 |
+
upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
|
| 230 |
+
upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
|
| 231 |
+
|
| 232 |
+
hyper_in_list: List[torch.Tensor] = []
|
| 233 |
+
for i in range(self.num_mask_tokens):
|
| 234 |
+
hyper_in_list.append(
|
| 235 |
+
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
|
| 236 |
+
)
|
| 237 |
+
hyper_in = torch.stack(hyper_in_list, dim=1)
|
| 238 |
+
b, c, h, w = upscaled_embedding.shape
|
| 239 |
+
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
|
| 240 |
+
|
| 241 |
+
# Generate mask quality predictions
|
| 242 |
+
iou_pred = self.iou_prediction_head(iou_token_out)
|
| 243 |
+
if self.pred_obj_scores:
|
| 244 |
+
assert s == 1
|
| 245 |
+
object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
|
| 246 |
+
else:
|
| 247 |
+
# Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
|
| 248 |
+
object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
|
| 249 |
+
|
| 250 |
+
return masks, iou_pred, mask_tokens_out, object_score_logits
|
| 251 |
+
|
| 252 |
+
def _get_stability_scores(self, mask_logits):
|
| 253 |
+
"""
|
| 254 |
+
Compute stability scores of the mask logits based on the IoU between upper and
|
| 255 |
+
lower thresholds.
|
| 256 |
+
"""
|
| 257 |
+
mask_logits = mask_logits.flatten(-2)
|
| 258 |
+
stability_delta = self.dynamic_multimask_stability_delta
|
| 259 |
+
area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
|
| 260 |
+
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
|
| 261 |
+
stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
|
| 262 |
+
return stability_scores
|
| 263 |
+
|
| 264 |
+
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
|
| 265 |
+
"""
|
| 266 |
+
When outputting a single mask, if the stability score from the current single-mask
|
| 267 |
+
output (based on output token 0) falls below a threshold, we instead select from
|
| 268 |
+
multi-mask outputs (based on output token 1~3) the mask with the highest predicted
|
| 269 |
+
IoU score. This is intended to ensure a valid mask for both clicking and tracking.
|
| 270 |
+
"""
|
| 271 |
+
# The best mask from multimask output tokens (1~3)
|
| 272 |
+
multimask_logits = all_mask_logits[:, 1:, :, :]
|
| 273 |
+
multimask_iou_scores = all_iou_scores[:, 1:]
|
| 274 |
+
best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
|
| 275 |
+
batch_inds = torch.arange(
|
| 276 |
+
multimask_iou_scores.size(0), device=all_iou_scores.device
|
| 277 |
+
)
|
| 278 |
+
best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
|
| 279 |
+
best_multimask_logits = best_multimask_logits.unsqueeze(1)
|
| 280 |
+
best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
|
| 281 |
+
best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
|
| 282 |
+
|
| 283 |
+
# The mask from singlemask output token 0 and its stability score
|
| 284 |
+
singlemask_logits = all_mask_logits[:, 0:1, :, :]
|
| 285 |
+
singlemask_iou_scores = all_iou_scores[:, 0:1]
|
| 286 |
+
stability_scores = self._get_stability_scores(singlemask_logits)
|
| 287 |
+
is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
|
| 288 |
+
|
| 289 |
+
# Dynamically fall back to best multimask output upon low stability scores.
|
| 290 |
+
mask_logits_out = torch.where(
|
| 291 |
+
is_stable[..., None, None].expand_as(singlemask_logits),
|
| 292 |
+
singlemask_logits,
|
| 293 |
+
best_multimask_logits,
|
| 294 |
+
)
|
| 295 |
+
iou_scores_out = torch.where(
|
| 296 |
+
is_stable.expand_as(singlemask_iou_scores),
|
| 297 |
+
singlemask_iou_scores,
|
| 298 |
+
best_multimask_iou_scores,
|
| 299 |
+
)
|
| 300 |
+
return mask_logits_out, iou_scores_out
|
avs.code/v1m.code/model/visual/sam2/modeling/sam/prompt_encoder.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Optional, Tuple, Type
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
from model.visual.sam2.modeling.position_encoding import PositionEmbeddingRandom
|
| 13 |
+
|
| 14 |
+
from model.visual.sam2.modeling.sam2_utils import LayerNorm2d
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class PromptEncoder(nn.Module):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
embed_dim: int,
|
| 21 |
+
image_embedding_size: Tuple[int, int],
|
| 22 |
+
input_image_size: Tuple[int, int],
|
| 23 |
+
mask_in_chans: int,
|
| 24 |
+
activation: Type[nn.Module] = nn.GELU,
|
| 25 |
+
) -> None:
|
| 26 |
+
"""
|
| 27 |
+
Encodes prompts for input to SAM's mask decoder.
|
| 28 |
+
|
| 29 |
+
Arguments:
|
| 30 |
+
embed_dim (int): The prompts' embedding dimension
|
| 31 |
+
image_embedding_size (tuple(int, int)): The spatial size of the
|
| 32 |
+
image embedding, as (H, W).
|
| 33 |
+
input_image_size (int): The padded size of the image as input
|
| 34 |
+
to the image encoder, as (H, W).
|
| 35 |
+
mask_in_chans (int): The number of hidden channels used for
|
| 36 |
+
encoding input masks.
|
| 37 |
+
activation (nn.Module): The activation to use when encoding
|
| 38 |
+
input masks.
|
| 39 |
+
"""
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.embed_dim = embed_dim
|
| 42 |
+
self.input_image_size = input_image_size
|
| 43 |
+
self.image_embedding_size = image_embedding_size
|
| 44 |
+
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
| 45 |
+
|
| 46 |
+
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
|
| 47 |
+
point_embeddings = [
|
| 48 |
+
nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
|
| 49 |
+
]
|
| 50 |
+
self.point_embeddings = nn.ModuleList(point_embeddings)
|
| 51 |
+
self.not_a_point_embed = nn.Embedding(1, embed_dim)
|
| 52 |
+
|
| 53 |
+
self.mask_input_size = (
|
| 54 |
+
4 * image_embedding_size[0],
|
| 55 |
+
4 * image_embedding_size[1],
|
| 56 |
+
)
|
| 57 |
+
self.mask_downscaling = nn.Sequential(
|
| 58 |
+
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
|
| 59 |
+
LayerNorm2d(mask_in_chans // 4),
|
| 60 |
+
activation(),
|
| 61 |
+
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
|
| 62 |
+
LayerNorm2d(mask_in_chans),
|
| 63 |
+
activation(),
|
| 64 |
+
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
|
| 65 |
+
)
|
| 66 |
+
self.no_mask_embed = nn.Embedding(1, embed_dim)
|
| 67 |
+
|
| 68 |
+
def get_dense_pe(self) -> torch.Tensor:
|
| 69 |
+
"""
|
| 70 |
+
Returns the positional encoding used to encode point prompts,
|
| 71 |
+
applied to a dense set of points the shape of the image encoding.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
torch.Tensor: Positional encoding with shape
|
| 75 |
+
1x(embed_dim)x(embedding_h)x(embedding_w)
|
| 76 |
+
"""
|
| 77 |
+
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
|
| 78 |
+
|
| 79 |
+
def _embed_points(
|
| 80 |
+
self,
|
| 81 |
+
points: torch.Tensor,
|
| 82 |
+
labels: torch.Tensor,
|
| 83 |
+
pad: bool,
|
| 84 |
+
) -> torch.Tensor:
|
| 85 |
+
"""Embeds point prompts."""
|
| 86 |
+
points = points + 0.5 # Shift to center of pixel
|
| 87 |
+
if pad:
|
| 88 |
+
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
|
| 89 |
+
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
|
| 90 |
+
points = torch.cat([points, padding_point], dim=1)
|
| 91 |
+
labels = torch.cat([labels, padding_label], dim=1)
|
| 92 |
+
point_embedding = self.pe_layer.forward_with_coords(
|
| 93 |
+
points, self.input_image_size
|
| 94 |
+
)
|
| 95 |
+
point_embedding[labels == -1] = 0.0
|
| 96 |
+
point_embedding[labels == -1] += self.not_a_point_embed.weight
|
| 97 |
+
point_embedding[labels == 0] += self.point_embeddings[0].weight
|
| 98 |
+
point_embedding[labels == 1] += self.point_embeddings[1].weight
|
| 99 |
+
point_embedding[labels == 2] += self.point_embeddings[2].weight
|
| 100 |
+
point_embedding[labels == 3] += self.point_embeddings[3].weight
|
| 101 |
+
return point_embedding
|
| 102 |
+
|
| 103 |
+
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
| 104 |
+
"""Embeds box prompts."""
|
| 105 |
+
boxes = boxes + 0.5 # Shift to center of pixel
|
| 106 |
+
coords = boxes.reshape(-1, 2, 2)
|
| 107 |
+
corner_embedding = self.pe_layer.forward_with_coords(
|
| 108 |
+
coords, self.input_image_size
|
| 109 |
+
)
|
| 110 |
+
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
|
| 111 |
+
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
|
| 112 |
+
return corner_embedding
|
| 113 |
+
|
| 114 |
+
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
|
| 115 |
+
"""Embeds mask inputs."""
|
| 116 |
+
mask_embedding = self.mask_downscaling(masks)
|
| 117 |
+
return mask_embedding
|
| 118 |
+
|
| 119 |
+
def _get_batch_size(
|
| 120 |
+
self,
|
| 121 |
+
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
| 122 |
+
boxes: Optional[torch.Tensor],
|
| 123 |
+
masks: Optional[torch.Tensor],
|
| 124 |
+
) -> int:
|
| 125 |
+
"""
|
| 126 |
+
Gets the batch size of the output given the batch size of the input prompts.
|
| 127 |
+
"""
|
| 128 |
+
if points is not None:
|
| 129 |
+
return points[0].shape[0]
|
| 130 |
+
elif boxes is not None:
|
| 131 |
+
return boxes.shape[0]
|
| 132 |
+
elif masks is not None:
|
| 133 |
+
return masks.shape[0]
|
| 134 |
+
else:
|
| 135 |
+
return 1
|
| 136 |
+
|
| 137 |
+
def _get_device(self) -> torch.device:
|
| 138 |
+
return self.point_embeddings[0].weight.device
|
| 139 |
+
|
| 140 |
+
def forward(
|
| 141 |
+
self,
|
| 142 |
+
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
| 143 |
+
boxes: Optional[torch.Tensor],
|
| 144 |
+
masks: Optional[torch.Tensor],
|
| 145 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 146 |
+
"""
|
| 147 |
+
Embeds different types of prompts, returning both sparse and dense
|
| 148 |
+
embeddings.
|
| 149 |
+
|
| 150 |
+
Arguments:
|
| 151 |
+
points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
|
| 152 |
+
and labels to embed.
|
| 153 |
+
boxes (torch.Tensor or none): boxes to embed
|
| 154 |
+
masks (torch.Tensor or none): masks to embed
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
torch.Tensor: sparse embeddings for the points and boxes, with shape
|
| 158 |
+
BxNx(embed_dim), where N is determined by the number of input points
|
| 159 |
+
and boxes.
|
| 160 |
+
torch.Tensor: dense embeddings for the masks, in the shape
|
| 161 |
+
Bx(embed_dim)x(embed_H)x(embed_W)
|
| 162 |
+
"""
|
| 163 |
+
# we only utilise sounding as prompt.
|
| 164 |
+
bs = self._get_batch_size(points, boxes, masks)
|
| 165 |
+
sparse_embeddings = torch.empty(
|
| 166 |
+
(bs, 0, self.embed_dim), device=self._get_device()
|
| 167 |
+
)
|
| 168 |
+
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
|
| 169 |
+
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
|
| 170 |
+
)
|
| 171 |
+
'''
|
| 172 |
+
if points is not None:
|
| 173 |
+
coords, labels = points
|
| 174 |
+
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
|
| 175 |
+
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
|
| 176 |
+
if boxes is not None:
|
| 177 |
+
box_embeddings = self._embed_boxes(boxes)
|
| 178 |
+
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
|
| 179 |
+
|
| 180 |
+
if masks is not None:
|
| 181 |
+
dense_embeddings = self._embed_masks(masks)
|
| 182 |
+
else:
|
| 183 |
+
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
|
| 184 |
+
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
|
| 185 |
+
)
|
| 186 |
+
'''
|
| 187 |
+
return sparse_embeddings, dense_embeddings
|
| 188 |
+
|
avs.code/v1m.code/model/visual/sam2/modeling/sam/transformer.py
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import contextlib
|
| 8 |
+
import math
|
| 9 |
+
import warnings
|
| 10 |
+
from functools import partial
|
| 11 |
+
from typing import Tuple, Type
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from torch import nn, Tensor
|
| 16 |
+
|
| 17 |
+
from model.visual.sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
|
| 18 |
+
from model.visual.sam2.modeling.sam2_utils import MLP
|
| 19 |
+
from model.visual.sam2.utils.misc import get_sdpa_settings
|
| 20 |
+
|
| 21 |
+
warnings.simplefilter(action="ignore", category=FutureWarning)
|
| 22 |
+
# Check whether Flash Attention is available (and use it by default)
|
| 23 |
+
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
|
| 24 |
+
# A fallback setting to allow all available kernels if Flash Attention fails
|
| 25 |
+
ALLOW_ALL_KERNELS = False
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def sdp_kernel_context(dropout_p):
|
| 29 |
+
"""
|
| 30 |
+
Get the context for the attention scaled dot-product kernel. We use Flash Attention
|
| 31 |
+
by default, but fall back to all available kernels if Flash Attention fails.
|
| 32 |
+
"""
|
| 33 |
+
if ALLOW_ALL_KERNELS:
|
| 34 |
+
return contextlib.nullcontext()
|
| 35 |
+
|
| 36 |
+
return torch.backends.cuda.sdp_kernel(
|
| 37 |
+
enable_flash=USE_FLASH_ATTN,
|
| 38 |
+
# if Flash attention kernel is off, then math kernel needs to be enabled
|
| 39 |
+
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
|
| 40 |
+
enable_mem_efficient=OLD_GPU,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class TwoWayTransformer(nn.Module):
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
depth: int,
|
| 48 |
+
embedding_dim: int,
|
| 49 |
+
num_heads: int,
|
| 50 |
+
mlp_dim: int,
|
| 51 |
+
activation: Type[nn.Module] = nn.ReLU,
|
| 52 |
+
attention_downsample_rate: int = 2,
|
| 53 |
+
) -> None:
|
| 54 |
+
"""
|
| 55 |
+
A transformer decoder that attends to an input image using
|
| 56 |
+
queries whose positional embedding is supplied.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
depth (int): number of layers in the transformer
|
| 60 |
+
embedding_dim (int): the channel dimension for the input embeddings
|
| 61 |
+
num_heads (int): the number of heads for multihead attention. Must
|
| 62 |
+
divide embedding_dim
|
| 63 |
+
mlp_dim (int): the channel dimension internal to the MLP block
|
| 64 |
+
activation (nn.Module): the activation to use in the MLP block
|
| 65 |
+
"""
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.depth = depth
|
| 68 |
+
self.embedding_dim = embedding_dim
|
| 69 |
+
self.num_heads = num_heads
|
| 70 |
+
self.mlp_dim = mlp_dim
|
| 71 |
+
self.layers = nn.ModuleList()
|
| 72 |
+
|
| 73 |
+
for i in range(depth):
|
| 74 |
+
self.layers.append(
|
| 75 |
+
TwoWayAttentionBlock(
|
| 76 |
+
embedding_dim=embedding_dim,
|
| 77 |
+
num_heads=num_heads,
|
| 78 |
+
mlp_dim=mlp_dim,
|
| 79 |
+
activation=activation,
|
| 80 |
+
attention_downsample_rate=attention_downsample_rate,
|
| 81 |
+
skip_first_layer_pe=(i == 0),
|
| 82 |
+
)
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
self.final_attn_token_to_image = Attention(
|
| 86 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
| 87 |
+
)
|
| 88 |
+
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
| 89 |
+
|
| 90 |
+
def forward(
|
| 91 |
+
self,
|
| 92 |
+
image_embedding: Tensor,
|
| 93 |
+
image_pe: Tensor,
|
| 94 |
+
point_embedding: Tensor,
|
| 95 |
+
audio_res: [],
|
| 96 |
+
) -> Tuple[Tensor, Tensor]:
|
| 97 |
+
"""
|
| 98 |
+
Args:
|
| 99 |
+
image_embedding (torch.Tensor): image to attend to. Should be shape
|
| 100 |
+
B x embedding_dim x h x w for any h and w.
|
| 101 |
+
image_pe (torch.Tensor): the positional encoding to add to the image. Must
|
| 102 |
+
have the same shape as image_embedding.
|
| 103 |
+
point_embedding (torch.Tensor): the embedding to add to the query points.
|
| 104 |
+
Must have shape B x N_points x embedding_dim for any N_points.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
torch.Tensor: the processed point_embedding
|
| 108 |
+
torch.Tensor: the processed image_embedding
|
| 109 |
+
"""
|
| 110 |
+
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
| 111 |
+
bs, c, h, w = image_embedding.shape
|
| 112 |
+
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
| 113 |
+
image_pe = image_pe.flatten(2).permute(0, 2, 1)
|
| 114 |
+
|
| 115 |
+
visual_res, audio_res = audio_res
|
| 116 |
+
|
| 117 |
+
# Prepare queries
|
| 118 |
+
queries = point_embedding
|
| 119 |
+
keys = image_embedding
|
| 120 |
+
# Apply transformer blocks and final layernorm
|
| 121 |
+
for i, layer in enumerate(self.layers):
|
| 122 |
+
keys = keys + visual_res[i]
|
| 123 |
+
queries[:, 2:6] = queries[:, 2:6] + audio_res[i]
|
| 124 |
+
queries, keys = layer(
|
| 125 |
+
queries=queries,
|
| 126 |
+
keys=keys,
|
| 127 |
+
query_pe=point_embedding,
|
| 128 |
+
key_pe=image_pe,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
queries[:, 2:6] = queries[:, 2:6] + audio_res[-1]
|
| 132 |
+
keys = keys + visual_res[-1]
|
| 133 |
+
|
| 134 |
+
# Apply the final attention layer from the points to the image
|
| 135 |
+
q = queries + point_embedding
|
| 136 |
+
k = keys + image_pe
|
| 137 |
+
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
|
| 138 |
+
queries = queries + attn_out
|
| 139 |
+
queries = self.norm_final_attn(queries)
|
| 140 |
+
|
| 141 |
+
return queries, keys
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class TwoWayAttentionBlock(nn.Module):
|
| 145 |
+
def __init__(
|
| 146 |
+
self,
|
| 147 |
+
embedding_dim: int,
|
| 148 |
+
num_heads: int,
|
| 149 |
+
mlp_dim: int = 2048,
|
| 150 |
+
activation: Type[nn.Module] = nn.ReLU,
|
| 151 |
+
attention_downsample_rate: int = 2,
|
| 152 |
+
skip_first_layer_pe: bool = False,
|
| 153 |
+
) -> None:
|
| 154 |
+
"""
|
| 155 |
+
A transformer block with four layers: (1) self-attention of sparse
|
| 156 |
+
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
|
| 157 |
+
block on sparse inputs, and (4) cross attention of dense inputs to sparse
|
| 158 |
+
inputs.
|
| 159 |
+
|
| 160 |
+
Arguments:
|
| 161 |
+
embedding_dim (int): the channel dimension of the embeddings
|
| 162 |
+
num_heads (int): the number of heads in the attention layers
|
| 163 |
+
mlp_dim (int): the hidden dimension of the mlp block
|
| 164 |
+
activation (nn.Module): the activation of the mlp block
|
| 165 |
+
skip_first_layer_pe (bool): skip the PE on the first layer
|
| 166 |
+
"""
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.self_attn = Attention(embedding_dim, num_heads)
|
| 169 |
+
self.norm1 = nn.LayerNorm(embedding_dim)
|
| 170 |
+
|
| 171 |
+
self.cross_attn_token_to_image = Attention(
|
| 172 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
| 173 |
+
)
|
| 174 |
+
self.norm2 = nn.LayerNorm(embedding_dim)
|
| 175 |
+
|
| 176 |
+
self.mlp = MLP(
|
| 177 |
+
embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
|
| 178 |
+
)
|
| 179 |
+
self.norm3 = nn.LayerNorm(embedding_dim)
|
| 180 |
+
|
| 181 |
+
self.norm4 = nn.LayerNorm(embedding_dim)
|
| 182 |
+
self.cross_attn_image_to_token = Attention(
|
| 183 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
self.skip_first_layer_pe = skip_first_layer_pe
|
| 187 |
+
|
| 188 |
+
def forward(
|
| 189 |
+
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
|
| 190 |
+
) -> Tuple[Tensor, Tensor]:
|
| 191 |
+
# Self attention block
|
| 192 |
+
if self.skip_first_layer_pe:
|
| 193 |
+
queries = self.self_attn(q=queries, k=queries, v=queries)
|
| 194 |
+
else:
|
| 195 |
+
q = queries + query_pe
|
| 196 |
+
attn_out = self.self_attn(q=q, k=q, v=queries)
|
| 197 |
+
queries = queries + attn_out
|
| 198 |
+
queries = self.norm1(queries)
|
| 199 |
+
|
| 200 |
+
# Cross attention block, tokens attending to image embedding
|
| 201 |
+
q = queries + query_pe
|
| 202 |
+
k = keys + key_pe
|
| 203 |
+
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
|
| 204 |
+
queries = queries + attn_out
|
| 205 |
+
queries = self.norm2(queries)
|
| 206 |
+
|
| 207 |
+
# MLP block
|
| 208 |
+
mlp_out = self.mlp(queries)
|
| 209 |
+
queries = queries + mlp_out
|
| 210 |
+
queries = self.norm3(queries)
|
| 211 |
+
|
| 212 |
+
# Cross attention block, image embedding attending to tokens
|
| 213 |
+
q = queries + query_pe
|
| 214 |
+
k = keys + key_pe
|
| 215 |
+
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
|
| 216 |
+
keys = keys + attn_out
|
| 217 |
+
keys = self.norm4(keys)
|
| 218 |
+
|
| 219 |
+
return queries, keys
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class Attention(nn.Module):
|
| 223 |
+
"""
|
| 224 |
+
An attention layer that allows for downscaling the size of the embedding
|
| 225 |
+
after projection to queries, keys, and values.
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
def __init__(
|
| 229 |
+
self,
|
| 230 |
+
embedding_dim: int,
|
| 231 |
+
num_heads: int,
|
| 232 |
+
downsample_rate: int = 1,
|
| 233 |
+
dropout: float = 0.0,
|
| 234 |
+
kv_in_dim: int = None,
|
| 235 |
+
) -> None:
|
| 236 |
+
super().__init__()
|
| 237 |
+
self.embedding_dim = embedding_dim
|
| 238 |
+
self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
|
| 239 |
+
self.internal_dim = embedding_dim // downsample_rate
|
| 240 |
+
self.num_heads = num_heads
|
| 241 |
+
assert (
|
| 242 |
+
self.internal_dim % num_heads == 0
|
| 243 |
+
), "num_heads must divide embedding_dim."
|
| 244 |
+
|
| 245 |
+
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
| 246 |
+
self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
|
| 247 |
+
self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
|
| 248 |
+
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
| 249 |
+
|
| 250 |
+
self.dropout_p = dropout
|
| 251 |
+
|
| 252 |
+
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
|
| 253 |
+
b, n, c = x.shape
|
| 254 |
+
x = x.reshape(b, n, num_heads, c // num_heads)
|
| 255 |
+
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
| 256 |
+
|
| 257 |
+
def _recombine_heads(self, x: Tensor) -> Tensor:
|
| 258 |
+
b, n_heads, n_tokens, c_per_head = x.shape
|
| 259 |
+
x = x.transpose(1, 2)
|
| 260 |
+
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
| 261 |
+
|
| 262 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
| 263 |
+
# Input projections
|
| 264 |
+
q = self.q_proj(q)
|
| 265 |
+
k = self.k_proj(k)
|
| 266 |
+
v = self.v_proj(v)
|
| 267 |
+
|
| 268 |
+
# Separate into heads
|
| 269 |
+
q = self._separate_heads(q, self.num_heads)
|
| 270 |
+
k = self._separate_heads(k, self.num_heads)
|
| 271 |
+
v = self._separate_heads(v, self.num_heads)
|
| 272 |
+
|
| 273 |
+
dropout_p = self.dropout_p if self.training else 0.0
|
| 274 |
+
# Attention
|
| 275 |
+
try:
|
| 276 |
+
with sdp_kernel_context(dropout_p):
|
| 277 |
+
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
| 278 |
+
except Exception as e:
|
| 279 |
+
# Fall back to all kernels if the Flash attention kernel fails
|
| 280 |
+
warnings.warn(
|
| 281 |
+
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
|
| 282 |
+
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
|
| 283 |
+
category=UserWarning,
|
| 284 |
+
stacklevel=2,
|
| 285 |
+
)
|
| 286 |
+
global ALLOW_ALL_KERNELS
|
| 287 |
+
ALLOW_ALL_KERNELS = True
|
| 288 |
+
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
| 289 |
+
|
| 290 |
+
out = self._recombine_heads(out)
|
| 291 |
+
out = self.out_proj(out)
|
| 292 |
+
|
| 293 |
+
return out
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class RoPEAttention(Attention):
|
| 297 |
+
"""Attention with rotary position encoding."""
|
| 298 |
+
|
| 299 |
+
def __init__(
|
| 300 |
+
self,
|
| 301 |
+
*args,
|
| 302 |
+
rope_theta=10000.0,
|
| 303 |
+
# whether to repeat q rope to match k length
|
| 304 |
+
# this is needed for cross-attention to memories
|
| 305 |
+
rope_k_repeat=False,
|
| 306 |
+
feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
|
| 307 |
+
**kwargs,
|
| 308 |
+
):
|
| 309 |
+
super().__init__(*args, **kwargs)
|
| 310 |
+
|
| 311 |
+
self.compute_cis = partial(
|
| 312 |
+
compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
|
| 313 |
+
)
|
| 314 |
+
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
|
| 315 |
+
self.freqs_cis = freqs_cis
|
| 316 |
+
self.rope_k_repeat = rope_k_repeat
|
| 317 |
+
|
| 318 |
+
def forward(
|
| 319 |
+
self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
|
| 320 |
+
) -> Tensor:
|
| 321 |
+
# Input projections
|
| 322 |
+
q = self.q_proj(q)
|
| 323 |
+
k = self.k_proj(k)
|
| 324 |
+
v = self.v_proj(v)
|
| 325 |
+
|
| 326 |
+
# Separate into heads
|
| 327 |
+
q = self._separate_heads(q, self.num_heads)
|
| 328 |
+
k = self._separate_heads(k, self.num_heads)
|
| 329 |
+
v = self._separate_heads(v, self.num_heads)
|
| 330 |
+
|
| 331 |
+
# Apply rotary position encoding
|
| 332 |
+
w = h = math.sqrt(q.shape[-2])
|
| 333 |
+
self.freqs_cis = self.freqs_cis.to(q.device)
|
| 334 |
+
if self.freqs_cis.shape[0] != q.shape[-2]:
|
| 335 |
+
self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
|
| 336 |
+
if q.shape[-2] != k.shape[-2]:
|
| 337 |
+
assert self.rope_k_repeat
|
| 338 |
+
|
| 339 |
+
num_k_rope = k.size(-2) - num_k_exclude_rope
|
| 340 |
+
q, k[:, :, :num_k_rope] = apply_rotary_enc(
|
| 341 |
+
q,
|
| 342 |
+
k[:, :, :num_k_rope],
|
| 343 |
+
freqs_cis=self.freqs_cis,
|
| 344 |
+
repeat_freqs_k=self.rope_k_repeat,
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
dropout_p = self.dropout_p if self.training else 0.0
|
| 348 |
+
# Attention
|
| 349 |
+
try:
|
| 350 |
+
with sdp_kernel_context(dropout_p):
|
| 351 |
+
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
| 352 |
+
except Exception as e:
|
| 353 |
+
# Fall back to all kernels if the Flash attention kernel fails
|
| 354 |
+
warnings.warn(
|
| 355 |
+
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
|
| 356 |
+
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
|
| 357 |
+
category=UserWarning,
|
| 358 |
+
stacklevel=2,
|
| 359 |
+
)
|
| 360 |
+
global ALLOW_ALL_KERNELS
|
| 361 |
+
ALLOW_ALL_KERNELS = True
|
| 362 |
+
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
| 363 |
+
|
| 364 |
+
out = self._recombine_heads(out)
|
| 365 |
+
out = self.out_proj(out)
|
| 366 |
+
|
| 367 |
+
return out
|
avs.code/v1m.code/model/visual/sam2/modeling/sam2_base.py
ADDED
|
@@ -0,0 +1,940 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.distributed
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from torch.nn.init import trunc_normal_
|
| 12 |
+
|
| 13 |
+
from model.visual.sam2.modeling.sam.mask_decoder import MaskDecoder
|
| 14 |
+
from model.visual.sam2.modeling.sam.prompt_encoder import PromptEncoder
|
| 15 |
+
from model.visual.sam2.modeling.sam.transformer import TwoWayTransformer
|
| 16 |
+
from model.visual.sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames
|
| 17 |
+
|
| 18 |
+
# a large negative value as a placeholder score for missing objects
|
| 19 |
+
NO_OBJ_SCORE = -1024.0
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class SAM2Base(torch.nn.Module):
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
image_encoder,
|
| 26 |
+
memory_attention,
|
| 27 |
+
memory_encoder,
|
| 28 |
+
num_maskmem=7, # default 1 input frame + 6 previous frames
|
| 29 |
+
image_size=512,
|
| 30 |
+
backbone_stride=16, # stride of the image backbone output
|
| 31 |
+
sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob
|
| 32 |
+
sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob
|
| 33 |
+
# During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks
|
| 34 |
+
binarize_mask_from_pts_for_mem_enc=False,
|
| 35 |
+
use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder
|
| 36 |
+
# The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit,
|
| 37 |
+
# we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model
|
| 38 |
+
# a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM.
|
| 39 |
+
max_cond_frames_in_attn=-1,
|
| 40 |
+
# on the first frame, whether to directly add the no-memory embedding to the image feature
|
| 41 |
+
# (instead of using the transformer encoder)
|
| 42 |
+
directly_add_no_mem_embed=False,
|
| 43 |
+
# whether to use high-resolution feature maps in the SAM mask decoder
|
| 44 |
+
use_high_res_features_in_sam=False,
|
| 45 |
+
# whether to output multiple (3) masks for the first click on initial conditioning frames
|
| 46 |
+
multimask_output_in_sam=False,
|
| 47 |
+
# the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`;
|
| 48 |
+
# default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points)
|
| 49 |
+
multimask_min_pt_num=1,
|
| 50 |
+
multimask_max_pt_num=1,
|
| 51 |
+
# whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`)
|
| 52 |
+
multimask_output_for_tracking=False,
|
| 53 |
+
# Whether to use multimask tokens for obj ptr; Only relevant when both
|
| 54 |
+
# use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True
|
| 55 |
+
use_multimask_token_for_obj_ptr: bool = False,
|
| 56 |
+
# whether to use sigmoid to restrict ious prediction to [0-1]
|
| 57 |
+
iou_prediction_use_sigmoid=False,
|
| 58 |
+
# The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5).
|
| 59 |
+
# For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of
|
| 60 |
+
# (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame.
|
| 61 |
+
memory_temporal_stride_for_eval=1,
|
| 62 |
+
# whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks)
|
| 63 |
+
non_overlap_masks_for_mem_enc=False,
|
| 64 |
+
# whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 65 |
+
use_obj_ptrs_in_encoder=False,
|
| 66 |
+
# the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`)
|
| 67 |
+
max_obj_ptrs_in_encoder=16,
|
| 68 |
+
# whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`)
|
| 69 |
+
add_tpos_enc_to_obj_ptrs=True,
|
| 70 |
+
# whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference
|
| 71 |
+
# with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
|
| 72 |
+
proj_tpos_enc_in_obj_ptrs=False,
|
| 73 |
+
# whether to use signed distance (instead of unsigned absolute distance) in the temporal positional encoding in the object pointers
|
| 74 |
+
# (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
|
| 75 |
+
use_signed_tpos_enc_to_obj_ptrs=False,
|
| 76 |
+
# whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation
|
| 77 |
+
# (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking)
|
| 78 |
+
only_obj_ptrs_in_the_past_for_eval=False,
|
| 79 |
+
# Whether to predict if there is an object in the frame
|
| 80 |
+
pred_obj_scores: bool = False,
|
| 81 |
+
# Whether to use an MLP to predict object scores
|
| 82 |
+
pred_obj_scores_mlp: bool = False,
|
| 83 |
+
# Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True;
|
| 84 |
+
# Whether to have a fixed no obj pointer when there is no object present
|
| 85 |
+
# or to use it as an additive embedding with obj_ptr produced by decoder
|
| 86 |
+
fixed_no_obj_ptr: bool = False,
|
| 87 |
+
# Soft no object, i.e. mix in no_obj_ptr softly,
|
| 88 |
+
# hope to make recovery easier if there is a mistake and mitigate accumulation of errors
|
| 89 |
+
soft_no_obj_ptr: bool = False,
|
| 90 |
+
use_mlp_for_obj_ptr_proj: bool = False,
|
| 91 |
+
# add no obj embedding to spatial frames
|
| 92 |
+
no_obj_embed_spatial: bool = False,
|
| 93 |
+
# extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class.
|
| 94 |
+
sam_mask_decoder_extra_args=None,
|
| 95 |
+
compile_image_encoder: bool = False,
|
| 96 |
+
):
|
| 97 |
+
super().__init__()
|
| 98 |
+
|
| 99 |
+
# Part 1: the image backbone
|
| 100 |
+
self.image_encoder = image_encoder
|
| 101 |
+
# Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
|
| 102 |
+
self.use_high_res_features_in_sam = use_high_res_features_in_sam
|
| 103 |
+
self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
|
| 104 |
+
self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
|
| 105 |
+
self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
|
| 106 |
+
if use_obj_ptrs_in_encoder:
|
| 107 |
+
# A conv layer to downsample the mask prompt to stride 4 (the same stride as
|
| 108 |
+
# low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
|
| 109 |
+
# so that it can be fed into the SAM mask decoder to generate a pointer.
|
| 110 |
+
self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
|
| 111 |
+
self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
|
| 112 |
+
if proj_tpos_enc_in_obj_ptrs:
|
| 113 |
+
assert add_tpos_enc_to_obj_ptrs # these options need to be used together
|
| 114 |
+
self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
|
| 115 |
+
self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs
|
| 116 |
+
self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
|
| 117 |
+
|
| 118 |
+
# Part 2: memory attention to condition current frame's visual features
|
| 119 |
+
# with memories (and obj ptrs) from past frames
|
| 120 |
+
self.memory_attention = memory_attention
|
| 121 |
+
|
| 122 |
+
#### this is for Version 2.0
|
| 123 |
+
# self.hidden_dim = memory_attention.d_model
|
| 124 |
+
#### this is for Version 2.1
|
| 125 |
+
# self.hidden_dim = image_encoder.neck.d_model
|
| 126 |
+
self.hidden_dim = 256 # well, it is always 256 anyway.
|
| 127 |
+
|
| 128 |
+
# Part 3: memory encoder for the previous frame's outputs
|
| 129 |
+
self.memory_encoder = memory_encoder
|
| 130 |
+
self.mem_dim = self.hidden_dim
|
| 131 |
+
if hasattr(self.memory_encoder, "out_proj") and hasattr(
|
| 132 |
+
self.memory_encoder.out_proj, "weight"
|
| 133 |
+
):
|
| 134 |
+
# if there is compression of memories along channel dim
|
| 135 |
+
self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
|
| 136 |
+
self.num_maskmem = num_maskmem # Number of memories accessible
|
| 137 |
+
# Temporal encoding of the memories
|
| 138 |
+
self.maskmem_tpos_enc = torch.nn.Parameter(
|
| 139 |
+
torch.zeros(num_maskmem, 1, 1, self.mem_dim)
|
| 140 |
+
)
|
| 141 |
+
trunc_normal_(self.maskmem_tpos_enc, std=0.02)
|
| 142 |
+
# a single token to indicate no memory embedding from previous frames
|
| 143 |
+
self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
|
| 144 |
+
self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
|
| 145 |
+
trunc_normal_(self.no_mem_embed, std=0.02)
|
| 146 |
+
trunc_normal_(self.no_mem_pos_enc, std=0.02)
|
| 147 |
+
self.directly_add_no_mem_embed = directly_add_no_mem_embed
|
| 148 |
+
# Apply sigmoid to the output raw mask logits (to turn them from
|
| 149 |
+
# range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
|
| 150 |
+
self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
|
| 151 |
+
self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
|
| 152 |
+
self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
|
| 153 |
+
self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
|
| 154 |
+
self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
|
| 155 |
+
# On frames with mask input, whether to directly output the input mask without
|
| 156 |
+
# using a SAM prompt encoder + mask decoder
|
| 157 |
+
self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
|
| 158 |
+
self.multimask_output_in_sam = multimask_output_in_sam
|
| 159 |
+
self.multimask_min_pt_num = multimask_min_pt_num
|
| 160 |
+
self.multimask_max_pt_num = multimask_max_pt_num
|
| 161 |
+
self.multimask_output_for_tracking = multimask_output_for_tracking
|
| 162 |
+
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
|
| 163 |
+
self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid
|
| 164 |
+
|
| 165 |
+
# Part 4: SAM-style prompt encoder (for both mask and point inputs)
|
| 166 |
+
# and SAM-style mask decoder for the final mask output
|
| 167 |
+
self.image_size = image_size
|
| 168 |
+
self.backbone_stride = backbone_stride
|
| 169 |
+
self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
|
| 170 |
+
self.pred_obj_scores = pred_obj_scores
|
| 171 |
+
self.pred_obj_scores_mlp = pred_obj_scores_mlp
|
| 172 |
+
self.fixed_no_obj_ptr = fixed_no_obj_ptr
|
| 173 |
+
self.soft_no_obj_ptr = soft_no_obj_ptr
|
| 174 |
+
if self.fixed_no_obj_ptr:
|
| 175 |
+
assert self.pred_obj_scores
|
| 176 |
+
assert self.use_obj_ptrs_in_encoder
|
| 177 |
+
if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
|
| 178 |
+
self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
|
| 179 |
+
trunc_normal_(self.no_obj_ptr, std=0.02)
|
| 180 |
+
self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
|
| 181 |
+
self.no_obj_embed_spatial = None
|
| 182 |
+
if no_obj_embed_spatial:
|
| 183 |
+
self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim))
|
| 184 |
+
trunc_normal_(self.no_obj_embed_spatial, std=0.02)
|
| 185 |
+
|
| 186 |
+
self._build_sam_heads()
|
| 187 |
+
self.max_cond_frames_in_attn = max_cond_frames_in_attn
|
| 188 |
+
|
| 189 |
+
# Model compilation
|
| 190 |
+
if compile_image_encoder:
|
| 191 |
+
# Compile the forward function (not the full module) to allow loading checkpoints.
|
| 192 |
+
print(
|
| 193 |
+
"Image encoder compilation is enabled. First forward pass will be slow."
|
| 194 |
+
)
|
| 195 |
+
self.image_encoder.forward = torch.compile(
|
| 196 |
+
self.image_encoder.forward,
|
| 197 |
+
mode="max-autotune",
|
| 198 |
+
fullgraph=True,
|
| 199 |
+
dynamic=False,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
### we fix the use_mask_input_as_output_without_sam to be turned off.
|
| 203 |
+
self.use_mask_input_as_output_without_sam = False
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
@property
|
| 207 |
+
def device(self):
|
| 208 |
+
return next(self.parameters()).device
|
| 209 |
+
|
| 210 |
+
def forward(self, *args, **kwargs):
|
| 211 |
+
raise NotImplementedError(
|
| 212 |
+
"Please use the corresponding methods in SAM2VideoPredictor for inference or SAM2Train for training/fine-tuning"
|
| 213 |
+
"See notebooks/video_predictor_example.ipynb for an inference example."
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
def _build_sam_heads(self):
|
| 217 |
+
"""Build SAM-style prompt encoder and mask decoder."""
|
| 218 |
+
self.sam_prompt_embed_dim = self.hidden_dim
|
| 219 |
+
self.sam_image_embedding_size = self.image_size // self.backbone_stride
|
| 220 |
+
|
| 221 |
+
# build PromptEncoder and MaskDecoder from SAM
|
| 222 |
+
# (their hyperparameters like `mask_in_chans=16` are from SAM code)
|
| 223 |
+
self.sam_prompt_encoder = PromptEncoder(
|
| 224 |
+
embed_dim=self.sam_prompt_embed_dim,
|
| 225 |
+
image_embedding_size=(
|
| 226 |
+
self.sam_image_embedding_size,
|
| 227 |
+
self.sam_image_embedding_size,
|
| 228 |
+
),
|
| 229 |
+
input_image_size=(self.image_size, self.image_size),
|
| 230 |
+
mask_in_chans=16,
|
| 231 |
+
)
|
| 232 |
+
self.sam_mask_decoder = MaskDecoder(
|
| 233 |
+
num_multimask_outputs=3,
|
| 234 |
+
transformer=TwoWayTransformer(
|
| 235 |
+
depth=2,
|
| 236 |
+
embedding_dim=self.sam_prompt_embed_dim,
|
| 237 |
+
mlp_dim=2048,
|
| 238 |
+
num_heads=8,
|
| 239 |
+
),
|
| 240 |
+
transformer_dim=self.sam_prompt_embed_dim,
|
| 241 |
+
iou_head_depth=3,
|
| 242 |
+
iou_head_hidden_dim=256,
|
| 243 |
+
use_high_res_features=self.use_high_res_features_in_sam,
|
| 244 |
+
iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
|
| 245 |
+
pred_obj_scores=self.pred_obj_scores,
|
| 246 |
+
pred_obj_scores_mlp=self.pred_obj_scores_mlp,
|
| 247 |
+
use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
|
| 248 |
+
**(self.sam_mask_decoder_extra_args or {}),
|
| 249 |
+
)
|
| 250 |
+
if self.use_obj_ptrs_in_encoder:
|
| 251 |
+
# a linear projection on SAM output tokens to turn them into object pointers
|
| 252 |
+
self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
|
| 253 |
+
if self.use_mlp_for_obj_ptr_proj:
|
| 254 |
+
self.obj_ptr_proj = MLP(
|
| 255 |
+
self.hidden_dim, self.hidden_dim, self.hidden_dim, 3
|
| 256 |
+
)
|
| 257 |
+
else:
|
| 258 |
+
self.obj_ptr_proj = torch.nn.Identity()
|
| 259 |
+
if self.proj_tpos_enc_in_obj_ptrs:
|
| 260 |
+
# a linear projection on temporal positional encoding in object pointers to
|
| 261 |
+
# avoid potential interference with spatial positional encoding
|
| 262 |
+
self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
|
| 263 |
+
else:
|
| 264 |
+
self.obj_ptr_tpos_proj = torch.nn.Identity()
|
| 265 |
+
|
| 266 |
+
def _forward_sam_heads(
|
| 267 |
+
self,
|
| 268 |
+
backbone_features,
|
| 269 |
+
point_inputs=None,
|
| 270 |
+
mask_inputs=None,
|
| 271 |
+
high_res_features=None,
|
| 272 |
+
multimask_output=False,
|
| 273 |
+
audio_res=None
|
| 274 |
+
):
|
| 275 |
+
"""
|
| 276 |
+
Forward SAM prompt encoders and mask heads.
|
| 277 |
+
|
| 278 |
+
Inputs:
|
| 279 |
+
- backbone_features: image features of [B, C, H, W] shape
|
| 280 |
+
- point_inputs: a dictionary with "point_coords" and "point_labels", where
|
| 281 |
+
1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the
|
| 282 |
+
absolute pixel-unit coordinate in (x, y) format of the P input points
|
| 283 |
+
2) "point_labels" has shape [B, P] and int32 dtype, where 1 means
|
| 284 |
+
positive clicks, 0 means negative clicks, and -1 means padding
|
| 285 |
+
- mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the
|
| 286 |
+
same spatial size as the image.
|
| 287 |
+
- high_res_features: either 1) None or 2) or a list of length 2 containing
|
| 288 |
+
two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively,
|
| 289 |
+
which will be used as high-resolution feature maps for SAM decoder.
|
| 290 |
+
- multimask_output: if it's True, we output 3 candidate masks and their 3
|
| 291 |
+
corresponding IoU estimates, and if it's False, we output only 1 mask and
|
| 292 |
+
its corresponding IoU estimate.
|
| 293 |
+
|
| 294 |
+
Outputs:
|
| 295 |
+
- low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if
|
| 296 |
+
`multimask_output=True` and M = 1 if `multimask_output=False`), the SAM
|
| 297 |
+
output mask logits (before sigmoid) for the low-resolution masks, with 4x
|
| 298 |
+
the resolution (1/4 stride) of the input backbone_features.
|
| 299 |
+
- high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3
|
| 300 |
+
if `multimask_output=True` and M = 1 if `multimask_output=False`),
|
| 301 |
+
upsampled from the low-resolution masks, with shape size as the image
|
| 302 |
+
(stride is 1 pixel).
|
| 303 |
+
- ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1
|
| 304 |
+
if `multimask_output=False`), the estimated IoU of each output mask.
|
| 305 |
+
- low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`.
|
| 306 |
+
If `multimask_output=True`, it's the mask with the highest IoU estimate.
|
| 307 |
+
If `multimask_output=False`, it's the same as `low_res_multimasks`.
|
| 308 |
+
- high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`.
|
| 309 |
+
If `multimask_output=True`, it's the mask with the highest IoU estimate.
|
| 310 |
+
If `multimask_output=False`, it's the same as `high_res_multimasks`.
|
| 311 |
+
- obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted
|
| 312 |
+
based on the output token from the SAM mask decoder.
|
| 313 |
+
"""
|
| 314 |
+
B = backbone_features.size(0)
|
| 315 |
+
device = backbone_features.device
|
| 316 |
+
assert backbone_features.size(1) == self.sam_prompt_embed_dim
|
| 317 |
+
assert backbone_features.size(2) == self.sam_image_embedding_size
|
| 318 |
+
assert backbone_features.size(3) == self.sam_image_embedding_size
|
| 319 |
+
|
| 320 |
+
'''
|
| 321 |
+
# a) Handle point prompts
|
| 322 |
+
if point_inputs is not None:
|
| 323 |
+
sam_point_coords = point_inputs["point_coords"]
|
| 324 |
+
sam_point_labels = point_inputs["point_labels"]
|
| 325 |
+
assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
|
| 326 |
+
raise NotImplementedError
|
| 327 |
+
else:
|
| 328 |
+
# If no points are provide, pad with an empty point (with label -1)
|
| 329 |
+
sam_point_coords = torch.zeros(B, 1, 2, device=device)
|
| 330 |
+
sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
|
| 331 |
+
|
| 332 |
+
# b) Handle mask prompts
|
| 333 |
+
if mask_inputs is not None:
|
| 334 |
+
# If mask_inputs is provided, downsize it into low-res mask input if needed
|
| 335 |
+
# and feed it as a dense mask prompt into the SAM mask encoder
|
| 336 |
+
assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
|
| 337 |
+
if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
|
| 338 |
+
sam_mask_prompt = F.interpolate(
|
| 339 |
+
mask_inputs.float(),
|
| 340 |
+
size=self.sam_prompt_encoder.mask_input_size,
|
| 341 |
+
align_corners=False,
|
| 342 |
+
mode="bilinear",
|
| 343 |
+
antialias=True, # use antialias for downsampling
|
| 344 |
+
)
|
| 345 |
+
else:
|
| 346 |
+
sam_mask_prompt = mask_inputs
|
| 347 |
+
raise NotImplementedError
|
| 348 |
+
else:
|
| 349 |
+
# Otherwise, simply feed None (and SAM's prompt encoder will add
|
| 350 |
+
# a learned `no_mask_embed` to indicate no mask input in this case).
|
| 351 |
+
sam_mask_prompt = None
|
| 352 |
+
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
|
| 353 |
+
points=(sam_point_coords, sam_point_labels),
|
| 354 |
+
boxes=None,
|
| 355 |
+
masks=sam_mask_prompt,
|
| 356 |
+
)
|
| 357 |
+
'''
|
| 358 |
+
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
|
| 359 |
+
points=None,
|
| 360 |
+
boxes=None,
|
| 361 |
+
masks=None,
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
(
|
| 365 |
+
low_res_multimasks,
|
| 366 |
+
ious,
|
| 367 |
+
sam_output_tokens,
|
| 368 |
+
object_score_logits,
|
| 369 |
+
) = self.sam_mask_decoder(
|
| 370 |
+
image_embeddings=backbone_features,
|
| 371 |
+
image_pe=self.sam_prompt_encoder.get_dense_pe(),
|
| 372 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
| 373 |
+
dense_prompt_embeddings=dense_embeddings,
|
| 374 |
+
multimask_output=multimask_output,
|
| 375 |
+
repeat_image=False, # the image is already batched
|
| 376 |
+
high_res_features=high_res_features,
|
| 377 |
+
audio_res_features=audio_res
|
| 378 |
+
)
|
| 379 |
+
'''
|
| 380 |
+
if self.pred_obj_scores:
|
| 381 |
+
is_obj_appearing = object_score_logits > 0
|
| 382 |
+
|
| 383 |
+
# Mask used for spatial memories is always a *hard* choice between obj and no obj,
|
| 384 |
+
# consistent with the actual mask prediction
|
| 385 |
+
low_res_multimasks = torch.where(
|
| 386 |
+
is_obj_appearing[:, None, None],
|
| 387 |
+
low_res_multimasks,
|
| 388 |
+
NO_OBJ_SCORE,
|
| 389 |
+
)
|
| 390 |
+
'''
|
| 391 |
+
# convert masks from possibly bfloat16 (or float16) to float32
|
| 392 |
+
# (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
|
| 393 |
+
low_res_multimasks = low_res_multimasks.float()
|
| 394 |
+
high_res_multimasks = F.interpolate(
|
| 395 |
+
low_res_multimasks,
|
| 396 |
+
size=(self.image_size, self.image_size),
|
| 397 |
+
mode="bilinear",
|
| 398 |
+
align_corners=False,
|
| 399 |
+
)
|
| 400 |
+
sam_output_token = sam_output_tokens[:, 0]
|
| 401 |
+
if multimask_output:
|
| 402 |
+
# comment this line temporarily.
|
| 403 |
+
# take the best mask prediction (with the highest IoU estimation)
|
| 404 |
+
best_iou_inds = torch.argmax(ious, dim=-1)
|
| 405 |
+
batch_inds = torch.arange(B, device=device)
|
| 406 |
+
low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
| 407 |
+
high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
| 408 |
+
if sam_output_tokens.size(1) > 1:
|
| 409 |
+
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
|
| 410 |
+
else:
|
| 411 |
+
low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
|
| 412 |
+
|
| 413 |
+
# Extract object pointer from the SAM output token (with occlusion handling)
|
| 414 |
+
obj_ptr = self.obj_ptr_proj(sam_output_token)
|
| 415 |
+
|
| 416 |
+
# don't train occlusion at the moment, command temporarily.
|
| 417 |
+
if self.pred_obj_scores:
|
| 418 |
+
is_obj_appearing = object_score_logits > 0
|
| 419 |
+
# Allow *soft* no obj ptr, unlike for masks
|
| 420 |
+
if self.soft_no_obj_ptr:
|
| 421 |
+
lambda_is_obj_appearing = object_score_logits.sigmoid()
|
| 422 |
+
else:
|
| 423 |
+
lambda_is_obj_appearing = is_obj_appearing.float()
|
| 424 |
+
|
| 425 |
+
if self.fixed_no_obj_ptr:
|
| 426 |
+
obj_ptr = lambda_is_obj_appearing * obj_ptr
|
| 427 |
+
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
|
| 428 |
+
return (
|
| 429 |
+
low_res_multimasks,
|
| 430 |
+
high_res_multimasks,
|
| 431 |
+
ious,
|
| 432 |
+
low_res_masks,
|
| 433 |
+
high_res_masks,
|
| 434 |
+
obj_ptr,
|
| 435 |
+
object_score_logits,
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
|
| 439 |
+
"""
|
| 440 |
+
Directly turn binary `mask_inputs` into a output mask logits without using SAM.
|
| 441 |
+
(same input and output shapes as in _forward_sam_heads above).
|
| 442 |
+
"""
|
| 443 |
+
# Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
|
| 444 |
+
out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
|
| 445 |
+
mask_inputs_float = mask_inputs.float()
|
| 446 |
+
high_res_masks = mask_inputs_float * out_scale + out_bias
|
| 447 |
+
low_res_masks = F.interpolate(
|
| 448 |
+
high_res_masks,
|
| 449 |
+
size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
|
| 450 |
+
align_corners=False,
|
| 451 |
+
mode="bilinear",
|
| 452 |
+
antialias=True, # use antialias for downsampling
|
| 453 |
+
)
|
| 454 |
+
# a dummy IoU prediction of all 1's under mask input
|
| 455 |
+
ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
|
| 456 |
+
if not self.use_obj_ptrs_in_encoder:
|
| 457 |
+
# all zeros as a dummy object pointer (of shape [B, C])
|
| 458 |
+
obj_ptr = torch.zeros(
|
| 459 |
+
mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device
|
| 460 |
+
)
|
| 461 |
+
else:
|
| 462 |
+
# produce an object pointer using the SAM decoder from the mask input
|
| 463 |
+
_, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
|
| 464 |
+
backbone_features=backbone_features,
|
| 465 |
+
mask_inputs=self.mask_downsample(mask_inputs_float),
|
| 466 |
+
high_res_features=high_res_features,
|
| 467 |
+
)
|
| 468 |
+
# In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
|
| 469 |
+
# Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
|
| 470 |
+
# on the object_scores from the SAM decoder.
|
| 471 |
+
is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
|
| 472 |
+
is_obj_appearing = is_obj_appearing[..., None]
|
| 473 |
+
lambda_is_obj_appearing = is_obj_appearing.float()
|
| 474 |
+
object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
|
| 475 |
+
if self.pred_obj_scores:
|
| 476 |
+
if self.fixed_no_obj_ptr:
|
| 477 |
+
obj_ptr = lambda_is_obj_appearing * obj_ptr
|
| 478 |
+
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
|
| 479 |
+
|
| 480 |
+
return (
|
| 481 |
+
low_res_masks,
|
| 482 |
+
high_res_masks,
|
| 483 |
+
ious,
|
| 484 |
+
low_res_masks,
|
| 485 |
+
high_res_masks,
|
| 486 |
+
obj_ptr,
|
| 487 |
+
object_score_logits,
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
def precompute_high_res_features(self, backbone_out):
|
| 491 |
+
if self.use_high_res_features_in_sam:
|
| 492 |
+
# precompute projected level 0 and level 1 features in SAM decoder
|
| 493 |
+
# to avoid running it again on every SAM click
|
| 494 |
+
backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
|
| 495 |
+
backbone_out["backbone_fpn"][0]
|
| 496 |
+
)
|
| 497 |
+
backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
|
| 498 |
+
backbone_out["backbone_fpn"][1]
|
| 499 |
+
)
|
| 500 |
+
return backbone_out
|
| 501 |
+
|
| 502 |
+
def forward_image(self, img_batch: torch.Tensor, pre_compute=True):
|
| 503 |
+
"""Get the image feature on the input batch."""
|
| 504 |
+
backbone_out = self.image_encoder(img_batch)
|
| 505 |
+
return backbone_out if not pre_compute else self.precompute_high_res_features(backbone_out)
|
| 506 |
+
|
| 507 |
+
def _prepare_backbone_features(self, backbone_out):
|
| 508 |
+
"""Prepare and flatten visual features."""
|
| 509 |
+
backbone_out = backbone_out.copy()
|
| 510 |
+
assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
|
| 511 |
+
assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
|
| 512 |
+
|
| 513 |
+
feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
|
| 514 |
+
vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
|
| 515 |
+
feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
|
| 516 |
+
# flatten NxCxHxW to HWxNxC
|
| 517 |
+
vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
|
| 518 |
+
vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
|
| 519 |
+
|
| 520 |
+
return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
|
| 521 |
+
|
| 522 |
+
def _prepare_memory_conditioned_features(
|
| 523 |
+
self,
|
| 524 |
+
frame_idx,
|
| 525 |
+
is_init_cond_frame,
|
| 526 |
+
current_vision_feats,
|
| 527 |
+
current_vision_pos_embeds,
|
| 528 |
+
feat_sizes,
|
| 529 |
+
output_dict,
|
| 530 |
+
num_frames,
|
| 531 |
+
track_in_reverse=False, # tracking in reverse time order (for demo usage)
|
| 532 |
+
):
|
| 533 |
+
"""Fuse the current frame's visual feature map with previous memory."""
|
| 534 |
+
B = current_vision_feats[-1].size(1) # batch size on this frame
|
| 535 |
+
C = self.hidden_dim
|
| 536 |
+
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
|
| 537 |
+
device = current_vision_feats[-1].device
|
| 538 |
+
# The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
|
| 539 |
+
# In this case, we skip the fusion with any memory.
|
| 540 |
+
if self.num_maskmem == 0: # Disable memory and skip fusion
|
| 541 |
+
pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
|
| 542 |
+
return pix_feat
|
| 543 |
+
|
| 544 |
+
num_obj_ptr_tokens = 0
|
| 545 |
+
tpos_sign_mul = -1 if track_in_reverse else 1
|
| 546 |
+
# Step 1: condition the visual features of the current frame on previous memories
|
| 547 |
+
if not is_init_cond_frame:
|
| 548 |
+
# Retrieve the memories encoded with the maskmem backbone
|
| 549 |
+
to_cat_memory, to_cat_memory_pos_embed = [], []
|
| 550 |
+
# Add conditioning frames's output first (all cond frames have t_pos=0 for
|
| 551 |
+
# when getting temporal positional embedding below)
|
| 552 |
+
assert len(output_dict["cond_frame_outputs"]) > 0
|
| 553 |
+
# Select a maximum number of temporally closest cond frames for cross attention
|
| 554 |
+
cond_outputs = output_dict["cond_frame_outputs"]
|
| 555 |
+
selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
|
| 556 |
+
frame_idx, cond_outputs, self.max_cond_frames_in_attn
|
| 557 |
+
)
|
| 558 |
+
t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
|
| 559 |
+
# for t_pos in range(1, min(self.num_maskmem, frame_idx)):
|
| 560 |
+
# out = output_dict["non_cond_frame_outputs"].get(t_pos, None)
|
| 561 |
+
# t_pos_and_prevs.append((t_pos, out))
|
| 562 |
+
# Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
|
| 563 |
+
# the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
|
| 564 |
+
# We also allow taking the memory frame non-consecutively (with stride>1), in which case
|
| 565 |
+
# we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame.
|
| 566 |
+
stride = 1 if self.training else self.memory_temporal_stride_for_eval
|
| 567 |
+
|
| 568 |
+
for t_pos in range(1, self.num_maskmem):
|
| 569 |
+
t_rel = self.num_maskmem - t_pos # how many frames before current frame
|
| 570 |
+
if t_rel == 1:
|
| 571 |
+
# for t_rel == 1, we take the last frame (regardless of r)
|
| 572 |
+
if not track_in_reverse:
|
| 573 |
+
# the frame immediately before this frame (i.e. frame_idx - 1)
|
| 574 |
+
prev_frame_idx = frame_idx - t_rel
|
| 575 |
+
else:
|
| 576 |
+
# the frame immediately after this frame (i.e. frame_idx + 1)
|
| 577 |
+
prev_frame_idx = frame_idx + t_rel
|
| 578 |
+
else:
|
| 579 |
+
# for t_rel >= 2, we take the memory frame from every r-th frames
|
| 580 |
+
if not track_in_reverse:
|
| 581 |
+
# first find the nearest frame among every r-th frames before this frame
|
| 582 |
+
# for r=1, this would be (frame_idx - 2)
|
| 583 |
+
prev_frame_idx = ((frame_idx - 2) // stride) * stride
|
| 584 |
+
# then seek further among every r-th frames
|
| 585 |
+
prev_frame_idx = prev_frame_idx - (t_rel - 2) * stride
|
| 586 |
+
else:
|
| 587 |
+
# first find the nearest frame among every r-th frames after this frame
|
| 588 |
+
# for r=1, this would be (frame_idx + 2)
|
| 589 |
+
prev_frame_idx = -(-(frame_idx + 2) // stride) * stride
|
| 590 |
+
# then seek further among every r-th frames
|
| 591 |
+
prev_frame_idx = prev_frame_idx + (t_rel - 2) * stride
|
| 592 |
+
out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
|
| 593 |
+
if out is None:
|
| 594 |
+
# If an unselected conditioning frame is among the last (self.num_maskmem - 1)
|
| 595 |
+
# frames, we still attend to it as if it's a non-conditioning frame.
|
| 596 |
+
out = unselected_cond_outputs.get(prev_frame_idx, None)
|
| 597 |
+
t_pos_and_prevs.append((t_pos, out))
|
| 598 |
+
|
| 599 |
+
for t_pos, prev in t_pos_and_prevs:
|
| 600 |
+
if prev is None:
|
| 601 |
+
continue # skip padding frames
|
| 602 |
+
# "maskmem_features" might have been offloaded to CPU in demo use cases,
|
| 603 |
+
# so we load it back to GPU (it's a no-op if it's already on GPU).
|
| 604 |
+
feats = prev["maskmem_features"].to(device, non_blocking=True)
|
| 605 |
+
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
|
| 606 |
+
# Spatial positional encoding (it might have been offloaded to CPU in eval)
|
| 607 |
+
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
|
| 608 |
+
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
|
| 609 |
+
# Temporal positional encoding
|
| 610 |
+
maskmem_enc = (
|
| 611 |
+
maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
|
| 612 |
+
)
|
| 613 |
+
to_cat_memory_pos_embed.append(maskmem_enc)
|
| 614 |
+
# Construct the list of past object pointers
|
| 615 |
+
if self.use_obj_ptrs_in_encoder:
|
| 616 |
+
max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
|
| 617 |
+
# First add those object pointers from selected conditioning frames
|
| 618 |
+
# (optionally, only include object pointers in the past during evaluation)
|
| 619 |
+
if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
|
| 620 |
+
ptr_cond_outputs = {
|
| 621 |
+
t: out
|
| 622 |
+
for t, out in selected_cond_outputs.items()
|
| 623 |
+
if (t >= frame_idx if track_in_reverse else t <= frame_idx)
|
| 624 |
+
}
|
| 625 |
+
else:
|
| 626 |
+
ptr_cond_outputs = selected_cond_outputs
|
| 627 |
+
pos_and_ptrs = [
|
| 628 |
+
# Temporal pos encoding contains how far away each pointer is from current frame
|
| 629 |
+
(
|
| 630 |
+
(
|
| 631 |
+
(frame_idx - t) * tpos_sign_mul
|
| 632 |
+
if self.use_signed_tpos_enc_to_obj_ptrs
|
| 633 |
+
else abs(frame_idx - t)
|
| 634 |
+
),
|
| 635 |
+
out["obj_ptr"],
|
| 636 |
+
)
|
| 637 |
+
for t, out in ptr_cond_outputs.items()
|
| 638 |
+
]
|
| 639 |
+
# Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
|
| 640 |
+
for t_diff in range(1, max_obj_ptrs_in_encoder):
|
| 641 |
+
t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
|
| 642 |
+
if t < 0 or (num_frames is not None and t >= num_frames):
|
| 643 |
+
break
|
| 644 |
+
out = output_dict["non_cond_frame_outputs"].get(
|
| 645 |
+
t, unselected_cond_outputs.get(t, None)
|
| 646 |
+
)
|
| 647 |
+
if out is not None:
|
| 648 |
+
pos_and_ptrs.append((t_diff, out["obj_ptr"]))
|
| 649 |
+
# If we have at least one object pointer, add them to the across attention
|
| 650 |
+
if len(pos_and_ptrs) > 0:
|
| 651 |
+
pos_list, ptrs_list = zip(*pos_and_ptrs)
|
| 652 |
+
# stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
|
| 653 |
+
obj_ptrs = torch.stack(ptrs_list, dim=0)
|
| 654 |
+
# a temporal positional embedding based on how far each object pointer is from
|
| 655 |
+
# the current frame (sine embedding normalized by the max pointer num).
|
| 656 |
+
# default false.
|
| 657 |
+
if self.add_tpos_enc_to_obj_ptrs:
|
| 658 |
+
t_diff_max = max_obj_ptrs_in_encoder - 1
|
| 659 |
+
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
|
| 660 |
+
obj_pos = torch.tensor(pos_list, device=device)
|
| 661 |
+
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
|
| 662 |
+
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
|
| 663 |
+
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
|
| 664 |
+
else:
|
| 665 |
+
obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
|
| 666 |
+
if self.mem_dim < C:
|
| 667 |
+
# split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
|
| 668 |
+
obj_ptrs = obj_ptrs.reshape(
|
| 669 |
+
-1, B, C // self.mem_dim, self.mem_dim
|
| 670 |
+
)
|
| 671 |
+
obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
|
| 672 |
+
obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
|
| 673 |
+
to_cat_memory.append(obj_ptrs)
|
| 674 |
+
to_cat_memory_pos_embed.append(obj_pos)
|
| 675 |
+
num_obj_ptr_tokens = obj_ptrs.shape[0]
|
| 676 |
+
else:
|
| 677 |
+
num_obj_ptr_tokens = 0
|
| 678 |
+
else:
|
| 679 |
+
# for initial conditioning frames, encode them without using any previous memory
|
| 680 |
+
if self.directly_add_no_mem_embed:
|
| 681 |
+
# directly add no-mem embedding (instead of using the transformer encoder)
|
| 682 |
+
pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
|
| 683 |
+
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
|
| 684 |
+
return pix_feat_with_mem
|
| 685 |
+
# Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder)
|
| 686 |
+
# the Following lines will never be triggered.
|
| 687 |
+
raise NotImplementedError
|
| 688 |
+
to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
|
| 689 |
+
to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
|
| 690 |
+
|
| 691 |
+
# Step 2: Concatenate the memories and forward through the transformer encoder
|
| 692 |
+
memory = torch.cat(to_cat_memory, dim=0)
|
| 693 |
+
memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
|
| 694 |
+
|
| 695 |
+
pix_feat_with_mem = self.memory_attention(
|
| 696 |
+
curr=current_vision_feats,
|
| 697 |
+
curr_pos=current_vision_pos_embeds,
|
| 698 |
+
memory=memory,
|
| 699 |
+
memory_pos=memory_pos_embed,
|
| 700 |
+
num_obj_ptr_tokens=num_obj_ptr_tokens,
|
| 701 |
+
)
|
| 702 |
+
# reshape the output (HW)BC => BCHW
|
| 703 |
+
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
|
| 704 |
+
return pix_feat_with_mem
|
| 705 |
+
|
| 706 |
+
def _encode_new_memory(
|
| 707 |
+
self,
|
| 708 |
+
current_vision_feats,
|
| 709 |
+
feat_sizes,
|
| 710 |
+
pred_masks_high_res,
|
| 711 |
+
object_score_logits,
|
| 712 |
+
is_mask_from_pts,
|
| 713 |
+
):
|
| 714 |
+
"""Encode the current image and its prediction into a memory feature."""
|
| 715 |
+
B = current_vision_feats[-1].size(1) # batch size on this frame
|
| 716 |
+
C = self.hidden_dim
|
| 717 |
+
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
|
| 718 |
+
# top-level feature, (HW)BC => BCHW
|
| 719 |
+
pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
|
| 720 |
+
if self.non_overlap_masks_for_mem_enc and not self.training:
|
| 721 |
+
# optionally, apply non-overlapping constraints to the masks (it's applied
|
| 722 |
+
# in the batch dimension and should only be used during eval, where all
|
| 723 |
+
# the objects come from the same video under batch size 1).
|
| 724 |
+
pred_masks_high_res = self._apply_non_overlapping_constraints(
|
| 725 |
+
pred_masks_high_res
|
| 726 |
+
)
|
| 727 |
+
raise NotImplementedError
|
| 728 |
+
# scale the raw mask logits with a temperature before applying sigmoid
|
| 729 |
+
binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
|
| 730 |
+
if binarize and not self.training:
|
| 731 |
+
mask_for_mem = (pred_masks_high_res > 0).float()
|
| 732 |
+
else:
|
| 733 |
+
# apply sigmoid on the raw mask logits to turn them into range (0, 1)
|
| 734 |
+
mask_for_mem = torch.sigmoid(pred_masks_high_res)
|
| 735 |
+
# apply scale and bias terms to the sigmoid probabilities
|
| 736 |
+
if self.sigmoid_scale_for_mem_enc != 1.0:
|
| 737 |
+
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
|
| 738 |
+
if self.sigmoid_bias_for_mem_enc != 0.0:
|
| 739 |
+
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
|
| 740 |
+
maskmem_out = self.memory_encoder(
|
| 741 |
+
pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied
|
| 742 |
+
)
|
| 743 |
+
maskmem_features = maskmem_out["vision_features"]
|
| 744 |
+
maskmem_pos_enc = maskmem_out["vision_pos_enc"]
|
| 745 |
+
# add a no-object embedding to the spatial memory to indicate that the frame
|
| 746 |
+
# is predicted to be occluded (i.e. no object is appearing in the frame)
|
| 747 |
+
if self.no_obj_embed_spatial is not None:
|
| 748 |
+
is_obj_appearing = (object_score_logits > 0).float()
|
| 749 |
+
maskmem_features += (
|
| 750 |
+
1 - is_obj_appearing[..., None, None]
|
| 751 |
+
) * self.no_obj_embed_spatial[..., None, None].expand(
|
| 752 |
+
*maskmem_features.shape
|
| 753 |
+
)
|
| 754 |
+
# it will be used in sam2.1
|
| 755 |
+
# raise NotImplementedError
|
| 756 |
+
|
| 757 |
+
return maskmem_features, maskmem_pos_enc
|
| 758 |
+
|
| 759 |
+
def _track_step(
|
| 760 |
+
self,
|
| 761 |
+
frame_idx,
|
| 762 |
+
is_init_cond_frame,
|
| 763 |
+
current_vision_feats,
|
| 764 |
+
current_vision_pos_embeds,
|
| 765 |
+
feat_sizes,
|
| 766 |
+
point_inputs,
|
| 767 |
+
mask_inputs,
|
| 768 |
+
output_dict,
|
| 769 |
+
num_frames,
|
| 770 |
+
track_in_reverse,
|
| 771 |
+
prev_sam_mask_logits,
|
| 772 |
+
):
|
| 773 |
+
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
|
| 774 |
+
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
|
| 775 |
+
if len(current_vision_feats) > 1:
|
| 776 |
+
high_res_features = [
|
| 777 |
+
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
|
| 778 |
+
for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
|
| 779 |
+
]
|
| 780 |
+
else:
|
| 781 |
+
high_res_features = None
|
| 782 |
+
if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
|
| 783 |
+
# When use_mask_input_as_output_without_sam=True, we directly output the mask input
|
| 784 |
+
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
|
| 785 |
+
pix_feat = current_vision_feats[-1].permute(1, 2, 0)
|
| 786 |
+
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
|
| 787 |
+
sam_outputs = self._use_mask_as_output(
|
| 788 |
+
pix_feat, high_res_features, mask_inputs
|
| 789 |
+
)
|
| 790 |
+
else:
|
| 791 |
+
# fused the visual feature with previous memory features in the memory bank
|
| 792 |
+
pix_feat = self._prepare_memory_conditioned_features(
|
| 793 |
+
frame_idx=frame_idx,
|
| 794 |
+
is_init_cond_frame=is_init_cond_frame,
|
| 795 |
+
current_vision_feats=current_vision_feats[-1:],
|
| 796 |
+
current_vision_pos_embeds=current_vision_pos_embeds[-1:],
|
| 797 |
+
feat_sizes=feat_sizes[-1:],
|
| 798 |
+
output_dict=output_dict,
|
| 799 |
+
num_frames=num_frames,
|
| 800 |
+
track_in_reverse=track_in_reverse,
|
| 801 |
+
)
|
| 802 |
+
# apply SAM-style segmentation head
|
| 803 |
+
# here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
|
| 804 |
+
# e.g. in demo where such logits come from earlier interaction instead of correction sampling
|
| 805 |
+
# (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
|
| 806 |
+
if prev_sam_mask_logits is not None:
|
| 807 |
+
assert point_inputs is not None and mask_inputs is None
|
| 808 |
+
mask_inputs = prev_sam_mask_logits
|
| 809 |
+
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
|
| 810 |
+
sam_outputs = self._forward_sam_heads(
|
| 811 |
+
backbone_features=pix_feat,
|
| 812 |
+
point_inputs=point_inputs,
|
| 813 |
+
mask_inputs=mask_inputs,
|
| 814 |
+
high_res_features=high_res_features,
|
| 815 |
+
multimask_output=multimask_output,
|
| 816 |
+
)
|
| 817 |
+
|
| 818 |
+
return current_out, sam_outputs, high_res_features, pix_feat
|
| 819 |
+
|
| 820 |
+
def _encode_memory_in_output(
|
| 821 |
+
self,
|
| 822 |
+
current_vision_feats,
|
| 823 |
+
feat_sizes,
|
| 824 |
+
point_inputs,
|
| 825 |
+
run_mem_encoder,
|
| 826 |
+
high_res_masks,
|
| 827 |
+
object_score_logits,
|
| 828 |
+
current_out,
|
| 829 |
+
):
|
| 830 |
+
if run_mem_encoder and self.num_maskmem > 0:
|
| 831 |
+
high_res_masks_for_mem_enc = high_res_masks
|
| 832 |
+
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
|
| 833 |
+
current_vision_feats=current_vision_feats,
|
| 834 |
+
feat_sizes=feat_sizes,
|
| 835 |
+
pred_masks_high_res=high_res_masks_for_mem_enc,
|
| 836 |
+
object_score_logits=object_score_logits,
|
| 837 |
+
is_mask_from_pts=(point_inputs is not None),
|
| 838 |
+
)
|
| 839 |
+
current_out["maskmem_features"] = maskmem_features
|
| 840 |
+
current_out["maskmem_pos_enc"] = maskmem_pos_enc
|
| 841 |
+
else:
|
| 842 |
+
current_out["maskmem_features"] = None
|
| 843 |
+
current_out["maskmem_pos_enc"] = None
|
| 844 |
+
|
| 845 |
+
def track_step(
|
| 846 |
+
self,
|
| 847 |
+
frame_idx,
|
| 848 |
+
is_init_cond_frame,
|
| 849 |
+
current_vision_feats,
|
| 850 |
+
current_vision_pos_embeds,
|
| 851 |
+
feat_sizes,
|
| 852 |
+
point_inputs,
|
| 853 |
+
mask_inputs,
|
| 854 |
+
output_dict,
|
| 855 |
+
num_frames,
|
| 856 |
+
track_in_reverse=False, # tracking in reverse time order (for demo usage)
|
| 857 |
+
# Whether to run the memory encoder on the predicted masks. Sometimes we might want
|
| 858 |
+
# to skip the memory encoder with `run_mem_encoder=False`. For example,
|
| 859 |
+
# in demo we might call `track_step` multiple times for each user click,
|
| 860 |
+
# and only encode the memory when the user finalizes their clicks. And in ablation
|
| 861 |
+
# settings like SAM training on static images, we don't need the memory encoder.
|
| 862 |
+
run_mem_encoder=True,
|
| 863 |
+
# The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
|
| 864 |
+
prev_sam_mask_logits=None,
|
| 865 |
+
):
|
| 866 |
+
current_out, sam_outputs, _, _ = self._track_step(
|
| 867 |
+
frame_idx,
|
| 868 |
+
is_init_cond_frame,
|
| 869 |
+
current_vision_feats,
|
| 870 |
+
current_vision_pos_embeds,
|
| 871 |
+
feat_sizes,
|
| 872 |
+
point_inputs,
|
| 873 |
+
mask_inputs,
|
| 874 |
+
output_dict,
|
| 875 |
+
num_frames,
|
| 876 |
+
track_in_reverse,
|
| 877 |
+
prev_sam_mask_logits,
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
(
|
| 881 |
+
_,
|
| 882 |
+
_,
|
| 883 |
+
_,
|
| 884 |
+
low_res_masks,
|
| 885 |
+
high_res_masks,
|
| 886 |
+
obj_ptr,
|
| 887 |
+
object_score_logits,
|
| 888 |
+
) = sam_outputs
|
| 889 |
+
|
| 890 |
+
current_out["pred_masks"] = low_res_masks
|
| 891 |
+
current_out["pred_masks_high_res"] = high_res_masks
|
| 892 |
+
current_out["obj_ptr"] = obj_ptr
|
| 893 |
+
if not self.training:
|
| 894 |
+
# Only add this in inference (to avoid unused param in activation checkpointing;
|
| 895 |
+
# it's mainly used in the demo to encode spatial memories w/ consolidated masks)
|
| 896 |
+
current_out["object_score_logits"] = object_score_logits
|
| 897 |
+
|
| 898 |
+
# Finally run the memory encoder on the predicted mask to encode
|
| 899 |
+
# it into a new memory feature (that can be used in future frames)
|
| 900 |
+
self._encode_memory_in_output(
|
| 901 |
+
current_vision_feats,
|
| 902 |
+
feat_sizes,
|
| 903 |
+
point_inputs,
|
| 904 |
+
run_mem_encoder,
|
| 905 |
+
high_res_masks,
|
| 906 |
+
object_score_logits,
|
| 907 |
+
current_out,
|
| 908 |
+
)
|
| 909 |
+
|
| 910 |
+
return current_out
|
| 911 |
+
|
| 912 |
+
def _use_multimask(self, is_init_cond_frame, point_inputs):
|
| 913 |
+
"""Whether to use multimask output in the SAM head."""
|
| 914 |
+
num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
|
| 915 |
+
multimask_output = (
|
| 916 |
+
self.multimask_output_in_sam
|
| 917 |
+
and (is_init_cond_frame or self.multimask_output_for_tracking)
|
| 918 |
+
and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
|
| 919 |
+
)
|
| 920 |
+
return multimask_output
|
| 921 |
+
|
| 922 |
+
def _apply_non_overlapping_constraints(self, pred_masks):
|
| 923 |
+
"""
|
| 924 |
+
Apply non-overlapping constraints to the object scores in pred_masks. Here we
|
| 925 |
+
keep only the highest scoring object at each spatial location in pred_masks.
|
| 926 |
+
"""
|
| 927 |
+
batch_size = pred_masks.size(0)
|
| 928 |
+
if batch_size == 1:
|
| 929 |
+
return pred_masks
|
| 930 |
+
|
| 931 |
+
device = pred_masks.device
|
| 932 |
+
# "max_obj_inds": object index of the object with the highest score at each location
|
| 933 |
+
max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
|
| 934 |
+
# "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
|
| 935 |
+
batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
|
| 936 |
+
keep = max_obj_inds == batch_obj_inds
|
| 937 |
+
# suppress overlapping regions' scores below -10.0 so that the foreground regions
|
| 938 |
+
# don't overlap (here sigmoid(-10.0)=4.5398e-05)
|
| 939 |
+
pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
|
| 940 |
+
return pred_masks
|
avs.code/v1m.code/model/visual/sam2/modeling/sam2_utils.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import copy
|
| 9 |
+
from typing import Tuple
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
from model.visual.sam2.utils.misc import mask_to_box
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
|
| 20 |
+
"""
|
| 21 |
+
Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs`
|
| 22 |
+
that are temporally closest to the current frame at `frame_idx`. Here, we take
|
| 23 |
+
- a) the closest conditioning frame before `frame_idx` (if any);
|
| 24 |
+
- b) the closest conditioning frame after `frame_idx` (if any);
|
| 25 |
+
- c) any other temporally closest conditioning frames until reaching a total
|
| 26 |
+
of `max_cond_frame_num` conditioning frames.
|
| 27 |
+
|
| 28 |
+
Outputs:
|
| 29 |
+
- selected_outputs: selected items (keys & values) from `cond_frame_outputs`.
|
| 30 |
+
- unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`.
|
| 31 |
+
"""
|
| 32 |
+
if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
|
| 33 |
+
selected_outputs = cond_frame_outputs
|
| 34 |
+
unselected_outputs = {}
|
| 35 |
+
else:
|
| 36 |
+
assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
|
| 37 |
+
selected_outputs = {}
|
| 38 |
+
|
| 39 |
+
# the closest conditioning frame before `frame_idx` (if any)
|
| 40 |
+
idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
|
| 41 |
+
if idx_before is not None:
|
| 42 |
+
selected_outputs[idx_before] = cond_frame_outputs[idx_before]
|
| 43 |
+
|
| 44 |
+
# the closest conditioning frame after `frame_idx` (if any)
|
| 45 |
+
idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
|
| 46 |
+
if idx_after is not None:
|
| 47 |
+
selected_outputs[idx_after] = cond_frame_outputs[idx_after]
|
| 48 |
+
|
| 49 |
+
# add other temporally closest conditioning frames until reaching a total
|
| 50 |
+
# of `max_cond_frame_num` conditioning frames.
|
| 51 |
+
num_remain = max_cond_frame_num - len(selected_outputs)
|
| 52 |
+
inds_remain = sorted(
|
| 53 |
+
(t for t in cond_frame_outputs if t not in selected_outputs),
|
| 54 |
+
key=lambda x: abs(x - frame_idx),
|
| 55 |
+
)[:num_remain]
|
| 56 |
+
selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
|
| 57 |
+
unselected_outputs = {
|
| 58 |
+
t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
return selected_outputs, unselected_outputs
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_1d_sine_pe(pos_inds, dim, temperature=10000):
|
| 65 |
+
"""
|
| 66 |
+
Get 1D sine positional embedding as in the original Transformer paper.
|
| 67 |
+
"""
|
| 68 |
+
pe_dim = dim // 2
|
| 69 |
+
dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
|
| 70 |
+
dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
|
| 71 |
+
|
| 72 |
+
pos_embed = pos_inds.unsqueeze(-1) / dim_t
|
| 73 |
+
pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
|
| 74 |
+
return pos_embed
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def get_activation_fn(activation):
|
| 78 |
+
"""Return an activation function given a string"""
|
| 79 |
+
if activation == "relu":
|
| 80 |
+
return F.relu
|
| 81 |
+
if activation == "gelu":
|
| 82 |
+
return F.gelu
|
| 83 |
+
if activation == "glu":
|
| 84 |
+
return F.glu
|
| 85 |
+
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def get_clones(module, N):
|
| 89 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class DropPath(nn.Module):
|
| 93 |
+
# adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
|
| 94 |
+
def __init__(self, drop_prob=0.0, scale_by_keep=True):
|
| 95 |
+
super(DropPath, self).__init__()
|
| 96 |
+
self.drop_prob = drop_prob
|
| 97 |
+
self.scale_by_keep = scale_by_keep
|
| 98 |
+
|
| 99 |
+
def forward(self, x):
|
| 100 |
+
if self.drop_prob == 0.0 or not self.training:
|
| 101 |
+
return x
|
| 102 |
+
keep_prob = 1 - self.drop_prob
|
| 103 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
| 104 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 105 |
+
if keep_prob > 0.0 and self.scale_by_keep:
|
| 106 |
+
random_tensor.div_(keep_prob)
|
| 107 |
+
return x * random_tensor
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# Lightly adapted from
|
| 111 |
+
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
|
| 112 |
+
class MLP(nn.Module):
|
| 113 |
+
def __init__(
|
| 114 |
+
self,
|
| 115 |
+
input_dim: int,
|
| 116 |
+
hidden_dim: int,
|
| 117 |
+
output_dim: int,
|
| 118 |
+
num_layers: int,
|
| 119 |
+
activation: nn.Module = nn.ReLU,
|
| 120 |
+
sigmoid_output: bool = False,
|
| 121 |
+
) -> None:
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.num_layers = num_layers
|
| 124 |
+
h = [hidden_dim] * (num_layers - 1)
|
| 125 |
+
self.layers = nn.ModuleList(
|
| 126 |
+
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
| 127 |
+
)
|
| 128 |
+
self.sigmoid_output = sigmoid_output
|
| 129 |
+
self.act = activation()
|
| 130 |
+
|
| 131 |
+
def forward(self, x):
|
| 132 |
+
for i, layer in enumerate(self.layers):
|
| 133 |
+
x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
|
| 134 |
+
if self.sigmoid_output:
|
| 135 |
+
x = F.sigmoid(x)
|
| 136 |
+
return x
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
|
| 140 |
+
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
|
| 141 |
+
class LayerNorm2d(nn.Module):
|
| 142 |
+
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.weight = nn.Parameter(torch.ones(num_channels))
|
| 145 |
+
self.bias = nn.Parameter(torch.zeros(num_channels))
|
| 146 |
+
self.eps = eps
|
| 147 |
+
|
| 148 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 149 |
+
u = x.mean(1, keepdim=True)
|
| 150 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 151 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 152 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 153 |
+
return x
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def sample_box_points(
|
| 157 |
+
masks: torch.Tensor,
|
| 158 |
+
noise: float = 0.1, # SAM default
|
| 159 |
+
noise_bound: int = 20, # SAM default
|
| 160 |
+
top_left_label: int = 2,
|
| 161 |
+
bottom_right_label: int = 3,
|
| 162 |
+
) -> Tuple[np.array, np.array]:
|
| 163 |
+
"""
|
| 164 |
+
Sample a noised version of the top left and bottom right corners of a given `bbox`
|
| 165 |
+
|
| 166 |
+
Inputs:
|
| 167 |
+
- masks: [B, 1, H,W] boxes, dtype=torch.Tensor
|
| 168 |
+
- noise: noise as a fraction of box width and height, dtype=float
|
| 169 |
+
- noise_bound: maximum amount of noise (in pure pixesl), dtype=int
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
- box_coords: [B, num_pt, 2], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.float
|
| 173 |
+
- box_labels: [B, num_pt], label 2 is reserverd for top left and 3 for bottom right corners, dtype=torch.int32
|
| 174 |
+
"""
|
| 175 |
+
device = masks.device
|
| 176 |
+
box_coords = mask_to_box(masks)
|
| 177 |
+
B, _, H, W = masks.shape
|
| 178 |
+
box_labels = torch.tensor(
|
| 179 |
+
[top_left_label, bottom_right_label], dtype=torch.int, device=device
|
| 180 |
+
).repeat(B)
|
| 181 |
+
if noise > 0.0:
|
| 182 |
+
if not isinstance(noise_bound, torch.Tensor):
|
| 183 |
+
noise_bound = torch.tensor(noise_bound, device=device)
|
| 184 |
+
bbox_w = box_coords[..., 2] - box_coords[..., 0]
|
| 185 |
+
bbox_h = box_coords[..., 3] - box_coords[..., 1]
|
| 186 |
+
max_dx = torch.min(bbox_w * noise, noise_bound)
|
| 187 |
+
max_dy = torch.min(bbox_h * noise, noise_bound)
|
| 188 |
+
box_noise = 2 * torch.rand(B, 1, 4, device=device) - 1
|
| 189 |
+
box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1)
|
| 190 |
+
|
| 191 |
+
box_coords = box_coords + box_noise
|
| 192 |
+
img_bounds = (
|
| 193 |
+
torch.tensor([W, H, W, H], device=device) - 1
|
| 194 |
+
) # uncentered pixel coords
|
| 195 |
+
box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping
|
| 196 |
+
|
| 197 |
+
box_coords = box_coords.reshape(-1, 2, 2) # always 2 points
|
| 198 |
+
box_labels = box_labels.reshape(-1, 2)
|
| 199 |
+
return box_coords, box_labels
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1):
|
| 203 |
+
"""
|
| 204 |
+
Sample `num_pt` random points (along with their labels) independently from the error regions.
|
| 205 |
+
|
| 206 |
+
Inputs:
|
| 207 |
+
- gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
|
| 208 |
+
- pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
|
| 209 |
+
- num_pt: int, number of points to sample independently for each of the B error maps
|
| 210 |
+
|
| 211 |
+
Outputs:
|
| 212 |
+
- points: [B, num_pt, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
|
| 213 |
+
- labels: [B, num_pt], dtype=torch.int32, where 1 means positive clicks and 0 means
|
| 214 |
+
negative clicks
|
| 215 |
+
"""
|
| 216 |
+
if pred_masks is None: # if pred_masks is not provided, treat it as empty
|
| 217 |
+
pred_masks = torch.zeros_like(gt_masks)
|
| 218 |
+
assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
|
| 219 |
+
assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
|
| 220 |
+
assert num_pt >= 0
|
| 221 |
+
|
| 222 |
+
B, _, H_im, W_im = gt_masks.shape
|
| 223 |
+
device = gt_masks.device
|
| 224 |
+
|
| 225 |
+
# false positive region, a new point sampled in this region should have
|
| 226 |
+
# negative label to correct the FP error
|
| 227 |
+
fp_masks = ~gt_masks & pred_masks
|
| 228 |
+
# false negative region, a new point sampled in this region should have
|
| 229 |
+
# positive label to correct the FN error
|
| 230 |
+
fn_masks = gt_masks & ~pred_masks
|
| 231 |
+
# whether the prediction completely match the ground-truth on each mask
|
| 232 |
+
all_correct = torch.all((gt_masks == pred_masks).flatten(2), dim=2)
|
| 233 |
+
all_correct = all_correct[..., None, None]
|
| 234 |
+
|
| 235 |
+
# channel 0 is FP map, while channel 1 is FN map
|
| 236 |
+
pts_noise = torch.rand(B, num_pt, H_im, W_im, 2, device=device)
|
| 237 |
+
# sample a negative new click from FP region or a positive new click
|
| 238 |
+
# from FN region, depend on where the maximum falls,
|
| 239 |
+
# and in case the predictions are all correct (no FP or FN), we just
|
| 240 |
+
# sample a negative click from the background region
|
| 241 |
+
pts_noise[..., 0] *= fp_masks | (all_correct & ~gt_masks)
|
| 242 |
+
pts_noise[..., 1] *= fn_masks
|
| 243 |
+
pts_idx = pts_noise.flatten(2).argmax(dim=2)
|
| 244 |
+
labels = (pts_idx % 2).to(torch.int32)
|
| 245 |
+
pts_idx = pts_idx // 2
|
| 246 |
+
pts_x = pts_idx % W_im
|
| 247 |
+
pts_y = pts_idx // W_im
|
| 248 |
+
points = torch.stack([pts_x, pts_y], dim=2).to(torch.float)
|
| 249 |
+
return points, labels
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True):
|
| 253 |
+
"""
|
| 254 |
+
Sample 1 random point (along with its label) from the center of each error region,
|
| 255 |
+
that is, the point with the largest distance to the boundary of each error region.
|
| 256 |
+
This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py
|
| 257 |
+
|
| 258 |
+
Inputs:
|
| 259 |
+
- gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
|
| 260 |
+
- pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
|
| 261 |
+
- padding: if True, pad with boundary of 1 px for distance transform
|
| 262 |
+
|
| 263 |
+
Outputs:
|
| 264 |
+
- points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
|
| 265 |
+
- labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks
|
| 266 |
+
"""
|
| 267 |
+
import cv2
|
| 268 |
+
|
| 269 |
+
if pred_masks is None:
|
| 270 |
+
pred_masks = torch.zeros_like(gt_masks)
|
| 271 |
+
assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
|
| 272 |
+
assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
|
| 273 |
+
|
| 274 |
+
B, _, _, W_im = gt_masks.shape
|
| 275 |
+
device = gt_masks.device
|
| 276 |
+
|
| 277 |
+
# false positive region, a new point sampled in this region should have
|
| 278 |
+
# negative label to correct the FP error
|
| 279 |
+
fp_masks = ~gt_masks & pred_masks
|
| 280 |
+
# false negative region, a new point sampled in this region should have
|
| 281 |
+
# positive label to correct the FN error
|
| 282 |
+
fn_masks = gt_masks & ~pred_masks
|
| 283 |
+
|
| 284 |
+
fp_masks = fp_masks.cpu().numpy()
|
| 285 |
+
fn_masks = fn_masks.cpu().numpy()
|
| 286 |
+
points = torch.zeros(B, 1, 2, dtype=torch.float)
|
| 287 |
+
labels = torch.ones(B, 1, dtype=torch.int32)
|
| 288 |
+
for b in range(B):
|
| 289 |
+
fn_mask = fn_masks[b, 0]
|
| 290 |
+
fp_mask = fp_masks[b, 0]
|
| 291 |
+
if padding:
|
| 292 |
+
fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant")
|
| 293 |
+
fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant")
|
| 294 |
+
# compute the distance of each point in FN/FP region to its boundary
|
| 295 |
+
fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
|
| 296 |
+
fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
|
| 297 |
+
if padding:
|
| 298 |
+
fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
|
| 299 |
+
fp_mask_dt = fp_mask_dt[1:-1, 1:-1]
|
| 300 |
+
|
| 301 |
+
# take the point in FN/FP region with the largest distance to its boundary
|
| 302 |
+
fn_mask_dt_flat = fn_mask_dt.reshape(-1)
|
| 303 |
+
fp_mask_dt_flat = fp_mask_dt.reshape(-1)
|
| 304 |
+
fn_argmax = np.argmax(fn_mask_dt_flat)
|
| 305 |
+
fp_argmax = np.argmax(fp_mask_dt_flat)
|
| 306 |
+
is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax]
|
| 307 |
+
pt_idx = fn_argmax if is_positive else fp_argmax
|
| 308 |
+
points[b, 0, 0] = pt_idx % W_im # x
|
| 309 |
+
points[b, 0, 1] = pt_idx // W_im # y
|
| 310 |
+
labels[b, 0] = int(is_positive)
|
| 311 |
+
|
| 312 |
+
points = points.to(device)
|
| 313 |
+
labels = labels.to(device)
|
| 314 |
+
return points, labels
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def get_next_point(gt_masks, pred_masks, method):
|
| 318 |
+
if method == "uniform":
|
| 319 |
+
return sample_random_points_from_errors(gt_masks, pred_masks)
|
| 320 |
+
elif method == "center":
|
| 321 |
+
return sample_one_point_from_error_center(gt_masks, pred_masks)
|
| 322 |
+
else:
|
| 323 |
+
raise ValueError(f"unknown sampling method {method}")
|
avs.code/v1m.code/model/visual/sam2/organised_sam2_train.py
ADDED
|
@@ -0,0 +1,811 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.distributed
|
| 12 |
+
from model.visual.sam2.modeling.sam2_base import SAM2Base
|
| 13 |
+
from model.visual.sam2.modeling.sam2_utils import (
|
| 14 |
+
get_1d_sine_pe,
|
| 15 |
+
get_next_point,
|
| 16 |
+
sample_box_points,
|
| 17 |
+
select_closest_cond_frames,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
from utils.misc import concat_points
|
| 21 |
+
|
| 22 |
+
from utils.data_utils import BatchedVideoDatapoint
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class SAM2Train(SAM2Base):
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
image_encoder,
|
| 29 |
+
memory_attention=None,
|
| 30 |
+
memory_encoder=None,
|
| 31 |
+
prob_to_use_pt_input_for_train=0.0,
|
| 32 |
+
prob_to_use_pt_input_for_eval=0.0,
|
| 33 |
+
prob_to_use_box_input_for_train=0.0,
|
| 34 |
+
prob_to_use_box_input_for_eval=0.0,
|
| 35 |
+
# if it is greater than 1, we interactive point sampling in the 1st frame and other randomly selected frames
|
| 36 |
+
num_frames_to_correct_for_train=1, # default: only iteratively sample on first frame
|
| 37 |
+
num_frames_to_correct_for_eval=1, # default: only iteratively sample on first frame
|
| 38 |
+
rand_frames_to_correct_for_train=False,
|
| 39 |
+
rand_frames_to_correct_for_eval=False,
|
| 40 |
+
# how many frames to use as initial conditioning frames (for both point input and mask input; the first frame is always used as an initial conditioning frame)
|
| 41 |
+
# - if `rand_init_cond_frames` below is True, we randomly sample 1~num_init_cond_frames initial conditioning frames
|
| 42 |
+
# - otherwise we sample a fixed number of num_init_cond_frames initial conditioning frames
|
| 43 |
+
# note: for point input, we sample correction points on all such initial conditioning frames, and we require that `num_frames_to_correct` >= `num_init_cond_frames`;
|
| 44 |
+
# these are initial conditioning frames because as we track the video, more conditioning frames might be added
|
| 45 |
+
# when a frame receives correction clicks under point input if `add_all_frames_to_correct_as_cond=True`
|
| 46 |
+
num_init_cond_frames_for_train=1, # default: only use the first frame as initial conditioning frame
|
| 47 |
+
num_init_cond_frames_for_eval=1, # default: only use the first frame as initial conditioning frame
|
| 48 |
+
rand_init_cond_frames_for_train=True, # default: random 1~num_init_cond_frames_for_train cond frames (to be constent w/ previous TA data loader)
|
| 49 |
+
rand_init_cond_frames_for_eval=False,
|
| 50 |
+
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
|
| 51 |
+
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
|
| 52 |
+
add_all_frames_to_correct_as_cond=False,
|
| 53 |
+
# how many additional correction points to sample (on each frame selected to be corrected)
|
| 54 |
+
# note that the first frame receives an initial input click (in addition to any correction clicks)
|
| 55 |
+
num_correction_pt_per_frame=7,
|
| 56 |
+
# method for point sampling during evaluation
|
| 57 |
+
# "uniform" (sample uniformly from error region) or "center" (use the point with the largest distance to error region boundary)
|
| 58 |
+
# default to "center" to be consistent with evaluation in the SAM paper
|
| 59 |
+
pt_sampling_for_eval="center",
|
| 60 |
+
# During training, we optionally allow sampling the correction points from GT regions
|
| 61 |
+
# instead of the prediction error regions with a small probability. This might allow the
|
| 62 |
+
# model to overfit less to the error regions in training datasets
|
| 63 |
+
prob_to_sample_from_gt_for_train=0.0,
|
| 64 |
+
use_act_ckpt_iterative_pt_sampling=False,
|
| 65 |
+
# whether to forward image features per frame (as it's being tracked) during evaluation, instead of forwarding image features
|
| 66 |
+
# of all frames at once. This avoids backbone OOM errors on very long videos in evaluation, but could be slightly slower.
|
| 67 |
+
forward_backbone_per_frame_for_eval=False,
|
| 68 |
+
freeze_image_encoder=False,
|
| 69 |
+
**kwargs,
|
| 70 |
+
):
|
| 71 |
+
super().__init__(image_encoder, memory_attention, memory_encoder, **kwargs)
|
| 72 |
+
self.use_act_ckpt_iterative_pt_sampling = use_act_ckpt_iterative_pt_sampling
|
| 73 |
+
self.forward_backbone_per_frame_for_eval = forward_backbone_per_frame_for_eval
|
| 74 |
+
|
| 75 |
+
# Point sampler and conditioning frames
|
| 76 |
+
self.prob_to_use_pt_input_for_train = prob_to_use_pt_input_for_train
|
| 77 |
+
self.prob_to_use_box_input_for_train = prob_to_use_box_input_for_train
|
| 78 |
+
self.prob_to_use_pt_input_for_eval = prob_to_use_pt_input_for_eval
|
| 79 |
+
self.prob_to_use_box_input_for_eval = prob_to_use_box_input_for_eval
|
| 80 |
+
if prob_to_use_pt_input_for_train > 0 or prob_to_use_pt_input_for_eval > 0:
|
| 81 |
+
logging.info(
|
| 82 |
+
f"Training with points (sampled from masks) as inputs with p={prob_to_use_pt_input_for_train}"
|
| 83 |
+
)
|
| 84 |
+
assert num_frames_to_correct_for_train >= num_init_cond_frames_for_train
|
| 85 |
+
assert num_frames_to_correct_for_eval >= num_init_cond_frames_for_eval
|
| 86 |
+
|
| 87 |
+
self.num_frames_to_correct_for_train = num_frames_to_correct_for_train
|
| 88 |
+
self.num_frames_to_correct_for_eval = num_frames_to_correct_for_eval
|
| 89 |
+
self.rand_frames_to_correct_for_train = rand_frames_to_correct_for_train
|
| 90 |
+
self.rand_frames_to_correct_for_eval = rand_frames_to_correct_for_eval
|
| 91 |
+
# Initial multi-conditioning frames
|
| 92 |
+
self.num_init_cond_frames_for_train = num_init_cond_frames_for_train
|
| 93 |
+
self.num_init_cond_frames_for_eval = num_init_cond_frames_for_eval
|
| 94 |
+
self.rand_init_cond_frames_for_train = rand_init_cond_frames_for_train
|
| 95 |
+
self.rand_init_cond_frames_for_eval = rand_init_cond_frames_for_eval
|
| 96 |
+
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
|
| 97 |
+
self.num_correction_pt_per_frame = num_correction_pt_per_frame
|
| 98 |
+
self.pt_sampling_for_eval = pt_sampling_for_eval
|
| 99 |
+
self.prob_to_sample_from_gt_for_train = prob_to_sample_from_gt_for_train
|
| 100 |
+
# A random number generator with a fixed initial seed across GPUs
|
| 101 |
+
self.rng = np.random.default_rng(seed=42)
|
| 102 |
+
if freeze_image_encoder:
|
| 103 |
+
for p in self.image_encoder.parameters():
|
| 104 |
+
p.requires_grad = False
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def forward(self, input: BatchedVideoDatapoint):
|
| 108 |
+
if self.training or not self.forward_backbone_per_frame_for_eval:
|
| 109 |
+
# precompute image features on all frames before tracking
|
| 110 |
+
backbone_out = self.forward_image(input.flat_img_batch)
|
| 111 |
+
else:
|
| 112 |
+
# defer image feature computation on a frame until it's being tracked
|
| 113 |
+
backbone_out = {"backbone_fpn": None, "vision_pos_enc": None}
|
| 114 |
+
backbone_out = self.prepare_prompt_inputs(backbone_out, input)
|
| 115 |
+
previous_stages_out = self.forward_tracking(backbone_out, input)
|
| 116 |
+
|
| 117 |
+
return previous_stages_out
|
| 118 |
+
|
| 119 |
+
def _prepare_backbone_features_per_frame(self, img_batch, img_ids):
|
| 120 |
+
"""Compute the image backbone features on the fly for the given img_ids."""
|
| 121 |
+
# Only forward backbone on unique image ids to avoid repetitive computation
|
| 122 |
+
# (if `img_ids` has only one element, it's already unique so we skip this step).
|
| 123 |
+
if img_ids.numel() > 1:
|
| 124 |
+
unique_img_ids, inv_ids = torch.unique(img_ids, return_inverse=True)
|
| 125 |
+
else:
|
| 126 |
+
unique_img_ids, inv_ids = img_ids, None
|
| 127 |
+
|
| 128 |
+
# Compute the image features on those unique image ids
|
| 129 |
+
image = img_batch[unique_img_ids]
|
| 130 |
+
backbone_out = self.forward_image(image)
|
| 131 |
+
(
|
| 132 |
+
_,
|
| 133 |
+
vision_feats,
|
| 134 |
+
vision_pos_embeds,
|
| 135 |
+
feat_sizes,
|
| 136 |
+
) = self._prepare_backbone_features(backbone_out)
|
| 137 |
+
'''
|
| 138 |
+
vision_feats
|
| 139 |
+
torch.Size([65536, 5, 32])
|
| 140 |
+
torch.Size([16384, 5, 64])
|
| 141 |
+
torch.Size([4096, 5, 256])
|
| 142 |
+
'''
|
| 143 |
+
# Inverse-map image features for `unique_img_ids` to the final image features
|
| 144 |
+
# for the original input `img_ids`.
|
| 145 |
+
if inv_ids is not None:
|
| 146 |
+
image = image[inv_ids]
|
| 147 |
+
vision_feats = [x[:, inv_ids] for x in vision_feats]
|
| 148 |
+
vision_pos_embeds = [x[:, inv_ids] for x in vision_pos_embeds]
|
| 149 |
+
|
| 150 |
+
return image, vision_feats, vision_pos_embeds, feat_sizes
|
| 151 |
+
|
| 152 |
+
@staticmethod
|
| 153 |
+
def dont_prepare_prompt_inputs(backbone_out, num_frames=5, cond_frame=0):
|
| 154 |
+
backbone_out["gt_masks_per_frame"] = {}
|
| 155 |
+
backbone_out["num_frames"] = num_frames
|
| 156 |
+
backbone_out["use_pt_input"] = False
|
| 157 |
+
# always start from the first frame.
|
| 158 |
+
backbone_out["init_cond_frames"] = [cond_frame]
|
| 159 |
+
backbone_out["frames_not_in_init_cond"] = [i for i in range(0, num_frames) if i != cond_frame]
|
| 160 |
+
# backbone_out["init_cond_frames"] = []
|
| 161 |
+
# backbone_out["frames_not_in_init_cond"] = [i for i in range(0, num_frames)]
|
| 162 |
+
|
| 163 |
+
backbone_out["mask_inputs_per_frame"] = {}
|
| 164 |
+
backbone_out["point_inputs_per_frame"] = {}
|
| 165 |
+
backbone_out["frames_to_add_correction_pt"] = []
|
| 166 |
+
return backbone_out
|
| 167 |
+
|
| 168 |
+
def prepare_prompt_inputs(self, backbone_out, input, start_frame_idx=0):
|
| 169 |
+
"""
|
| 170 |
+
Prepare input mask, point or box prompts. Optionally, we allow tracking from
|
| 171 |
+
a custom `start_frame_idx` to the end of the video (for evaluation purposes).
|
| 172 |
+
"""
|
| 173 |
+
# Load the ground-truth masks on all frames (so that we can later
|
| 174 |
+
# sample correction points from them)
|
| 175 |
+
# gt_masks_per_frame = {
|
| 176 |
+
# stage_id: targets.segments.unsqueeze(1) # [B, 1, H_im, W_im]
|
| 177 |
+
# for stage_id, targets in enumerate(input.find_targets)
|
| 178 |
+
# }
|
| 179 |
+
gt_masks_per_frame = {
|
| 180 |
+
stage_id: masks.unsqueeze(1) # [B, 1, H_im, W_im]
|
| 181 |
+
for stage_id, masks in enumerate(input.masks)
|
| 182 |
+
}
|
| 183 |
+
# gt_masks_per_frame = input.masks.unsqueeze(2) # [T,B,1,H_im,W_im] keep everything in tensor form
|
| 184 |
+
backbone_out["gt_masks_per_frame"] = gt_masks_per_frame
|
| 185 |
+
num_frames = input.num_frames
|
| 186 |
+
backbone_out["num_frames"] = num_frames
|
| 187 |
+
|
| 188 |
+
# Randomly decide whether to use point inputs or mask inputs
|
| 189 |
+
if self.training:
|
| 190 |
+
prob_to_use_pt_input = self.prob_to_use_pt_input_for_train
|
| 191 |
+
prob_to_use_box_input = self.prob_to_use_box_input_for_train
|
| 192 |
+
num_frames_to_correct = self.num_frames_to_correct_for_train
|
| 193 |
+
rand_frames_to_correct = self.rand_frames_to_correct_for_train
|
| 194 |
+
num_init_cond_frames = self.num_init_cond_frames_for_train
|
| 195 |
+
rand_init_cond_frames = self.rand_init_cond_frames_for_train
|
| 196 |
+
else:
|
| 197 |
+
prob_to_use_pt_input = self.prob_to_use_pt_input_for_eval
|
| 198 |
+
prob_to_use_box_input = self.prob_to_use_box_input_for_eval
|
| 199 |
+
num_frames_to_correct = self.num_frames_to_correct_for_eval
|
| 200 |
+
rand_frames_to_correct = self.rand_frames_to_correct_for_eval
|
| 201 |
+
num_init_cond_frames = self.num_init_cond_frames_for_eval
|
| 202 |
+
rand_init_cond_frames = self.rand_init_cond_frames_for_eval
|
| 203 |
+
if num_frames == 1:
|
| 204 |
+
# here we handle a special case for mixing video + SAM on image training,
|
| 205 |
+
# where we force using point input for the SAM task on static images
|
| 206 |
+
prob_to_use_pt_input = 1.0
|
| 207 |
+
num_frames_to_correct = 1
|
| 208 |
+
num_init_cond_frames = 1
|
| 209 |
+
assert num_init_cond_frames >= 1
|
| 210 |
+
# (here `self.rng.random()` returns value in range 0.0 <= X < 1.0)
|
| 211 |
+
use_pt_input = self.rng.random() < prob_to_use_pt_input
|
| 212 |
+
if rand_init_cond_frames and num_init_cond_frames > 1:
|
| 213 |
+
# randomly select 1 to `num_init_cond_frames` frames as initial conditioning frames
|
| 214 |
+
num_init_cond_frames = self.rng.integers(
|
| 215 |
+
1, num_init_cond_frames, endpoint=True
|
| 216 |
+
)
|
| 217 |
+
if (
|
| 218 |
+
use_pt_input
|
| 219 |
+
and rand_frames_to_correct
|
| 220 |
+
and num_frames_to_correct > num_init_cond_frames
|
| 221 |
+
):
|
| 222 |
+
# randomly select `num_init_cond_frames` to `num_frames_to_correct` frames to sample
|
| 223 |
+
# correction clicks (only for the case of point input)
|
| 224 |
+
num_frames_to_correct = self.rng.integers(
|
| 225 |
+
num_init_cond_frames, num_frames_to_correct, endpoint=True
|
| 226 |
+
)
|
| 227 |
+
backbone_out["use_pt_input"] = use_pt_input
|
| 228 |
+
|
| 229 |
+
# Sample initial conditioning frames
|
| 230 |
+
if num_init_cond_frames == 1:
|
| 231 |
+
init_cond_frames = [start_frame_idx] # starting frame
|
| 232 |
+
else:
|
| 233 |
+
# starting frame + randomly selected remaining frames (without replacement)
|
| 234 |
+
init_cond_frames = [start_frame_idx] + self.rng.choice(
|
| 235 |
+
range(start_frame_idx + 1, num_frames),
|
| 236 |
+
num_init_cond_frames - 1,
|
| 237 |
+
replace=False,
|
| 238 |
+
).tolist()
|
| 239 |
+
backbone_out["init_cond_frames"] = init_cond_frames
|
| 240 |
+
backbone_out["frames_not_in_init_cond"] = [
|
| 241 |
+
t for t in range(start_frame_idx, num_frames) if t not in init_cond_frames
|
| 242 |
+
]
|
| 243 |
+
# Prepare mask or point inputs on initial conditioning frames
|
| 244 |
+
backbone_out["mask_inputs_per_frame"] = {} # {frame_idx: <input_masks>}
|
| 245 |
+
backbone_out["point_inputs_per_frame"] = {} # {frame_idx: <input_points>}
|
| 246 |
+
for t in init_cond_frames:
|
| 247 |
+
if not use_pt_input:
|
| 248 |
+
backbone_out["mask_inputs_per_frame"][t] = gt_masks_per_frame[t]
|
| 249 |
+
else:
|
| 250 |
+
# During training # P(box) = prob_to_use_pt_input * prob_to_use_box_input
|
| 251 |
+
use_box_input = self.rng.random() < prob_to_use_box_input
|
| 252 |
+
if use_box_input:
|
| 253 |
+
points, labels = sample_box_points(
|
| 254 |
+
gt_masks_per_frame[t],
|
| 255 |
+
)
|
| 256 |
+
else:
|
| 257 |
+
# (here we only sample **one initial point** on initial conditioning frames from the
|
| 258 |
+
# ground-truth mask; we may sample more correction points on the fly)
|
| 259 |
+
points, labels = get_next_point(
|
| 260 |
+
gt_masks=gt_masks_per_frame[t],
|
| 261 |
+
pred_masks=None,
|
| 262 |
+
method=(
|
| 263 |
+
"uniform" if self.training else self.pt_sampling_for_eval
|
| 264 |
+
),
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
point_inputs = {"point_coords": points, "point_labels": labels}
|
| 268 |
+
backbone_out["point_inputs_per_frame"][t] = point_inputs
|
| 269 |
+
|
| 270 |
+
# Sample frames where we will add correction clicks on the fly
|
| 271 |
+
# based on the error between prediction and ground-truth masks
|
| 272 |
+
if not use_pt_input:
|
| 273 |
+
# no correction points will be sampled when using mask inputs
|
| 274 |
+
frames_to_add_correction_pt = []
|
| 275 |
+
elif num_frames_to_correct == num_init_cond_frames:
|
| 276 |
+
frames_to_add_correction_pt = init_cond_frames
|
| 277 |
+
else:
|
| 278 |
+
assert num_frames_to_correct > num_init_cond_frames
|
| 279 |
+
# initial cond frame + randomly selected remaining frames (without replacement)
|
| 280 |
+
extra_num = num_frames_to_correct - num_init_cond_frames
|
| 281 |
+
frames_to_add_correction_pt = (
|
| 282 |
+
init_cond_frames
|
| 283 |
+
+ self.rng.choice(
|
| 284 |
+
backbone_out["frames_not_in_init_cond"], extra_num, replace=False
|
| 285 |
+
).tolist()
|
| 286 |
+
)
|
| 287 |
+
backbone_out["frames_to_add_correction_pt"] = frames_to_add_correction_pt
|
| 288 |
+
|
| 289 |
+
return backbone_out
|
| 290 |
+
|
| 291 |
+
def forward_tracking_wo_prompt(self, backbone_out, audio_res=None, return_dict=False):
|
| 292 |
+
# img_feats_already_computed = True.
|
| 293 |
+
"""Forward video tracking on each frame (and sample correction clicks)."""
|
| 294 |
+
# Prepare the backbone features
|
| 295 |
+
# - vision_feats and vision_pos_embeds are in (HW)BC format
|
| 296 |
+
(
|
| 297 |
+
_,
|
| 298 |
+
vision_feats,
|
| 299 |
+
vision_pos_embeds,
|
| 300 |
+
feat_sizes,
|
| 301 |
+
) = self._prepare_backbone_features(backbone_out)
|
| 302 |
+
|
| 303 |
+
# Starting the stage loop
|
| 304 |
+
num_frames = backbone_out["num_frames"]
|
| 305 |
+
init_cond_frames = backbone_out["init_cond_frames"]
|
| 306 |
+
frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"]
|
| 307 |
+
# first process all the initial conditioning frames to encode them as memory,
|
| 308 |
+
# and then conditioning on them to track the remaining frames
|
| 309 |
+
processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"]
|
| 310 |
+
output_dict = {
|
| 311 |
+
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
| 312 |
+
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
av_v_feats, av_a_feats = audio_res
|
| 316 |
+
for stage_id in processing_order:
|
| 317 |
+
# Get the image features for the current frames
|
| 318 |
+
img_ids = stage_id
|
| 319 |
+
# Retrieve image features according to img_ids (if they are already computed).
|
| 320 |
+
current_vision_feats = [x[:, img_ids].unsqueeze(1) for x in vision_feats] # add unsqueeze to maintain single sample.
|
| 321 |
+
current_vision_pos_embeds = [x[:, img_ids].unsqueeze(1) for x in vision_pos_embeds] # add unsqueeze to maintain single sample.
|
| 322 |
+
current_av_v_feats = [x[img_ids] for x in av_v_feats]
|
| 323 |
+
current_av_a_feats = [x[img_ids] for x in av_a_feats]
|
| 324 |
+
|
| 325 |
+
# Get output masks based on this frame's prompts and previous memory
|
| 326 |
+
current_out = self.track_step_wo_prompt(
|
| 327 |
+
frame_idx=stage_id,
|
| 328 |
+
is_init_cond_frame=stage_id in init_cond_frames,
|
| 329 |
+
current_vision_feats=current_vision_feats,
|
| 330 |
+
current_vision_pos_embeds=current_vision_pos_embeds,
|
| 331 |
+
feat_sizes=feat_sizes,
|
| 332 |
+
point_inputs=None, # backbone_out["point_inputs_per_frame"].get(stage_id, None),
|
| 333 |
+
mask_inputs=None, # backbone_out["mask_inputs_per_frame"].get(stage_id, None),
|
| 334 |
+
gt_masks=None, # backbone_out["gt_masks_per_frame"].get(stage_id, None),
|
| 335 |
+
frames_to_add_correction_pt=None, # frames_to_add_correction_pt,
|
| 336 |
+
output_dict=output_dict,
|
| 337 |
+
num_frames=num_frames,
|
| 338 |
+
audio_res=(current_av_v_feats, current_av_a_feats),
|
| 339 |
+
)
|
| 340 |
+
# Append the output, depending on whether it's a conditioning frame
|
| 341 |
+
add_output_as_cond_frame = stage_id in init_cond_frames or (
|
| 342 |
+
self.add_all_frames_to_correct_as_cond
|
| 343 |
+
and stage_id in frames_to_add_correction_pt
|
| 344 |
+
)
|
| 345 |
+
if add_output_as_cond_frame:
|
| 346 |
+
output_dict["cond_frame_outputs"][stage_id] = current_out
|
| 347 |
+
else:
|
| 348 |
+
output_dict["non_cond_frame_outputs"][stage_id] = current_out
|
| 349 |
+
|
| 350 |
+
if return_dict:
|
| 351 |
+
return output_dict
|
| 352 |
+
# turn `output_dict` into a list for loss function
|
| 353 |
+
all_frame_outputs = {}
|
| 354 |
+
all_frame_outputs.update(output_dict["cond_frame_outputs"])
|
| 355 |
+
all_frame_outputs.update(output_dict["non_cond_frame_outputs"])
|
| 356 |
+
all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)]
|
| 357 |
+
# Make DDP happy with activation checkpointing by removing unused keys
|
| 358 |
+
all_frame_outputs = [
|
| 359 |
+
{k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs
|
| 360 |
+
]
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
return all_frame_outputs
|
| 364 |
+
|
| 365 |
+
def track_step_wo_prompt(
|
| 366 |
+
self,
|
| 367 |
+
frame_idx,
|
| 368 |
+
is_init_cond_frame,
|
| 369 |
+
current_vision_feats,
|
| 370 |
+
current_vision_pos_embeds,
|
| 371 |
+
feat_sizes,
|
| 372 |
+
point_inputs,
|
| 373 |
+
mask_inputs,
|
| 374 |
+
output_dict,
|
| 375 |
+
num_frames,
|
| 376 |
+
track_in_reverse=False, # tracking in reverse time order (for demo usage)
|
| 377 |
+
run_mem_encoder=True, # Whether to run the memory encoder on the predicted masks.
|
| 378 |
+
prev_sam_mask_logits=None, # The previously predicted SAM mask logits.
|
| 379 |
+
frames_to_add_correction_pt=None,
|
| 380 |
+
gt_masks=None,
|
| 381 |
+
audio_res=None,
|
| 382 |
+
):
|
| 383 |
+
if frames_to_add_correction_pt is None:
|
| 384 |
+
frames_to_add_correction_pt = []
|
| 385 |
+
|
| 386 |
+
current_out, sam_outputs, high_res_features, pix_feat = self._track_step_wo_prompt(
|
| 387 |
+
frame_idx,
|
| 388 |
+
is_init_cond_frame,
|
| 389 |
+
current_vision_feats,
|
| 390 |
+
current_vision_pos_embeds,
|
| 391 |
+
feat_sizes,
|
| 392 |
+
point_inputs,
|
| 393 |
+
mask_inputs,
|
| 394 |
+
output_dict,
|
| 395 |
+
num_frames,
|
| 396 |
+
track_in_reverse,
|
| 397 |
+
prev_sam_mask_logits,
|
| 398 |
+
audio_res
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
(
|
| 402 |
+
low_res_multimasks,
|
| 403 |
+
high_res_multimasks,
|
| 404 |
+
ious,
|
| 405 |
+
low_res_masks,
|
| 406 |
+
high_res_masks,
|
| 407 |
+
obj_ptr,
|
| 408 |
+
object_score_logits,
|
| 409 |
+
) = sam_outputs
|
| 410 |
+
current_out["multistep_pred_masks"] = low_res_masks
|
| 411 |
+
current_out["multistep_pred_masks_high_res"] = high_res_masks
|
| 412 |
+
current_out["multistep_pred_multimasks"] = [low_res_multimasks]
|
| 413 |
+
current_out["multistep_pred_multimasks_high_res"] = [high_res_multimasks]
|
| 414 |
+
current_out["multistep_pred_ious"] = [ious]
|
| 415 |
+
current_out["multistep_point_inputs"] = [point_inputs]
|
| 416 |
+
current_out["multistep_object_score_logits"] = [object_score_logits]
|
| 417 |
+
|
| 418 |
+
'''
|
| 419 |
+
# Optionally, sample correction points iteratively to correct the mask
|
| 420 |
+
if frame_idx in frames_to_add_correction_pt:
|
| 421 |
+
point_inputs, final_sam_outputs = self._iter_correct_pt_sampling(
|
| 422 |
+
is_init_cond_frame,
|
| 423 |
+
point_inputs,
|
| 424 |
+
gt_masks,
|
| 425 |
+
high_res_features,
|
| 426 |
+
pix_feat,
|
| 427 |
+
low_res_multimasks,
|
| 428 |
+
high_res_multimasks,
|
| 429 |
+
ious,
|
| 430 |
+
low_res_masks,
|
| 431 |
+
high_res_masks,
|
| 432 |
+
object_score_logits,
|
| 433 |
+
current_out,
|
| 434 |
+
)
|
| 435 |
+
(
|
| 436 |
+
_,
|
| 437 |
+
_,
|
| 438 |
+
_,
|
| 439 |
+
low_res_masks,
|
| 440 |
+
high_res_masks,
|
| 441 |
+
obj_ptr,
|
| 442 |
+
object_score_logits,
|
| 443 |
+
) = final_sam_outputs
|
| 444 |
+
'''
|
| 445 |
+
# Use the final prediction (after all correction steps for output and eval)
|
| 446 |
+
current_out["pred_masks"] = low_res_masks
|
| 447 |
+
current_out["pred_masks_high_res"] = high_res_masks
|
| 448 |
+
current_out["obj_ptr"] = obj_ptr
|
| 449 |
+
|
| 450 |
+
# Finally run the memory encoder on the predicted mask to encode
|
| 451 |
+
# it into a new memory feature (that can be used in future frames)
|
| 452 |
+
|
| 453 |
+
self._encode_memory_in_output(
|
| 454 |
+
current_vision_feats,
|
| 455 |
+
feat_sizes,
|
| 456 |
+
666., # point_inputs,
|
| 457 |
+
run_mem_encoder,
|
| 458 |
+
# we follow SAM2 predictor, if we have multiple masks output, we only utilise the first one to perform
|
| 459 |
+
# the memory rope attention.
|
| 460 |
+
high_res_masks,
|
| 461 |
+
object_score_logits,
|
| 462 |
+
current_out,
|
| 463 |
+
)
|
| 464 |
+
return current_out
|
| 465 |
+
|
| 466 |
+
def _track_step_wo_prompt(
|
| 467 |
+
self,
|
| 468 |
+
frame_idx,
|
| 469 |
+
is_init_cond_frame,
|
| 470 |
+
current_vision_feats,
|
| 471 |
+
current_vision_pos_embeds,
|
| 472 |
+
feat_sizes,
|
| 473 |
+
point_inputs,
|
| 474 |
+
mask_inputs,
|
| 475 |
+
output_dict,
|
| 476 |
+
num_frames,
|
| 477 |
+
track_in_reverse,
|
| 478 |
+
prev_sam_mask_logits,
|
| 479 |
+
audio_res=None
|
| 480 |
+
):
|
| 481 |
+
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
|
| 482 |
+
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
|
| 483 |
+
if len(current_vision_feats) > 1:
|
| 484 |
+
high_res_features = [
|
| 485 |
+
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
|
| 486 |
+
for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
|
| 487 |
+
]
|
| 488 |
+
else:
|
| 489 |
+
high_res_features = None
|
| 490 |
+
if mask_inputs is not None and self.use_mask_input_as_output_without_sam: # False
|
| 491 |
+
# When use_mask_input_as_output_without_sam=True, we directly output the mask input
|
| 492 |
+
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
|
| 493 |
+
pix_feat = current_vision_feats[-1].permute(1, 2, 0)
|
| 494 |
+
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
|
| 495 |
+
sam_outputs = self._use_mask_as_output(
|
| 496 |
+
pix_feat, high_res_features, mask_inputs
|
| 497 |
+
)
|
| 498 |
+
else:
|
| 499 |
+
# fused the visual feature with previous memory features in the memory bank
|
| 500 |
+
pix_feat = self._prepare_memory_conditioned_features(
|
| 501 |
+
frame_idx=frame_idx,
|
| 502 |
+
is_init_cond_frame=is_init_cond_frame,
|
| 503 |
+
current_vision_feats=current_vision_feats[-1:],
|
| 504 |
+
current_vision_pos_embeds=current_vision_pos_embeds[-1:],
|
| 505 |
+
feat_sizes=feat_sizes[-1:],
|
| 506 |
+
output_dict=output_dict,
|
| 507 |
+
num_frames=num_frames,
|
| 508 |
+
track_in_reverse=track_in_reverse,
|
| 509 |
+
)
|
| 510 |
+
# current_vision_feats[-1] = current_vision_feats[-1] + self.no_mem_embed
|
| 511 |
+
# pix_feat = current_vision_feats[-1].permute(1, 2, 0)
|
| 512 |
+
# pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
|
| 513 |
+
|
| 514 |
+
# we do not apply any prompts except audio.
|
| 515 |
+
'''
|
| 516 |
+
# apply SAM-style segmentation head
|
| 517 |
+
# here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
|
| 518 |
+
# e.g. in demo where such logits come from earlier interaction instead of correction sampling
|
| 519 |
+
# (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
|
| 520 |
+
# if prev_sam_mask_logits is not None:
|
| 521 |
+
# assert point_inputs is not None and mask_inputs is None
|
| 522 |
+
# mask_inputs = prev_sam_mask_logits
|
| 523 |
+
|
| 524 |
+
## comment this line, as we don't use points as prompts.
|
| 525 |
+
# multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
|
| 526 |
+
'''
|
| 527 |
+
|
| 528 |
+
sam_outputs = self._forward_sam_heads(
|
| 529 |
+
backbone_features=pix_feat,
|
| 530 |
+
point_inputs=point_inputs,
|
| 531 |
+
mask_inputs=mask_inputs,
|
| 532 |
+
high_res_features=high_res_features,
|
| 533 |
+
multimask_output=True,
|
| 534 |
+
audio_res=audio_res
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
return current_out, sam_outputs, high_res_features, pix_feat
|
| 538 |
+
|
| 539 |
+
def forward_tracking(
|
| 540 |
+
self, backbone_out, input: BatchedVideoDatapoint, return_dict=False
|
| 541 |
+
):
|
| 542 |
+
"""Forward video tracking on each frame (and sample correction clicks)."""
|
| 543 |
+
img_feats_already_computed = backbone_out["backbone_fpn"] is not None
|
| 544 |
+
if img_feats_already_computed:
|
| 545 |
+
# Prepare the backbone features
|
| 546 |
+
# - vision_feats and vision_pos_embeds are in (HW)BC format
|
| 547 |
+
(
|
| 548 |
+
_,
|
| 549 |
+
vision_feats,
|
| 550 |
+
vision_pos_embeds,
|
| 551 |
+
feat_sizes,
|
| 552 |
+
) = self._prepare_backbone_features(backbone_out)
|
| 553 |
+
|
| 554 |
+
# Starting the stage loop
|
| 555 |
+
num_frames = backbone_out["num_frames"]
|
| 556 |
+
init_cond_frames = backbone_out["init_cond_frames"]
|
| 557 |
+
frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"]
|
| 558 |
+
# first process all the initial conditioning frames to encode them as memory,
|
| 559 |
+
# and then conditioning on them to track the remaining frames
|
| 560 |
+
processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"]
|
| 561 |
+
output_dict = {
|
| 562 |
+
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
| 563 |
+
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
| 564 |
+
}
|
| 565 |
+
for stage_id in processing_order:
|
| 566 |
+
# Get the image features for the current frames
|
| 567 |
+
# img_ids = input.find_inputs[stage_id].img_ids
|
| 568 |
+
img_ids = input.flat_obj_to_img_idx[stage_id]
|
| 569 |
+
if img_feats_already_computed:
|
| 570 |
+
# Retrieve image features according to img_ids (if they are already computed).
|
| 571 |
+
current_vision_feats = [x[:, img_ids] for x in vision_feats]
|
| 572 |
+
current_vision_pos_embeds = [x[:, img_ids] for x in vision_pos_embeds]
|
| 573 |
+
else:
|
| 574 |
+
# Otherwise, compute the image features on the fly for the given img_ids
|
| 575 |
+
# (this might be used for evaluation on long videos to avoid backbone OOM).
|
| 576 |
+
(
|
| 577 |
+
_,
|
| 578 |
+
current_vision_feats,
|
| 579 |
+
current_vision_pos_embeds,
|
| 580 |
+
feat_sizes,
|
| 581 |
+
) = self._prepare_backbone_features_per_frame(
|
| 582 |
+
input.flat_img_batch, img_ids
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
# Get output masks based on this frame's prompts and previous memory
|
| 586 |
+
current_out = self.track_step(
|
| 587 |
+
frame_idx=stage_id,
|
| 588 |
+
is_init_cond_frame=stage_id in init_cond_frames,
|
| 589 |
+
current_vision_feats=current_vision_feats,
|
| 590 |
+
current_vision_pos_embeds=current_vision_pos_embeds,
|
| 591 |
+
feat_sizes=feat_sizes,
|
| 592 |
+
point_inputs=backbone_out["point_inputs_per_frame"].get(stage_id, None),
|
| 593 |
+
mask_inputs=backbone_out["mask_inputs_per_frame"].get(stage_id, None),
|
| 594 |
+
gt_masks=backbone_out["gt_masks_per_frame"].get(stage_id, None),
|
| 595 |
+
frames_to_add_correction_pt=frames_to_add_correction_pt,
|
| 596 |
+
output_dict=output_dict,
|
| 597 |
+
num_frames=num_frames,
|
| 598 |
+
)
|
| 599 |
+
# Append the output, depending on whether it's a conditioning frame
|
| 600 |
+
add_output_as_cond_frame = stage_id in init_cond_frames or (
|
| 601 |
+
self.add_all_frames_to_correct_as_cond
|
| 602 |
+
and stage_id in frames_to_add_correction_pt
|
| 603 |
+
)
|
| 604 |
+
if add_output_as_cond_frame:
|
| 605 |
+
output_dict["cond_frame_outputs"][stage_id] = current_out
|
| 606 |
+
else:
|
| 607 |
+
output_dict["non_cond_frame_outputs"][stage_id] = current_out
|
| 608 |
+
|
| 609 |
+
if return_dict:
|
| 610 |
+
return output_dict
|
| 611 |
+
# turn `output_dict` into a list for loss function
|
| 612 |
+
all_frame_outputs = {}
|
| 613 |
+
all_frame_outputs.update(output_dict["cond_frame_outputs"])
|
| 614 |
+
all_frame_outputs.update(output_dict["non_cond_frame_outputs"])
|
| 615 |
+
all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)]
|
| 616 |
+
# Make DDP happy with activation checkpointing by removing unused keys
|
| 617 |
+
all_frame_outputs = [
|
| 618 |
+
{k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs
|
| 619 |
+
]
|
| 620 |
+
|
| 621 |
+
return all_frame_outputs
|
| 622 |
+
|
| 623 |
+
def track_step(
|
| 624 |
+
self,
|
| 625 |
+
frame_idx,
|
| 626 |
+
is_init_cond_frame,
|
| 627 |
+
current_vision_feats,
|
| 628 |
+
current_vision_pos_embeds,
|
| 629 |
+
feat_sizes,
|
| 630 |
+
point_inputs,
|
| 631 |
+
mask_inputs,
|
| 632 |
+
output_dict,
|
| 633 |
+
num_frames,
|
| 634 |
+
track_in_reverse=False, # tracking in reverse time order (for demo usage)
|
| 635 |
+
run_mem_encoder=True, # Whether to run the memory encoder on the predicted masks.
|
| 636 |
+
prev_sam_mask_logits=None, # The previously predicted SAM mask logits.
|
| 637 |
+
frames_to_add_correction_pt=None,
|
| 638 |
+
gt_masks=None,
|
| 639 |
+
):
|
| 640 |
+
if frames_to_add_correction_pt is None:
|
| 641 |
+
frames_to_add_correction_pt = []
|
| 642 |
+
current_out, sam_outputs, high_res_features, pix_feat = self._track_step(
|
| 643 |
+
frame_idx,
|
| 644 |
+
is_init_cond_frame,
|
| 645 |
+
current_vision_feats,
|
| 646 |
+
current_vision_pos_embeds,
|
| 647 |
+
feat_sizes,
|
| 648 |
+
point_inputs,
|
| 649 |
+
mask_inputs,
|
| 650 |
+
output_dict,
|
| 651 |
+
num_frames,
|
| 652 |
+
track_in_reverse,
|
| 653 |
+
prev_sam_mask_logits,
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
(
|
| 657 |
+
low_res_multimasks,
|
| 658 |
+
high_res_multimasks,
|
| 659 |
+
ious,
|
| 660 |
+
low_res_masks,
|
| 661 |
+
high_res_masks,
|
| 662 |
+
obj_ptr,
|
| 663 |
+
object_score_logits,
|
| 664 |
+
) = sam_outputs
|
| 665 |
+
|
| 666 |
+
current_out["multistep_pred_masks"] = low_res_masks
|
| 667 |
+
current_out["multistep_pred_masks_high_res"] = high_res_masks
|
| 668 |
+
current_out["multistep_pred_multimasks"] = [low_res_multimasks]
|
| 669 |
+
current_out["multistep_pred_multimasks_high_res"] = [high_res_multimasks]
|
| 670 |
+
current_out["multistep_pred_ious"] = [ious]
|
| 671 |
+
current_out["multistep_point_inputs"] = [point_inputs]
|
| 672 |
+
current_out["multistep_object_score_logits"] = [object_score_logits]
|
| 673 |
+
|
| 674 |
+
# Optionally, sample correction points iteratively to correct the mask
|
| 675 |
+
if frame_idx in frames_to_add_correction_pt:
|
| 676 |
+
point_inputs, final_sam_outputs = self._iter_correct_pt_sampling(
|
| 677 |
+
is_init_cond_frame,
|
| 678 |
+
point_inputs,
|
| 679 |
+
gt_masks,
|
| 680 |
+
high_res_features,
|
| 681 |
+
pix_feat,
|
| 682 |
+
low_res_multimasks,
|
| 683 |
+
high_res_multimasks,
|
| 684 |
+
ious,
|
| 685 |
+
low_res_masks,
|
| 686 |
+
high_res_masks,
|
| 687 |
+
object_score_logits,
|
| 688 |
+
current_out,
|
| 689 |
+
)
|
| 690 |
+
(
|
| 691 |
+
_,
|
| 692 |
+
_,
|
| 693 |
+
_,
|
| 694 |
+
low_res_masks,
|
| 695 |
+
high_res_masks,
|
| 696 |
+
obj_ptr,
|
| 697 |
+
object_score_logits,
|
| 698 |
+
) = final_sam_outputs
|
| 699 |
+
|
| 700 |
+
# Use the final prediction (after all correction steps for output and eval)
|
| 701 |
+
current_out["pred_masks"] = low_res_masks
|
| 702 |
+
current_out["pred_masks_high_res"] = high_res_masks
|
| 703 |
+
current_out["obj_ptr"] = obj_ptr
|
| 704 |
+
|
| 705 |
+
# Finally run the memory encoder on the predicted mask to encode
|
| 706 |
+
# it into a new memory feature (that can be used in future frames)
|
| 707 |
+
self._encode_memory_in_output(
|
| 708 |
+
current_vision_feats,
|
| 709 |
+
feat_sizes,
|
| 710 |
+
point_inputs,
|
| 711 |
+
run_mem_encoder,
|
| 712 |
+
high_res_masks,
|
| 713 |
+
object_score_logits,
|
| 714 |
+
current_out,
|
| 715 |
+
)
|
| 716 |
+
return current_out
|
| 717 |
+
|
| 718 |
+
def _iter_correct_pt_sampling(
|
| 719 |
+
self,
|
| 720 |
+
is_init_cond_frame,
|
| 721 |
+
point_inputs,
|
| 722 |
+
gt_masks,
|
| 723 |
+
high_res_features,
|
| 724 |
+
pix_feat_with_mem,
|
| 725 |
+
low_res_multimasks,
|
| 726 |
+
high_res_multimasks,
|
| 727 |
+
ious,
|
| 728 |
+
low_res_masks,
|
| 729 |
+
high_res_masks,
|
| 730 |
+
object_score_logits,
|
| 731 |
+
current_out,
|
| 732 |
+
):
|
| 733 |
+
|
| 734 |
+
assert gt_masks is not None
|
| 735 |
+
all_pred_masks = [low_res_masks]
|
| 736 |
+
all_pred_high_res_masks = [high_res_masks]
|
| 737 |
+
all_pred_multimasks = [low_res_multimasks]
|
| 738 |
+
all_pred_high_res_multimasks = [high_res_multimasks]
|
| 739 |
+
all_pred_ious = [ious]
|
| 740 |
+
all_point_inputs = [point_inputs]
|
| 741 |
+
all_object_score_logits = [object_score_logits]
|
| 742 |
+
for _ in range(self.num_correction_pt_per_frame):
|
| 743 |
+
# sample a new point from the error between prediction and ground-truth
|
| 744 |
+
# (with a small probability, directly sample from GT masks instead of errors)
|
| 745 |
+
if self.training and self.prob_to_sample_from_gt_for_train > 0:
|
| 746 |
+
sample_from_gt = (
|
| 747 |
+
self.rng.random() < self.prob_to_sample_from_gt_for_train
|
| 748 |
+
)
|
| 749 |
+
else:
|
| 750 |
+
sample_from_gt = False
|
| 751 |
+
# if `pred_for_new_pt` is None, only GT masks will be used for point sampling
|
| 752 |
+
pred_for_new_pt = None if sample_from_gt else (high_res_masks > 0)
|
| 753 |
+
new_points, new_labels = get_next_point(
|
| 754 |
+
gt_masks=gt_masks,
|
| 755 |
+
pred_masks=pred_for_new_pt,
|
| 756 |
+
method="uniform" if self.training else self.pt_sampling_for_eval,
|
| 757 |
+
)
|
| 758 |
+
point_inputs = concat_points(point_inputs, new_points, new_labels)
|
| 759 |
+
# Feed the mask logits of the previous SAM outputs in the next SAM decoder step.
|
| 760 |
+
# For tracking, this means that when the user adds a correction click, we also feed
|
| 761 |
+
# the tracking output mask logits along with the click as input to the SAM decoder.
|
| 762 |
+
mask_inputs = low_res_masks
|
| 763 |
+
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
|
| 764 |
+
if self.use_act_ckpt_iterative_pt_sampling and not multimask_output:
|
| 765 |
+
sam_outputs = torch.utils.checkpoint.checkpoint(
|
| 766 |
+
self._forward_sam_heads,
|
| 767 |
+
backbone_features=pix_feat_with_mem,
|
| 768 |
+
point_inputs=point_inputs,
|
| 769 |
+
mask_inputs=mask_inputs,
|
| 770 |
+
high_res_features=high_res_features,
|
| 771 |
+
multimask_output=multimask_output,
|
| 772 |
+
use_reentrant=False,
|
| 773 |
+
)
|
| 774 |
+
else:
|
| 775 |
+
sam_outputs = self._forward_sam_heads(
|
| 776 |
+
backbone_features=pix_feat_with_mem,
|
| 777 |
+
point_inputs=point_inputs,
|
| 778 |
+
mask_inputs=mask_inputs,
|
| 779 |
+
high_res_features=high_res_features,
|
| 780 |
+
multimask_output=multimask_output,
|
| 781 |
+
)
|
| 782 |
+
(
|
| 783 |
+
low_res_multimasks,
|
| 784 |
+
high_res_multimasks,
|
| 785 |
+
ious,
|
| 786 |
+
low_res_masks,
|
| 787 |
+
high_res_masks,
|
| 788 |
+
_,
|
| 789 |
+
object_score_logits,
|
| 790 |
+
) = sam_outputs
|
| 791 |
+
all_pred_masks.append(low_res_masks)
|
| 792 |
+
all_pred_high_res_masks.append(high_res_masks)
|
| 793 |
+
all_pred_multimasks.append(low_res_multimasks)
|
| 794 |
+
all_pred_high_res_multimasks.append(high_res_multimasks)
|
| 795 |
+
all_pred_ious.append(ious)
|
| 796 |
+
all_point_inputs.append(point_inputs)
|
| 797 |
+
all_object_score_logits.append(object_score_logits)
|
| 798 |
+
|
| 799 |
+
# Concatenate the masks along channel (to compute losses on all of them,
|
| 800 |
+
# using `MultiStepIteractiveMasks`)
|
| 801 |
+
current_out["multistep_pred_masks"] = torch.cat(all_pred_masks, dim=1)
|
| 802 |
+
current_out["multistep_pred_masks_high_res"] = torch.cat(
|
| 803 |
+
all_pred_high_res_masks, dim=1
|
| 804 |
+
)
|
| 805 |
+
current_out["multistep_pred_multimasks"] = all_pred_multimasks
|
| 806 |
+
current_out["multistep_pred_multimasks_high_res"] = all_pred_high_res_multimasks
|
| 807 |
+
current_out["multistep_pred_ious"] = all_pred_ious
|
| 808 |
+
current_out["multistep_point_inputs"] = all_point_inputs
|
| 809 |
+
current_out["multistep_object_score_logits"] = all_object_score_logits
|
| 810 |
+
|
| 811 |
+
return point_inputs, sam_outputs
|
avs.code/v1m.code/model/visual/sam2/utils/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|