| """
|
| SAR-DDPM Inference on real SAR images.
|
| """
|
|
|
| import argparse
|
| import torch
|
| import os
|
| import cv2
|
| import numpy as np
|
|
|
| import torch.nn.functional as F
|
|
|
| from guided_diffusion import dist_util, logger
|
| from guided_diffusion.image_datasets import load_data
|
| from guided_diffusion.resample import create_named_schedule_sampler
|
| from guided_diffusion.script_util import (
|
| sr_model_and_diffusion_defaults,
|
| sr_create_model_and_diffusion,
|
| args_to_dict,
|
| add_dict_to_argparser,
|
| )
|
| from guided_diffusion.train_util import TrainLoop
|
| from torch.utils.data import DataLoader
|
| from torch.optim import AdamW
|
|
|
| from valdata import ValData, ValDataNew, ValDataNewReal
|
| from skimage.metrics import peak_signal_noise_ratio as psnr
|
| from skimage.metrics import structural_similarity as ssim
|
|
|
|
|
|
|
| val_dir = 'path_to_validation_data/'
|
| base_path = 'path_to_save_results/'
|
| resume_checkpoint_clean = './weights/sar_ddpm.pt'
|
|
|
|
|
|
|
|
|
| def main():
|
| args = create_argparser().parse_args()
|
|
|
| print(args)
|
|
|
|
|
| model_clean, diffusion = sr_create_model_and_diffusion(
|
| **args_to_dict(args, sr_model_and_diffusion_defaults().keys())
|
| )
|
|
|
|
|
| print(torch.device('cuda'))
|
|
|
| schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
|
|
|
|
|
| val_data = DataLoader(ValDataNewReal(dataset_path=val_dir), batch_size=1, shuffle=False, num_workers=1)
|
|
|
| device0 = torch.device("cuda:0")
|
|
|
| model_clean.load_state_dict(torch.load(resume_checkpoint_clean, map_location="cuda:0"))
|
|
|
|
|
| model_clean.to(device0)
|
|
|
|
|
|
|
|
|
| params = list(model_clean.parameters())
|
|
|
| print('model clean device:')
|
| print(next(model_clean.parameters()).device)
|
|
|
|
|
|
|
| with torch.no_grad():
|
| number = 0
|
|
|
|
|
| for batch_id1, data_var in enumerate(val_data):
|
| number = number+1
|
| clean_batch, model_kwargs1 = data_var
|
|
|
| single_img = model_kwargs1['SR'].to(dist_util.dev())
|
|
|
| count = 0
|
| [t1,t2,max_r,max_c] = single_img.size()
|
|
|
| N =9
|
|
|
| val_inputv = single_img.clone()
|
|
|
| for row in range(0,max_r,100):
|
| for col in range(0,max_c,100):
|
|
|
|
|
| val_inputv[:,:,:row,:col] = single_img[:,:,max_r-row:,max_c-col:]
|
| val_inputv[:,:,row:,col:] = single_img[:,:,:max_r-row,:max_c-col]
|
| val_inputv[:,:,row:,:col] = single_img[:,:,:max_r-row,max_c-col:]
|
| val_inputv[:,:,:row,col:] = single_img[:,:,max_r-row:,:max_c-col]
|
|
|
| model_kwargs = {}
|
| for k, v in model_kwargs1.items():
|
| if('Index' in k):
|
| img_name=v
|
| elif('SR' in k):
|
| model_kwargs[k] = val_inputv.to(dist_util.dev())
|
| else:
|
| model_kwargs[k]= v.to(dist_util.dev())
|
|
|
|
|
|
|
| sample = diffusion.p_sample_loop(
|
| model_clean,
|
| (clean_batch.shape[0], 3, 256,256),
|
| clip_denoised=True,
|
| model_kwargs=model_kwargs,
|
| )
|
|
|
|
|
|
|
| if count==0:
|
| sample_new = (1.0/N)*sample
|
| else :
|
| sample_new[:,:,max_r-row:,max_c-col:] = sample_new[:,:,max_r-row:,max_c-col:] + (1.0/N)*sample[:,:,:row,:col]
|
| sample_new[:,:,:max_r-row,:max_c-col] = sample_new[:,:,:max_r-row,:max_c-col] + (1.0/N)*sample[:,:,row:,col:]
|
| sample_new[:,:,:max_r-row,max_c-col:] = sample_new[:,:,:max_r-row,max_c-col:] + (1.0/N)*sample[:,:,row:,:col]
|
| sample_new[:,:,max_r-row:,:max_c-col] = sample_new[:,:,max_r-row:,:max_c-col] + (1.0/N)*sample[:,:,:row,col:]
|
|
|
| count += 1
|
|
|
| sample_new = ((sample_new + 1) * 127.5)
|
| sample_new = sample_new.clamp(0, 255).to(torch.uint8)
|
| sample_new = sample_new.permute(0, 2, 3, 1)
|
| sample_new = sample_new.contiguous().cpu().numpy()
|
| sample_new = sample_new[0][:,:,::-1]
|
|
|
| sample_new = cv2.cvtColor(sample_new, cv2.COLOR_BGR2GRAY)
|
| print(img_name[0])
|
| cv2.imwrite(base_path+'pred_'+img_name[0],sample_new)
|
|
|
|
|
|
|
|
|
|
|
|
|
| def create_argparser():
|
| defaults = dict(
|
| data_dir= val_dir,
|
| schedule_sampler="uniform",
|
| lr=1e-4,
|
| weight_decay=0.0,
|
| lr_anneal_steps=0,
|
| batch_size=2,
|
| microbatch=1,
|
| ema_rate="0.9999",
|
| log_interval=100,
|
| save_interval=200,
|
| use_fp16=False,
|
| fp16_scale_growth=1e-3,
|
| )
|
| defaults.update(sr_model_and_diffusion_defaults())
|
| parser = argparse.ArgumentParser()
|
| add_dict_to_argparser(parser, defaults)
|
| return parser
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|