| """DDP training: frozen SAM2 + text, trainable AuralFuser (Ref-AVS).""" |
| import os |
| import argparse |
| import random |
|
|
| import numpy |
| import torch |
| from easydict import EasyDict |
|
|
|
|
| def seed_it(seed): |
| os.environ["PYTHONSEED"] = str(seed) |
| random.seed(seed) |
| numpy.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.enabled = True |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
|
|
| def main(local_rank, ngpus_per_node, hyp_param): |
| hyp_param.local_rank = local_rank |
| torch.distributed.init_process_group( |
| backend='nccl', init_method='env://', |
| rank=local_rank, world_size=hyp_param.gpus, |
| ) |
| seed_it(local_rank + hyp_param.seed) |
| torch.cuda.set_device(local_rank) |
|
|
| import model.visual.sam2 |
|
|
| from hydra import compose |
| from hydra.utils import instantiate |
| from omegaconf import OmegaConf |
|
|
| cfg = compose(config_name='configs/training/sam2_training_config.yaml') |
| OmegaConf.resolve(cfg) |
| hyp_param.contrastive_learning = OmegaConf.to_container(cfg.contrastive_learning, resolve=True) |
|
|
| arch_h = compose(config_name='configs/auralfuser/architecture.yaml') |
| OmegaConf.resolve(arch_h) |
| hyp_param.aural_fuser = OmegaConf.to_container(arch_h.aural_fuser, resolve=True) |
|
|
| hyp_param.image_size = 1024 |
| hyp_param.image_embedding_size = int(hyp_param.image_size / 16) |
|
|
| from model.mymodel import AVmodel |
| av_model = AVmodel(hyp_param).cuda(local_rank) |
| av_model = torch.nn.parallel.DistributedDataParallel( |
| av_model, device_ids=[local_rank], find_unused_parameters=True, |
| ) |
|
|
| from utils.utils import manipulate_params |
| optimiser = torch.optim.AdamW(manipulate_params(hyp_param, av_model.module.aural_fuser), betas=(0.9, 0.999)) |
|
|
| from dataloader.dataset import AV |
| from dataloader.visual.visual_augmentation import Augmentation as VisualAugmentation |
| from dataloader.audio.audio_augmentation import Augmentation as AudioAugmentation |
| from torch.utils.data.distributed import DistributedSampler |
|
|
| compose_api = instantiate(cfg.train_transforms, _recursive_=True)[0] |
| audio_aug = AudioAugmentation(mono=True) |
| train_dataset = AV( |
| split='train', |
| augmentation={"visual": compose_api, "audio": audio_aug}, |
| param=hyp_param, |
| root_path=hyp_param.data_root_path, |
| ) |
|
|
| visual_aug = VisualAugmentation( |
| hyp_param.image_mean, hyp_param.image_std, |
| hyp_param.image_size, hyp_param.image_size, |
| hyp_param.scale_list, ignore_index=hyp_param.ignore_index, |
| ) |
| train_loader = torch.utils.data.DataLoader( |
| train_dataset, |
| batch_size=hyp_param.batch_size, |
| sampler=DistributedSampler(train_dataset, shuffle=True), |
| num_workers=hyp_param.num_workers, |
| drop_last=True, |
| ) |
|
|
| def _test_loader(split): |
| ds = AV(split=split, augmentation={"visual": visual_aug, "audio": audio_aug}, |
| param=hyp_param, root_path=hyp_param.data_root_path) |
| return torch.utils.data.DataLoader( |
| ds, batch_size=4, |
| sampler=DistributedSampler(ds, shuffle=False), |
| num_workers=hyp_param.num_workers, |
| ) |
|
|
| test_s_loader = _test_loader('test_s') |
| test_u_loader = _test_loader('test_u') |
| test_n_loader = _test_loader('test_n') |
|
|
| criterion = instantiate(cfg.loss, _recursive_=True)['all'] |
|
|
| from utils.tensorboard import Tensorboard |
| tensorboard = Tensorboard(config=hyp_param) if local_rank <= 0 else None |
|
|
| from trainer.train import Trainer |
| from utils.foreground_iou import ForegroundIoU |
| from utils.foreground_fscore import ForegroundFScore |
| from utils.foreground_s import ForegroundS |
| metrics = { |
| "foreground_iou": ForegroundIoU(), |
| "foreground_f-score": ForegroundFScore(0 if local_rank <= 0 else local_rank), |
| "foreground_s": ForegroundS(), |
| } |
| trainer = Trainer(hyp_param, loss=criterion, tensorboard=tensorboard, metrics=metrics) |
|
|
| test_s_best, test_u_best = 0.2, 0.2 |
| for epoch in range(hyp_param.epochs + 1): |
| av_model.train() |
| av_model.module.freeze_sam_parameters() |
| train_loader.sampler.set_epoch(epoch) |
| trainer.train(epoch=epoch, dataloader=train_loader, model=av_model, optimiser=optimiser) |
|
|
| torch.distributed.barrier() |
| torch.cuda.empty_cache() |
|
|
| av_model.eval() |
| test_s, _ = trainer.valid(epoch=epoch, dataloader=test_s_loader, model=av_model, process='test_s') |
| test_u, _ = trainer.valid(epoch=epoch, dataloader=test_u_loader, model=av_model, process='test_u') |
| trainer.valid_null(epoch=epoch, dataloader=test_n_loader, model=av_model, process='test_n') |
|
|
| if local_rank <= 0 and (test_s > test_s_best or test_u > test_u_best): |
| test_s_best = max(test_s, test_s_best) |
| test_u_best = max(test_u, test_u_best) |
| torch.save( |
| av_model.module.aural_fuser.state_dict(), |
| os.path.join( |
| hyp_param.saved_dir, |
| f's({float(test_s)})_u({float(test_u)}).pth', |
| ), |
| ) |
|
|
| torch.distributed.barrier() |
| torch.cuda.empty_cache() |
|
|
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser(description='Ref-AVS training') |
| parser.add_argument('--local_rank', type=int, default=-1) |
| parser.add_argument('-g', '--gpus', default=1, type=int) |
| parser.add_argument('--batch_size', default=1, type=int) |
| parser.add_argument('--epochs', default=80, type=int) |
| parser.add_argument('--lr', default=5e-4, type=float) |
| args = parser.parse_args() |
|
|
| from configs.config import C |
| args = EasyDict({**C, **vars(args)}) |
| os.environ['MASTER_ADDR'] = '127.0.0.1' |
| os.environ['MASTER_PORT'] = '9901' |
| torch.multiprocessing.spawn(main, nprocs=args.gpus, args=(args.gpus, args)) |
|
|