File size: 5,986 Bytes
c6dfc69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""DDP training: frozen SAM2 + text, trainable AuralFuser (Ref-AVS)."""
import os
import argparse
import random

import numpy
import torch
from easydict import EasyDict


def seed_it(seed):
    os.environ["PYTHONSEED"] = str(seed)
    random.seed(seed)
    numpy.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def main(local_rank, ngpus_per_node, hyp_param):
    hyp_param.local_rank = local_rank
    torch.distributed.init_process_group(
        backend='nccl', init_method='env://',
        rank=local_rank, world_size=hyp_param.gpus,
    )
    seed_it(local_rank + hyp_param.seed)
    torch.cuda.set_device(local_rank)

    import model.visual.sam2  # noqa: F401 — registers Hydra config store

    from hydra import compose
    from hydra.utils import instantiate
    from omegaconf import OmegaConf

    cfg = compose(config_name='configs/training/sam2_training_config.yaml')
    OmegaConf.resolve(cfg)
    hyp_param.contrastive_learning = OmegaConf.to_container(cfg.contrastive_learning, resolve=True)

    arch_h = compose(config_name='configs/auralfuser/architecture.yaml')
    OmegaConf.resolve(arch_h)
    hyp_param.aural_fuser = OmegaConf.to_container(arch_h.aural_fuser, resolve=True)

    hyp_param.image_size = 1024
    hyp_param.image_embedding_size = int(hyp_param.image_size / 16)

    from model.mymodel import AVmodel
    av_model = AVmodel(hyp_param).cuda(local_rank)
    av_model = torch.nn.parallel.DistributedDataParallel(
        av_model, device_ids=[local_rank], find_unused_parameters=True,
    )

    from utils.utils import manipulate_params
    optimiser = torch.optim.AdamW(manipulate_params(hyp_param, av_model.module.aural_fuser), betas=(0.9, 0.999))

    from dataloader.dataset import AV
    from dataloader.visual.visual_augmentation import Augmentation as VisualAugmentation
    from dataloader.audio.audio_augmentation import Augmentation as AudioAugmentation
    from torch.utils.data.distributed import DistributedSampler

    compose_api = instantiate(cfg.train_transforms, _recursive_=True)[0]
    audio_aug = AudioAugmentation(mono=True)
    train_dataset = AV(
        split='train',
        augmentation={"visual": compose_api, "audio": audio_aug},
        param=hyp_param,
        root_path=hyp_param.data_root_path,
    )

    visual_aug = VisualAugmentation(
        hyp_param.image_mean, hyp_param.image_std,
        hyp_param.image_size, hyp_param.image_size,
        hyp_param.scale_list, ignore_index=hyp_param.ignore_index,
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=hyp_param.batch_size,
        sampler=DistributedSampler(train_dataset, shuffle=True),
        num_workers=hyp_param.num_workers,
        drop_last=True,
    )

    def _test_loader(split):
        ds = AV(split=split, augmentation={"visual": visual_aug, "audio": audio_aug},
                param=hyp_param, root_path=hyp_param.data_root_path)
        return torch.utils.data.DataLoader(
            ds, batch_size=4,
            sampler=DistributedSampler(ds, shuffle=False),
            num_workers=hyp_param.num_workers,
        )

    test_s_loader = _test_loader('test_s')
    test_u_loader = _test_loader('test_u')
    test_n_loader = _test_loader('test_n')

    criterion = instantiate(cfg.loss, _recursive_=True)['all']

    from utils.tensorboard import Tensorboard
    tensorboard = Tensorboard(config=hyp_param) if local_rank <= 0 else None

    from trainer.train import Trainer
    from utils.foreground_iou import ForegroundIoU
    from utils.foreground_fscore import ForegroundFScore
    from utils.foreground_s import ForegroundS
    metrics = {
        "foreground_iou": ForegroundIoU(),
        "foreground_f-score": ForegroundFScore(0 if local_rank <= 0 else local_rank),
        "foreground_s": ForegroundS(),
    }
    trainer = Trainer(hyp_param, loss=criterion, tensorboard=tensorboard, metrics=metrics)

    test_s_best, test_u_best = 0.2, 0.2
    for epoch in range(hyp_param.epochs + 1):
        av_model.train()
        av_model.module.freeze_sam_parameters()
        train_loader.sampler.set_epoch(epoch)
        trainer.train(epoch=epoch, dataloader=train_loader, model=av_model, optimiser=optimiser)

        torch.distributed.barrier()
        torch.cuda.empty_cache()

        av_model.eval()
        test_s, _ = trainer.valid(epoch=epoch, dataloader=test_s_loader, model=av_model, process='test_s')
        test_u, _ = trainer.valid(epoch=epoch, dataloader=test_u_loader, model=av_model, process='test_u')
        trainer.valid_null(epoch=epoch, dataloader=test_n_loader, model=av_model, process='test_n')

        if local_rank <= 0 and (test_s > test_s_best or test_u > test_u_best):
            test_s_best = max(test_s, test_s_best)
            test_u_best = max(test_u, test_u_best)
            torch.save(
                av_model.module.aural_fuser.state_dict(),
                os.path.join(
                    hyp_param.saved_dir,
                    f's({float(test_s)})_u({float(test_u)}).pth',
                ),
            )

        torch.distributed.barrier()
        torch.cuda.empty_cache()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Ref-AVS training')
    parser.add_argument('--local_rank', type=int, default=-1)
    parser.add_argument('-g', '--gpus', default=1, type=int)
    parser.add_argument('--batch_size', default=1, type=int)
    parser.add_argument('--epochs', default=80, type=int)
    parser.add_argument('--lr', default=5e-4, type=float)
    args = parser.parse_args()

    from configs.config import C
    args = EasyDict({**C, **vars(args)})
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '9901'
    torch.multiprocessing.spawn(main, nprocs=args.gpus, args=(args.gpus, args))