| import os.path
|
| import logging
|
| import torch
|
| import argparse
|
| import json
|
| import glob
|
|
|
| from pprint import pprint
|
| from fvcore.nn import FlopCountAnalysis
|
| from utils.model_summary import get_model_activation, get_model_flops
|
| from utils import utils_logger
|
| from utils import utils_image as util
|
|
|
|
|
| def select_model(args, device):
|
|
|
|
|
| model_id = args.model_id
|
| if model_id == 0:
|
|
|
|
|
|
|
|
|
|
|
| from models.team00_EFDN import EFDN
|
| name, data_range = f"{model_id:02}_EFDN_baseline", 1.0
|
| model_path = os.path.join('model_zoo', 'team00_EFDN.pth')
|
| model = EFDN()
|
| model.load_state_dict(torch.load(model_path), strict=True)
|
| elif model_id == 23:
|
| from models.team23_DSCF import DSCF
|
|
|
| name, data_range = f"{model_id:02}_DSCF", 1.0
|
| model_path = os.path.join('model_zoo', 'team23_DSCF.pth')
|
| model = DSCF(3,3,feature_channels=26,upscale=4)
|
| state_dict = torch.load(model_path)
|
|
|
| model.load_state_dict(state_dict, strict=False)
|
| else:
|
| raise NotImplementedError(f"Model {model_id} is not implemented.")
|
|
|
|
|
| model.eval()
|
| tile = None
|
| for k, v in model.named_parameters():
|
| v.requires_grad = False
|
| model = model.to(device)
|
| return model, name, data_range, tile
|
|
|
|
|
| def select_dataset(data_dir, mode):
|
|
|
| if mode == "test":
|
| path = [
|
| (
|
| p.replace("_HR", "_LR").replace(".png", "x4.png"),
|
| p
|
| ) for p in sorted(glob.glob(os.path.join(data_dir, "DIV2K_LSDIR_test_HR/*.png")))
|
| ]
|
|
|
|
|
| elif mode == "valid":
|
| path = [
|
| (
|
| p.replace("_HR", "_LR").replace(".png", "x4.png"),
|
| p
|
| ) for p in sorted(glob.glob(os.path.join(data_dir, "DIV2K_LSDIR_valid_HR/*.png")))
|
| ]
|
| else:
|
| raise NotImplementedError(f"{mode} is not implemented in select_dataset")
|
|
|
| return path
|
|
|
|
|
| def forward(img_lq, model, tile=None, tile_overlap=32, scale=4):
|
| if tile is None:
|
|
|
| output = model(img_lq)
|
| else:
|
|
|
| b, c, h, w = img_lq.size()
|
| tile = min(tile, h, w)
|
| tile_overlap = tile_overlap
|
| sf = scale
|
|
|
| stride = tile - tile_overlap
|
| h_idx_list = list(range(0, h-tile, stride)) + [h-tile]
|
| w_idx_list = list(range(0, w-tile, stride)) + [w-tile]
|
| E = torch.zeros(b, c, h*sf, w*sf).type_as(img_lq)
|
| W = torch.zeros_like(E)
|
|
|
| for h_idx in h_idx_list:
|
| for w_idx in w_idx_list:
|
| in_patch = img_lq[..., h_idx:h_idx+tile, w_idx:w_idx+tile]
|
| out_patch = model(in_patch)
|
| out_patch_mask = torch.ones_like(out_patch)
|
|
|
| E[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch)
|
| W[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch_mask)
|
| output = E.div_(W)
|
|
|
| return output
|
|
|
| def run(model, model_name, data_range, tile, logger, device, args, mode="test"):
|
|
|
| sf = 4
|
| border = sf
|
| results = dict()
|
| results[f"{mode}_runtime"] = []
|
| results[f"{mode}_psnr"] = []
|
| if args.ssim:
|
| results[f"{mode}_ssim"] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
| data_path = select_dataset(args.data_dir, mode)
|
| save_path = os.path.join(args.save_dir, model_name, mode)
|
| util.mkdir(save_path)
|
|
|
| start = torch.cuda.Event(enable_timing=True)
|
| end = torch.cuda.Event(enable_timing=True)
|
|
|
| for i, (img_lr, img_hr) in enumerate(data_path):
|
|
|
|
|
|
|
|
|
| img_name, ext = os.path.splitext(os.path.basename(img_hr))
|
| img_lr = util.imread_uint(img_lr, n_channels=3)
|
| img_lr = util.uint2tensor4(img_lr, data_range)
|
| img_lr = img_lr.to(device)
|
|
|
|
|
|
|
|
|
| start.record()
|
| img_sr = forward(img_lr, model, tile)
|
| end.record()
|
| torch.cuda.synchronize()
|
| results[f"{mode}_runtime"].append(start.elapsed_time(end))
|
| img_sr = util.tensor2uint(img_sr, data_range)
|
|
|
|
|
|
|
|
|
| img_hr = util.imread_uint(img_hr, n_channels=3)
|
| img_hr = img_hr.squeeze()
|
| img_hr = util.modcrop(img_hr, sf)
|
|
|
|
|
|
|
|
|
|
|
|
|
| psnr = util.calculate_psnr(img_sr, img_hr, border=border)
|
| results[f"{mode}_psnr"].append(psnr)
|
|
|
| if args.ssim:
|
| ssim = util.calculate_ssim(img_sr, img_hr, border=border)
|
| results[f"{mode}_ssim"].append(ssim)
|
| logger.info("{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.".format(img_name + ext, psnr, ssim))
|
| else:
|
| logger.info("{:s} - PSNR: {:.2f} dB".format(img_name + ext, psnr))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| results[f"{mode}_memory"] = torch.cuda.max_memory_allocated(torch.cuda.current_device()) / 1024 ** 2
|
| results[f"{mode}_ave_runtime"] = sum(results[f"{mode}_runtime"]) / len(results[f"{mode}_runtime"])
|
| results[f"{mode}_ave_psnr"] = sum(results[f"{mode}_psnr"]) / len(results[f"{mode}_psnr"])
|
| if args.ssim:
|
| results[f"{mode}_ave_ssim"] = sum(results[f"{mode}_ssim"]) / len(results[f"{mode}_ssim"])
|
|
|
|
|
| logger.info("{:>16s} : {:<.3f} [M]".format("Max Memory", results[f"{mode}_memory"]))
|
| logger.info("------> Average runtime of ({}) is : {:.6f} milliseconds".format("test" if mode == "test" else "valid", results[f"{mode}_ave_runtime"]))
|
| logger.info("------> Average PSNR of ({}) is : {:.6f} dB".format("test" if mode == "test" else "valid", results[f"{mode}_ave_psnr"]))
|
|
|
| return results
|
|
|
|
|
| def main(args):
|
|
|
| utils_logger.logger_info("NTIRE2025-EfficientSR", log_path="NTIRE2025-EfficientSR.log")
|
| logger = logging.getLogger("NTIRE2025-EfficientSR")
|
|
|
|
|
|
|
|
|
| torch.cuda.current_device()
|
| torch.cuda.empty_cache()
|
| torch.backends.cudnn.benchmark = False
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
| json_dir = os.path.join(os.getcwd(), "results.json")
|
| if not os.path.exists(json_dir):
|
| results = dict()
|
| else:
|
| with open(json_dir, "r") as f:
|
| results = json.load(f)
|
|
|
|
|
|
|
|
|
| model, model_name, data_range, tile = select_model(args, device)
|
| logger.info(model_name)
|
|
|
|
|
| if True:
|
|
|
|
|
|
|
|
|
|
|
| valid_results = run(model, model_name, data_range, tile, logger, device, args, mode="valid")
|
|
|
| results[model_name] = valid_results
|
|
|
|
|
| if args.include_test:
|
| test_results = run(model, model_name, data_range, tile, logger, device, args, mode="test")
|
| results[model_name].update(test_results)
|
|
|
| input_dim = (3, 256, 256)
|
| activations, num_conv = get_model_activation(model, input_dim)
|
| activations = activations/10**6
|
| logger.info("{:>16s} : {:<.4f} [M]".format("#Activations", activations))
|
| logger.info("{:>16s} : {:<d}".format("#Conv2d", num_conv))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| input_fake = torch.rand(1, 3, 256, 256).to(device)
|
| flops = FlopCountAnalysis(model, input_fake).total()
|
| flops = flops/10**9
|
| logger.info("{:>16s} : {:<.4f} [G]".format("FLOPs", flops))
|
|
|
| num_parameters = sum(map(lambda x: x.numel(), model.parameters()))
|
| num_parameters = num_parameters/10**6
|
| logger.info("{:>16s} : {:<.4f} [M]".format("#Params", num_parameters))
|
| results[model_name].update({"activations": activations, "num_conv": num_conv, "flops": flops, "num_parameters": num_parameters})
|
|
|
| with open(json_dir, "w") as f:
|
| json.dump(results, f)
|
| if args.include_test:
|
| fmt = "{:20s}\t{:10s}\t{:10s}\t{:14s}\t{:14s}\t{:14s}\t{:10s}\t{:10s}\t{:8s}\t{:8s}\t{:8s}\n"
|
| s = fmt.format("Model", "Val PSNR", "Test PSNR", "Val Time [ms]", "Test Time [ms]", "Ave Time [ms]",
|
| "Params [M]", "FLOPs [G]", "Acts [M]", "Mem [M]", "Conv")
|
| else:
|
| fmt = "{:20s}\t{:10s}\t{:14s}\t{:10s}\t{:10s}\t{:8s}\t{:8s}\t{:8s}\n"
|
| s = fmt.format("Model", "Val PSNR", "Val Time [ms]", "Params [M]", "FLOPs [G]", "Acts [M]", "Mem [M]", "Conv")
|
| for k, v in results.items():
|
| val_psnr = f"{v['valid_ave_psnr']:2.2f}"
|
| val_time = f"{v['valid_ave_runtime']:3.2f}"
|
| mem = f"{v['valid_memory']:2.2f}"
|
|
|
| num_param = f"{v['num_parameters']:2.3f}"
|
| flops = f"{v['flops']:2.2f}"
|
| acts = f"{v['activations']:2.2f}"
|
| conv = f"{v['num_conv']:4d}"
|
| if args.include_test:
|
|
|
| test_psnr = f"{v['test_ave_psnr']:2.2f}"
|
| test_time = f"{v['test_ave_runtime']:3.2f}"
|
| ave_time = f"{(v['valid_ave_runtime'] + v['test_ave_runtime']) / 2:3.2f}"
|
| s += fmt.format(k, val_psnr, test_psnr, val_time, test_time, ave_time, num_param, flops, acts, mem, conv)
|
| else:
|
| s += fmt.format(k, val_psnr, val_time, num_param, flops, acts, mem, conv)
|
| with open(os.path.join(os.getcwd(), 'results.txt'), "w") as f:
|
| f.write(s)
|
|
|
|
|
| if __name__ == "__main__":
|
| parser = argparse.ArgumentParser("NTIRE2025-EfficientSR")
|
| parser.add_argument("--data_dir", default="../", type=str)
|
| parser.add_argument("--save_dir", default="../results", type=str)
|
| parser.add_argument("--model_id", default=0, type=int)
|
| parser.add_argument("--include_test", action="store_true", help="Inference on the `DIV2K_LSDIR_test` set")
|
| parser.add_argument("--ssim", action="store_true", help="Calculate SSIM")
|
|
|
| args = parser.parse_args()
|
| pprint(args)
|
|
|
| main(args)
|
|
|