File size: 7,907 Bytes
c6dfc69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
"""Distributed inference on the test set; runs the same three `process` modes as training validation."""
import os
import pathlib
import torch
import numpy
import random
import argparse
from easydict import EasyDict

# Avoid import failure when configs.config creates saved_dir without write permission.
_real_mkdir = pathlib.Path.mkdir


def _safe_mkdir(self, mode=0o777, parents=False, exist_ok=False):
    try:
        return _real_mkdir(self, mode, parents=parents, exist_ok=exist_ok)
    except PermissionError:
        pass


pathlib.Path.mkdir = _safe_mkdir


def seed_it(seed):
    random.seed(seed)
    os.environ["PYTHONSEED"] = str(seed)
    numpy.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    torch.manual_seed(seed)


class _DummyTensorboard:
    """Minimal Tensorboard stub so Trainer.valid runs without wandb logging."""

    def upload_wandb_info(self, info_dict):
        pass

    def upload_wandb_image(self, *args, **kwargs):
        pass


def main(local_rank, ngpus_per_node, hyp_param):
    hyp_param.local_rank = local_rank
    torch.distributed.init_process_group(
        backend='nccl',
        init_method='env://',
        rank=hyp_param.local_rank,
        world_size=hyp_param.gpus * 1
    )
    seed_it(local_rank + hyp_param.seed)

    import model.visual.sam2  # noqa: F401 — registers Hydra `configs`
    from hydra import compose
    from omegaconf import OmegaConf

    arch_h = compose(config_name='auralfuser/architecture.yaml')
    OmegaConf.resolve(arch_h)
    hyp_param.aural_fuser = OmegaConf.to_container(arch_h.aural_fuser, resolve=True)

    train_cfg = compose(config_name='training/sam2_training_config.yaml')
    OmegaConf.resolve(train_cfg)
    hyp_param.contrastive_learning = OmegaConf.to_container(train_cfg.contrastive_learning, resolve=True)

    from model.mymodel import AVmodel
    av_model = AVmodel(hyp_param).cuda()
    torch.cuda.set_device(hyp_param.local_rank)
    ckpt_sd = torch.load(hyp_param.inference_ckpt, map_location="cpu")
    if not isinstance(ckpt_sd, dict):
        raise TypeError("Checkpoint must be a state_dict dictionary.")
    # Same as v1s/v2: full-model ckpt vs train-only aural_fuser ckpt (e.g. keys vgg.*, f_blocks.*).
    if any(k.startswith("v_model.") or k.startswith("aural_fuser.") for k in ckpt_sd.keys()):
        av_model.load_state_dict(ckpt_sd, strict=True)
    else:
        av_model.aural_fuser.load_state_dict(ckpt_sd, strict=True)

    av_model = torch.nn.parallel.distributed.DistributedDataParallel(av_model, device_ids=[hyp_param.local_rank],
                                                                     find_unused_parameters=False)
    av_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(av_model)
    av_model.eval()

    from dataloader.dataset import AV
    from dataloader.visual.visual_augmentation import Augmentation as VisualAugmentation
    from dataloader.audio.audio_augmentation import Augmentation as AudioAugmentation
    from torch.utils.data import DataLoader, Subset
    from torch.utils.data.distributed import DistributedSampler

    visual_augmentation = VisualAugmentation(hyp_param.image_mean, hyp_param.image_std,
                                             hyp_param.image_size, hyp_param.image_size,
                                             hyp_param.scale_list, ignore_index=hyp_param.ignore_index)
    audio_augmentation = AudioAugmentation(mono=True)

    dataset = AV(split='test', augmentation={"visual": visual_augmentation, "audio": audio_augmentation},
                 param=hyp_param, root_path=hyp_param.data_root_path, data_name=hyp_param.inference_data_name)

    max_batches = getattr(hyp_param, "inference_max_batches", 0) or 0
    if max_batches > 0:
        n_samples = min(max_batches * hyp_param.batch_size, len(dataset))
        dataset = Subset(dataset, range(n_samples))

    sampler = DistributedSampler(dataset, shuffle=False)
    test_dataloader = DataLoader(dataset, batch_size=hyp_param.batch_size, sampler=sampler,
                                 num_workers=hyp_param.num_workers)

    from trainer.train import Trainer
    from utils.foreground_iou import ForegroundIoU
    from utils.foreground_fscore import ForegroundFScore

    metrics = {
        "foreground_iou": ForegroundIoU(),
        "foreground_f-score": ForegroundFScore(hyp_param.local_rank),
    }
    trainer = Trainer(hyp_param, loss=None, tensorboard=_DummyTensorboard(), metrics=metrics)

    # Same three modes as main.py validation: default first mask / iou_select / iou_occ_select
    runs = [
        ("", "default (logits[:,0])"),
        ("iou_select", "iou_select"),
        ("iou_occ_select", "iou_occ_select"),
    ]
    results = []
    for process, label in runs:
        fiou, ffscore = trainer.valid(epoch=0, dataloader=test_dataloader, model=av_model, process=process)
        results.append((label, fiou, ffscore))
        torch.cuda.empty_cache()

    if hyp_param.local_rank <= 0:
        print("\n========== inference (same three process flags as training valid) ==========")
        for label, fiou, ffscore in results:
            print("  {:32s}  f_iou={}  f_f-score={}".format(label, fiou, ffscore))
        print("=======================================================\n")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Inference: full test set + three process modes')

    parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N')

    parser.add_argument("--local_rank", type=int, default=-1,
                        help='multi-process training for DDP')

    parser.add_argument('-g', '--gpus', default=1, type=int,
                        help='number of gpus per node')

    parser.add_argument('--batch_size', default=1, type=int,
                        help='Batch size (match training if needed)')

    parser.add_argument('--epochs', default=80, type=int,
                        help="unused")

    parser.add_argument('--lr', default=1e-5, type=float,
                        help="unused")

    parser.add_argument('--online', action="store_true",
                        help='unused')

    parser.add_argument(
        '--inference_ckpt', type=str, default=None,
        help='Trained AuralSAM2 checkpoint (.pth state_dict). '
             'SAM2 backbone is loaded from backbone_weight in configs (same path as training: repo_root/ckpts/sam_ckpts/). '
             'Default if unset: avs.code/training_details/.../hiera_l.pth',
    )
    parser.add_argument('--inference_data_name', type=str, default='v1m',
                        help='AVSBench subset folder label (v1s|v1m|v2); must match training test split')
    parser.add_argument('--inference_max_batches', type=int, default=0,
                        help='0 = full test; >0 = first N batches only (debug)')

    args = parser.parse_args()

    from configs.config import C

    args = EasyDict({**C, **vars(args)})

    _repo = pathlib.Path(__file__).resolve().parent
    # Repo root: .../AuralSAM2 (parent of avs.code)
    _workspace = _repo.parent.parent
    args.data_root_path = str(_workspace / 'AVSBench')
    args.backbone_weight = str(_workspace / 'ckpts' / 'sam_ckpts' / 'sam2_hiera_large.pt')
    args.audio.PRETRAINED_VGGISH_MODEL_PATH = str(_workspace / 'ckpts' / 'vggish-10086976.pth')
    args.saved_dir = '/tmp/v1m_infer_ckpt'
    pathlib.Path(args.saved_dir).mkdir(parents=True, exist_ok=True)
    if args.inference_ckpt is None:
        args.inference_ckpt = str(
            _repo.parent / 'training_details' / 'v1m' / 'hiera_l' / 'hiera_l.pth'
        )

    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '9901'

    torch.multiprocessing.spawn(main, nprocs=args.gpus, args=(args.gpus, args))