| import torch |
| from torch import nn |
| import torch.multiprocessing as mp |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from argparse import ArgumentParser |
| import os |
|
|
| current_dir = os.path.abspath(os.path.dirname(__file__)) |
|
|
| from datasets import standardize_dataset_name |
| from models import get_model |
| from utils import get_config, get_dataloader, setup, cleanup |
| from evaluate import evaluate |
|
|
|
|
| parser = ArgumentParser(description="Test a trained model on a dataset.") |
| |
| parser.add_argument("--weight_path", type=str, required=True, help="The name of the weight to use.") |
| parser.add_argument("--output_filename", type=str, default=None, help="The name of the result file.") |
|
|
| |
| parser.add_argument("--dataset", type=str, required=True, help="The dataset to evaluate on.") |
| parser.add_argument("--split", type=str, default="val", choices=["val", "test"], help="The split to evaluate on.") |
| parser.add_argument("--input_size", type=int, default=224, help="The size of the input image.") |
| parser.add_argument("--sliding_window", action="store_true", help="Use sliding window strategy for evaluation.") |
| parser.add_argument("--max_input_size", type=int, default=4096, help="The maximum size of the input image in evaluation. Images larger than this will be processed using sliding window by force to avoid OOM.") |
| parser.add_argument("--max_num_windows", type=int, default=8, help="The maximum number of windows to be simultaneously processed.") |
| parser.add_argument("--resize_to_multiple", action="store_true", help="Resize the image to the nearest multiple of the input size.") |
| parser.add_argument("--stride", type=int, default=None, help="The stride for sliding window strategy.") |
| parser.add_argument("--amp", action="store_true", help="Use automatic mixed precision for evaluation.") |
| parser.add_argument("--device", type=str, default="cuda", help="The device to use for evaluation.") |
| parser.add_argument("--num_workers", type=int, default=8, help="The number of workers for the data loader.") |
| parser.add_argument("--local_rank", type=int, default=-1, help="The local rank for distributed training.") |
|
|
|
|
| def run(local_rank: int, nprocs: int, args: ArgumentParser): |
| print(f"Rank {local_rank} process among {nprocs} processes.") |
| setup(local_rank, nprocs) |
| print(f"Initialized successfully. Training with {nprocs} GPUs.") |
| device = f"cuda:{local_rank}" if local_rank != -1 else "cuda:0" |
| print(f"Using device: {device}.") |
|
|
| ddp = nprocs > 1 |
| _ = get_config(vars(args).copy(), mute=False) |
|
|
| model = get_model(model_info_path=args.weight_path).to(device) |
| model = DDP(nn.SyncBatchNorm.convert_sync_batchnorm(model), device_ids=[local_rank], output_device=local_rank) if ddp else model |
| model = model.to(device) |
| model.eval() |
|
|
| args.output_filename = f"{model.model_name}_{args.weight_path.split('/')[-1].split('.')[0]}" if args.output_filename is None else args.output_filename |
|
|
| dataloader = get_dataloader(args, split=args.split) |
| scores = evaluate( |
| model=model, |
| data_loader=dataloader, |
| sliding_window=args.sliding_window, |
| max_input_size=args.max_input_size, |
| window_size=args.input_size, |
| stride=args.stride, |
| max_num_windows=args.max_num_windows, |
| amp=args.amp, |
| local_rank=local_rank, |
| nprocs=nprocs, |
| ) |
|
|
| if local_rank == 0: |
| for k, v in scores.items(): |
| print(f"{k}: {v}") |
|
|
| result_dir = os.path.join(current_dir, "results", args.dataset, args.split) |
| os.makedirs(result_dir, exist_ok=True) |
| with open(os.path.join(result_dir, f"{args.output_filename}.txt"), "w") as f: |
| for k, v in scores.items(): |
| f.write(f"{k}: {v}\n") |
| |
| cleanup(ddp) |
|
|
|
|
| if __name__ == "__main__": |
| args = parser.parse_args() |
| args.dataset = standardize_dataset_name(args.dataset) |
|
|
| if args.dataset in ["sha", "shb", "qnrf", "nwpu"]: |
| assert args.split == "val", f"Split {args.split} is not available for dataset {args.dataset}." |
|
|
| |
| args.stride = args.stride or args.input_size |
| assert os.path.exists(args.weight_path), f"Weight path {args.weight_path} does not exist." |
| args.in_memory_dataset = False |
|
|
| args.nprocs = torch.cuda.device_count() |
| print(f"Using {args.nprocs} GPUs.") |
| if args.nprocs > 1: |
| mp.spawn(run, nprocs=args.nprocs, args=(args.nprocs, args)) |
| else: |
| run(0, 1, args) |
|
|