| from argparse import ArgumentParser |
| import os |
| import json |
| import sys |
| from tqdm import tqdm |
| import numpy as np |
| import torch |
| from torch.utils.data import DataLoader |
| import torchvision.transforms as transforms |
|
|
| sys.path.append(".") |
| sys.path.append("..") |
|
|
| from criteria.lpips.lpips import LPIPS |
| from datasets.gt_res_dataset import GTResDataset |
|
|
|
|
| def parse_args(): |
| parser = ArgumentParser(add_help=False) |
| parser.add_argument('--mode', type=str, default='lpips', choices=['lpips', 'l2']) |
| parser.add_argument('--data_path', type=str, default='results') |
| parser.add_argument('--gt_path', type=str, default='gt_images') |
| parser.add_argument('--workers', type=int, default=4) |
| parser.add_argument('--batch_size', type=int, default=4) |
| parser.add_argument('--is_cars', action='store_true') |
| args = parser.parse_args() |
| return args |
|
|
|
|
| def run(args): |
| resize_dims = (256, 256) |
| if args.is_cars: |
| resize_dims = (192, 256) |
| transform = transforms.Compose([transforms.Resize(resize_dims), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) |
|
|
| print('Loading dataset') |
| dataset = GTResDataset(root_path=args.data_path, |
| gt_dir=args.gt_path, |
| transform=transform) |
|
|
| dataloader = DataLoader(dataset, |
| batch_size=args.batch_size, |
| shuffle=False, |
| num_workers=int(args.workers), |
| drop_last=True) |
|
|
| if args.mode == 'lpips': |
| loss_func = LPIPS(net_type='alex') |
| elif args.mode == 'l2': |
| loss_func = torch.nn.MSELoss() |
| else: |
| raise Exception('Not a valid mode!') |
| loss_func.cuda() |
|
|
| global_i = 0 |
| scores_dict = {} |
| all_scores = [] |
| for result_batch, gt_batch in tqdm(dataloader): |
| for i in range(args.batch_size): |
| loss = float(loss_func(result_batch[i:i + 1].cuda(), gt_batch[i:i + 1].cuda())) |
| all_scores.append(loss) |
| im_path = dataset.pairs[global_i][0] |
| scores_dict[os.path.basename(im_path)] = loss |
| global_i += 1 |
|
|
| all_scores = list(scores_dict.values()) |
| mean = np.mean(all_scores) |
| std = np.std(all_scores) |
| result_str = 'Average loss is {:.2f}+-{:.2f}'.format(mean, std) |
| print('Finished with ', args.data_path) |
| print(result_str) |
|
|
| out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics') |
| if not os.path.exists(out_path): |
| os.makedirs(out_path) |
|
|
| with open(os.path.join(out_path, 'stat_{}.txt'.format(args.mode)), 'w') as f: |
| f.write(result_str) |
| with open(os.path.join(out_path, 'scores_{}.json'.format(args.mode)), 'w') as f: |
| json.dump(scores_dict, f) |
|
|
|
|
| if __name__ == '__main__': |
| args = parse_args() |
| run(args) |
|
|