| |
| import argparse |
| import cv2 |
| import torch |
| import os, shutil, time |
| import sys |
| from multiprocessing import Process, Queue |
| from os import path as osp |
| from tqdm import tqdm |
| import copy |
| import warnings |
| import gc |
|
|
| warnings.filterwarnings("ignore") |
|
|
| |
| root_path = os.path.abspath('.') |
| sys.path.append(root_path) |
| from degradation.ESR.utils import np2tensor |
| from degradation.ESR.degradations_functionality import * |
| from degradation.ESR.diffjpeg import * |
| from degradation.degradation_esr import degradation_v1 |
| from opt import opt |
| os.environ['CUDA_VISIBLE_DEVICES'] = opt['CUDA_VISIBLE_DEVICES'] |
|
|
|
|
|
|
| def crop_process(path, crop_size, lr_dataset_path, output_index): |
| ''' crop the image here (also do usm here) |
| Args: |
| path (str): Path of the image |
| crop_size (int): Crop size |
| lr_dataset_path (str): LR dataset path folder name |
| output_index (int): The index we used to store images |
| Returns: |
| output_index (int): The next index we need to use to store images |
| ''' |
|
|
| |
| img = cv2.imread(path) |
| height, width = img.shape[0:2] |
|
|
| res_store = [] |
| crop_num = (height//crop_size)*(width//crop_size) |
|
|
| |
| shift_offset_h, shift_offset_w = 0, 0 |
|
|
|
|
| |
| choices = [i for i in range(crop_num)] |
| shift_offset_h = 0 |
| shift_offset_w = 0 |
|
|
|
|
| for choice in choices: |
| row_num = (width//crop_size) |
| x, y = crop_size * (choice // row_num), crop_size * (choice % row_num) |
| |
| res_store.append((x, y)) |
|
|
| |
|
|
| for (h, w) in res_store: |
| cropped_img = img[h+shift_offset_h : h+crop_size+shift_offset_h, w+shift_offset_w : w+crop_size+shift_offset_w, ...] |
| cropped_img = np.ascontiguousarray(cropped_img) |
| cv2.imwrite(osp.join(lr_dataset_path, f'img_{output_index:06d}.png'), cropped_img, [cv2.IMWRITE_PNG_COMPRESSION, 0]) |
|
|
| output_index += 1 |
|
|
| return output_index |
|
|
|
|
|
|
| def single_process(queue, opt, process_id): |
| ''' Multi Process instance |
| Args: |
| queue (multiprocessing.Queue): The input queue |
| opt (dict): The setting we need to use |
| process_id (int): The id we used to store temporary file |
| ''' |
|
|
| |
| obj_img = degradation_v1() |
|
|
| while True: |
| items = queue.get() |
| if items == None: |
| break |
| input_path, store_path = items |
|
|
| |
| obj_img.reset_kernels(opt) |
| |
| |
| img_bgr = cv2.imread(input_path) |
|
|
| out = np2tensor(img_bgr) |
|
|
| |
| obj_img.degradate_process(out, opt, store_path, process_id, verbose = False) |
|
|
|
|
|
|
| @torch.no_grad() |
| def generate_low_res_esr(org_opt, verbose=False): |
| ''' Generate LR dataset from HR ones by ESR degradation |
| Args: |
| org_opt (dict): The setting we will use |
| verbose (bool): Whether we print out some information |
| ''' |
|
|
| |
| input_folder = org_opt['input_folder'] |
| save_folder = org_opt['save_folder'] |
| if osp.exists(save_folder): |
| shutil.rmtree(save_folder) |
| if osp.exists("tmp"): |
| shutil.rmtree("tmp") |
| os.makedirs(save_folder) |
| os.makedirs("tmp") |
| if os.path.exists("datasets/degradation_log.txt"): |
| os.remove("datasets/degradation_log.txt") |
|
|
|
|
| |
| input_img_lists, output_img_lists = [], [] |
| for file in sorted(os.listdir(input_folder)): |
| input_img_lists.append(osp.join(input_folder, file)) |
| output_img_lists.append(osp.join("tmp", file)) |
| assert(len(input_img_lists) == len(output_img_lists)) |
|
|
|
|
| |
| parallel_num = opt['parallel_num'] |
| queue = Queue() |
|
|
|
|
| |
| for idx in range(len(input_img_lists)): |
| |
| queue.put((input_img_lists[idx], output_img_lists[idx])) |
|
|
|
|
| |
| Processes = [] |
| for process_id in range(parallel_num): |
| p1 = Process(target=single_process, args =(queue, opt, process_id, )) |
| p1.start() |
| Processes.append(p1) |
| for _ in range(parallel_num): |
| queue.put(None) |
| |
|
|
| |
| for idx in tqdm(range(0, len(output_img_lists)), desc ="Degradation"): |
| while True: |
| if os.path.exists(output_img_lists[idx]): |
| break |
| time.sleep(0.1) |
|
|
| |
| for process in Processes: |
| process.join() |
|
|
|
|
|
|
| |
| output_index = 1 |
| for img_name in sorted(os.listdir("tmp")): |
| path = os.path.join("tmp", img_name) |
| output_index = crop_process(path, opt['hr_size']//opt['scale'], opt['save_folder'], output_index) |
| |
|
|
| |
| def main(args): |
| opt['input_folder'] = args.input |
| opt['save_folder'] = args.output |
|
|
| generate_low_res_esr(opt) |
|
|
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--input', type=str, default = opt["full_patch_source"], help='Input folder') |
| parser.add_argument('--output', type=str, default = opt["lr_dataset_path"], help='Output folder') |
| args = parser.parse_args() |
|
|
| main(args) |