File size: 7,801 Bytes
c6dfc69 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 | """Distributed inference on the test set; runs the same three `process` modes as training validation."""
import os
import pathlib
import torch
import numpy
import random
import argparse
from easydict import EasyDict
# Avoid import failure when configs.config creates saved_dir without write permission.
_real_mkdir = pathlib.Path.mkdir
def _safe_mkdir(self, mode=0o777, parents=False, exist_ok=False):
try:
return _real_mkdir(self, mode, parents=parents, exist_ok=exist_ok)
except PermissionError:
pass
pathlib.Path.mkdir = _safe_mkdir
def seed_it(seed):
random.seed(seed)
os.environ["PYTHONSEED"] = str(seed)
numpy.random.seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = True
torch.manual_seed(seed)
class _DummyTensorboard:
"""Minimal Tensorboard stub so Trainer.valid runs without wandb logging."""
def upload_wandb_info(self, info_dict):
pass
def upload_wandb_image(self, *args, **kwargs):
pass
def main(local_rank, ngpus_per_node, hyp_param):
hyp_param.local_rank = local_rank
torch.distributed.init_process_group(
backend='nccl',
init_method='env://',
rank=hyp_param.local_rank,
world_size=hyp_param.gpus * 1
)
seed_it(local_rank + hyp_param.seed)
import model.visual.sam2 # noqa: F401 — registers Hydra `configs`
from hydra import compose
from omegaconf import OmegaConf
arch_h = compose(config_name='auralfuser/architecture.yaml')
OmegaConf.resolve(arch_h)
hyp_param.aural_fuser = OmegaConf.to_container(arch_h.aural_fuser, resolve=True)
train_cfg = compose(config_name='training/sam2_training_config.yaml')
OmegaConf.resolve(train_cfg)
hyp_param.contrastive_learning = OmegaConf.to_container(train_cfg.contrastive_learning, resolve=True)
from model.mymodel import AVmodel
av_model = AVmodel(hyp_param).cuda()
torch.cuda.set_device(hyp_param.local_rank)
ckpt_sd = torch.load(hyp_param.inference_ckpt, map_location="cpu")
if not isinstance(ckpt_sd, dict):
raise TypeError("Checkpoint must be a state_dict dictionary.")
# Support both formats:
# 1) full-model checkpoint (keys like `v_model.*`, `aural_fuser.*`)
# 2) train-only checkpoint for aural_fuser (keys without `aural_fuser.` prefix)
if any(k.startswith("v_model.") or k.startswith("aural_fuser.") for k in ckpt_sd.keys()):
av_model.load_state_dict(ckpt_sd, strict=True)
else:
av_model.aural_fuser.load_state_dict(ckpt_sd, strict=True)
av_model = torch.nn.parallel.distributed.DistributedDataParallel(av_model, device_ids=[hyp_param.local_rank],
find_unused_parameters=False)
av_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(av_model)
av_model.eval()
from dataloader.dataset import AV
from dataloader.visual.visual_augmentation import Augmentation as VisualAugmentation
from dataloader.audio.audio_augmentation import Augmentation as AudioAugmentation
from torch.utils.data import DataLoader, Subset
from torch.utils.data.distributed import DistributedSampler
visual_augmentation = VisualAugmentation(hyp_param.image_mean, hyp_param.image_std,
hyp_param.image_size, hyp_param.image_size,
hyp_param.scale_list, ignore_index=hyp_param.ignore_index)
audio_augmentation = AudioAugmentation(mono=True)
dataset = AV(split='test', augmentation={"visual": visual_augmentation, "audio": audio_augmentation},
param=hyp_param, root_path=hyp_param.data_root_path)
max_batches = getattr(hyp_param, "inference_max_batches", 0) or 0
if max_batches > 0:
n_samples = min(max_batches * hyp_param.batch_size, len(dataset))
dataset = Subset(dataset, range(n_samples))
sampler = DistributedSampler(dataset, shuffle=False)
test_dataloader = DataLoader(dataset, batch_size=hyp_param.batch_size, sampler=sampler,
num_workers=hyp_param.num_workers)
from trainer.train import Trainer
from utils.foreground_iou import ForegroundIoU
from utils.foreground_fscore import ForegroundFScore
metrics = {
"foreground_iou": ForegroundIoU(),
"foreground_f-score": ForegroundFScore(hyp_param.local_rank),
}
trainer = Trainer(hyp_param, loss=None, tensorboard=_DummyTensorboard(), metrics=metrics)
# Same three modes as main.py validation: default first mask / iou_select / iou_occ_select
runs = [
("", "default (logits[:,0])"),
("iou_select", "iou_select"),
("iou_occ_select", "iou_occ_select"),
]
results = []
for process, label in runs:
fiou, ffscore = trainer.valid(epoch=0, dataloader=test_dataloader, model=av_model, process=process)
results.append((label, fiou, ffscore))
torch.cuda.empty_cache()
if hyp_param.local_rank <= 0:
print("\n========== inference (same three process flags as training valid) ==========")
for label, fiou, ffscore in results:
print(" {:32s} f_iou={} f_f-score={}".format(label, fiou, ffscore))
print("=======================================================\n")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Inference: full test set + three process modes')
parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N')
parser.add_argument("--local_rank", type=int, default=-1,
help='multi-process training for DDP')
parser.add_argument('-g', '--gpus', default=1, type=int,
help='number of gpus per node')
parser.add_argument('--batch_size', default=1, type=int,
help='Batch size (match training if needed)')
parser.add_argument('--epochs', default=80, type=int,
help="unused")
parser.add_argument('--lr', default=1e-5, type=float,
help="unused")
parser.add_argument('--online', action="store_true",
help='unused')
parser.add_argument(
'--inference_ckpt', type=str, default=None,
help='Trained AuralSAM2 checkpoint (.pth state_dict: full model or aural_fuser-only). '
'SAM2 backbone is loaded from backbone_weight in configs (same path as training: repo_root/ckpts/sam_ckpts/). '
'Default if unset: avs.code/training_details/.../hiera_l.pth',
)
parser.add_argument('--inference_max_batches', type=int, default=0,
help='0 = full test; >0 = first N batches only (debug)')
args = parser.parse_args()
from configs.config import C
args = EasyDict({**C, **vars(args)})
_repo = pathlib.Path(__file__).resolve().parent
# Repo root: .../AuralSAM2 (parent of avs.code)
_workspace = _repo.parent.parent
args.data_root_path = str(_workspace / 'AVSBench')
args.backbone_weight = str(_workspace / 'ckpts' / 'sam_ckpts' / 'sam2_hiera_large.pt')
args.audio.PRETRAINED_VGGISH_MODEL_PATH = str(_workspace / 'ckpts' / 'vggish-10086976.pth')
args.saved_dir = '/tmp/v2_infer_ckpt'
pathlib.Path(args.saved_dir).mkdir(parents=True, exist_ok=True)
if args.inference_ckpt is None:
args.inference_ckpt = str(
_repo.parent / 'training_details' / 'v2' / 'hiera_l' / 'hiera_l.pth'
)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '9901'
torch.multiprocessing.spawn(main, nprocs=args.gpus, args=(args.gpus, args))
|