File size: 5,986 Bytes
c6dfc69 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | """DDP training: frozen SAM2 + text, trainable AuralFuser (Ref-AVS)."""
import os
import argparse
import random
import numpy
import torch
from easydict import EasyDict
def seed_it(seed):
os.environ["PYTHONSEED"] = str(seed)
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def main(local_rank, ngpus_per_node, hyp_param):
hyp_param.local_rank = local_rank
torch.distributed.init_process_group(
backend='nccl', init_method='env://',
rank=local_rank, world_size=hyp_param.gpus,
)
seed_it(local_rank + hyp_param.seed)
torch.cuda.set_device(local_rank)
import model.visual.sam2 # noqa: F401 — registers Hydra config store
from hydra import compose
from hydra.utils import instantiate
from omegaconf import OmegaConf
cfg = compose(config_name='configs/training/sam2_training_config.yaml')
OmegaConf.resolve(cfg)
hyp_param.contrastive_learning = OmegaConf.to_container(cfg.contrastive_learning, resolve=True)
arch_h = compose(config_name='configs/auralfuser/architecture.yaml')
OmegaConf.resolve(arch_h)
hyp_param.aural_fuser = OmegaConf.to_container(arch_h.aural_fuser, resolve=True)
hyp_param.image_size = 1024
hyp_param.image_embedding_size = int(hyp_param.image_size / 16)
from model.mymodel import AVmodel
av_model = AVmodel(hyp_param).cuda(local_rank)
av_model = torch.nn.parallel.DistributedDataParallel(
av_model, device_ids=[local_rank], find_unused_parameters=True,
)
from utils.utils import manipulate_params
optimiser = torch.optim.AdamW(manipulate_params(hyp_param, av_model.module.aural_fuser), betas=(0.9, 0.999))
from dataloader.dataset import AV
from dataloader.visual.visual_augmentation import Augmentation as VisualAugmentation
from dataloader.audio.audio_augmentation import Augmentation as AudioAugmentation
from torch.utils.data.distributed import DistributedSampler
compose_api = instantiate(cfg.train_transforms, _recursive_=True)[0]
audio_aug = AudioAugmentation(mono=True)
train_dataset = AV(
split='train',
augmentation={"visual": compose_api, "audio": audio_aug},
param=hyp_param,
root_path=hyp_param.data_root_path,
)
visual_aug = VisualAugmentation(
hyp_param.image_mean, hyp_param.image_std,
hyp_param.image_size, hyp_param.image_size,
hyp_param.scale_list, ignore_index=hyp_param.ignore_index,
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=hyp_param.batch_size,
sampler=DistributedSampler(train_dataset, shuffle=True),
num_workers=hyp_param.num_workers,
drop_last=True,
)
def _test_loader(split):
ds = AV(split=split, augmentation={"visual": visual_aug, "audio": audio_aug},
param=hyp_param, root_path=hyp_param.data_root_path)
return torch.utils.data.DataLoader(
ds, batch_size=4,
sampler=DistributedSampler(ds, shuffle=False),
num_workers=hyp_param.num_workers,
)
test_s_loader = _test_loader('test_s')
test_u_loader = _test_loader('test_u')
test_n_loader = _test_loader('test_n')
criterion = instantiate(cfg.loss, _recursive_=True)['all']
from utils.tensorboard import Tensorboard
tensorboard = Tensorboard(config=hyp_param) if local_rank <= 0 else None
from trainer.train import Trainer
from utils.foreground_iou import ForegroundIoU
from utils.foreground_fscore import ForegroundFScore
from utils.foreground_s import ForegroundS
metrics = {
"foreground_iou": ForegroundIoU(),
"foreground_f-score": ForegroundFScore(0 if local_rank <= 0 else local_rank),
"foreground_s": ForegroundS(),
}
trainer = Trainer(hyp_param, loss=criterion, tensorboard=tensorboard, metrics=metrics)
test_s_best, test_u_best = 0.2, 0.2
for epoch in range(hyp_param.epochs + 1):
av_model.train()
av_model.module.freeze_sam_parameters()
train_loader.sampler.set_epoch(epoch)
trainer.train(epoch=epoch, dataloader=train_loader, model=av_model, optimiser=optimiser)
torch.distributed.barrier()
torch.cuda.empty_cache()
av_model.eval()
test_s, _ = trainer.valid(epoch=epoch, dataloader=test_s_loader, model=av_model, process='test_s')
test_u, _ = trainer.valid(epoch=epoch, dataloader=test_u_loader, model=av_model, process='test_u')
trainer.valid_null(epoch=epoch, dataloader=test_n_loader, model=av_model, process='test_n')
if local_rank <= 0 and (test_s > test_s_best or test_u > test_u_best):
test_s_best = max(test_s, test_s_best)
test_u_best = max(test_u, test_u_best)
torch.save(
av_model.module.aural_fuser.state_dict(),
os.path.join(
hyp_param.saved_dir,
f's({float(test_s)})_u({float(test_u)}).pth',
),
)
torch.distributed.barrier()
torch.cuda.empty_cache()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Ref-AVS training')
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('-g', '--gpus', default=1, type=int)
parser.add_argument('--batch_size', default=1, type=int)
parser.add_argument('--epochs', default=80, type=int)
parser.add_argument('--lr', default=5e-4, type=float)
args = parser.parse_args()
from configs.config import C
args = EasyDict({**C, **vars(args)})
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '9901'
torch.multiprocessing.spawn(main, nprocs=args.gpus, args=(args.gpus, args))
|