TTrain404 commited on
Commit
39aef76
·
verified ·
1 Parent(s): 89c24a6

Upload 24 files

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 malshaV
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,36 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SAR-DDPM
2
+
3
+ Code for the paper [SAR despeckling using a Denoising Diffusion Probabilistic Model](https://arxiv.org/pdf/2206.04514.pdf), acepted at IEEE Geoscience and Remote Sensing Letters
4
+
5
+
6
+ ## To train the SAR-DDPM model:
7
+
8
+ - Download the weights 64x64 -> 256x256 upsampler from [here](https://github.com/openai/guided-diffusion).
9
+
10
+ - Create a folder ./weights and place the dowloaded weights in the folder.
11
+
12
+ - Specify the paths to your training data and validation data in ./scripts/sarddpm_train.py (line 23 and line 25)
13
+
14
+ - Use the following command to run the code (change the GPU number according to GPU availability):
15
+
16
+ ```bash
17
+ MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --diffusion_steps 1000 --large_size 256 --small_size 64 --learn_sigma True --noise_schedule linear --num_channels 192 --num_heads 4 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
18
+ export PYTHONPATH=$PYTHONPATH:$(pwd)
19
+ CUDA_VISIBLE_DEVICES=0 python scripts/sarddpm_train.py $MODEL_FLAGS
20
+ ```
21
+
22
+
23
+ ### Acknowledgement:
24
+
25
+ This code is based on DDPM implementation in [guided-diffusion](https://github.com/openai/guided-diffusion)
26
+
27
+
28
+ ### Citation:
29
+
30
+ ```
31
+ @ARTICLE{perera2022sar,
32
+ author={Perera, Malsha V. and Nair, Nithin Gopalakrishnan and Bandara, Wele Gedara Chaminda and Patel, Vishal M.},
33
+ journal={IEEE Geoscience and Remote Sensing Letters},
34
+ title={SAR Despeckling using a Denoising Diffusion Probabilistic Model},
35
+ year={2023}}
36
+ ```
core/logger.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import logging
4
+ from collections import OrderedDict
5
+ import json
6
+ from datetime import datetime
7
+
8
+
9
+ def mkdirs(paths):
10
+ if isinstance(paths, str):
11
+ os.makedirs(paths, exist_ok=True)
12
+ else:
13
+ for path in paths:
14
+ os.makedirs(path, exist_ok=True)
15
+
16
+
17
+ def get_timestamp():
18
+ return datetime.now().strftime('%y%m%d_%H%M%S')
19
+
20
+
21
+ def parse(args):
22
+ phase = args.phase
23
+ opt_path = args.config
24
+ gpu_ids = args.gpu_ids
25
+ enable_wandb = args.enable_wandb
26
+ # remove comments starting with '//'
27
+ json_str = ''
28
+ with open(opt_path, 'r') as f:
29
+ for line in f:
30
+ line = line.split('//')[0] + '\n'
31
+ json_str += line
32
+ opt = json.loads(json_str, object_pairs_hook=OrderedDict)
33
+
34
+ # set log directory
35
+ if args.debug:
36
+ opt['name'] = 'debug_{}'.format(opt['name'])
37
+ experiments_root = os.path.join(
38
+ 'experiments', '{}_{}'.format(opt['name'], get_timestamp()))
39
+ opt['path']['experiments_root'] = experiments_root
40
+ for key, path in opt['path'].items():
41
+ if 'resume' not in key and 'experiments' not in key:
42
+ opt['path'][key] = os.path.join(experiments_root, path)
43
+ mkdirs(opt['path'][key])
44
+
45
+ # change dataset length limit
46
+ opt['phase'] = phase
47
+
48
+ # export CUDA_VISIBLE_DEVICES
49
+ if gpu_ids is not None:
50
+ opt['gpu_ids'] = [int(id) for id in gpu_ids.split(',')]
51
+ gpu_list = gpu_ids
52
+ else:
53
+ gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
54
+ os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
55
+ print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
56
+ if len(gpu_list) > 1:
57
+ opt['distributed'] = True
58
+ else:
59
+ opt['distributed'] = False
60
+
61
+ # debug
62
+ if 'debug' in opt['name']:
63
+ opt['train']['val_freq'] = 2
64
+ opt['train']['print_freq'] = 2
65
+ opt['train']['save_checkpoint_freq'] = 3
66
+ opt['datasets']['train']['batch_size'] = 2
67
+ opt['model']['beta_schedule']['train']['n_timestep'] = 10
68
+ opt['model']['beta_schedule']['val']['n_timestep'] = 10
69
+ opt['datasets']['train']['data_len'] = 6
70
+ opt['datasets']['val']['data_len'] = 3
71
+
72
+ # validation in train phase
73
+ if phase == 'train':
74
+ opt['datasets']['val']['data_len'] = 3
75
+
76
+ # W&B Logging
77
+ try:
78
+ log_wandb_ckpt = args.log_wandb_ckpt
79
+ opt['log_wandb_ckpt'] = log_wandb_ckpt
80
+ except:
81
+ pass
82
+ try:
83
+ log_eval = args.log_eval
84
+ opt['log_eval'] = log_eval
85
+ except:
86
+ pass
87
+ try:
88
+ log_infer = args.log_infer
89
+ opt['log_infer'] = log_infer
90
+ except:
91
+ pass
92
+ opt['enable_wandb'] = enable_wandb
93
+
94
+ return opt
95
+
96
+
97
+ class NoneDict(dict):
98
+ def __missing__(self, key):
99
+ return None
100
+
101
+
102
+ # convert to NoneDict, which return None for missing key.
103
+ def dict_to_nonedict(opt):
104
+ if isinstance(opt, dict):
105
+ new_opt = dict()
106
+ for key, sub_opt in opt.items():
107
+ new_opt[key] = dict_to_nonedict(sub_opt)
108
+ return NoneDict(**new_opt)
109
+ elif isinstance(opt, list):
110
+ return [dict_to_nonedict(sub_opt) for sub_opt in opt]
111
+ else:
112
+ return opt
113
+
114
+
115
+ def dict2str(opt, indent_l=1):
116
+ '''dict to string for logger'''
117
+ msg = ''
118
+ for k, v in opt.items():
119
+ if isinstance(v, dict):
120
+ msg += ' ' * (indent_l * 2) + k + ':[\n'
121
+ msg += dict2str(v, indent_l + 1)
122
+ msg += ' ' * (indent_l * 2) + ']\n'
123
+ else:
124
+ msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
125
+ return msg
126
+
127
+
128
+ def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False):
129
+ '''set up logger'''
130
+ l = logging.getLogger(logger_name)
131
+ formatter = logging.Formatter(
132
+ '%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', datefmt='%y-%m-%d %H:%M:%S')
133
+ log_file = os.path.join(root, '{}.log'.format(phase))
134
+ fh = logging.FileHandler(log_file, mode='w')
135
+ fh.setFormatter(formatter)
136
+ l.setLevel(level)
137
+ l.addHandler(fh)
138
+ if screen:
139
+ sh = logging.StreamHandler()
140
+ sh.setFormatter(formatter)
141
+ l.addHandler(sh)
core/metrics.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import numpy as np
4
+ import cv2
5
+ from torchvision.utils import make_grid
6
+
7
+
8
+ def tensor2img(tensor, out_type=np.uint8, min_max=(-1, 1)):
9
+ '''
10
+ Converts a torch Tensor into an image Numpy array
11
+ Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
12
+ Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
13
+ '''
14
+ tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp
15
+ tensor = (tensor - min_max[0]) / \
16
+ (min_max[1] - min_max[0]) # to range [0,1]
17
+ n_dim = tensor.dim()
18
+ if n_dim == 4:
19
+ n_img = len(tensor)
20
+ img_np = make_grid(tensor, nrow=int(
21
+ math.sqrt(n_img)), normalize=False).numpy()
22
+ img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB
23
+ elif n_dim == 3:
24
+ img_np = tensor.numpy()
25
+ img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB
26
+ elif n_dim == 2:
27
+ img_np = tensor.numpy()
28
+ else:
29
+ raise TypeError(
30
+ 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
31
+ if out_type == np.uint8:
32
+ img_np = (img_np * 255.0).round()
33
+ # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
34
+ return img_np.astype(out_type)
35
+
36
+
37
+ def save_img(img, img_path, mode='RGB'):
38
+ cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
39
+ # cv2.imwrite(img_path, img)
40
+
41
+
42
+ def calculate_psnr(img1, img2):
43
+ # img1 and img2 have range [0, 255]
44
+ img1 = img1.astype(np.float64)
45
+ img2 = img2.astype(np.float64)
46
+ mse = np.mean((img1 - img2)**2)
47
+ if mse == 0:
48
+ return float('inf')
49
+ return 20 * math.log10(255.0 / math.sqrt(mse))
50
+
51
+
52
+ def ssim(img1, img2):
53
+ C1 = (0.01 * 255)**2
54
+ C2 = (0.03 * 255)**2
55
+
56
+ img1 = img1.astype(np.float64)
57
+ img2 = img2.astype(np.float64)
58
+ kernel = cv2.getGaussianKernel(11, 1.5)
59
+ window = np.outer(kernel, kernel.transpose())
60
+
61
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
62
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
63
+ mu1_sq = mu1**2
64
+ mu2_sq = mu2**2
65
+ mu1_mu2 = mu1 * mu2
66
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
67
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
68
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
69
+
70
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
71
+ (sigma1_sq + sigma2_sq + C2))
72
+ return ssim_map.mean()
73
+
74
+
75
+ def calculate_ssim(img1, img2):
76
+ '''calculate SSIM
77
+ the same outputs as MATLAB's
78
+ img1, img2: [0, 255]
79
+ '''
80
+ if not img1.shape == img2.shape:
81
+ raise ValueError('Input images must have the same dimensions.')
82
+ if img1.ndim == 2:
83
+ return ssim(img1, img2)
84
+ elif img1.ndim == 3:
85
+ if img1.shape[2] == 3:
86
+ ssims = []
87
+ for i in range(3):
88
+ ssims.append(ssim(img1, img2))
89
+ return np.array(ssims).mean()
90
+ elif img1.shape[2] == 1:
91
+ return ssim(np.squeeze(img1), np.squeeze(img2))
92
+ else:
93
+ raise ValueError('Wrong input image dimensions.')
core/wandb_logger.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ class WandbLogger:
4
+ """
5
+ Log using `Weights and Biases`.
6
+ """
7
+ def __init__(self):
8
+ try:
9
+ import wandb
10
+ except ImportError:
11
+ raise ImportError(
12
+ "To use the Weights and Biases Logger please install wandb."
13
+ "Run `pip install wandb` to install it."
14
+ )
15
+
16
+ self._wandb = wandb
17
+
18
+ # Initialize a W&B run
19
+ if self._wandb.run is None:
20
+ self._wandb.init(
21
+ project='diff_derain',
22
+ dir='./experiments'
23
+ )
24
+
25
+ self.config = self._wandb.config
26
+
27
+ # if self.config.get('log_eval', None):
28
+ # self.eval_table = self._wandb.Table(columns=['fake_image',
29
+ # 'sr_image',
30
+ # 'hr_image',
31
+ # 'psnr',
32
+ # 'ssim'])
33
+ # else:
34
+ self.eval_table = None
35
+
36
+ # if self.config.get('log_infer', None):
37
+ # self.infer_table = self._wandb.Table(columns=['fake_image',
38
+ # 'sr_image',
39
+ # 'hr_image'])
40
+ # else:
41
+ self.infer_table = None
42
+
43
+ def log_metrics(self, metrics, commit=True):
44
+ """
45
+ Log train/validation metrics onto W&B.
46
+
47
+ metrics: dictionary of metrics to be logged
48
+ """
49
+ self._wandb.log(metrics, commit=commit)
50
+
51
+ def log_image(self, key_name, image_array):
52
+ """
53
+ Log image array onto W&B.
54
+
55
+ key_name: name of the key
56
+ image_array: numpy array of image.
57
+ """
58
+ self._wandb.log({key_name: self._wandb.Image(image_array)})
59
+
60
+ def log_images(self, key_name, list_images):
61
+ """
62
+ Log list of image array onto W&B
63
+
64
+ key_name: name of the key
65
+ list_images: list of numpy image arrays
66
+ """
67
+ self._wandb.log({key_name: [self._wandb.Image(img) for img in list_images]})
68
+
69
+ def log_checkpoint(self, current_epoch, current_step):
70
+ """
71
+ Log the model checkpoint as W&B artifacts
72
+
73
+ current_epoch: the current epoch
74
+ current_step: the current batch step
75
+ """
76
+ model_artifact = self._wandb.Artifact(
77
+ self._wandb.run.id + "_model", type="model"
78
+ )
79
+
80
+ gen_path = os.path.join(
81
+ self.config.path['checkpoint'], 'I{}_E{}_gen.pth'.format(current_step, current_epoch))
82
+ opt_path = os.path.join(
83
+ self.config.path['checkpoint'], 'I{}_E{}_opt.pth'.format(current_step, current_epoch))
84
+
85
+ model_artifact.add_file(gen_path)
86
+ model_artifact.add_file(opt_path)
87
+ self._wandb.log_artifact(model_artifact, aliases=["latest"])
88
+
89
+ def log_eval_data(self, fake_img, sr_img, hr_img, psnr=None, ssim=None):
90
+ """
91
+ Add data row-wise to the initialized table.
92
+ """
93
+ if psnr is not None and ssim is not None:
94
+ self.eval_table.add_data(
95
+ self._wandb.Image(fake_img),
96
+ self._wandb.Image(sr_img),
97
+ self._wandb.Image(hr_img),
98
+ psnr,
99
+ ssim
100
+ )
101
+ else:
102
+ self.infer_table.add_data(
103
+ self._wandb.Image(fake_img),
104
+ self._wandb.Image(sr_img),
105
+ self._wandb.Image(hr_img)
106
+ )
107
+
108
+ def log_eval_table(self, commit=False):
109
+ """
110
+ Log the table
111
+ """
112
+ if self.eval_table:
113
+ self._wandb.log({'eval_data': self.eval_table}, commit=commit)
114
+ elif self.infer_table:
115
+ self._wandb.log({'infer_data': self.infer_table}, commit=commit)
guided_diffusion/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ Codebase for "Improved Denoising Diffusion Probabilistic Models".
3
+ """
guided_diffusion/dist_util.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers for distributed training.
3
+ """
4
+
5
+ import io
6
+ import os
7
+ import socket
8
+
9
+ import blobfile as bf
10
+ from mpi4py import MPI
11
+ import torch as th
12
+ import torch.distributed as dist
13
+
14
+ # Change this to reflect your cluster layout.
15
+ # The GPU for a given rank is (rank % GPUS_PER_NODE).
16
+ GPUS_PER_NODE = 8
17
+
18
+ SETUP_RETRY_COUNT = 3
19
+
20
+
21
+ def setup_dist():
22
+ """
23
+ Setup a distributed process group.
24
+ """
25
+ if dist.is_initialized():
26
+ return
27
+ os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}"
28
+ # os.environ["CUDA_VISIBLE_DEVICES"] = '1'
29
+ # print(os.environ["CUDA_VISIBLE_DEVICES"])
30
+
31
+ comm = MPI.COMM_WORLD
32
+ backend = "gloo" if not th.cuda.is_available() else "nccl"
33
+
34
+ if backend == "gloo":
35
+ hostname = "localhost"
36
+ else:
37
+ hostname = socket.gethostbyname(socket.getfqdn())
38
+ os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
39
+ os.environ["RANK"] = str(comm.rank)
40
+ os.environ["WORLD_SIZE"] = str(comm.size)
41
+
42
+ port = comm.bcast(_find_free_port(), root=0)
43
+ os.environ["MASTER_PORT"] = str(port)
44
+ dist.init_process_group(backend=backend, init_method="env://")
45
+
46
+
47
+ def dev():
48
+ """
49
+ Get the device to use for torch.distributed.
50
+ """
51
+ if th.cuda.is_available():
52
+ return th.device(f"cuda")
53
+ return th.device("cpu")
54
+
55
+
56
+ def load_state_dict(path, **kwargs):
57
+ """
58
+ Load a PyTorch file without redundant fetches across MPI ranks.
59
+ """
60
+ chunk_size = 2 ** 30 # MPI has a relatively small size limit
61
+ if MPI.COMM_WORLD.Get_rank() == 0:
62
+ with bf.BlobFile(path, "rb") as f:
63
+ data = f.read()
64
+ num_chunks = len(data) // chunk_size
65
+ if len(data) % chunk_size:
66
+ num_chunks += 1
67
+ MPI.COMM_WORLD.bcast(num_chunks)
68
+ for i in range(0, len(data), chunk_size):
69
+ MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
70
+ else:
71
+ num_chunks = MPI.COMM_WORLD.bcast(None)
72
+ data = bytes()
73
+ for _ in range(num_chunks):
74
+ data += MPI.COMM_WORLD.bcast(None)
75
+
76
+ return th.load(io.BytesIO(data), **kwargs)
77
+
78
+
79
+ def sync_params(params):
80
+ """
81
+ Synchronize a sequence of Tensors across ranks from rank 0.
82
+ """
83
+ for p in params:
84
+ with th.no_grad():
85
+ dist.broadcast(p, 0)
86
+
87
+
88
+ def _find_free_port():
89
+ try:
90
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
91
+ s.bind(("", 0))
92
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
93
+ return s.getsockname()[1]
94
+ finally:
95
+ s.close()
guided_diffusion/fp16_util.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers to train with 16-bit precision.
3
+ """
4
+
5
+ import numpy as np
6
+ import torch as th
7
+ import torch.nn as nn
8
+ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
9
+
10
+ from . import logger
11
+
12
+ INITIAL_LOG_LOSS_SCALE = 20.0
13
+
14
+
15
+ def convert_module_to_f16(l):
16
+ """
17
+ Convert primitive modules to float16.
18
+ """
19
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
20
+ l.weight.data = l.weight.data.half()
21
+ if l.bias is not None:
22
+ l.bias.data = l.bias.data.half()
23
+
24
+
25
+ def convert_module_to_f32(l):
26
+ """
27
+ Convert primitive modules to float32, undoing convert_module_to_f16().
28
+ """
29
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
30
+ l.weight.data = l.weight.data.float()
31
+ if l.bias is not None:
32
+ l.bias.data = l.bias.data.float()
33
+
34
+
35
+ def make_master_params(param_groups_and_shapes):
36
+ """
37
+ Copy model parameters into a (differently-shaped) list of full-precision
38
+ parameters.
39
+ """
40
+ master_params = []
41
+ for param_group, shape in param_groups_and_shapes:
42
+ master_param = nn.Parameter(
43
+ _flatten_dense_tensors(
44
+ [param.detach().float() for (_, param) in param_group]
45
+ ).view(shape)
46
+ )
47
+ master_param.requires_grad = True
48
+ master_params.append(master_param)
49
+ return master_params
50
+
51
+
52
+ def model_grads_to_master_grads(param_groups_and_shapes, master_params):
53
+ """
54
+ Copy the gradients from the model parameters into the master parameters
55
+ from make_master_params().
56
+ """
57
+ for master_param, (param_group, shape) in zip(
58
+ master_params, param_groups_and_shapes
59
+ ):
60
+ master_param.grad = _flatten_dense_tensors(
61
+ [param_grad_or_zeros(param) for (_, param) in param_group]
62
+ ).view(shape)
63
+
64
+
65
+ def master_params_to_model_params(param_groups_and_shapes, master_params):
66
+ """
67
+ Copy the master parameter data back into the model parameters.
68
+ """
69
+ # Without copying to a list, if a generator is passed, this will
70
+ # silently not copy any parameters.
71
+ for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
72
+ for (_, param), unflat_master_param in zip(
73
+ param_group, unflatten_master_params(param_group, master_param.view(-1))
74
+ ):
75
+ param.detach().copy_(unflat_master_param)
76
+
77
+
78
+ def unflatten_master_params(param_group, master_param):
79
+ return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
80
+
81
+
82
+ def get_param_groups_and_shapes(named_model_params):
83
+ named_model_params = list(named_model_params)
84
+ scalar_vector_named_params = (
85
+ [(n, p) for (n, p) in named_model_params if p.ndim <= 1],
86
+ (-1),
87
+ )
88
+ matrix_named_params = (
89
+ [(n, p) for (n, p) in named_model_params if p.ndim > 1],
90
+ (1, -1),
91
+ )
92
+ return [scalar_vector_named_params, matrix_named_params]
93
+
94
+
95
+ # def master_params_to_state_dict(
96
+ # model, param_groups_and_shapes, master_params, use_fp16
97
+ # ):
98
+ # if use_fp16:
99
+ # state_dict = model.state_dict()
100
+ # for master_param, (param_group, _) in zip(
101
+ # master_params, param_groups_and_shapes
102
+ # ):
103
+ # for (name, _), unflat_master_param in zip(
104
+ # param_group, unflatten_master_params(param_group, master_param.view(-1))
105
+ # ):
106
+ # assert name in state_dict
107
+ # state_dict[name] = unflat_master_param
108
+ # else:
109
+ # state_dict = model.state_dict()
110
+ # for i, (name, _value) in enumerate(model.named_parameters()):
111
+ # assert name in state_dict
112
+ # state_dict[name] = master_params[i]
113
+ # return state_dict
114
+
115
+ def master_params_to_state_dict(
116
+ model, param_groups_and_shapes, master_params, use_fp16
117
+ ):
118
+ if use_fp16:
119
+ state_dict = model.state_dict()
120
+ for master_param, (param_group, _) in zip(
121
+ master_params, param_groups_and_shapes
122
+ ):
123
+ for (name, _), unflat_master_param in zip(
124
+ param_group, unflatten_master_params(param_group, master_param.view(-1))
125
+ ):
126
+ if name in state_dict:
127
+ state_dict[name] = unflat_master_param
128
+ else:
129
+ state_dict = model.state_dict()
130
+ for i, (name, _value) in enumerate(model.named_parameters()):
131
+ if name in state_dict:
132
+ state_dict[name] = master_params[i]
133
+ return state_dict
134
+
135
+ def state_dict_to_master_params(model, state_dict, use_fp16):
136
+ if use_fp16:
137
+ named_model_params = [
138
+ (name, state_dict[name]) for name, _ in model.named_parameters()
139
+ ]
140
+ param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
141
+ master_params = make_master_params(param_groups_and_shapes)
142
+ else:
143
+ master_params = [state_dict[name] for name, _ in model.named_parameters()]
144
+ return master_params
145
+
146
+
147
+ def zero_master_grads(master_params):
148
+ for param in master_params:
149
+ param.grad = None
150
+
151
+
152
+ def zero_grad(model_params):
153
+ for param in model_params:
154
+ # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
155
+ if param.grad is not None:
156
+ param.grad.detach_()
157
+ param.grad.zero_()
158
+
159
+
160
+ def param_grad_or_zeros(param):
161
+ if param.grad is not None:
162
+ return param.grad.data.detach()
163
+ else:
164
+ return th.zeros_like(param)
165
+
166
+
167
+ class MixedPrecisionTrainer:
168
+ def __init__(
169
+ self,
170
+ *,
171
+ model,
172
+ use_fp16=False,
173
+ fp16_scale_growth=1e-3,
174
+ initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
175
+ ):
176
+ self.model = model
177
+ self.use_fp16 = use_fp16
178
+ self.fp16_scale_growth = fp16_scale_growth
179
+
180
+ self.model_params = list(self.model.parameters())
181
+ self.master_params = self.model_params
182
+ self.param_groups_and_shapes = None
183
+ self.lg_loss_scale = initial_lg_loss_scale
184
+
185
+ if self.use_fp16:
186
+ self.param_groups_and_shapes = get_param_groups_and_shapes(
187
+ self.model.named_parameters()
188
+ )
189
+ self.master_params = make_master_params(self.param_groups_and_shapes)
190
+ self.model.convert_to_fp16()
191
+
192
+ def zero_grad(self):
193
+ zero_grad(self.model_params)
194
+
195
+ def backward(self, loss: th.Tensor):
196
+ if self.use_fp16:
197
+ loss_scale = 2 ** self.lg_loss_scale
198
+ (loss * loss_scale).backward()
199
+ else:
200
+ loss.backward()
201
+
202
+ def optimize(self, opt: th.optim.Optimizer):
203
+ if self.use_fp16:
204
+ return self._optimize_fp16(opt)
205
+ else:
206
+ return self._optimize_normal(opt)
207
+
208
+ def _optimize_fp16(self, opt: th.optim.Optimizer):
209
+ logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
210
+ model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
211
+ grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
212
+ if check_overflow(grad_norm):
213
+ self.lg_loss_scale -= 1
214
+ logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
215
+ zero_master_grads(self.master_params)
216
+ return False
217
+
218
+ logger.logkv_mean("grad_norm", grad_norm)
219
+ logger.logkv_mean("param_norm", param_norm)
220
+
221
+ self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
222
+ opt.step()
223
+ zero_master_grads(self.master_params)
224
+ master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
225
+ self.lg_loss_scale += self.fp16_scale_growth
226
+ return True
227
+
228
+ def _optimize_normal(self, opt: th.optim.Optimizer):
229
+ grad_norm, param_norm = self._compute_norms()
230
+ logger.logkv_mean("grad_norm", grad_norm)
231
+ logger.logkv_mean("param_norm", param_norm)
232
+ opt.step()
233
+ return True
234
+
235
+ def _compute_norms(self, grad_scale=1.0):
236
+ grad_norm = 0.0
237
+ param_norm = 0.0
238
+ for p in self.master_params:
239
+ with th.no_grad():
240
+ param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
241
+ if p.grad is not None:
242
+ grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
243
+ return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
244
+
245
+ def master_params_to_state_dict(self, master_params):
246
+ return master_params_to_state_dict(
247
+ self.model, self.param_groups_and_shapes, master_params, self.use_fp16
248
+ )
249
+
250
+ def state_dict_to_master_params(self, state_dict):
251
+ return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
252
+
253
+
254
+ def check_overflow(value):
255
+ return (value == float("inf")) or (value == -float("inf")) or (value != value)
guided_diffusion/gaussian_diffusion.py ADDED
@@ -0,0 +1,1023 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This code started out as a PyTorch port of Ho et al's diffusion models:
3
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
4
+
5
+ Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
6
+ """
7
+
8
+ import enum
9
+ import math
10
+
11
+ import numpy as np
12
+ import torch as th
13
+
14
+ from .nn import mean_flat
15
+ from .losses import normal_kl, discretized_gaussian_log_likelihood
16
+ import cv2
17
+
18
+
19
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
20
+ """
21
+ Get a pre-defined beta schedule for the given name.
22
+
23
+ The beta schedule library consists of beta schedules which remain similar
24
+ in the limit of num_diffusion_timesteps.
25
+ Beta schedules may be added, but should not be removed or changed once
26
+ they are committed to maintain backwards compatibility.
27
+ """
28
+ # schedule_name=cosine
29
+ if schedule_name == "linear":
30
+ # Linear schedule from Ho et al, extended to work for any number of
31
+ # diffusion steps.
32
+ scale = 1000 / num_diffusion_timesteps
33
+ beta_start = scale * 0.0001
34
+ beta_end = scale * 0.02
35
+ return np.linspace(
36
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
37
+ )
38
+ elif schedule_name == "cosine":
39
+ return betas_for_alpha_bar(
40
+ num_diffusion_timesteps,
41
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
42
+ )
43
+ else:
44
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
45
+
46
+
47
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
48
+ """
49
+ Create a beta schedule that discretizes the given alpha_t_bar function,
50
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
51
+
52
+ :param num_diffusion_timesteps: the number of betas to produce.
53
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
54
+ produces the cumulative product of (1-beta) up to that
55
+ part of the diffusion process.
56
+ :param max_beta: the maximum beta to use; use values lower than 1 to
57
+ prevent singularities.
58
+ """
59
+ betas = []
60
+ for i in range(num_diffusion_timesteps):
61
+ t1 = i / num_diffusion_timesteps
62
+ t2 = (i + 1) / num_diffusion_timesteps
63
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
64
+ return np.array(betas)
65
+
66
+
67
+ class ModelMeanType(enum.Enum):
68
+ """
69
+ Which type of output the model predicts.
70
+ """
71
+
72
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
73
+ START_X = enum.auto() # the model predicts x_0
74
+ EPSILON = enum.auto() # the model predicts epsilon
75
+
76
+
77
+ class ModelVarType(enum.Enum):
78
+ """
79
+ What is used as the model's output variance.
80
+
81
+ The LEARNED_RANGE option has been added to allow the model to predict
82
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
83
+ """
84
+
85
+ LEARNED = enum.auto()
86
+ FIXED_SMALL = enum.auto()
87
+ FIXED_LARGE = enum.auto()
88
+ LEARNED_RANGE = enum.auto()
89
+
90
+
91
+ class LossType(enum.Enum):
92
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
93
+ RESCALED_MSE = (
94
+ enum.auto()
95
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
96
+ KL = enum.auto() # use the variational lower-bound
97
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
98
+
99
+ def is_vb(self):
100
+ return self == LossType.KL or self == LossType.RESCALED_KL
101
+
102
+
103
+ class GaussianDiffusion:
104
+ """
105
+ Utilities for training and sampling diffusion models.
106
+
107
+ Ported directly from here, and then adapted over time to further experimentation.
108
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
109
+
110
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
111
+ starting at T and going to 1.
112
+ :param model_mean_type: a ModelMeanType determining what the model outputs.
113
+ :param model_var_type: a ModelVarType determining how variance is output.
114
+ :param loss_type: a LossType determining the loss function to use.
115
+ :param rescale_timesteps: if True, pass floating point timesteps into the
116
+ model so that they are always scaled like in the
117
+ original paper (0 to 1000).
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ *,
123
+ betas,
124
+ model_mean_type,
125
+ model_var_type,
126
+ loss_type,
127
+ rescale_timesteps=False,
128
+ ):
129
+ self.model_mean_type = model_mean_type
130
+ self.model_var_type = model_var_type
131
+ self.loss_type = loss_type
132
+ self.rescale_timesteps = rescale_timesteps
133
+
134
+ # Use float64 for accuracy.
135
+ betas = np.array(betas, dtype=np.float64)
136
+ self.betas = betas
137
+ assert len(betas.shape) == 1, "betas must be 1-D"
138
+ assert (betas > 0).all() and (betas <= 1).all()
139
+
140
+ self.num_timesteps = int(betas.shape[0])
141
+
142
+ alphas = 1.0 - betas
143
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
144
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
145
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
146
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
147
+
148
+ # calculations for diffusion q(x_t | x_{t-1}) and others
149
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
150
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
151
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
152
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
153
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
154
+
155
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
156
+ self.posterior_variance = (
157
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
158
+ )
159
+ # log calculation clipped because the posterior variance is 0 at the
160
+ # beginning of the diffusion chain.
161
+ self.posterior_log_variance_clipped = np.log(
162
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
163
+ )
164
+ self.posterior_mean_coef1 = (
165
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
166
+ )
167
+ self.posterior_mean_coef2 = (
168
+ (1.0 - self.alphas_cumprod_prev)
169
+ * np.sqrt(alphas)
170
+ / (1.0 - self.alphas_cumprod)
171
+ )
172
+
173
+ def q_mean_variance(self, x_start, t):
174
+ """
175
+ Get the distribution q(x_t | x_0).
176
+
177
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
178
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
179
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
180
+ """
181
+ mean = (
182
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
183
+ )
184
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
185
+ log_variance = _extract_into_tensor(
186
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
187
+ )
188
+ return mean, variance, log_variance
189
+
190
+ def q_sample(self, x_start, t, noise=None):
191
+ """
192
+ Diffuse the data for a given number of diffusion steps.
193
+
194
+ In other words, sample from q(x_t | x_0).
195
+
196
+ :param x_start: the initial data batch.
197
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
198
+ :param noise: if specified, the split-out normal noise.
199
+ :return: A noisy version of x_start.
200
+ """
201
+ if noise is None:
202
+ noise = th.randn_like(x_start)
203
+ assert noise.shape == x_start.shape
204
+ return (
205
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
206
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
207
+ * noise
208
+ )
209
+
210
+ def q_posterior_mean_variance(self, x_start, x_t, t):
211
+ """
212
+ Compute the mean and variance of the diffusion posterior:
213
+
214
+ q(x_{t-1} | x_t, x_0)
215
+
216
+ """
217
+ assert x_start.shape == x_t.shape
218
+ posterior_mean = (
219
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
220
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
221
+ )
222
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
223
+ posterior_log_variance_clipped = _extract_into_tensor(
224
+ self.posterior_log_variance_clipped, t, x_t.shape
225
+ )
226
+ assert (
227
+ posterior_mean.shape[0]
228
+ == posterior_variance.shape[0]
229
+ == posterior_log_variance_clipped.shape[0]
230
+ == x_start.shape[0]
231
+ )
232
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
233
+
234
+ def p_mean_variance(
235
+ self, model, x, t, clip_denoised=True, denoised_fn=None, x_start=None , model_kwargs=None, device=None
236
+ ):
237
+ """
238
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
239
+ the initial x, x_0.
240
+
241
+ :param model: the model, which takes a signal and a batch of timesteps
242
+ as input.
243
+ :param x: the [N x C x ...] tensor at time t.
244
+ :param t: a 1-D Tensor of timesteps.
245
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
246
+ :param denoised_fn: if not None, a function which applies to the
247
+ x_start prediction before it is used to sample. Applies before
248
+ clip_denoised.
249
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
250
+ pass to the model. This can be used for conditioning.
251
+ :return: a dict with the following keys:
252
+ - 'mean': the model mean output.
253
+ - 'variance': the model variance output.
254
+ - 'log_variance': the log of 'variance'.
255
+ - 'pred_xstart': the prediction for x_0.
256
+ """
257
+ if model_kwargs is None:
258
+ model_kwargs = {}
259
+
260
+ B, C = x.shape[:2]
261
+ assert t.shape == (B,)
262
+
263
+
264
+ model_inp = th.cat([x,x_start],1)
265
+ # model_inp = x
266
+ model_output = model(model_inp, self._scale_timesteps(t), **model_kwargs)
267
+
268
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
269
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
270
+ model_output, model_var_values = th.split(model_output, C, dim=1)
271
+ if self.model_var_type == ModelVarType.LEARNED:
272
+ model_log_variance = model_var_values
273
+ model_variance = th.exp(model_log_variance)
274
+ else:
275
+ min_log = _extract_into_tensor(
276
+ self.posterior_log_variance_clipped, t, x.shape
277
+ )
278
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
279
+ # The model_var_values is [-1, 1] for [min_var, max_var].
280
+ frac = (model_var_values + 1) / 2
281
+ model_log_variance = frac * max_log + (1 - frac) * min_log
282
+ model_variance = th.exp(model_log_variance)
283
+ else:
284
+ model_variance, model_log_variance = {
285
+ # for fixedlarge, we set the initial (log-)variance like so
286
+ # to get a better decoder log likelihood.
287
+ ModelVarType.FIXED_LARGE: (
288
+ np.append(self.posterior_variance[1], self.betas[1:]),
289
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
290
+ ),
291
+ ModelVarType.FIXED_SMALL: (
292
+ self.posterior_variance,
293
+ self.posterior_log_variance_clipped,
294
+ ),
295
+ }[self.model_var_type]
296
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
297
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
298
+
299
+ def process_xstart(x):
300
+ if denoised_fn is not None:
301
+ x = denoised_fn(x)
302
+ if clip_denoised:
303
+ return x.clamp(-1, 1)
304
+ return x
305
+
306
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
307
+ pred_xstart = process_xstart(
308
+ self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
309
+ )
310
+ model_mean = model_output
311
+ elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
312
+ if self.model_mean_type == ModelMeanType.START_X:
313
+ pred_xstart = process_xstart(model_output)
314
+ else:
315
+ pred_xstart = process_xstart(
316
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
317
+ )
318
+ model_mean, _, _ = self.q_posterior_mean_variance(
319
+ x_start=pred_xstart, x_t=x, t=t
320
+ )
321
+ else:
322
+ raise NotImplementedError(self.model_mean_type)
323
+
324
+ assert (
325
+ model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
326
+ )
327
+ return {
328
+ "mean": model_mean,
329
+ "variance": model_variance,
330
+ "log_variance": model_log_variance,
331
+ "pred_xstart": pred_xstart,
332
+ }
333
+
334
+ def _predict_xstart_from_eps(self, x_t, t, eps):
335
+ assert x_t.shape == eps.shape
336
+ return (
337
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
338
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
339
+ )
340
+
341
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
342
+ assert x_t.shape == xprev.shape
343
+ return ( # (xprev - coef2*x_t) / coef1
344
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
345
+ - _extract_into_tensor(
346
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
347
+ )
348
+ * x_t
349
+ )
350
+
351
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
352
+ return (
353
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
354
+ - pred_xstart
355
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
356
+
357
+ def _scale_timesteps(self, t):
358
+ if self.rescale_timesteps:
359
+ return t.float() * (1000.0 / self.num_timesteps)
360
+ return t
361
+
362
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
363
+ """
364
+ Compute the mean for the previous step, given a function cond_fn that
365
+ computes the gradient of a conditional log probability with respect to
366
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
367
+ condition on y.
368
+
369
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
370
+ """
371
+ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
372
+ new_mean = (
373
+ p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
374
+ )
375
+ return new_mean
376
+
377
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
378
+ """
379
+ Compute what the p_mean_variance output would have been, should the
380
+ model's score function be conditioned by cond_fn.
381
+
382
+ See condition_mean() for details on cond_fn.
383
+
384
+ Unlike condition_mean(), this instead uses the conditioning strategy
385
+ from Song et al (2020).
386
+ """
387
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
388
+
389
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
390
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
391
+ x, self._scale_timesteps(t), **model_kwargs
392
+ )
393
+
394
+ out = p_mean_var.copy()
395
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
396
+ out["mean"], _, _ = self.q_posterior_mean_variance(
397
+ x_start=out["pred_xstart"], x_t=x, t=t
398
+ )
399
+ return out
400
+
401
+ def p_sample(
402
+ self,
403
+ model,
404
+ x,
405
+ t,
406
+ clip_denoised=True,
407
+ denoised_fn=None,
408
+ cond_fn=None,
409
+ model_kwargs=None,
410
+ device = None,
411
+ ):
412
+ """
413
+ Sample x_{t-1} from the model at the given timestep.
414
+
415
+ :param model: the model to sample from.
416
+ :param x: the current tensor at x_{t-1}.
417
+ :param t: the value of t, starting at 0 for the first diffusion step.
418
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
419
+ :param denoised_fn: if not None, a function which applies to the
420
+ x_start prediction before it is used to sample.
421
+ :param cond_fn: if not None, this is a gradient function that acts
422
+ similarly to the model.
423
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
424
+ pass to the model. This can be used for conditioning.
425
+ :return: a dict containing the following keys:
426
+ - 'sample': a random sample from the model.
427
+ - 'pred_xstart': a prediction of x_0.
428
+ """
429
+
430
+ x_disto_start = model_kwargs["SR"]
431
+
432
+ # x_disto_start = model_kwargs["noise"]
433
+
434
+
435
+ # x_t = self.q_sample(x_start, t)
436
+ # x_disto = self.q_sample(x_disto_start, t)
437
+ # model_inp = th.cat([x_t,x_disto_start],1)
438
+ # x_start =
439
+ out = self.p_mean_variance(
440
+ model,
441
+ x,
442
+ t,
443
+ clip_denoised=clip_denoised,
444
+ denoised_fn=denoised_fn,
445
+ x_start=x_disto_start,
446
+ model_kwargs=model_kwargs,
447
+ device = device,
448
+ )
449
+ # out = self.p_mean_variance(
450
+ # model,
451
+ # x,
452
+ # t,
453
+ # clip_denoised=clip_denoised,
454
+ # denoised_fn=denoised_fn,
455
+ # model_kwargs=model_kwargs,
456
+ # )
457
+ noise = th.randn_like(x)
458
+ nonzero_mask = (
459
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
460
+ ) # no noise when t == 0
461
+ if cond_fn is not None:
462
+ out["mean"] = self.condition_mean(
463
+ cond_fn, out, x, t, model_kwargs=model_kwargs
464
+ )
465
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
466
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
467
+
468
+ def p_sample_loop(
469
+ self,
470
+ model,
471
+ shape,
472
+ noise=None,
473
+ clip_denoised=True,
474
+ denoised_fn=None,
475
+ cond_fn=None,
476
+ model_kwargs=None,
477
+ device=None,
478
+ progress=False,
479
+ ):
480
+ """
481
+ Generate samples from the model.
482
+
483
+ :param model: the model module.
484
+ :param shape: the shape of the samples, (N, C, H, W).
485
+ :param noise: if specified, the noise from the encoder to sample.
486
+ Should be of the same shape as `shape`.
487
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
488
+ :param denoised_fn: if not None, a function which applies to the
489
+ x_start prediction before it is used to sample.
490
+ :param cond_fn: if not None, this is a gradient function that acts
491
+ similarly to the model.
492
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
493
+ pass to the model. This can be used for conditioning.
494
+ :param device: if specified, the device to create the samples on.
495
+ If not specified, use a model parameter's device.
496
+ :param progress: if True, show a tqdm progress bar.
497
+ :return: a non-differentiable batch of samples.
498
+ """
499
+ final = None
500
+ for sample in self.p_sample_loop_progressive(
501
+ model,
502
+ shape,
503
+ noise=noise,
504
+ clip_denoised=clip_denoised,
505
+ denoised_fn=denoised_fn,
506
+ cond_fn=cond_fn,
507
+ model_kwargs=model_kwargs,
508
+ device=device,
509
+ progress=progress,
510
+ ):
511
+ final = sample
512
+ return final["sample"]
513
+
514
+ def p_sample_loop_progressive(
515
+ self,
516
+ model,
517
+ shape,
518
+ noise=None,
519
+ clip_denoised=True,
520
+ denoised_fn=None,
521
+ cond_fn=None,
522
+ model_kwargs=None,
523
+ device=None,
524
+ progress=False,
525
+ ):
526
+ """
527
+ Generate samples from the model and yield intermediate samples from
528
+ each timestep of diffusion.
529
+
530
+ Arguments are the same as p_sample_loop().
531
+ Returns a generator over dicts, where each dict is the return value of
532
+ p_sample().
533
+ """
534
+ if device is None:
535
+ device = next(model.parameters()).device
536
+ assert isinstance(shape, (tuple, list))
537
+ if noise is not None:
538
+ img = noise
539
+ else:
540
+ img = th.randn(*shape, device=device)
541
+
542
+ indices = list(range(self.num_timesteps))[::-1]
543
+
544
+ if progress:
545
+ # Lazy import so that we don't depend on tqdm.
546
+ from tqdm.auto import tqdm
547
+
548
+ indices = tqdm(indices)
549
+
550
+ for i in indices:
551
+ t = th.tensor([i] * shape[0], device=device)
552
+ with th.no_grad():
553
+ out = self.p_sample(
554
+ model,
555
+ img,
556
+ t,
557
+ clip_denoised=clip_denoised,
558
+ denoised_fn=denoised_fn,
559
+ cond_fn=cond_fn,
560
+ model_kwargs=model_kwargs,
561
+ device = device,
562
+ )
563
+ yield out
564
+ img = out["sample"]
565
+
566
+
567
+ def ddim_sample(
568
+ self,
569
+ model,
570
+ x,
571
+ t,
572
+ clip_denoised=True,
573
+ denoised_fn=None,
574
+ cond_fn=None,
575
+ model_kwargs=None,
576
+ device = None,
577
+ eta=0.0,
578
+ ):
579
+ """
580
+ Sample x_{t-1} from the model using DDIM.
581
+
582
+ Same usage as p_sample().
583
+ """
584
+ x_disto_start = model_kwargs["SR"]
585
+
586
+ # x_disto_start = model_kwargs["noise"]
587
+
588
+
589
+ # x_t = self.q_sample(x_start, t)
590
+ # x_disto = self.q_sample(x_disto_start, t)
591
+ # model_inp = th.cat([x_t,x_disto_start],1)
592
+ # x_start =
593
+ out = self.p_mean_variance(
594
+ model,
595
+ x,
596
+ t,
597
+ clip_denoised=clip_denoised,
598
+ denoised_fn=denoised_fn,
599
+ x_start=x_disto_start,
600
+ model_kwargs=model_kwargs,
601
+ device = device,
602
+ )
603
+ if cond_fn is not None:
604
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
605
+
606
+ # Usually our model outputs epsilon, but we re-derive it
607
+ # in case we used x_start or x_prev prediction.
608
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
609
+
610
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
611
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
612
+ sigma = (
613
+ eta
614
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
615
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
616
+ )
617
+ # Equation 12.
618
+ noise = th.randn_like(x)
619
+ mean_pred = (
620
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
621
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
622
+ )
623
+ nonzero_mask = (
624
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
625
+ ) # no noise when t == 0
626
+ sample = mean_pred + nonzero_mask * sigma * noise
627
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
628
+
629
+ def ddim_reverse_sample(
630
+ self,
631
+ model,
632
+ x,
633
+ t,
634
+ clip_denoised=True,
635
+ denoised_fn=None,
636
+ model_kwargs=None,
637
+ eta=0.0,
638
+ ):
639
+ """
640
+ Sample x_{t+1} from the model using DDIM reverse ODE.
641
+ """
642
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
643
+ out = self.p_mean_variance(
644
+ model,
645
+ x,
646
+ t,
647
+ clip_denoised=clip_denoised,
648
+ denoised_fn=denoised_fn,
649
+ model_kwargs=model_kwargs,
650
+ )
651
+ # Usually our model outputs epsilon, but we re-derive it
652
+ # in case we used x_start or x_prev prediction.
653
+ eps = (
654
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
655
+ - out["pred_xstart"]
656
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
657
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
658
+
659
+ # Equation 12. reversed
660
+ mean_pred = (
661
+ out["pred_xstart"] * th.sqrt(alpha_bar_next)
662
+ + th.sqrt(1 - alpha_bar_next) * eps
663
+ )
664
+
665
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
666
+
667
+ def ddim_sample_loop(
668
+ self,
669
+ model,
670
+ shape,
671
+ noise=None,
672
+ clip_denoised=True,
673
+ denoised_fn=None,
674
+ cond_fn=None,
675
+ model_kwargs=None,
676
+ device=None,
677
+ progress=False,
678
+ eta=0.0,
679
+ ):
680
+ """
681
+ Generate samples from the model using DDIM.
682
+
683
+ Same usage as p_sample_loop().
684
+ """
685
+ final = None
686
+ for sample in self.ddim_sample_loop_progressive(
687
+ model,
688
+ shape,
689
+ noise=noise,
690
+ clip_denoised=clip_denoised,
691
+ denoised_fn=denoised_fn,
692
+ cond_fn=cond_fn,
693
+ model_kwargs=model_kwargs,
694
+ device=device,
695
+ progress=progress,
696
+ eta=eta,
697
+ ):
698
+ final = sample
699
+ return final["sample"]
700
+
701
+ def ddim_sample_loop_progressive(
702
+ self,
703
+ model,
704
+ shape,
705
+ noise=None,
706
+ clip_denoised=True,
707
+ denoised_fn=None,
708
+ cond_fn=None,
709
+ model_kwargs=None,
710
+ device=None,
711
+ progress=False,
712
+ eta=0.0,
713
+ ):
714
+ """
715
+ Use DDIM to sample from the model and yield intermediate samples from
716
+ each timestep of DDIM.
717
+
718
+ Same usage as p_sample_loop_progressive().
719
+ """
720
+ if device is None:
721
+ device = next(model.parameters()).device
722
+ assert isinstance(shape, (tuple, list))
723
+ if noise is not None:
724
+ img = noise
725
+
726
+ else:
727
+ img = th.randn(*shape, device=device)
728
+
729
+ indices = list(range(self.num_timesteps))[::-1]
730
+
731
+ if progress:
732
+ # Lazy import so that we don't depend on tqdm.
733
+ from tqdm.auto import tqdm
734
+
735
+ indices = tqdm(indices)
736
+
737
+ # print(indices)
738
+
739
+ for i in indices:
740
+ t = th.tensor([i] * shape[0], device=device)
741
+ # print(i)
742
+ if i==0:
743
+ out = self.ddim_sample(
744
+ model,
745
+ img,
746
+ t,
747
+ clip_denoised=clip_denoised,
748
+ denoised_fn=denoised_fn,
749
+ cond_fn=cond_fn,
750
+ model_kwargs=model_kwargs,
751
+ device = device,
752
+ eta=eta,
753
+ )
754
+ yield out
755
+ img = out["sample"]
756
+
757
+
758
+ else:
759
+
760
+ with th.no_grad():
761
+ out = self.ddim_sample(
762
+ model,
763
+ img,
764
+ t,
765
+ clip_denoised=clip_denoised,
766
+ denoised_fn=denoised_fn,
767
+ cond_fn=cond_fn,
768
+ model_kwargs=model_kwargs,
769
+ eta=eta,
770
+ )
771
+ yield out
772
+ img = out["sample"]
773
+
774
+ # out = self.ddim_sample(
775
+ # model,
776
+ # img,
777
+ # t,
778
+ # clip_denoised=clip_denoised,
779
+ # denoised_fn=denoised_fn,
780
+ # cond_fn=cond_fn,
781
+ # model_kwargs=model_kwargs,
782
+ # device = device,
783
+ # eta=eta,
784
+ # )
785
+ # yield out
786
+ # img = out["sample"]
787
+
788
+
789
+
790
+
791
+ def _vb_terms_bpd(
792
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
793
+ ):
794
+ """
795
+ Get a term for the variational lower-bound.
796
+
797
+ The resulting units are bits (rather than nats, as one might expect).
798
+ This allows for comparison to other papers.
799
+
800
+ :return: a dict with the following keys:
801
+ - 'output': a shape [N] tensor of NLLs or KLs.
802
+ - 'pred_xstart': the x_0 predictions.
803
+ """
804
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
805
+ x_start=x_start, x_t=x_t, t=t
806
+ )
807
+ out = self.p_mean_variance(
808
+ model, x_t, t, clip_denoised=clip_denoised,x_start = x_start, model_kwargs=model_kwargs
809
+ )
810
+ kl = normal_kl(
811
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
812
+ )
813
+ kl = mean_flat(kl) / np.log(2.0)
814
+
815
+ decoder_nll = -discretized_gaussian_log_likelihood(
816
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
817
+ )
818
+ assert decoder_nll.shape == x_start.shape
819
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
820
+
821
+ # At the first timestep return the decoder NLL,
822
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
823
+ output = th.where((t == 0), decoder_nll, kl)
824
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
825
+
826
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
827
+ """
828
+ Compute training losses for a single timestep.
829
+
830
+ :param model: the model to evaluate loss on.
831
+ :param x_start: the [N x C x ...] tensor of inputs.
832
+ :param t: a batch of timestep indices.
833
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
834
+ pass to the model. This can be used for conditioning.
835
+ :param noise: if specified, the specific Gaussian noise to try to remove.
836
+ :return: a dict with the key "loss" containing a tensor of shape [N].
837
+ Some mean or variance settings may also have other keys.
838
+ """
839
+ if model_kwargs is None:
840
+ model_kwargs = {}
841
+ if noise is None:
842
+ noise = th.randn_like(x_start)
843
+ x_disto_start = model_kwargs["SR"]
844
+
845
+ # x_disto_start = model_kwargs["noise"]
846
+
847
+ x_t = self.q_sample(x_start, t, noise=noise) ###use this
848
+ # x_t = model_kwargs["SR"]
849
+
850
+ # x_disto = self.q_sample(x_disto_start, t, noise=noise)
851
+ model_inp = th.cat([x_t,x_disto_start],1) ### use this
852
+
853
+ # model_inp = x_t
854
+ # model_inp1 = th.cat([x_disto,x_disto_start],1)
855
+ model_output = model(model_inp, self._scale_timesteps(t), **model_kwargs)
856
+ # print(model_output.type())
857
+
858
+ # model_output1 = model(model_inp1, self._scale_timesteps(t), **model_kwargs)
859
+
860
+ # x_t = self.q_sample(x_start, t, noise=noise)
861
+ def process_xstart(x):
862
+ return x.clamp(-1, 1)
863
+ # model_output11, model_var_values11 = th.split(model_output1, 3, dim=1)
864
+ # model_output11, model_var_values11 = th.split(model_output1, 3, dim=1)
865
+
866
+ # pred_xstart1 = process_xstart(self._predict_xstart_from_eps(x_t=x_disto, t=t, eps=model_output11))
867
+ terms = {}
868
+
869
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
870
+ terms["loss"] = self._vb_terms_bpd(
871
+ model=model,
872
+ x_start=x_start,
873
+ x_t=x_t,
874
+ t=t,
875
+ clip_denoised=False,
876
+ model_kwargs=model_kwargs,
877
+ )["output"]
878
+ if self.loss_type == LossType.RESCALED_KL:
879
+ terms["loss"] *= self.num_timesteps
880
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
881
+ # print(model_kwargs)
882
+ # for _ in model_kwargs:
883
+ # print(_)
884
+ # stop
885
+
886
+ # model_output1 = model(model_inp1, self._scale_timesteps(t), **model_kwargs)
887
+
888
+ if self.model_var_type in [
889
+ ModelVarType.LEARNED,
890
+ ModelVarType.LEARNED_RANGE,
891
+ ]:
892
+ B, C = x_t.shape[:2]
893
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
894
+ model_output, model_var_values = th.split(model_output, C, dim=1)
895
+ # model_output1, model_var_values = th.split(model_output1, C, dim=1)
896
+
897
+ # Learn the variance using the variational bound, but don't let
898
+ # it affect our mean prediction.
899
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
900
+
901
+ terms["vb"] = self._vb_terms_bpd(
902
+ model=lambda *args, r=frozen_out: r,
903
+ x_start=x_start,
904
+ x_t=x_t,
905
+ t=t,
906
+ clip_denoised=False,
907
+ )["output"]
908
+
909
+ if self.loss_type == LossType.RESCALED_MSE:
910
+ # Divide by 1000 for equivalence with initial implementation.
911
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
912
+ terms["vb"] *= self.num_timesteps / 1000.0
913
+
914
+ target = {
915
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
916
+ x_start=x_start, x_t=x_t, t=t
917
+ )[0],
918
+ ModelMeanType.START_X: x_start,
919
+ ModelMeanType.EPSILON: noise,
920
+ }[self.model_mean_type]
921
+ assert model_output.shape == target.shape == x_start.shape
922
+
923
+ terms["mse"] = mean_flat((target - model_output) ** 2) #+ 0.001*mean_flat((model_output- model_output11) ** 2)
924
+ if "vb" in terms:
925
+ terms["loss"] = terms["mse"] + terms["vb"]
926
+
927
+ else:
928
+ terms["loss"] = terms["mse"]
929
+ else:
930
+ raise NotImplementedError(self.loss_type)
931
+
932
+ return terms
933
+
934
+ def _prior_bpd(self, x_start):
935
+ """
936
+ Get the prior KL term for the variational lower-bound, measured in
937
+ bits-per-dim.
938
+
939
+ This term can't be optimized, as it only depends on the encoder.
940
+
941
+ :param x_start: the [N x C x ...] tensor of inputs.
942
+ :return: a batch of [N] KL values (in bits), one per batch element.
943
+ """
944
+ batch_size = x_start.shape[0]
945
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
946
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
947
+ kl_prior = normal_kl(
948
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
949
+ )
950
+ return mean_flat(kl_prior) / np.log(2.0)
951
+
952
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
953
+ """
954
+ Compute the entire variational lower-bound, measured in bits-per-dim,
955
+ as well as other related quantities.
956
+
957
+ :param model: the model to evaluate loss on.
958
+ :param x_start: the [N x C x ...] tensor of inputs.
959
+ :param clip_denoised: if True, clip denoised samples.
960
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
961
+ pass to the model. This can be used for conditioning.
962
+
963
+ :return: a dict containing the following keys:
964
+ - total_bpd: the total variational lower-bound, per batch element.
965
+ - prior_bpd: the prior term in the lower-bound.
966
+ - vb: an [N x T] tensor of terms in the lower-bound.
967
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
968
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
969
+ """
970
+ device = x_start.device
971
+ batch_size = x_start.shape[0]
972
+
973
+ vb = []
974
+ xstart_mse = []
975
+ mse = []
976
+ for t in list(range(self.num_timesteps))[::-1]:
977
+ t_batch = th.tensor([t] * batch_size, device=device)
978
+ noise = th.randn_like(x_start)
979
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
980
+ # Calculate VLB term at the current timestep
981
+ with th.no_grad():
982
+ out = self._vb_terms_bpd(
983
+ model,
984
+ x_start=x_start,
985
+ x_t=x_t,
986
+ t=t_batch,
987
+ clip_denoised=clip_denoised,
988
+ model_kwargs=model_kwargs,
989
+ )
990
+ vb.append(out["output"])
991
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
992
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
993
+ mse.append(mean_flat((eps - noise) ** 2))
994
+
995
+ vb = th.stack(vb, dim=1)
996
+ xstart_mse = th.stack(xstart_mse, dim=1)
997
+ mse = th.stack(mse, dim=1)
998
+
999
+ prior_bpd = self._prior_bpd(x_start)
1000
+ total_bpd = vb.sum(dim=1) + prior_bpd
1001
+ return {
1002
+ "total_bpd": total_bpd,
1003
+ "prior_bpd": prior_bpd,
1004
+ "vb": vb,
1005
+ "xstart_mse": xstart_mse,
1006
+ "mse": mse,
1007
+ }
1008
+
1009
+
1010
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
1011
+ """
1012
+ Extract values from a 1-D numpy array for a batch of indices.
1013
+
1014
+ :param arr: the 1-D numpy array.
1015
+ :param timesteps: a tensor of indices into the array to extract.
1016
+ :param broadcast_shape: a larger shape of K dimensions with the batch
1017
+ dimension equal to the length of timesteps.
1018
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
1019
+ """
1020
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
1021
+ while len(res.shape) < len(broadcast_shape):
1022
+ res = res[..., None]
1023
+ return res.expand(broadcast_shape)
guided_diffusion/image_datasets.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch as th
4
+ from PIL import Image
5
+ import blobfile as bf
6
+ from mpi4py import MPI
7
+ import numpy as np
8
+ from torch.utils.data import DataLoader, Dataset
9
+ import cv2
10
+ import imgaug.augmenters as iaa
11
+ from basicsr.data import degradations as degradations
12
+ import cv2
13
+ import math
14
+ import random
15
+ seed = np.random.RandomState(112311)
16
+ def load_data(
17
+ *,
18
+ data_dir,
19
+ gt_dir,
20
+ batch_size,
21
+ image_size,
22
+ class_cond=False,
23
+ deterministic=False,
24
+ random_crop=False,
25
+ random_flip=True,
26
+ ):
27
+ """
28
+ For a dataset, create a generator over (images, kwargs) pairs.
29
+
30
+ Each images is an NCHW float tensor, and the kwargs dict contains zero or
31
+ more keys, each of which map to a batched Tensor of their own.
32
+ The kwargs dict can be used for class labels, in which case the key is "y"
33
+ and the values are integer tensors of class labels.
34
+
35
+ :param data_dir: a dataset directory.
36
+ :param batch_size: the batch size of each returned pair.
37
+ :param image_size: the size to which images are resized.
38
+ :param class_cond: if True, include a "y" key in returned dicts for class
39
+ label. If classes are not available and this is true, an
40
+ exception will be raised.
41
+ :param deterministic: if True, yield results in a deterministic order.
42
+ :param random_crop: if True, randomly crop the images for augmentation.
43
+ :param random_flip: if True, randomly flip the images for augmentation.
44
+ """
45
+ if not data_dir:
46
+ raise ValueError("unspecified data directory")
47
+ all_files = _list_image_files_recursively(data_dir)
48
+ classes = None
49
+ if class_cond:
50
+ # Assume classes are the first part of the filename,
51
+ # before an underscore.
52
+ class_names = [bf.basename(path).split("_")[0] for path in all_files]
53
+ sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))}
54
+ classes = [sorted_classes[x] for x in class_names]
55
+ dataset = ImageDataset(
56
+ image_size,
57
+ all_files,
58
+ gt_dir,
59
+ classes=classes,
60
+ shard=MPI.COMM_WORLD.Get_rank(),
61
+ num_shards=MPI.COMM_WORLD.Get_size(),
62
+ random_crop=random_crop,
63
+ random_flip=random_flip,
64
+ )
65
+ if deterministic:
66
+ loader = DataLoader(
67
+ dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
68
+ )
69
+ else:
70
+ loader = DataLoader(
71
+ dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
72
+ )
73
+ while True:
74
+ yield from loader
75
+
76
+
77
+ def _list_image_files_recursively(data_dir):
78
+ results = []
79
+ for entry in sorted(bf.listdir(data_dir)):
80
+ full_path = bf.join(data_dir, entry)
81
+ ext = entry.split(".")[-1]
82
+ if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
83
+ results.append(full_path)
84
+ elif bf.isdir(full_path):
85
+ results.extend(_list_image_files_recursively(full_path))
86
+ return results
87
+
88
+
89
+ class RandomCrop(object):
90
+
91
+ def __init__(self, crop_size=[256,256]):
92
+ """Set the height and weight before and after cropping"""
93
+ self.crop_size_h = crop_size[0]
94
+ self.crop_size_w = crop_size[1]
95
+
96
+ def __call__(self, inputs, target):
97
+ input_size_h, input_size_w, _ = inputs.shape
98
+ try:
99
+ x_start = random.randint(0, input_size_w - self.crop_size_w)
100
+ y_start = random.randint(0, input_size_h - self.crop_size_h)
101
+ inputs = inputs[y_start: y_start + self.crop_size_h, x_start: x_start + self.crop_size_w]
102
+ target = target[y_start: y_start + self.crop_size_h, x_start: x_start + self.crop_size_w]
103
+ except:
104
+ inputs=cv2.resize(inputs,(256,256))
105
+ target=cv2.resize(target,(256,256))
106
+
107
+ return inputs,target
108
+
109
+ class ImageDataset(Dataset):
110
+ def __init__(
111
+ self,
112
+ resolution,
113
+ image_paths,
114
+ gt_paths,
115
+ classes=None,
116
+ shard=0,
117
+ num_shards=1,
118
+ random_crop=False,
119
+ random_flip=True,
120
+ ):
121
+ super().__init__()
122
+ self.resolution = resolution
123
+ self.local_images = image_paths[shard:][::num_shards]
124
+ self.local_classes = None if classes is None else classes[shard:][::num_shards]
125
+ self.random_crop = True #random_crop
126
+ self.random_flip = random_flip
127
+ self.gt_paths=gt_paths
128
+ # train_list=train_list[:10000]
129
+
130
+ self.deformation = iaa.ElasticTransformation(alpha=[0, 50.], sigma=[4., 5.])
131
+
132
+ def __len__(self):
133
+ return len(self.local_images)
134
+
135
+ def __getitem__(self, idx):
136
+ path = self.local_images[idx]
137
+
138
+
139
+ pil_image = cv2.imread(path) ## Clean image RGB
140
+
141
+ pil_image = cv2.cvtColor(pil_image, cv2.COLOR_BGR2GRAY)
142
+ pil_image = np.repeat(pil_image[:,:,np.newaxis],3, axis=2)
143
+
144
+
145
+
146
+ im1 = ((np.float32(pil_image)+1.0)/256.0)**2
147
+ gamma_noise = seed.gamma(size=im1.shape, shape=1.0, scale=1.0).astype(im1.dtype)
148
+ syn_sar = np.sqrt(im1 * gamma_noise)
149
+ pil_image1 = syn_sar * 256-1 ## Noisy image
150
+
151
+
152
+
153
+
154
+ arr1=np.array(pil_image)
155
+ arr2=np.array(pil_image1)
156
+
157
+
158
+
159
+ arr1 = cv2.resize(arr1, (256,256), interpolation=cv2.INTER_LINEAR)
160
+ arr2= cv2.resize(arr2, (256,256), interpolation=cv2.INTER_LINEAR)
161
+
162
+
163
+
164
+
165
+ arr1 = arr1.astype(np.float32) / 127.5 - 1
166
+ arr2 = arr2.astype(np.float32) / 127.5 - 1
167
+
168
+
169
+
170
+ out_dict = {}
171
+
172
+
173
+
174
+ arr2 = np.transpose(arr2, [2, 0, 1])
175
+ arr1 = np.transpose(arr1, [2, 0, 1])
176
+
177
+ out_dict["SR"]=arr2
178
+ out_dict["HR"]=arr1
179
+
180
+
181
+
182
+ return arr1, out_dict
183
+
184
+
185
+
186
+
187
+ def center_crop_arr(pil_image, pil_image1, image_size):
188
+ # We are not on a new enough PIL to support the `reducing_gap`
189
+ # argument, which uses BOX downsampling at powers of two first.
190
+ # Thus, we do it by hand to improve downsample quality.
191
+ while min(*pil_image.size) >= 2 * image_size:
192
+ pil_image = pil_image.resize(
193
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
194
+ )
195
+
196
+ scale = image_size / min(*pil_image.size)
197
+ pil_image = pil_image.resize(
198
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
199
+ )
200
+ while min(*pil_image1.size) >= 2 * image_size:
201
+ pil_image1 = pil_image1.resize(
202
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
203
+ )
204
+
205
+ scale = image_size / min(*pil_image1.size)
206
+ pil_image1 = pil_image1.resize(
207
+ tuple(round(x * scale) for x in pil_image1.size), resample=Image.BICUBIC
208
+ )
209
+
210
+ arr = np.array(pil_image)
211
+ arr1 = np.array(pil_image1)
212
+
213
+ crop_y = (arr.shape[0] - image_size) // 2
214
+ crop_x = (arr.shape[1] - image_size) // 2
215
+ return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size],arr1[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
216
+
217
+
218
+ def random_crop_arr(pil_image, pil_image1, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
219
+ min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
220
+ max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
221
+ smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
222
+
223
+ # We are not on a new enough PIL to support the `reducing_gap`
224
+ # argument, which uses BOX downsampling at powers of two first.
225
+ # Thus, we do it by hand to improve downsample quality.
226
+ while min(*pil_image.size) >= 2 * smaller_dim_size:
227
+ pil_image = pil_image.resize(
228
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
229
+ )
230
+
231
+ scale = smaller_dim_size / min(*pil_image.size)
232
+ pil_image = pil_image.resize(
233
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
234
+ )
235
+ while min(*pil_image1.size) >= 2 * smaller_dim_size:
236
+ pil_image = pil_image.resize(
237
+ tuple(x // 2 for x in pil_image1.size), resample=Image.BOX
238
+ )
239
+
240
+ scale = smaller_dim_size / min(*pil_image1.size)
241
+ pil_image1 = pil_image1.resize(
242
+ tuple(round(x * scale) for x in pil_image1.size), resample=Image.BICUBIC
243
+ )
244
+ arr = np.array(pil_image)
245
+ arr1 = np.array(pil_image1)
246
+
247
+ crop_y = random.randrange(arr.shape[0] - image_size + 1)
248
+ crop_x = random.randrange(arr.shape[1] - image_size + 1)
249
+ return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size],arr1[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
guided_diffusion/logger.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
3
+ https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import shutil
9
+ import os.path as osp
10
+ import json
11
+ import time
12
+ import datetime
13
+ import tempfile
14
+ import warnings
15
+ from collections import defaultdict
16
+ from contextlib import contextmanager
17
+
18
+ DEBUG = 10
19
+ INFO = 20
20
+ WARN = 30
21
+ ERROR = 40
22
+
23
+ DISABLED = 50
24
+
25
+
26
+ class KVWriter(object):
27
+ def writekvs(self, kvs):
28
+ raise NotImplementedError
29
+
30
+
31
+ class SeqWriter(object):
32
+ def writeseq(self, seq):
33
+ raise NotImplementedError
34
+
35
+
36
+ class HumanOutputFormat(KVWriter, SeqWriter):
37
+ def __init__(self, filename_or_file):
38
+ if isinstance(filename_or_file, str):
39
+ self.file = open(filename_or_file, "wt")
40
+ self.own_file = True
41
+ else:
42
+ assert hasattr(filename_or_file, "read"), (
43
+ "expected file or str, got %s" % filename_or_file
44
+ )
45
+ self.file = filename_or_file
46
+ self.own_file = False
47
+
48
+ def writekvs(self, kvs):
49
+ # Create strings for printing
50
+ key2str = {}
51
+ for (key, val) in sorted(kvs.items()):
52
+ if hasattr(val, "__float__"):
53
+ valstr = "%-8.3g" % val
54
+ else:
55
+ valstr = str(val)
56
+ key2str[self._truncate(key)] = self._truncate(valstr)
57
+
58
+ # Find max widths
59
+ if len(key2str) == 0:
60
+ print("WARNING: tried to write empty key-value dict")
61
+ return
62
+ else:
63
+ keywidth = max(map(len, key2str.keys()))
64
+ valwidth = max(map(len, key2str.values()))
65
+
66
+ # Write out the data
67
+ dashes = "-" * (keywidth + valwidth + 7)
68
+ lines = [dashes]
69
+ for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
70
+ lines.append(
71
+ "| %s%s | %s%s |"
72
+ % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
73
+ )
74
+ lines.append(dashes)
75
+ self.file.write("\n".join(lines) + "\n")
76
+
77
+ # Flush the output to the file
78
+ self.file.flush()
79
+
80
+ def _truncate(self, s):
81
+ maxlen = 30
82
+ return s[: maxlen - 3] + "..." if len(s) > maxlen else s
83
+
84
+ def writeseq(self, seq):
85
+ seq = list(seq)
86
+ for (i, elem) in enumerate(seq):
87
+ self.file.write(elem)
88
+ if i < len(seq) - 1: # add space unless this is the last one
89
+ self.file.write(" ")
90
+ self.file.write("\n")
91
+ self.file.flush()
92
+
93
+ def close(self):
94
+ if self.own_file:
95
+ self.file.close()
96
+
97
+
98
+ class JSONOutputFormat(KVWriter):
99
+ def __init__(self, filename):
100
+ self.file = open(filename, "wt")
101
+
102
+ def writekvs(self, kvs):
103
+ for k, v in sorted(kvs.items()):
104
+ if hasattr(v, "dtype"):
105
+ kvs[k] = float(v)
106
+ self.file.write(json.dumps(kvs) + "\n")
107
+ self.file.flush()
108
+
109
+ def close(self):
110
+ self.file.close()
111
+
112
+
113
+ class CSVOutputFormat(KVWriter):
114
+ def __init__(self, filename):
115
+ self.file = open(filename, "w+t")
116
+ self.keys = []
117
+ self.sep = ","
118
+
119
+ def writekvs(self, kvs):
120
+ # Add our current row to the history
121
+ extra_keys = list(kvs.keys() - self.keys)
122
+ extra_keys.sort()
123
+ if extra_keys:
124
+ self.keys.extend(extra_keys)
125
+ self.file.seek(0)
126
+ lines = self.file.readlines()
127
+ self.file.seek(0)
128
+ for (i, k) in enumerate(self.keys):
129
+ if i > 0:
130
+ self.file.write(",")
131
+ self.file.write(k)
132
+ self.file.write("\n")
133
+ for line in lines[1:]:
134
+ self.file.write(line[:-1])
135
+ self.file.write(self.sep * len(extra_keys))
136
+ self.file.write("\n")
137
+ for (i, k) in enumerate(self.keys):
138
+ if i > 0:
139
+ self.file.write(",")
140
+ v = kvs.get(k)
141
+ if v is not None:
142
+ self.file.write(str(v))
143
+ self.file.write("\n")
144
+ self.file.flush()
145
+
146
+ def close(self):
147
+ self.file.close()
148
+
149
+
150
+ class TensorBoardOutputFormat(KVWriter):
151
+ """
152
+ Dumps key/value pairs into TensorBoard's numeric format.
153
+ """
154
+
155
+ def __init__(self, dir):
156
+ os.makedirs(dir, exist_ok=True)
157
+ self.dir = dir
158
+ self.step = 1
159
+ prefix = "events"
160
+ path = osp.join(osp.abspath(dir), prefix)
161
+ import tensorflow as tf
162
+ from tensorflow.python import pywrap_tensorflow
163
+ from tensorflow.core.util import event_pb2
164
+ from tensorflow.python.util import compat
165
+
166
+ self.tf = tf
167
+ self.event_pb2 = event_pb2
168
+ self.pywrap_tensorflow = pywrap_tensorflow
169
+ self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
170
+
171
+ def writekvs(self, kvs):
172
+ def summary_val(k, v):
173
+ kwargs = {"tag": k, "simple_value": float(v)}
174
+ return self.tf.Summary.Value(**kwargs)
175
+
176
+ summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
177
+ event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
178
+ event.step = (
179
+ self.step
180
+ ) # is there any reason why you'd want to specify the step?
181
+ self.writer.WriteEvent(event)
182
+ self.writer.Flush()
183
+ self.step += 1
184
+
185
+ def close(self):
186
+ if self.writer:
187
+ self.writer.Close()
188
+ self.writer = None
189
+
190
+
191
+ def make_output_format(format, ev_dir, log_suffix=""):
192
+ os.makedirs(ev_dir, exist_ok=True)
193
+ if format == "stdout":
194
+ return HumanOutputFormat(sys.stdout)
195
+ elif format == "log":
196
+ return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
197
+ elif format == "json":
198
+ return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
199
+ elif format == "csv":
200
+ return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
201
+ elif format == "tensorboard":
202
+ return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
203
+ else:
204
+ raise ValueError("Unknown format specified: %s" % (format,))
205
+
206
+
207
+ # ================================================================
208
+ # API
209
+ # ================================================================
210
+
211
+
212
+ def logkv(key, val):
213
+ """
214
+ Log a value of some diagnostic
215
+ Call this once for each diagnostic quantity, each iteration
216
+ If called many times, last value will be used.
217
+ """
218
+ get_current().logkv(key, val)
219
+
220
+
221
+ def logkv_mean(key, val):
222
+ """
223
+ The same as logkv(), but if called many times, values averaged.
224
+ """
225
+ get_current().logkv_mean(key, val)
226
+
227
+
228
+ def logkvs(d):
229
+ """
230
+ Log a dictionary of key-value pairs
231
+ """
232
+ for (k, v) in d.items():
233
+ logkv(k, v)
234
+
235
+
236
+ def dumpkvs():
237
+ """
238
+ Write all of the diagnostics from the current iteration
239
+ """
240
+ return get_current().dumpkvs()
241
+
242
+
243
+ def getkvs():
244
+ return get_current().name2val
245
+
246
+
247
+ def log(*args, level=INFO):
248
+ """
249
+ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
250
+ """
251
+ get_current().log(*args, level=level)
252
+
253
+
254
+ def debug(*args):
255
+ log(*args, level=DEBUG)
256
+
257
+
258
+ def info(*args):
259
+ log(*args, level=INFO)
260
+
261
+
262
+ def warn(*args):
263
+ log(*args, level=WARN)
264
+
265
+
266
+ def error(*args):
267
+ log(*args, level=ERROR)
268
+
269
+
270
+ def set_level(level):
271
+ """
272
+ Set logging threshold on current logger.
273
+ """
274
+ get_current().set_level(level)
275
+
276
+
277
+ def set_comm(comm):
278
+ get_current().set_comm(comm)
279
+
280
+
281
+ def get_dir():
282
+ """
283
+ Get directory that log files are being written to.
284
+ will be None if there is no output directory (i.e., if you didn't call start)
285
+ """
286
+ return get_current().get_dir()
287
+
288
+
289
+ record_tabular = logkv
290
+ dump_tabular = dumpkvs
291
+
292
+
293
+ @contextmanager
294
+ def profile_kv(scopename):
295
+ logkey = "wait_" + scopename
296
+ tstart = time.time()
297
+ try:
298
+ yield
299
+ finally:
300
+ get_current().name2val[logkey] += time.time() - tstart
301
+
302
+
303
+ def profile(n):
304
+ """
305
+ Usage:
306
+ @profile("my_func")
307
+ def my_func(): code
308
+ """
309
+
310
+ def decorator_with_name(func):
311
+ def func_wrapper(*args, **kwargs):
312
+ with profile_kv(n):
313
+ return func(*args, **kwargs)
314
+
315
+ return func_wrapper
316
+
317
+ return decorator_with_name
318
+
319
+
320
+ # ================================================================
321
+ # Backend
322
+ # ================================================================
323
+
324
+
325
+ def get_current():
326
+ if Logger.CURRENT is None:
327
+ _configure_default_logger()
328
+
329
+ return Logger.CURRENT
330
+
331
+
332
+ class Logger(object):
333
+ DEFAULT = None # A logger with no output files. (See right below class definition)
334
+ # So that you can still log to the terminal without setting up any output files
335
+ CURRENT = None # Current logger being used by the free functions above
336
+
337
+ def __init__(self, dir, output_formats, comm=None):
338
+ self.name2val = defaultdict(float) # values this iteration
339
+ self.name2cnt = defaultdict(int)
340
+ self.level = INFO
341
+ self.dir = dir
342
+ self.output_formats = output_formats
343
+ self.comm = comm
344
+
345
+ # Logging API, forwarded
346
+ # ----------------------------------------
347
+ def logkv(self, key, val):
348
+ self.name2val[key] = val
349
+
350
+ def logkv_mean(self, key, val):
351
+ oldval, cnt = self.name2val[key], self.name2cnt[key]
352
+ self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
353
+ self.name2cnt[key] = cnt + 1
354
+
355
+ def dumpkvs(self):
356
+ if self.comm is None:
357
+ d = self.name2val
358
+ else:
359
+ d = mpi_weighted_mean(
360
+ self.comm,
361
+ {
362
+ name: (val, self.name2cnt.get(name, 1))
363
+ for (name, val) in self.name2val.items()
364
+ },
365
+ )
366
+ if self.comm.rank != 0:
367
+ d["dummy"] = 1 # so we don't get a warning about empty dict
368
+ out = d.copy() # Return the dict for unit testing purposes
369
+ for fmt in self.output_formats:
370
+ if isinstance(fmt, KVWriter):
371
+ fmt.writekvs(d)
372
+ self.name2val.clear()
373
+ self.name2cnt.clear()
374
+ return out
375
+
376
+ def log(self, *args, level=INFO):
377
+ if self.level <= level:
378
+ self._do_log(args)
379
+
380
+ # Configuration
381
+ # ----------------------------------------
382
+ def set_level(self, level):
383
+ self.level = level
384
+
385
+ def set_comm(self, comm):
386
+ self.comm = comm
387
+
388
+ def get_dir(self):
389
+ return self.dir
390
+
391
+ def close(self):
392
+ for fmt in self.output_formats:
393
+ fmt.close()
394
+
395
+ # Misc
396
+ # ----------------------------------------
397
+ def _do_log(self, args):
398
+ for fmt in self.output_formats:
399
+ if isinstance(fmt, SeqWriter):
400
+ fmt.writeseq(map(str, args))
401
+
402
+
403
+ def get_rank_without_mpi_import():
404
+ # check environment variables here instead of importing mpi4py
405
+ # to avoid calling MPI_Init() when this module is imported
406
+ for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
407
+ if varname in os.environ:
408
+ return int(os.environ[varname])
409
+ return 0
410
+
411
+
412
+ def mpi_weighted_mean(comm, local_name2valcount):
413
+ """
414
+ Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
415
+ Perform a weighted average over dicts that are each on a different node
416
+ Input: local_name2valcount: dict mapping key -> (value, count)
417
+ Returns: key -> mean
418
+ """
419
+ all_name2valcount = comm.gather(local_name2valcount)
420
+ if comm.rank == 0:
421
+ name2sum = defaultdict(float)
422
+ name2count = defaultdict(float)
423
+ for n2vc in all_name2valcount:
424
+ for (name, (val, count)) in n2vc.items():
425
+ try:
426
+ val = float(val)
427
+ except ValueError:
428
+ if comm.rank == 0:
429
+ warnings.warn(
430
+ "WARNING: tried to compute mean on non-float {}={}".format(
431
+ name, val
432
+ )
433
+ )
434
+ else:
435
+ name2sum[name] += val * count
436
+ name2count[name] += count
437
+ return {name: name2sum[name] / name2count[name] for name in name2sum}
438
+ else:
439
+ return {}
440
+
441
+
442
+ def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
443
+ """
444
+ If comm is provided, average all numerical stats across that comm
445
+ """
446
+ if dir is None:
447
+ dir = os.getenv("OPENAI_LOGDIR")
448
+ if dir is None:
449
+ dir = osp.join(
450
+ tempfile.gettempdir(),
451
+ datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
452
+ )
453
+ assert isinstance(dir, str)
454
+ dir = os.path.expanduser(dir)
455
+ os.makedirs(os.path.expanduser(dir), exist_ok=True)
456
+
457
+ rank = get_rank_without_mpi_import()
458
+ if rank > 0:
459
+ log_suffix = log_suffix + "-rank%03i" % rank
460
+
461
+ if format_strs is None:
462
+ if rank == 0:
463
+ format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
464
+ else:
465
+ format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
466
+ format_strs = filter(None, format_strs)
467
+ output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
468
+
469
+ Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
470
+ if output_formats:
471
+ log("Logging to %s" % dir)
472
+
473
+
474
+ def _configure_default_logger():
475
+ configure()
476
+ Logger.DEFAULT = Logger.CURRENT
477
+
478
+
479
+ def reset():
480
+ if Logger.CURRENT is not Logger.DEFAULT:
481
+ Logger.CURRENT.close()
482
+ Logger.CURRENT = Logger.DEFAULT
483
+ log("Reset logger")
484
+
485
+
486
+ @contextmanager
487
+ def scoped_configure(dir=None, format_strs=None, comm=None):
488
+ prevlogger = Logger.CURRENT
489
+ configure(dir=dir, format_strs=format_strs, comm=comm)
490
+ try:
491
+ yield
492
+ finally:
493
+ Logger.CURRENT.close()
494
+ Logger.CURRENT = prevlogger
495
+
guided_diffusion/losses.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers for various likelihood-based losses. These are ported from the original
3
+ Ho et al. diffusion models codebase:
4
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
5
+ """
6
+
7
+ import numpy as np
8
+
9
+ import torch as th
10
+
11
+
12
+ def normal_kl(mean1, logvar1, mean2, logvar2):
13
+ """
14
+ Compute the KL divergence between two gaussians.
15
+
16
+ Shapes are automatically broadcasted, so batches can be compared to
17
+ scalars, among other use cases.
18
+ """
19
+ tensor = None
20
+ for obj in (mean1, logvar1, mean2, logvar2):
21
+ if isinstance(obj, th.Tensor):
22
+ tensor = obj
23
+ break
24
+ assert tensor is not None, "at least one argument must be a Tensor"
25
+
26
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
27
+ # Tensors, but it does not work for th.exp().
28
+ logvar1, logvar2 = [
29
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
30
+ for x in (logvar1, logvar2)
31
+ ]
32
+
33
+ return 0.5 * (
34
+ -1.0
35
+ + logvar2
36
+ - logvar1
37
+ + th.exp(logvar1 - logvar2)
38
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
39
+ )
40
+
41
+
42
+ def approx_standard_normal_cdf(x):
43
+ """
44
+ A fast approximation of the cumulative distribution function of the
45
+ standard normal.
46
+ """
47
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
48
+
49
+
50
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
51
+ """
52
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
53
+ given image.
54
+
55
+ :param x: the target images. It is assumed that this was uint8 values,
56
+ rescaled to the range [-1, 1].
57
+ :param means: the Gaussian mean Tensor.
58
+ :param log_scales: the Gaussian log stddev Tensor.
59
+ :return: a tensor like x of log probabilities (in nats).
60
+ """
61
+ assert x.shape == means.shape == log_scales.shape
62
+ centered_x = x - means
63
+ inv_stdv = th.exp(-log_scales)
64
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
65
+ cdf_plus = approx_standard_normal_cdf(plus_in)
66
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
67
+ cdf_min = approx_standard_normal_cdf(min_in)
68
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
69
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
70
+ cdf_delta = cdf_plus - cdf_min
71
+ log_probs = th.where(
72
+ x < -0.999,
73
+ log_cdf_plus,
74
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
75
+ )
76
+ assert log_probs.shape == x.shape
77
+ return log_probs
guided_diffusion/nn.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Various utilities for neural networks.
3
+ """
4
+
5
+ import math
6
+
7
+ import torch as th
8
+ import torch.nn as nn
9
+
10
+
11
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12
+ class SiLU(nn.Module):
13
+ def forward(self, x):
14
+ return x * th.sigmoid(x)
15
+
16
+
17
+ class GroupNorm32(nn.GroupNorm):
18
+ def forward(self, x):
19
+ return super().forward(x.float()).type(x.dtype)
20
+
21
+
22
+ def conv_nd(dims, *args, **kwargs):
23
+ """
24
+ Create a 1D, 2D, or 3D convolution module.
25
+ """
26
+ if dims == 1:
27
+ return nn.Conv1d(*args, **kwargs)
28
+ elif dims == 2:
29
+ return nn.Conv2d(*args, **kwargs)
30
+ elif dims == 3:
31
+ return nn.Conv3d(*args, **kwargs)
32
+ raise ValueError(f"unsupported dimensions: {dims}")
33
+
34
+
35
+ def linear(*args, **kwargs):
36
+ """
37
+ Create a linear module.
38
+ """
39
+ return nn.Linear(*args, **kwargs)
40
+
41
+
42
+ def avg_pool_nd(dims, *args, **kwargs):
43
+ """
44
+ Create a 1D, 2D, or 3D average pooling module.
45
+ """
46
+ if dims == 1:
47
+ return nn.AvgPool1d(*args, **kwargs)
48
+ elif dims == 2:
49
+ return nn.AvgPool2d(*args, **kwargs)
50
+ elif dims == 3:
51
+ return nn.AvgPool3d(*args, **kwargs)
52
+ raise ValueError(f"unsupported dimensions: {dims}")
53
+
54
+
55
+ def update_ema(target_params, source_params, rate=0.99):
56
+ """
57
+ Update target parameters to be closer to those of source parameters using
58
+ an exponential moving average.
59
+
60
+ :param target_params: the target parameter sequence.
61
+ :param source_params: the source parameter sequence.
62
+ :param rate: the EMA rate (closer to 1 means slower).
63
+ """
64
+ for targ, src in zip(target_params, source_params):
65
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
66
+
67
+
68
+ def zero_module(module):
69
+ """
70
+ Zero out the parameters of a module and return it.
71
+ """
72
+ for p in module.parameters():
73
+ p.detach().zero_()
74
+ return module
75
+
76
+
77
+ def scale_module(module, scale):
78
+ """
79
+ Scale the parameters of a module and return it.
80
+ """
81
+ for p in module.parameters():
82
+ p.detach().mul_(scale)
83
+ return module
84
+
85
+
86
+ def mean_flat(tensor):
87
+ """
88
+ Take the mean over all non-batch dimensions.
89
+ """
90
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
91
+
92
+
93
+ def normalization(channels):
94
+ """
95
+ Make a standard normalization layer.
96
+
97
+ :param channels: number of input channels.
98
+ :return: an nn.Module for normalization.
99
+ """
100
+ return GroupNorm32(32, channels)
101
+
102
+
103
+ def timestep_embedding(timesteps, dim, max_period=10000):
104
+ """
105
+ Create sinusoidal timestep embeddings.
106
+
107
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
108
+ These may be fractional.
109
+ :param dim: the dimension of the output.
110
+ :param max_period: controls the minimum frequency of the embeddings.
111
+ :return: an [N x dim] Tensor of positional embeddings.
112
+ """
113
+ half = dim // 2
114
+ freqs = th.exp(
115
+ -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
116
+ ).to(device=timesteps.device)
117
+ args = timesteps[:, None].float() * freqs[None]
118
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
119
+ if dim % 2:
120
+ embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
121
+ return embedding
122
+
123
+
124
+ def checkpoint(func, inputs, params, flag):
125
+ """
126
+ Evaluate a function without caching intermediate activations, allowing for
127
+ reduced memory at the expense of extra compute in the backward pass.
128
+
129
+ :param func: the function to evaluate.
130
+ :param inputs: the argument sequence to pass to `func`.
131
+ :param params: a sequence of parameters `func` depends on but does not
132
+ explicitly take as arguments.
133
+ :param flag: if False, disable gradient checkpointing.
134
+ """
135
+ if flag:
136
+ args = tuple(inputs) + tuple(params)
137
+ return CheckpointFunction.apply(func, len(inputs), *args)
138
+ else:
139
+ return func(*inputs)
140
+
141
+
142
+ class CheckpointFunction(th.autograd.Function):
143
+ @staticmethod
144
+ def forward(ctx, run_function, length, *args):
145
+ ctx.run_function = run_function
146
+ ctx.input_tensors = list(args[:length])
147
+ ctx.input_params = list(args[length:])
148
+ with th.no_grad():
149
+ output_tensors = ctx.run_function(*ctx.input_tensors)
150
+ return output_tensors
151
+
152
+ @staticmethod
153
+ def backward(ctx, *output_grads):
154
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
155
+ with th.enable_grad():
156
+ # Fixes a bug where the first op in run_function modifies the
157
+ # Tensor storage in place, which is not allowed for detach()'d
158
+ # Tensors.
159
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
160
+ output_tensors = ctx.run_function(*shallow_copies)
161
+ input_grads = th.autograd.grad(
162
+ output_tensors,
163
+ ctx.input_tensors + ctx.input_params,
164
+ output_grads,
165
+ allow_unused=True,
166
+ )
167
+ del ctx.input_tensors
168
+ del ctx.input_params
169
+ del output_tensors
170
+ return (None, None) + input_grads
guided_diffusion/resample.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import numpy as np
4
+ import torch as th
5
+ import torch.distributed as dist
6
+
7
+
8
+ def create_named_schedule_sampler(name, diffusion):
9
+ """
10
+ Create a ScheduleSampler from a library of pre-defined samplers.
11
+
12
+ :param name: the name of the sampler.
13
+ :param diffusion: the diffusion object to sample for.
14
+ """
15
+ if name == "uniform":
16
+ return UniformSampler(diffusion)
17
+ elif name == "loss-second-moment":
18
+ return LossSecondMomentResampler(diffusion)
19
+ else:
20
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
21
+
22
+
23
+ class ScheduleSampler(ABC):
24
+ """
25
+ A distribution over timesteps in the diffusion process, intended to reduce
26
+ variance of the objective.
27
+
28
+ By default, samplers perform unbiased importance sampling, in which the
29
+ objective's mean is unchanged.
30
+ However, subclasses may override sample() to change how the resampled
31
+ terms are reweighted, allowing for actual changes in the objective.
32
+ """
33
+
34
+ @abstractmethod
35
+ def weights(self):
36
+ """
37
+ Get a numpy array of weights, one per diffusion step.
38
+
39
+ The weights needn't be normalized, but must be positive.
40
+ """
41
+
42
+ def sample(self, batch_size, device):
43
+ """
44
+ Importance-sample timesteps for a batch.
45
+
46
+ :param batch_size: the number of timesteps.
47
+ :param device: the torch device to save to.
48
+ :return: a tuple (timesteps, weights):
49
+ - timesteps: a tensor of timestep indices.
50
+ - weights: a tensor of weights to scale the resulting losses.
51
+ """
52
+ w = self.weights()
53
+ p = w / np.sum(w)
54
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
55
+ indices = th.from_numpy(indices_np).long().to(device)
56
+ weights_np = 1 / (len(p) * p[indices_np])
57
+ weights = th.from_numpy(weights_np).float().to(device)
58
+ return indices, weights
59
+
60
+
61
+ class UniformSampler(ScheduleSampler):
62
+ def __init__(self, diffusion):
63
+ self.diffusion = diffusion
64
+ self._weights = np.ones([diffusion.num_timesteps])
65
+
66
+ def weights(self):
67
+ return self._weights
68
+
69
+
70
+ class LossAwareSampler(ScheduleSampler):
71
+ def update_with_local_losses(self, local_ts, local_losses):
72
+ """
73
+ Update the reweighting using losses from a model.
74
+
75
+ Call this method from each rank with a batch of timesteps and the
76
+ corresponding losses for each of those timesteps.
77
+ This method will perform synchronization to make sure all of the ranks
78
+ maintain the exact same reweighting.
79
+
80
+ :param local_ts: an integer Tensor of timesteps.
81
+ :param local_losses: a 1D Tensor of losses.
82
+ """
83
+ batch_sizes = [
84
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
85
+ for _ in range(dist.get_world_size())
86
+ ]
87
+ dist.all_gather(
88
+ batch_sizes,
89
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
90
+ )
91
+
92
+ # Pad all_gather batches to be the maximum batch size.
93
+ batch_sizes = [x.item() for x in batch_sizes]
94
+ max_bs = max(batch_sizes)
95
+
96
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
97
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
98
+ dist.all_gather(timestep_batches, local_ts)
99
+ dist.all_gather(loss_batches, local_losses)
100
+ timesteps = [
101
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
102
+ ]
103
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
104
+ self.update_with_all_losses(timesteps, losses)
105
+
106
+ @abstractmethod
107
+ def update_with_all_losses(self, ts, losses):
108
+ """
109
+ Update the reweighting using losses from a model.
110
+
111
+ Sub-classes should override this method to update the reweighting
112
+ using losses from the model.
113
+
114
+ This method directly updates the reweighting without synchronizing
115
+ between workers. It is called by update_with_local_losses from all
116
+ ranks with identical arguments. Thus, it should have deterministic
117
+ behavior to maintain state across workers.
118
+
119
+ :param ts: a list of int timesteps.
120
+ :param losses: a list of float losses, one per timestep.
121
+ """
122
+
123
+
124
+ class LossSecondMomentResampler(LossAwareSampler):
125
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
126
+ self.diffusion = diffusion
127
+ self.history_per_term = history_per_term
128
+ self.uniform_prob = uniform_prob
129
+ self._loss_history = np.zeros(
130
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
131
+ )
132
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
133
+
134
+ def weights(self):
135
+ if not self._warmed_up():
136
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
137
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
138
+ weights /= np.sum(weights)
139
+ weights *= 1 - self.uniform_prob
140
+ weights += self.uniform_prob / len(weights)
141
+ return weights
142
+
143
+ def update_with_all_losses(self, ts, losses):
144
+ for t, loss in zip(ts, losses):
145
+ if self._loss_counts[t] == self.history_per_term:
146
+ # Shift out the oldest loss term.
147
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
148
+ self._loss_history[t, -1] = loss
149
+ else:
150
+ self._loss_history[t, self._loss_counts[t]] = loss
151
+ self._loss_counts[t] += 1
152
+
153
+ def _warmed_up(self):
154
+ return (self._loss_counts == self.history_per_term).all()
guided_diffusion/respace.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch as th
3
+
4
+ from .gaussian_diffusion import GaussianDiffusion
5
+
6
+
7
+ def space_timesteps(num_timesteps, section_counts):
8
+ """
9
+ Create a list of timesteps to use from an original diffusion process,
10
+ given the number of timesteps we want to take from equally-sized portions
11
+ of the original process.
12
+
13
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
14
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
15
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
16
+
17
+ If the stride is a string starting with "ddim", then the fixed striding
18
+ from the DDIM paper is used, and only one section is allowed.
19
+
20
+ :param num_timesteps: the number of diffusion steps in the original
21
+ process to divide up.
22
+ :param section_counts: either a list of numbers, or a string containing
23
+ comma-separated numbers, indicating the step count
24
+ per section. As a special case, use "ddimN" where N
25
+ is a number of steps to use the striding from the
26
+ DDIM paper.
27
+ :return: a set of diffusion steps from the original process to use.
28
+ """
29
+ if isinstance(section_counts, str):
30
+ if section_counts.startswith("ddim"):
31
+ desired_count = int(section_counts[len("ddim") :])
32
+ for i in range(1, num_timesteps):
33
+ if len(range(0, num_timesteps, i)) == desired_count:
34
+ return set(range(0, num_timesteps, i))
35
+ raise ValueError(
36
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
37
+ )
38
+ section_counts = [int(x) for x in section_counts.split(",")]
39
+ size_per = num_timesteps // len(section_counts)
40
+ extra = num_timesteps % len(section_counts)
41
+ start_idx = 0
42
+ all_steps = []
43
+ for i, section_count in enumerate(section_counts):
44
+ size = size_per + (1 if i < extra else 0)
45
+ if size < section_count:
46
+ raise ValueError(
47
+ f"cannot divide section of {size} steps into {section_count}"
48
+ )
49
+ if section_count <= 1:
50
+ frac_stride = 1
51
+ else:
52
+ frac_stride = (size - 1) / (section_count - 1)
53
+ cur_idx = 0.0
54
+ taken_steps = []
55
+ for _ in range(section_count):
56
+ taken_steps.append(start_idx + round(cur_idx))
57
+ cur_idx += frac_stride
58
+ all_steps += taken_steps
59
+ start_idx += size
60
+ return set(all_steps)
61
+
62
+
63
+ class SpacedDiffusion(GaussianDiffusion):
64
+ """
65
+ A diffusion process which can skip steps in a base diffusion process.
66
+
67
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
68
+ original diffusion process to retain.
69
+ :param kwargs: the kwargs to create the base diffusion process.
70
+ """
71
+
72
+ def __init__(self, use_timesteps, **kwargs):
73
+ self.use_timesteps = set(use_timesteps)
74
+ self.timestep_map = []
75
+ self.original_num_steps = len(kwargs["betas"])
76
+
77
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
78
+ last_alpha_cumprod = 1.0
79
+ new_betas = []
80
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
81
+ if i in self.use_timesteps:
82
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
83
+ last_alpha_cumprod = alpha_cumprod
84
+ self.timestep_map.append(i)
85
+ kwargs["betas"] = np.array(new_betas)
86
+ super().__init__(**kwargs)
87
+
88
+ def p_mean_variance(
89
+ self, model, *args, **kwargs
90
+ ): # pylint: disable=signature-differs
91
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
92
+
93
+ def training_losses(
94
+ self, model, *args, **kwargs
95
+ ): # pylint: disable=signature-differs
96
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
97
+
98
+ def condition_mean(self, cond_fn, *args, **kwargs):
99
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
100
+
101
+ def condition_score(self, cond_fn, *args, **kwargs):
102
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
103
+
104
+ def _wrap_model(self, model):
105
+ if isinstance(model, _WrappedModel):
106
+ return model
107
+ return _WrappedModel(
108
+ model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
109
+ )
110
+
111
+ def _scale_timesteps(self, t):
112
+ # Scaling is done by the wrapped model.
113
+ return t
114
+
115
+
116
+ class _WrappedModel:
117
+ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
118
+ self.model = model
119
+ self.timestep_map = timestep_map
120
+ self.rescale_timesteps = rescale_timesteps
121
+ self.original_num_steps = original_num_steps
122
+
123
+ def __call__(self, x, ts, **kwargs):
124
+
125
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
126
+ new_ts = map_tensor[ts]
127
+ if self.rescale_timesteps:
128
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
129
+ # print('new_ts')
130
+ # print(new_ts.device)
131
+ return self.model(x, new_ts, **kwargs)
guided_diffusion/script_util.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import inspect
3
+
4
+ from . import gaussian_diffusion as gd
5
+ from .respace import SpacedDiffusion, space_timesteps
6
+ from .unet import SuperResModel, UNetModel, EncoderUNetModel
7
+
8
+ NUM_CLASSES = 1000
9
+
10
+
11
+ def diffusion_defaults():
12
+ """
13
+ Defaults for image and classifier training.
14
+ """
15
+ return dict(
16
+ learn_sigma=False,
17
+ diffusion_steps=1000,
18
+ noise_schedule="linear",
19
+ timestep_respacing="ddim100",
20
+ use_kl=False,
21
+ predict_xstart=False,
22
+ rescale_timesteps=True,
23
+ rescale_learned_sigmas=False,
24
+ )
25
+
26
+
27
+ def classifier_defaults():
28
+ """
29
+ Defaults for classifier models.
30
+ """
31
+ return dict(
32
+ image_size=64,
33
+ classifier_use_fp16=False,
34
+ classifier_width=128,
35
+ classifier_depth=2,
36
+ classifier_attention_resolutions="32,16,8", # 16
37
+ classifier_use_scale_shift_norm=True, # False
38
+ classifier_resblock_updown=True, # False
39
+ classifier_pool="attention",
40
+ )
41
+
42
+
43
+ def model_and_diffusion_defaults():
44
+ """
45
+ Defaults for image training.
46
+ """
47
+ res = dict(
48
+ image_size=64,
49
+ num_channels=128,
50
+ num_res_blocks=2,
51
+ num_heads=4,
52
+ num_heads_upsample=-1,
53
+ num_head_channels=-1,
54
+ attention_resolutions="16,8",
55
+ channel_mult="",
56
+ dropout=0.0,
57
+ class_cond=False,
58
+ use_checkpoint=True,
59
+ use_scale_shift_norm=True,
60
+ resblock_updown=False,
61
+ use_fp16=False,
62
+ use_new_attention_order=False,
63
+ )
64
+ res.update(diffusion_defaults())
65
+ return res
66
+
67
+
68
+ def classifier_and_diffusion_defaults():
69
+ res = classifier_defaults()
70
+ res.update(diffusion_defaults())
71
+ return res
72
+
73
+
74
+ def create_model_and_diffusion(
75
+ image_size,
76
+ class_cond,
77
+ learn_sigma,
78
+ num_channels,
79
+ num_res_blocks,
80
+ channel_mult,
81
+ num_heads,
82
+ num_head_channels,
83
+ num_heads_upsample,
84
+ attention_resolutions,
85
+ dropout,
86
+ diffusion_steps,
87
+ noise_schedule,
88
+ timestep_respacing,
89
+ use_kl,
90
+ predict_xstart,
91
+ rescale_timesteps,
92
+ rescale_learned_sigmas,
93
+ use_checkpoint,
94
+ use_scale_shift_norm,
95
+ resblock_updown,
96
+ use_fp16,
97
+ use_new_attention_order,
98
+ ):
99
+ model = create_model(
100
+ image_size,
101
+ num_channels,
102
+ num_res_blocks,
103
+ channel_mult=channel_mult,
104
+ learn_sigma=learn_sigma,
105
+ class_cond=class_cond,
106
+ use_checkpoint=use_checkpoint,
107
+ attention_resolutions=attention_resolutions,
108
+ num_heads=num_heads,
109
+ num_head_channels=num_head_channels,
110
+ num_heads_upsample=num_heads_upsample,
111
+ use_scale_shift_norm=use_scale_shift_norm,
112
+ dropout=dropout,
113
+ resblock_updown=resblock_updown,
114
+ use_fp16=use_fp16,
115
+ use_new_attention_order=use_new_attention_order,
116
+ )
117
+ diffusion = create_gaussian_diffusion(
118
+ steps=diffusion_steps,
119
+ learn_sigma=learn_sigma,
120
+ noise_schedule=noise_schedule,
121
+ use_kl=use_kl,
122
+ predict_xstart=predict_xstart,
123
+ rescale_timesteps=rescale_timesteps,
124
+ rescale_learned_sigmas=rescale_learned_sigmas,
125
+ timestep_respacing=timestep_respacing,
126
+ )
127
+ return model, diffusion
128
+
129
+
130
+ def create_model(
131
+ image_size,
132
+ num_channels,
133
+ num_res_blocks,
134
+ channel_mult="",
135
+ learn_sigma=False,
136
+ class_cond=False,
137
+ use_checkpoint=True,
138
+ attention_resolutions="16",
139
+ num_heads=1,
140
+ num_head_channels=-1,
141
+ num_heads_upsample=-1,
142
+ use_scale_shift_norm=False,
143
+ dropout=0,
144
+ resblock_updown=False,
145
+ use_fp16=False,
146
+ use_new_attention_order=False,
147
+ ):
148
+ if channel_mult == "":
149
+ if image_size == 512:
150
+ channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
151
+ elif image_size == 256:
152
+ channel_mult = (1, 1, 2, 2, 4, 4)
153
+ elif image_size == 128:
154
+ channel_mult = (1, 1, 2, 3, 4)
155
+ elif image_size == 64:
156
+ channel_mult = (1, 2, 3, 4)
157
+ else:
158
+ raise ValueError(f"unsupported image size: {image_size}")
159
+ else:
160
+ channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
161
+
162
+ attention_ds = []
163
+ for res in attention_resolutions.split(","):
164
+ attention_ds.append(image_size // int(res))
165
+
166
+ return UNetModel(
167
+ image_size=image_size,
168
+ in_channels=3,
169
+ model_channels=num_channels,
170
+ out_channels=(3 if not learn_sigma else 6),
171
+ num_res_blocks=num_res_blocks,
172
+ attention_resolutions=tuple(attention_ds),
173
+ dropout=dropout,
174
+ channel_mult=channel_mult,
175
+ num_classes=(NUM_CLASSES if class_cond else None),
176
+ use_checkpoint=use_checkpoint,
177
+ use_fp16=use_fp16,
178
+ num_heads=num_heads,
179
+ num_head_channels=num_head_channels,
180
+ num_heads_upsample=num_heads_upsample,
181
+ use_scale_shift_norm=use_scale_shift_norm,
182
+ resblock_updown=resblock_updown,
183
+ use_new_attention_order=use_new_attention_order,
184
+ )
185
+
186
+
187
+ def create_classifier_and_diffusion(
188
+ image_size,
189
+ classifier_use_fp16,
190
+ classifier_width,
191
+ classifier_depth,
192
+ classifier_attention_resolutions,
193
+ classifier_use_scale_shift_norm,
194
+ classifier_resblock_updown,
195
+ classifier_pool,
196
+ learn_sigma,
197
+ diffusion_steps,
198
+ noise_schedule,
199
+ timestep_respacing,
200
+ use_kl,
201
+ predict_xstart,
202
+ rescale_timesteps,
203
+ rescale_learned_sigmas,
204
+ ):
205
+ classifier = create_classifier(
206
+ image_size,
207
+ classifier_use_fp16,
208
+ classifier_width,
209
+ classifier_depth,
210
+ classifier_attention_resolutions,
211
+ classifier_use_scale_shift_norm,
212
+ classifier_resblock_updown,
213
+ classifier_pool,
214
+ )
215
+ diffusion = create_gaussian_diffusion(
216
+ steps=diffusion_steps,
217
+ learn_sigma=learn_sigma,
218
+ noise_schedule=noise_schedule,
219
+ use_kl=use_kl,
220
+ predict_xstart=predict_xstart,
221
+ rescale_timesteps=rescale_timesteps,
222
+ rescale_learned_sigmas=rescale_learned_sigmas,
223
+ timestep_respacing=timestep_respacing,
224
+ )
225
+ return classifier, diffusion
226
+
227
+
228
+ def create_classifier(
229
+ image_size,
230
+ classifier_use_fp16,
231
+ classifier_width,
232
+ classifier_depth,
233
+ classifier_attention_resolutions,
234
+ classifier_use_scale_shift_norm,
235
+ classifier_resblock_updown,
236
+ classifier_pool,
237
+ ):
238
+ if image_size == 512:
239
+ channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
240
+ elif image_size == 256:
241
+ channel_mult = (1, 1, 2, 2, 4, 4)
242
+ elif image_size == 128:
243
+ channel_mult = (1, 1, 2, 3, 4)
244
+ elif image_size == 64:
245
+ channel_mult = (1, 2, 3, 4)
246
+ else:
247
+ raise ValueError(f"unsupported image size: {image_size}")
248
+
249
+ attention_ds = []
250
+ for res in classifier_attention_resolutions.split(","):
251
+ attention_ds.append(image_size // int(res))
252
+
253
+ return EncoderUNetModel(
254
+ image_size=image_size,
255
+ in_channels=3,
256
+ model_channels=classifier_width,
257
+ out_channels=1000,
258
+ num_res_blocks=classifier_depth,
259
+ attention_resolutions=tuple(attention_ds),
260
+ channel_mult=channel_mult,
261
+ use_fp16=classifier_use_fp16,
262
+ num_head_channels=64,
263
+ use_scale_shift_norm=classifier_use_scale_shift_norm,
264
+ resblock_updown=classifier_resblock_updown,
265
+ pool=classifier_pool,
266
+ )
267
+
268
+
269
+ def sr_model_and_diffusion_defaults():
270
+ res = model_and_diffusion_defaults()
271
+ res["large_size"] = 256
272
+ res["small_size"] = 256
273
+ arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0]
274
+ for k in res.copy().keys():
275
+ if k not in arg_names:
276
+ del res[k]
277
+ return res
278
+
279
+
280
+ def sr_create_model_and_diffusion(
281
+ large_size,
282
+ small_size,
283
+ class_cond,
284
+ learn_sigma,
285
+ num_channels,
286
+ num_res_blocks,
287
+ num_heads,
288
+ num_head_channels,
289
+ num_heads_upsample,
290
+ attention_resolutions,
291
+ dropout,
292
+ diffusion_steps,
293
+ noise_schedule,
294
+ timestep_respacing,
295
+ use_kl,
296
+ predict_xstart,
297
+ rescale_timesteps,
298
+ rescale_learned_sigmas,
299
+ use_checkpoint,
300
+ use_scale_shift_norm,
301
+ resblock_updown,
302
+ use_fp16,
303
+ ):
304
+ model = sr_create_model(
305
+ large_size,
306
+ small_size,
307
+ num_channels,
308
+ num_res_blocks,
309
+ learn_sigma=learn_sigma,
310
+ class_cond=class_cond,
311
+ use_checkpoint=use_checkpoint,
312
+ attention_resolutions=attention_resolutions,
313
+ num_heads=num_heads,
314
+ num_head_channels=num_head_channels,
315
+ num_heads_upsample=num_heads_upsample,
316
+ use_scale_shift_norm=use_scale_shift_norm,
317
+ dropout=dropout,
318
+ resblock_updown=resblock_updown,
319
+ use_fp16=use_fp16,
320
+ )
321
+ diffusion = create_gaussian_diffusion(
322
+ steps=diffusion_steps,
323
+ learn_sigma=learn_sigma,
324
+ noise_schedule=noise_schedule,
325
+ use_kl=use_kl,
326
+ predict_xstart=predict_xstart,
327
+ rescale_timesteps=rescale_timesteps,
328
+ rescale_learned_sigmas=rescale_learned_sigmas,
329
+ timestep_respacing=timestep_respacing,
330
+ )
331
+ return model, diffusion
332
+
333
+
334
+ def sr_create_model(
335
+ large_size,
336
+ small_size,
337
+ num_channels,
338
+ num_res_blocks,
339
+ learn_sigma,
340
+ class_cond,
341
+ use_checkpoint,
342
+ attention_resolutions,
343
+ num_heads,
344
+ num_head_channels,
345
+ num_heads_upsample,
346
+ use_scale_shift_norm,
347
+ dropout,
348
+ resblock_updown,
349
+ use_fp16,
350
+ ):
351
+ _ = small_size # hack to prevent unused variable
352
+
353
+ if large_size == 512:
354
+ channel_mult = (1, 1, 2, 2, 4, 4)
355
+ elif large_size == 256:
356
+ channel_mult = (1, 1, 2, 2, 4, 4)
357
+ elif large_size == 64:
358
+ channel_mult = (1, 2, 3, 4)
359
+ else:
360
+ raise ValueError(f"unsupported large size: {large_size}")
361
+
362
+ attention_ds = []
363
+ for res in attention_resolutions.split(","):
364
+ attention_ds.append(large_size // int(res))
365
+
366
+ return SuperResModel(
367
+ image_size=large_size,
368
+ in_channels=3,
369
+ model_channels=num_channels,
370
+ out_channels=(3 if not learn_sigma else 6),
371
+ num_res_blocks=num_res_blocks,
372
+ attention_resolutions=tuple(attention_ds),
373
+ dropout=dropout,
374
+ channel_mult=channel_mult,
375
+ num_classes=(NUM_CLASSES if class_cond else None),
376
+ use_checkpoint=use_checkpoint,
377
+ num_heads=num_heads,
378
+ num_head_channels=num_head_channels,
379
+ num_heads_upsample=num_heads_upsample,
380
+ use_scale_shift_norm=use_scale_shift_norm,
381
+ resblock_updown=resblock_updown,
382
+ use_fp16=use_fp16,
383
+ )
384
+
385
+
386
+ def create_gaussian_diffusion(
387
+ *,
388
+ steps=1000,
389
+ learn_sigma=False,
390
+ sigma_small=False,
391
+ noise_schedule="linear",
392
+ use_kl=False,
393
+ predict_xstart=False,
394
+ rescale_timesteps=False,
395
+ rescale_learned_sigmas=False,
396
+ timestep_respacing="",
397
+ ):
398
+ betas = gd.get_named_beta_schedule(noise_schedule, steps)
399
+ if use_kl:
400
+ loss_type = gd.LossType.RESCALED_KL
401
+ elif rescale_learned_sigmas:
402
+ loss_type = gd.LossType.RESCALED_MSE
403
+ else:
404
+ loss_type = gd.LossType.MSE
405
+ if not timestep_respacing:
406
+ timestep_respacing = [steps]
407
+ return SpacedDiffusion(
408
+ use_timesteps=space_timesteps(steps, timestep_respacing),
409
+ betas=betas,
410
+ model_mean_type=(
411
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
412
+ ),
413
+ model_var_type=(
414
+ (
415
+ gd.ModelVarType.FIXED_LARGE
416
+ if not sigma_small
417
+ else gd.ModelVarType.FIXED_SMALL
418
+ )
419
+ if not learn_sigma
420
+ else gd.ModelVarType.LEARNED_RANGE
421
+ ),
422
+ loss_type=loss_type,
423
+ rescale_timesteps=rescale_timesteps,
424
+ )
425
+
426
+
427
+ def add_dict_to_argparser(parser, default_dict):
428
+ for k, v in default_dict.items():
429
+ v_type = type(v)
430
+ if v is None:
431
+ v_type = str
432
+ elif isinstance(v, bool):
433
+ v_type = str2bool
434
+ parser.add_argument(f"--{k}", default=v, type=v_type)
435
+
436
+
437
+ def args_to_dict(args, keys):
438
+ return {k: getattr(args, k) for k in keys}
439
+
440
+
441
+ def str2bool(v):
442
+ """
443
+ https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
444
+ """
445
+ if isinstance(v, bool):
446
+ return v
447
+ if v.lower() in ("yes", "true", "t", "y", "1"):
448
+ return True
449
+ elif v.lower() in ("no", "false", "f", "n", "0"):
450
+ return False
451
+ else:
452
+ raise argparse.ArgumentTypeError("boolean value expected")
guided_diffusion/train_util.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import functools
3
+ import os
4
+
5
+ import blobfile as bf
6
+ import torch as th
7
+ import torch.distributed as dist
8
+ from torch.nn.parallel.distributed import DistributedDataParallel as DDP
9
+ from torch.optim import AdamW
10
+ import cv2
11
+ from . import dist_util, logger
12
+ from .fp16_util import MixedPrecisionTrainer
13
+ from .nn import update_ema
14
+ from .resample import LossAwareSampler, UniformSampler
15
+ import numpy as np
16
+ import skimage
17
+ from skimage.metrics import peak_signal_noise_ratio as psnr
18
+ import math
19
+ # For ImageNet experiments, this was a good default value.
20
+ # We found that the lg_loss_scale quickly climbed to
21
+ # 20-21 within the first ~1K steps of training.
22
+ INITIAL_LOG_LOSS_SCALE = 20.0
23
+
24
+ import core.metrics as Metrics
25
+ # from core.wandb_logger import WandbLogger
26
+ import wandb
27
+
28
+ class TrainLoop:
29
+ def __init__(
30
+ self,
31
+ *,
32
+ model,
33
+ diffusion,
34
+ data,
35
+ val_dat,
36
+ batch_size,
37
+ microbatch,
38
+ lr,
39
+ ema_rate,
40
+ log_interval,
41
+ save_interval,
42
+ resume_checkpoint,
43
+ args,
44
+ use_fp16=False,
45
+ fp16_scale_growth=1e-3,
46
+ schedule_sampler=None,
47
+ weight_decay=0.0,
48
+ lr_anneal_steps=0,
49
+ ):
50
+ self.model = model
51
+ self.diffusion = diffusion
52
+ self.data = data
53
+ self.val_data=val_dat
54
+ self.batch_size = batch_size
55
+ self.microbatch = microbatch if microbatch > 0 else batch_size
56
+ self.lr = lr
57
+ self.ema_rate = (
58
+ [ema_rate]
59
+ if isinstance(ema_rate, float)
60
+ else [float(x) for x in ema_rate.split(",")]
61
+ )
62
+ self.log_interval = log_interval
63
+ self.save_interval = save_interval
64
+ self.resume_checkpoint = resume_checkpoint
65
+ self.args = args
66
+ self.use_fp16 = use_fp16
67
+ self.fp16_scale_growth = fp16_scale_growth
68
+ self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
69
+ self.weight_decay = weight_decay
70
+ self.lr_anneal_steps = lr_anneal_steps
71
+
72
+ self.step = 0
73
+ self.resume_step = 0
74
+ self.global_batch = self.batch_size * dist.get_world_size()
75
+
76
+ self.sync_cuda = th.cuda.is_available()
77
+
78
+ self._load_and_sync_parameters()
79
+ self.mp_trainer = MixedPrecisionTrainer(
80
+ model=self.model,
81
+ use_fp16=self.use_fp16,
82
+ fp16_scale_growth=fp16_scale_growth,
83
+ )
84
+
85
+ self.opt = AdamW(
86
+ self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay
87
+ )
88
+ if self.resume_step:
89
+ self._load_optimizer_state()
90
+ # Model was resumed, either due to a restart or a checkpoint
91
+ # being specified at the command line.
92
+ self.ema_params = [
93
+ self._load_ema_parameters(rate) for rate in self.ema_rate
94
+ ]
95
+ else:
96
+ self.ema_params = [
97
+ copy.deepcopy(self.mp_trainer.master_params)
98
+ for _ in range(len(self.ema_rate))
99
+ ]
100
+
101
+ if th.cuda.is_available():
102
+ print('cuda available')
103
+ self.use_ddp = True
104
+ self.ddp_model = DDP(
105
+ self.model,
106
+ device_ids=[dist_util.dev()],
107
+ output_device=dist_util.dev(),
108
+ broadcast_buffers=False,
109
+ bucket_cap_mb=128,
110
+ find_unused_parameters=True,
111
+ )
112
+ else:
113
+ if dist.get_world_size() > 1:
114
+ logger.warn(
115
+ "Distributed training requires CUDA. "
116
+ "Gradients will not be synchronized properly!"
117
+ )
118
+ self.use_ddp = False
119
+ self.ddp_model = self.model
120
+
121
+ def _load_and_sync_parameters(self):
122
+ resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
123
+
124
+ if resume_checkpoint:
125
+ self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
126
+ if dist.get_rank() == 0:
127
+ logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
128
+ dict_load = dist_util.load_state_dict(resume_checkpoint, map_location=dist_util.dev())
129
+ self.model.load_state_dict(dict_load, strict=False)
130
+
131
+ dist_util.sync_params(self.model.parameters())
132
+
133
+ def _load_ema_parameters(self, rate):
134
+ ema_params = copy.deepcopy(self.mp_trainer.master_params)
135
+
136
+ main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
137
+ ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
138
+ if ema_checkpoint:
139
+ if dist.get_rank() == 0:
140
+ logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
141
+ state_dict = dist_util.load_state_dict(
142
+ ema_checkpoint, map_location=dist_util.dev()
143
+ )
144
+ ema_params = self.mp_trainer.state_dict_to_master_params(state_dict)
145
+
146
+ dist_util.sync_params(ema_params)
147
+ return ema_params
148
+
149
+ def _load_optimizer_state(self):
150
+ main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
151
+ opt_checkpoint = bf.join(
152
+ bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
153
+ )
154
+ if bf.exists(opt_checkpoint):
155
+ logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
156
+ state_dict = dist_util.load_state_dict(
157
+ opt_checkpoint, map_location=dist_util.dev()
158
+ )
159
+ self.opt.load_state_dict(state_dict)
160
+
161
+ def run_loop(self):
162
+ val_idx=0
163
+ best_psnr = 0
164
+ # wandb.init(project = 'diffusion_small', config=self.args)
165
+
166
+ while (
167
+ not self.lr_anneal_steps
168
+ or self.step + self.resume_step < self.lr_anneal_steps
169
+ ):
170
+ # wandb_logger = WandbLogger()
171
+
172
+ batch, cond = next(self.data)
173
+ self.run_step(batch, cond)
174
+
175
+
176
+
177
+
178
+ if (self.step+1) % self.save_interval == 0:
179
+
180
+ number=0
181
+ all_images=[]
182
+ number=0
183
+ print('validation')
184
+
185
+ with th.no_grad():
186
+ val_idx=val_idx+1
187
+ psnr_val = 0
188
+ for batch_id1, data_var in enumerate(self.val_data):
189
+ clean_batch, model_kwargs1 = data_var
190
+ model_kwargs={}
191
+ for k, v in model_kwargs1.items():
192
+ if('Index' in k):
193
+ img_name=v
194
+ else:
195
+ model_kwargs[k]= v.to(dist_util.dev())
196
+
197
+
198
+
199
+
200
+ sample = self.diffusion.p_sample_loop(
201
+ self.model,
202
+ (clean_batch.shape[0], 3, 256,256),
203
+ clip_denoised=True,
204
+ model_kwargs=model_kwargs,
205
+ )
206
+
207
+
208
+ sample = ((sample + 1) * 127.5)
209
+ sample = sample.clamp(0, 255).to(th.uint8)
210
+ sample = sample.permute(0, 2, 3, 1)
211
+ sample = sample.contiguous().cpu().numpy()
212
+
213
+
214
+
215
+ number=number+1
216
+
217
+ clean_image = ((model_kwargs['HR']+1)* 127.5).clamp(0, 255).to(th.uint8)
218
+ clean_image= clean_image.permute(0, 2, 3, 1)
219
+ clean_image= clean_image.contiguous().cpu().numpy()
220
+
221
+
222
+
223
+
224
+
225
+ clean_image = clean_image[0][:,:,::-1]
226
+ sample = sample[0][:,:,::-1]
227
+ clean_image = cv2.cvtColor(clean_image, cv2.COLOR_BGR2GRAY)
228
+ sample = cv2.cvtColor(sample, cv2.COLOR_BGR2GRAY)
229
+
230
+ psnr_im = psnr(clean_image,sample)
231
+ # print(img_name[0])
232
+ # print(psnr_im)
233
+
234
+
235
+ psnr_val = psnr_val + psnr_im
236
+
237
+
238
+ psnr_val = psnr_val/number
239
+
240
+ print('psnr =')
241
+ print(psnr_val)
242
+ # wandb.log({"psnr": psnr_val})
243
+
244
+ if best_psnr < psnr_val:
245
+ best_psnr = psnr_val
246
+ self.save_val()
247
+
248
+
249
+
250
+
251
+ # Run for a finite amount of time in integration tests.
252
+ # if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
253
+ # return
254
+ self.step += 1
255
+ # Save the last checkpoint if it wasn't already saved.
256
+ # if (self.step - 1) % self.save_interval != 0:
257
+ # self.save()
258
+
259
+ def run_step(self, batch, cond):
260
+ self.forward_backward(batch, cond)
261
+ took_step = self.mp_trainer.optimize(self.opt)
262
+ if took_step:
263
+ self._update_ema()
264
+ self._anneal_lr()
265
+ self.log_step()
266
+
267
+ def forward_backward(self, batch, cond):
268
+ self.mp_trainer.zero_grad()
269
+ num_im = 0
270
+ loss_wandb = 0
271
+ for i in range(0, batch.shape[0], self.microbatch):
272
+ num_im = num_im + 1
273
+
274
+ micro = batch[i : i + self.microbatch].to(dist_util.dev())
275
+ micro_cond = {
276
+ k: v[i : i + self.microbatch].to(dist_util.dev())
277
+ for k, v in cond.items()
278
+ }
279
+ last_batch = (i + self.microbatch) >= batch.shape[0]
280
+ t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
281
+ # print(t.shape)
282
+ # print(t)
283
+
284
+ compute_losses = functools.partial(
285
+ self.diffusion.training_losses,
286
+ self.ddp_model,
287
+ micro,
288
+ t,
289
+ model_kwargs=micro_cond,
290
+ )
291
+
292
+ if last_batch or not self.use_ddp:
293
+ losses = compute_losses()
294
+ else:
295
+ with self.ddp_model.no_sync():
296
+ losses = compute_losses()
297
+
298
+ if isinstance(self.schedule_sampler, LossAwareSampler):
299
+ self.schedule_sampler.update_with_local_losses(
300
+ t, losses["loss"].detach()
301
+ )
302
+
303
+ loss = (losses["loss"] * weights).mean()
304
+ loss_wandb = th.log10(loss) + loss_wandb
305
+
306
+ log_loss_dict(
307
+ self.diffusion, t, {k: v * weights for k, v in losses.items()}
308
+ )
309
+ self.mp_trainer.backward(loss)
310
+ loss_wandb_f = loss_wandb/num_im
311
+ # wandb.log({"loss": loss_wandb_f})
312
+
313
+ def _update_ema(self):
314
+ for rate, params in zip(self.ema_rate, self.ema_params):
315
+ update_ema(params, self.mp_trainer.master_params, rate=rate)
316
+
317
+ def _anneal_lr(self):
318
+ if not self.lr_anneal_steps:
319
+ return
320
+ frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
321
+ lr = self.lr * (1 - frac_done)
322
+ for param_group in self.opt.param_groups:
323
+ param_group["lr"] = lr
324
+
325
+ def log_step(self):
326
+ logger.logkv("step", self.step + self.resume_step)
327
+ logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
328
+
329
+ def save(self):
330
+ def save_checkpoint(rate, params):
331
+ state_dict = self.mp_trainer.master_params_to_state_dict(params)
332
+ if dist.get_rank() == 0:
333
+ logger.log(f"saving model {rate}...")
334
+ if not rate:
335
+ filename = f"model{(self.step+self.resume_step):06d}.pt"
336
+ else:
337
+ filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
338
+ with bf.BlobFile(bf.join("./weights", filename), "wb") as f:
339
+ th.save(state_dict, f)
340
+
341
+ save_checkpoint(0, self.mp_trainer.master_params)
342
+ for rate, params in zip(self.ema_rate, self.ema_params):
343
+ save_checkpoint(rate, params)
344
+
345
+ if dist.get_rank() == 0:
346
+ with bf.BlobFile(
347
+ bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
348
+ "wb",
349
+ ) as f:
350
+ th.save(self.opt.state_dict(), f)
351
+
352
+ dist.barrier()
353
+
354
+ def save_val(self):
355
+ def save_checkpoint_val(rate, params):
356
+ state_dict = self.mp_trainer.master_params_to_state_dict(params)
357
+ if dist.get_rank() == 0:
358
+ logger.log(f"saving model {rate}...")
359
+ if not rate:
360
+ filename = f"model{(self.step+self.resume_step):06d}.pt"
361
+ else:
362
+ filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
363
+ with bf.BlobFile(bf.join("./weights", filename), "wb") as f:
364
+ th.save(state_dict, f)
365
+
366
+ save_checkpoint_val(0, self.mp_trainer.master_params)
367
+ for rate, params in zip(self.ema_rate, self.ema_params):
368
+ save_checkpoint_val(rate, params)
369
+
370
+ if dist.get_rank() == 0:
371
+ with bf.BlobFile(
372
+ bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
373
+ "wb",
374
+ ) as f:
375
+ th.save(self.opt.state_dict(), f)
376
+
377
+ dist.barrier()
378
+
379
+
380
+ def parse_resume_step_from_filename(filename):
381
+ """
382
+ Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
383
+ checkpoint's number of steps.
384
+ """
385
+ split = filename.split("model")
386
+ if len(split) < 2:
387
+ return 0
388
+ split1 = split[-1].split(".")[0]
389
+ try:
390
+ return int(split1)
391
+ except ValueError:
392
+ return 0
393
+
394
+
395
+ def get_blob_logdir():
396
+ # You can change this to be a separate path to save checkpoints to
397
+ # a blobstore or some external drive.
398
+ return logger.get_dir()
399
+
400
+
401
+ def find_resume_checkpoint():
402
+ # On your infrastructure, you may want to override this to automatically
403
+ # discover the latest checkpoint on your blob storage, etc.
404
+ return None
405
+
406
+
407
+ def find_ema_checkpoint(main_checkpoint, step, rate):
408
+ if main_checkpoint is None:
409
+ return None
410
+ filename = f"ema_{rate}_{(step):06d}.pt"
411
+ path = bf.join(bf.dirname(main_checkpoint), filename)
412
+ if bf.exists(path):
413
+ return path
414
+ return None
415
+
416
+
417
+ def log_loss_dict(diffusion, ts, losses):
418
+ for key, values in losses.items():
419
+ logger.logkv_mean(key, values.mean().item())
420
+ # Log the quantiles (four quartiles, in particular).
421
+ for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
422
+ quartile = int(4 * sub_t / diffusion.num_timesteps)
423
+ logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
guided_diffusion/unet.py ADDED
@@ -0,0 +1,1908 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+
3
+ import math
4
+
5
+ import numpy as np
6
+ import torch as th
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from .fp16_util import convert_module_to_f16, convert_module_to_f32
11
+ from .nn import (
12
+ checkpoint,
13
+ conv_nd,
14
+ linear,
15
+ avg_pool_nd,
16
+ zero_module,
17
+ normalization,
18
+ timestep_embedding,
19
+ )
20
+
21
+
22
+ # from models.submodules import *
23
+ import torchvision.models
24
+
25
+ class VGG19(nn.Module):
26
+ def __init__(self):
27
+ super(VGG19, self).__init__()
28
+ '''
29
+ use vgg19 conv1_2, conv2_2, conv3_3 feature, before relu layer
30
+ '''
31
+ self.feature_list = [7]
32
+ vgg19 = torchvision.models.vgg19(pretrained=True)
33
+
34
+ self.model = th.nn.Sequential(*list(vgg19.features.children())[:self.feature_list[-1]+1])
35
+ # self.model.apply(convert_module_to_f16)
36
+
37
+ def forward(self, x , emb):
38
+ # x = (x-0.5)/0.5
39
+ features = []
40
+ for i, layer in enumerate(list(self.model)):
41
+ # print(layer,i)
42
+ x = layer(x)
43
+ if i in self.feature_list:
44
+ features.append(x)
45
+ # print(x.shape)
46
+ return features
47
+
48
+
49
+
50
+ class AttentionPool2d(nn.Module):
51
+ """
52
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ spacial_dim: int,
58
+ embed_dim: int,
59
+ num_heads_channels: int,
60
+ output_dim: int = None,
61
+ ):
62
+ super().__init__()
63
+ self.positional_embedding = nn.Parameter(
64
+ th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
65
+ )
66
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
67
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
68
+ self.num_heads = embed_dim // num_heads_channels
69
+ self.attention = QKVAttention(self.num_heads)
70
+
71
+ def forward(self, x):
72
+ b, c, *_spatial = x.shape
73
+ x = x.reshape(b, c, -1) # NC(HW)
74
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
75
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
76
+ x = self.qkv_proj(x)
77
+ x = self.attention(x)
78
+ x = self.c_proj(x)
79
+ return x[:, :, 0]
80
+
81
+
82
+ class TimestepBlock(nn.Module):
83
+ """
84
+ Any module where forward() takes timestep embeddings as a second argument.
85
+ """
86
+
87
+ @abstractmethod
88
+ def forward(self, x, emb):
89
+ """
90
+ Apply the module to `x` given `emb` timestep embeddings.
91
+ """
92
+
93
+
94
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
95
+ """
96
+ A sequential module that passes timestep embeddings to the children that
97
+ support it as an extra input.
98
+ """
99
+
100
+ def forward(self, x, emb):
101
+ for layer in self:
102
+ if isinstance(layer, TimestepBlock):
103
+ x = layer(x, emb)
104
+ else:
105
+ x = layer(x)
106
+ return x
107
+
108
+
109
+ class Upsample(nn.Module):
110
+ """
111
+ An upsampling layer with an optional convolution.
112
+
113
+ :param channels: channels in the inputs and outputs.
114
+ :param use_conv: a bool determining if a convolution is applied.
115
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
116
+ upsampling occurs in the inner-two dimensions.
117
+ """
118
+
119
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
120
+ super().__init__()
121
+ self.channels = channels
122
+ self.out_channels = out_channels or channels
123
+ self.use_conv = use_conv
124
+ self.dims = dims
125
+ if use_conv:
126
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
127
+
128
+ def forward(self, x):
129
+ assert x.shape[1] == self.channels
130
+ if self.dims == 3:
131
+ x = F.interpolate(
132
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
133
+ )
134
+ else:
135
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
136
+ if self.use_conv:
137
+ x = self.conv(x)
138
+ return x
139
+
140
+
141
+ class Downsample(nn.Module):
142
+ """
143
+ A downsampling layer with an optional convolution.
144
+
145
+ :param channels: channels in the inputs and outputs.
146
+ :param use_conv: a bool determining if a convolution is applied.
147
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
148
+ downsampling occurs in the inner-two dimensions.
149
+ """
150
+
151
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
152
+ super().__init__()
153
+ self.channels = channels
154
+ self.out_channels = out_channels or channels
155
+ self.use_conv = use_conv
156
+ self.dims = dims
157
+ stride = 2 if dims != 3 else (1, 2, 2)
158
+ if use_conv:
159
+ self.op = conv_nd(
160
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=1
161
+ )
162
+ else:
163
+ assert self.channels == self.out_channels
164
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
165
+
166
+ def forward(self, x):
167
+ assert x.shape[1] == self.channels
168
+ return self.op(x)
169
+
170
+
171
+ class ResBlock(TimestepBlock):
172
+ """
173
+ A residual block that can optionally change the number of channels.
174
+
175
+ :param channels: the number of input channels.
176
+ :param emb_channels: the number of timestep embedding channels.
177
+ :param dropout: the rate of dropout.
178
+ :param out_channels: if specified, the number of out channels.
179
+ :param use_conv: if True and out_channels is specified, use a spatial
180
+ convolution instead of a smaller 1x1 convolution to change the
181
+ channels in the skip connection.
182
+ :param dims: determines if the signal is 1D, 2D, or 3D.
183
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
184
+ :param up: if True, use this block for upsampling.
185
+ :param down: if True, use this block for downsampling.
186
+ """
187
+
188
+ def __init__(
189
+ self,
190
+ channels,
191
+ emb_channels,
192
+ dropout,
193
+ out_channels=None,
194
+ use_conv=False,
195
+ use_scale_shift_norm=False,
196
+ dims=2,
197
+ use_checkpoint=False,
198
+ up=False,
199
+ down=False,
200
+ ):
201
+ super().__init__()
202
+ self.channels = channels
203
+ self.emb_channels = emb_channels
204
+ self.dropout = dropout
205
+ self.out_channels = out_channels or channels
206
+ self.use_conv = use_conv
207
+ self.use_checkpoint = use_checkpoint
208
+ self.use_scale_shift_norm = use_scale_shift_norm
209
+
210
+ self.in_layers = nn.Sequential(
211
+ normalization(channels),
212
+ nn.SiLU(),
213
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
214
+ )
215
+
216
+ self.updown = up or down
217
+
218
+ if up:
219
+ self.h_upd = Upsample(channels, False, dims)
220
+ self.x_upd = Upsample(channels, False, dims)
221
+ elif down:
222
+ self.h_upd = Downsample(channels, False, dims)
223
+ self.x_upd = Downsample(channels, False, dims)
224
+ else:
225
+ self.h_upd = self.x_upd = nn.Identity()
226
+
227
+ self.emb_layers = nn.Sequential(
228
+ nn.SiLU(),
229
+ linear(
230
+ emb_channels,
231
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
232
+ ),
233
+ )
234
+ self.out_layers = nn.Sequential(
235
+ normalization(self.out_channels),
236
+ nn.SiLU(),
237
+ nn.Dropout(p=dropout),
238
+ zero_module(
239
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
240
+ ),
241
+ )
242
+
243
+ if self.out_channels == channels:
244
+ self.skip_connection = nn.Identity()
245
+ elif use_conv:
246
+ self.skip_connection = conv_nd(
247
+ dims, channels, self.out_channels, 3, padding=1
248
+ )
249
+ else:
250
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
251
+
252
+ def forward(self, x, emb):
253
+ """
254
+ Apply the block to a Tensor, conditioned on a timestep embedding.
255
+
256
+ :param x: an [N x C x ...] Tensor of features.
257
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
258
+ :return: an [N x C x ...] Tensor of outputs.
259
+ """
260
+ return checkpoint(
261
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
262
+ )
263
+
264
+ def _forward(self, x, emb):
265
+ if self.updown:
266
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
267
+ h = in_rest(x)
268
+ h = self.h_upd(h)
269
+ x = self.x_upd(x)
270
+ h = in_conv(h)
271
+ else:
272
+ h = self.in_layers(x)
273
+ emb_out = self.emb_layers(emb).type(h.dtype)
274
+ while len(emb_out.shape) < len(h.shape):
275
+ emb_out = emb_out[..., None]
276
+ if self.use_scale_shift_norm:
277
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
278
+ scale, shift = th.chunk(emb_out, 2, dim=1)
279
+ h = out_norm(h) * (1 + scale) + shift
280
+ h = out_rest(h)
281
+ else:
282
+ h = h + emb_out
283
+ h = self.out_layers(h)
284
+ return self.skip_connection(x) + h
285
+
286
+
287
+ class AttentionBlock(nn.Module):
288
+ """
289
+ An attention block that allows spatial positions to attend to each other.
290
+
291
+ Originally ported from here, but adapted to the N-d case.
292
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
293
+ """
294
+
295
+ def __init__(
296
+ self,
297
+ channels,
298
+ num_heads=1,
299
+ num_head_channels=-1,
300
+ use_checkpoint=False,
301
+ use_new_attention_order=False,
302
+ ):
303
+ super().__init__()
304
+ self.channels = channels
305
+ if num_head_channels == -1:
306
+ self.num_heads = num_heads
307
+ else:
308
+ assert (
309
+ channels % num_head_channels == 0
310
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
311
+ self.num_heads = channels // num_head_channels
312
+ self.use_checkpoint = use_checkpoint
313
+ self.norm = normalization(channels)
314
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
315
+ if use_new_attention_order:
316
+ # split qkv before split heads
317
+ self.attention = QKVAttention(self.num_heads)
318
+ else:
319
+ # split heads before split qkv
320
+ self.attention = QKVAttentionLegacy(self.num_heads)
321
+
322
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
323
+
324
+ def forward(self, x):
325
+ return checkpoint(self._forward, (x,), self.parameters(), True)
326
+
327
+ def _forward(self, x):
328
+ b, c, *spatial = x.shape
329
+ x = x.reshape(b, c, -1)
330
+ qkv = self.qkv(self.norm(x))
331
+ h = self.attention(qkv)
332
+ h = self.proj_out(h)
333
+ return (x + h).reshape(b, c, *spatial)
334
+
335
+
336
+ def count_flops_attn(model, _x, y):
337
+ """
338
+ A counter for the `thop` package to count the operations in an
339
+ attention operation.
340
+ Meant to be used like:
341
+ macs, params = thop.profile(
342
+ model,
343
+ inputs=(inputs, timestamps),
344
+ custom_ops={QKVAttention: QKVAttention.count_flops},
345
+ )
346
+ """
347
+ b, c, *spatial = y[0].shape
348
+ num_spatial = int(np.prod(spatial))
349
+ # We perform two matmuls with the same number of ops.
350
+ # The first computes the weight matrix, the second computes
351
+ # the combination of the value vectors.
352
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
353
+ model.total_ops += th.DoubleTensor([matmul_ops])
354
+
355
+
356
+ class QKVAttentionLegacy(nn.Module):
357
+ """
358
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
359
+ """
360
+
361
+ def __init__(self, n_heads):
362
+ super().__init__()
363
+ self.n_heads = n_heads
364
+
365
+ def forward(self, qkv):
366
+ """
367
+ Apply QKV attention.
368
+
369
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
370
+ :return: an [N x (H * C) x T] tensor after attention.
371
+ """
372
+ bs, width, length = qkv.shape
373
+ assert width % (3 * self.n_heads) == 0
374
+ ch = width // (3 * self.n_heads)
375
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
376
+ scale = 1 / math.sqrt(math.sqrt(ch))
377
+ weight = th.einsum(
378
+ "bct,bcs->bts", q * scale, k * scale
379
+ ) # More stable with f16 than dividing afterwards
380
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
381
+ a = th.einsum("bts,bcs->bct", weight, v)
382
+ return a.reshape(bs, -1, length)
383
+
384
+ @staticmethod
385
+ def count_flops(model, _x, y):
386
+ return count_flops_attn(model, _x, y)
387
+
388
+
389
+ class QKVAttention(nn.Module):
390
+ """
391
+ A module which performs QKV attention and splits in a different order.
392
+ """
393
+
394
+ def __init__(self, n_heads):
395
+ super().__init__()
396
+ self.n_heads = n_heads
397
+
398
+ def forward(self, qkv):
399
+ """
400
+ Apply QKV attention.
401
+
402
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
403
+ :return: an [N x (H * C) x T] tensor after attention.
404
+ """
405
+ bs, width, length = qkv.shape
406
+ assert width % (3 * self.n_heads) == 0
407
+ ch = width // (3 * self.n_heads)
408
+ q, k, v = qkv.chunk(3, dim=1)
409
+ scale = 1 / math.sqrt(math.sqrt(ch))
410
+ weight = th.einsum(
411
+ "bct,bcs->bts",
412
+ (q * scale).view(bs * self.n_heads, ch, length),
413
+ (k * scale).view(bs * self.n_heads, ch, length),
414
+ ) # More stable with f16 than dividing afterwards
415
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
416
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
417
+ return a.reshape(bs, -1, length)
418
+
419
+ @staticmethod
420
+ def count_flops(model, _x, y):
421
+ return count_flops_attn(model, _x, y)
422
+
423
+
424
+ class UNetModel(nn.Module):
425
+ """
426
+ The full UNet model with attention and timestep embedding.
427
+ :param in_channels: channels in the input Tensor.
428
+ :param model_channels: base channel count for the model.
429
+ :param out_channels: channels in the output Tensor.
430
+ :param num_res_blocks: number of residual blocks per downsample.
431
+ :param attention_resolutions: a collection of downsample rates at which
432
+ attention will take place. May be a set, list, or tuple.
433
+ For example, if this contains 4, then at 4x downsampling, attention
434
+ will be used.
435
+ :param dropout: the dropout probability.
436
+ :param channel_mult: channel multiplier for each level of the UNet.
437
+ :param conv_resample: if True, use learned convolutions for upsampling and
438
+ downsampling.
439
+ :param dims: determines if the signal is 1D, 2D, or 3D.
440
+ :param num_classes: if specified (as an int), then this model will be
441
+ class-conditional with `num_classes` classes.
442
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
443
+ :param num_heads: the number of attention heads in each attention layer.
444
+ :param num_heads_channels: if specified, ignore num_heads and instead use
445
+ a fixed channel width per attention head.
446
+ :param num_heads_upsample: works with num_heads to set a different number
447
+ of heads for upsampling. Deprecated.
448
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
449
+ :param resblock_updown: use residual blocks for up/downsampling.
450
+ :param use_new_attention_order: use a different attention pattern for potentially
451
+ increased efficiency.
452
+ """
453
+
454
+ def __init__(
455
+ self,
456
+ image_size,
457
+ in_channels,
458
+ model_channels,
459
+ out_channels,
460
+ num_res_blocks,
461
+ attention_resolutions,
462
+ dropout=0,
463
+ channel_mult=(1, 2, 4, 8),
464
+ conv_resample=True,
465
+ dims=2,
466
+ num_classes=None,
467
+ use_checkpoint=False,
468
+ use_fp16=False,
469
+ num_heads=1,
470
+ num_head_channels=-1,
471
+ num_heads_upsample=-1,
472
+ use_scale_shift_norm=False,
473
+ resblock_updown=False,
474
+ use_new_attention_order=False,
475
+ ):
476
+ super().__init__()
477
+
478
+ if num_heads_upsample == -1:
479
+ num_heads_upsample = num_heads
480
+ in_channels=6
481
+ self.image_size = image_size
482
+ self.in_channels = in_channels
483
+ self.model_channels = model_channels
484
+ self.out_channels = out_channels
485
+ self.num_res_blocks = num_res_blocks
486
+ self.attention_resolutions = attention_resolutions
487
+ self.dropout = dropout
488
+ self.channel_mult = channel_mult
489
+ self.conv_resample = conv_resample
490
+ self.num_classes = num_classes
491
+ self.use_checkpoint = use_checkpoint
492
+ self.dtype = th.float16 if use_fp16 else th.float32
493
+ self.num_heads = num_heads
494
+ self.num_head_channels = num_head_channels
495
+ self.num_heads_upsample = num_heads_upsample
496
+ time_embed_dim = model_channels * 4
497
+ self.time_embed = nn.Sequential(
498
+ linear(model_channels, time_embed_dim),
499
+ nn.SiLU(),
500
+ linear(time_embed_dim, time_embed_dim),
501
+ )
502
+
503
+ if self.num_classes is not None:
504
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
505
+
506
+ ch = input_ch = int(channel_mult[0] * model_channels)
507
+ # print(channel_mult,in_channels)
508
+ # in_channels=6
509
+ # print(in_channels)
510
+ # self.input_transform_1 = conv_nd(2, 6, 3, 3, padding=1)
511
+ self.input_blocks = nn.ModuleList(
512
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
513
+ )
514
+ self._feature_size = ch
515
+ input_block_chans = [ch]
516
+ ds = 1
517
+ blah=0
518
+ for level, mult in enumerate(channel_mult):
519
+ for _ in range(num_res_blocks):
520
+ # print(level,mult,int(mult * model_channels))
521
+
522
+ layers = [
523
+ ResBlock(
524
+ ch,
525
+ time_embed_dim,
526
+ dropout,
527
+ out_channels=int(mult * model_channels),
528
+ dims=dims,
529
+ use_checkpoint=use_checkpoint,
530
+ use_scale_shift_norm=use_scale_shift_norm,
531
+ )
532
+ ]
533
+ ch = int(mult * model_channels)
534
+ if ds in attention_resolutions:
535
+ layers.append(
536
+ AttentionBlock(
537
+ ch,
538
+ use_checkpoint=use_checkpoint,
539
+ num_heads=num_heads,
540
+ num_head_channels=num_head_channels,
541
+ use_new_attention_order=use_new_attention_order,
542
+ )
543
+ )
544
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
545
+ self._feature_size += ch
546
+ input_block_chans.append(ch)
547
+ if level != len(channel_mult) - 1:
548
+ out_ch = ch
549
+ blah=blah+1
550
+ # if(blah==1):
551
+ # ch1=ch+64
552
+ # elif(blah==2):
553
+ # ch1=ch+128
554
+ # elif(blah==3):
555
+ # ch1=ch+256
556
+ # else:
557
+ # ch1=ch
558
+ ch1=ch
559
+ # print(resblock_updown)
560
+ self.input_blocks.append(
561
+ TimestepEmbedSequential(
562
+ ResBlock(
563
+ ch1,
564
+ time_embed_dim,
565
+ dropout,
566
+ out_channels=out_ch,
567
+ dims=dims,
568
+ use_checkpoint=use_checkpoint,
569
+ use_scale_shift_norm=use_scale_shift_norm,
570
+ down=True,
571
+ )
572
+ if resblock_updown
573
+ else Downsample(
574
+ ch, conv_resample, dims=dims, out_channels=out_ch
575
+ )
576
+ )
577
+ )
578
+ ch = out_ch
579
+ input_block_chans.append(ch)
580
+ ds *= 2
581
+ self._feature_size += ch
582
+ # print(input_block_chans)
583
+ self.middle_block = TimestepEmbedSequential(
584
+ ResBlock(
585
+ ch,
586
+ time_embed_dim,
587
+ dropout,
588
+ dims=dims,
589
+ use_checkpoint=use_checkpoint,
590
+ use_scale_shift_norm=use_scale_shift_norm,
591
+ ),
592
+ AttentionBlock(
593
+ ch,
594
+ use_checkpoint=use_checkpoint,
595
+ num_heads=num_heads,
596
+ num_head_channels=num_head_channels,
597
+ use_new_attention_order=use_new_attention_order,
598
+ ),
599
+ ResBlock(
600
+ ch,
601
+ time_embed_dim,
602
+ dropout,
603
+ dims=dims,
604
+ use_checkpoint=use_checkpoint,
605
+ use_scale_shift_norm=use_scale_shift_norm,
606
+ ),
607
+ )
608
+ self._feature_size += ch
609
+
610
+ self.output_blocks = nn.ModuleList([])
611
+ for level, mult in list(enumerate(channel_mult))[::-1]:
612
+ for i in range(num_res_blocks + 1):
613
+ ich = input_block_chans.pop()
614
+ layers = [
615
+ ResBlock(
616
+ ch + ich,
617
+ time_embed_dim,
618
+ dropout,
619
+ out_channels=int(model_channels * mult),
620
+ dims=dims,
621
+ use_checkpoint=use_checkpoint,
622
+ use_scale_shift_norm=use_scale_shift_norm,
623
+ )
624
+ ]
625
+ ch = int(model_channels * mult)
626
+ if ds in attention_resolutions:
627
+ layers.append(
628
+ AttentionBlock(
629
+ ch,
630
+ use_checkpoint=use_checkpoint,
631
+ num_heads=num_heads_upsample,
632
+ num_head_channels=num_head_channels,
633
+ use_new_attention_order=use_new_attention_order,
634
+ )
635
+ )
636
+ if level and i == num_res_blocks:
637
+ out_ch = ch
638
+ layers.append(
639
+ ResBlock(
640
+ ch,
641
+ time_embed_dim,
642
+ dropout,
643
+ out_channels=out_ch,
644
+ dims=dims,
645
+ use_checkpoint=use_checkpoint,
646
+ use_scale_shift_norm=use_scale_shift_norm,
647
+ up=True,
648
+ )
649
+ if resblock_updown
650
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
651
+ )
652
+ ds //= 2
653
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
654
+ self._feature_size += ch
655
+
656
+ self.vgg=VGG19()
657
+ # self.conv_convert1 = ResBlock(
658
+ # 256,
659
+ # time_embed_dim,
660
+ # dropout,
661
+ # out_channels=192,
662
+ # dims=dims,
663
+ # use_checkpoint=use_checkpoint,
664
+ # use_scale_shift_norm=use_scale_shift_norm,
665
+ # )
666
+ self.conv_convert2 = ResBlock(
667
+ 320,
668
+ time_embed_dim,
669
+ dropout,
670
+ out_channels=192,
671
+ dims=dims,
672
+ use_checkpoint=use_checkpoint,
673
+ use_scale_shift_norm=use_scale_shift_norm,
674
+ )
675
+ # self.conv_convert3 = ResBlock(
676
+ # 640,
677
+ # time_embed_dim,
678
+ # dropout,
679
+ # out_channels=384,
680
+ # dims=dims,
681
+ # use_checkpoint=use_checkpoint,
682
+ # use_scale_shift_norm=use_scale_shift_norm,
683
+ # )
684
+
685
+ self.out = nn.Sequential(
686
+ normalization(ch),
687
+ nn.SiLU(),
688
+ zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
689
+ )
690
+ # print(input_ch,out_channels)
691
+ def convert_to_fp16(self):
692
+ """
693
+ Convert the torso of the model to float16.
694
+ """
695
+ self.vgg.apply(convert_module_to_f16)
696
+ self.input_blocks.apply(convert_module_to_f16)
697
+ self.middle_block.apply(convert_module_to_f16)
698
+ self.output_blocks.apply(convert_module_to_f16)
699
+ # self.conv_convert1.apply(convert_module_to_f16)
700
+ self.conv_convert2.apply(convert_module_to_f16)
701
+ # self.conv_convert3.apply(convert_module_to_f16)
702
+ self.input_transform_1.apply(convert_module_to_f16)
703
+
704
+
705
+ def convert_to_fp32(self):
706
+ """
707
+ Convert the torso of the model to float32.
708
+ """
709
+ self.vgg.apply(convert_module_to_f32)
710
+
711
+ self.input_blocks.apply(convert_module_to_f32)
712
+ self.middle_block.apply(convert_module_to_f32)
713
+ self.output_blocks.apply(convert_module_to_f32)
714
+
715
+ def forward(self, x, timesteps, low_res ,high_res, y=None,**kwargs):
716
+ """
717
+ Apply the model to an input batch.
718
+
719
+ :param x: an [N x C x ...] Tensor of inputs.
720
+ :param timesteps: a 1-D batch of timesteps.
721
+ :param y: an [N] Tensor of labels, if class-conditional.
722
+ :return: an [N x C x ...] Tensor of outputs.
723
+ """
724
+
725
+ hs = []
726
+ # x1 = th.cat([x,high_res],1).type(self.dtype)
727
+ # x1 = self.input_transform_1(x)
728
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
729
+ # input1=low_res
730
+ high_res=x[:,3:]
731
+ # vgg_feats = self.vgg(high_res.type(self.dtype), emb)
732
+ # vgg_feats1 = self.vgg(high_res.type(self.dtype), emb)
733
+
734
+ # print(x.shape)
735
+ # print(emb.shape)
736
+ # vgg_feats=vgg_feats.type(self.dtype)
737
+ # print(vgg_feats[0].shape)
738
+ # print(emb.shape)
739
+ h = x.type(self.dtype)
740
+
741
+ for i , module in enumerate(self.input_blocks):
742
+ # print(i,module,h.shape)
743
+
744
+ # if(i==3):
745
+ # # print()
746
+ # h= th.cat([h,vgg_feats[0]],1)
747
+ # h = self.conv_convert1(h,emb)
748
+ # if(i==6):
749
+ # h= th.cat([h,vgg_feats[0]],1)
750
+ # h = self.conv_convert2(h,emb)
751
+
752
+ # if(i==9):
753
+ # h = th.cat([h,vgg_feats[2]],1)
754
+ # h = self.conv_convert3(h,emb)
755
+ # print(h.shape)
756
+ # print(h.shape,emb.shape)
757
+ h = module(h, emb)
758
+
759
+ hs.append(h)
760
+ # print(h.shape)
761
+ h = self.middle_block(h, emb)
762
+ # stop
763
+ for module in self.output_blocks:
764
+ h = th.cat([h, hs.pop()], dim=1)
765
+ h = module(h, emb)
766
+ h = h.type(x.dtype)
767
+ out=self.out(h)
768
+ return out
769
+
770
+
771
+ class SuperResModel(UNetModel):
772
+ """
773
+ A UNetModel that performs super-resolution.
774
+
775
+ Expects an extra kwarg `low_res` to condition on a low-resolution image.
776
+ """
777
+
778
+ def __init__(self, image_size, in_channels, *args, **kwargs):
779
+ super().__init__(image_size, in_channels * 2, *args, **kwargs)
780
+
781
+ def forward(self, x, timesteps, low_res=None, **kwargs):
782
+ _, _, new_height, new_width = x.shape
783
+ low_res = kwargs['SR']
784
+ # upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
785
+ # print(x.shape,low_res.shape)
786
+ # high_res= kwargs['full_res']
787
+ # print(x.shape)
788
+ # low_res=x[:,3:]
789
+ # x = th.cat([x, low_res], dim=1)
790
+ return super().forward(x, timesteps,low_res,low_res, **kwargs)
791
+
792
+
793
+ class EncoderUNetModel(nn.Module):
794
+ """
795
+ The half UNet model with attention and timestep embedding.
796
+
797
+ For usage, see UNet.
798
+ """
799
+
800
+ def __init__(
801
+ self,
802
+ image_size,
803
+ in_channels,
804
+ model_channels,
805
+ out_channels,
806
+ num_res_blocks,
807
+ attention_resolutions,
808
+ dropout=0,
809
+ channel_mult=(1, 2, 4, 8),
810
+ conv_resample=True,
811
+ dims=2,
812
+ use_checkpoint=False,
813
+ use_fp16=False,
814
+ num_heads=1,
815
+ num_head_channels=-1,
816
+ num_heads_upsample=-1,
817
+ use_scale_shift_norm=False,
818
+ resblock_updown=False,
819
+ use_new_attention_order=False,
820
+ pool="adaptive",
821
+ ):
822
+ super().__init__()
823
+
824
+ if num_heads_upsample == -1:
825
+ num_heads_upsample = num_heads
826
+
827
+ self.in_channels = in_channels
828
+ self.model_channels = model_channels
829
+ self.out_channels = out_channels
830
+ self.num_res_blocks = num_res_blocks
831
+ self.attention_resolutions = attention_resolutions
832
+ self.dropout = dropout
833
+ self.channel_mult = channel_mult
834
+ self.conv_resample = conv_resample
835
+ self.use_checkpoint = use_checkpoint
836
+ self.dtype = th.float16 if use_fp16 else th.float32
837
+ self.num_heads = num_heads
838
+ self.num_head_channels = num_head_channels
839
+ self.num_heads_upsample = num_heads_upsample
840
+
841
+ time_embed_dim = model_channels * 4
842
+ self.time_embed = nn.Sequential(
843
+ linear(model_channels, time_embed_dim),
844
+ nn.SiLU(),
845
+ linear(time_embed_dim, time_embed_dim),
846
+ )
847
+
848
+ ch = int(channel_mult[0] * model_channels)
849
+ self.input_blocks = nn.ModuleList(
850
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
851
+ )
852
+ self._feature_size = ch
853
+ input_block_chans = [ch]
854
+ ds = 1
855
+ for level, mult in enumerate(channel_mult):
856
+ for _ in range(num_res_blocks):
857
+ layers = [
858
+ ResBlock(
859
+ ch,
860
+ time_embed_dim,
861
+ dropout,
862
+ out_channels=int(mult * model_channels),
863
+ dims=dims,
864
+ use_checkpoint=use_checkpoint,
865
+ use_scale_shift_norm=use_scale_shift_norm,
866
+ )
867
+ ]
868
+ ch = int(mult * model_channels)
869
+ if ds in attention_resolutions:
870
+ layers.append(
871
+ AttentionBlock(
872
+ ch,
873
+ use_checkpoint=use_checkpoint,
874
+ num_heads=num_heads,
875
+ num_head_channels=num_head_channels,
876
+ use_new_attention_order=use_new_attention_order,
877
+ )
878
+ )
879
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
880
+ self._feature_size += ch
881
+ input_block_chans.append(ch)
882
+ if level != len(channel_mult) - 1:
883
+ out_ch = ch
884
+ self.input_blocks.append(
885
+ TimestepEmbedSequential(
886
+ ResBlock(
887
+ ch,
888
+ time_embed_dim,
889
+ dropout,
890
+ out_channels=out_ch,
891
+ dims=dims,
892
+ use_checkpoint=use_checkpoint,
893
+ use_scale_shift_norm=use_scale_shift_norm,
894
+ down=True,
895
+ )
896
+ if resblock_updown
897
+ else Downsample(
898
+ ch, conv_resample, dims=dims, out_channels=out_ch
899
+ )
900
+ )
901
+ )
902
+ ch = out_ch
903
+ input_block_chans.append(ch)
904
+ ds *= 2
905
+ self._feature_size += ch
906
+
907
+ self.middle_block = TimestepEmbedSequential(
908
+ ResBlock(
909
+ ch,
910
+ time_embed_dim,
911
+ dropout,
912
+ dims=dims,
913
+ use_checkpoint=use_checkpoint,
914
+ use_scale_shift_norm=use_scale_shift_norm,
915
+ ),
916
+ AttentionBlock(
917
+ ch,
918
+ use_checkpoint=use_checkpoint,
919
+ num_heads=num_heads,
920
+ num_head_channels=num_head_channels,
921
+ use_new_attention_order=use_new_attention_order,
922
+ ),
923
+ ResBlock(
924
+ ch,
925
+ time_embed_dim,
926
+ dropout,
927
+ dims=dims,
928
+ use_checkpoint=use_checkpoint,
929
+ use_scale_shift_norm=use_scale_shift_norm,
930
+ ),
931
+ )
932
+ self._feature_size += ch
933
+ self.pool = pool
934
+ if pool == "adaptive":
935
+ self.out = nn.Sequential(
936
+ normalization(ch),
937
+ nn.SiLU(),
938
+ nn.AdaptiveAvgPool2d((1, 1)),
939
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
940
+ nn.Flatten(),
941
+ )
942
+ elif pool == "attention":
943
+ assert num_head_channels != -1
944
+ self.out = nn.Sequential(
945
+ normalization(ch),
946
+ nn.SiLU(),
947
+ AttentionPool2d(
948
+ (image_size // ds), ch, num_head_channels, out_channels
949
+ ),
950
+ )
951
+ elif pool == "spatial":
952
+ self.out = nn.Sequential(
953
+ nn.Linear(self._feature_size, 2048),
954
+ nn.ReLU(),
955
+ nn.Linear(2048, self.out_channels),
956
+ )
957
+ elif pool == "spatial_v2":
958
+ self.out = nn.Sequential(
959
+ nn.Linear(self._feature_size, 2048),
960
+ normalization(2048),
961
+ nn.SiLU(),
962
+ nn.Linear(2048, self.out_channels),
963
+ )
964
+ else:
965
+ raise NotImplementedError(f"Unexpected {pool} pooling")
966
+
967
+ def convert_to_fp16(self):
968
+ """
969
+ Convert the torso of the model to float16.
970
+ """
971
+ self.input_blocks.apply(convert_module_to_f16)
972
+ self.middle_block.apply(convert_module_to_f16)
973
+
974
+ def convert_to_fp32(self):
975
+ """
976
+ Convert the torso of the model to float32.
977
+ """
978
+ self.input_blocks.apply(convert_module_to_f32)
979
+ self.middle_block.apply(convert_module_to_f32)
980
+
981
+ def forward(self, x, timesteps):
982
+ """
983
+ Apply the model to an input batch.
984
+
985
+ :param x: an [N x C x ...] Tensor of inputs.
986
+ :param timesteps: a 1-D batch of timesteps.
987
+ :return: an [N x K] Tensor of outputs.
988
+ """
989
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
990
+
991
+ results = []
992
+ h = x.type(self.dtype)
993
+ for module in self.input_blocks:
994
+ h = module(h, emb)
995
+ if self.pool.startswith("spatial"):
996
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
997
+ h = self.middle_block(h, emb)
998
+
999
+ if self.pool.startswith("spatial"):
1000
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
1001
+ h = th.cat(results, axis=-1)
1002
+ return self.out(h)
1003
+ else:
1004
+ h = h.type(x.dtype)
1005
+ return self.out(h)
1006
+
1007
+
1008
+
1009
+
1010
+ # from abc import abstractmethod
1011
+
1012
+ # import math
1013
+
1014
+ # import numpy as np
1015
+ # import torch as th
1016
+ # import torch.nn as nn
1017
+ # import torch.nn.functional as F
1018
+
1019
+ # from .fp16_util import convert_module_to_f16, convert_module_to_f32
1020
+ # from .nn import (
1021
+ # checkpoint,
1022
+ # conv_nd,
1023
+ # linear,
1024
+ # avg_pool_nd,
1025
+ # zero_module,
1026
+ # normalization,
1027
+ # timestep_embedding,
1028
+ # )
1029
+
1030
+
1031
+ # class AttentionPool2d(nn.Module):
1032
+ # """
1033
+ # Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
1034
+ # """
1035
+
1036
+ # def __init__(
1037
+ # self,
1038
+ # spacial_dim: int,
1039
+ # embed_dim: int,
1040
+ # num_heads_channels: int,
1041
+ # output_dim: int = None,
1042
+ # ):
1043
+ # super().__init__()
1044
+ # self.positional_embedding = nn.Parameter(
1045
+ # th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
1046
+ # )
1047
+ # self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
1048
+ # self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
1049
+ # self.num_heads = embed_dim // num_heads_channels
1050
+ # self.attention = QKVAttention(self.num_heads)
1051
+
1052
+ # def forward(self, x):
1053
+ # b, c, *_spatial = x.shape
1054
+ # x = x.reshape(b, c, -1) # NC(HW)
1055
+ # x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
1056
+ # x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
1057
+ # x = self.qkv_proj(x)
1058
+ # x = self.attention(x)
1059
+ # x = self.c_proj(x)
1060
+ # return x[:, :, 0]
1061
+
1062
+
1063
+ # class TimestepBlock(nn.Module):
1064
+ # """
1065
+ # Any module where forward() takes timestep embeddings as a second argument.
1066
+ # """
1067
+
1068
+ # @abstractmethod
1069
+ # def forward(self, x, emb):
1070
+ # """
1071
+ # Apply the module to `x` given `emb` timestep embeddings.
1072
+ # """
1073
+
1074
+
1075
+ # class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
1076
+ # """
1077
+ # A sequential module that passes timestep embeddings to the children that
1078
+ # support it as an extra input.
1079
+ # """
1080
+
1081
+ # def forward(self, x, emb):
1082
+ # for layer in self:
1083
+ # if isinstance(layer, TimestepBlock):
1084
+ # x = layer(x, emb)
1085
+ # else:
1086
+ # x = layer(x)
1087
+ # return x
1088
+
1089
+
1090
+ # class Upsample(nn.Module):
1091
+ # """
1092
+ # An upsampling layer with an optional convolution.
1093
+
1094
+ # :param channels: channels in the inputs and outputs.
1095
+ # :param use_conv: a bool determining if a convolution is applied.
1096
+ # :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
1097
+ # upsampling occurs in the inner-two dimensions.
1098
+ # """
1099
+
1100
+ # def __init__(self, channels, use_conv, dims=2, out_channels=None):
1101
+ # super().__init__()
1102
+ # self.channels = channels
1103
+ # self.out_channels = out_channels or channels
1104
+ # self.use_conv = use_conv
1105
+ # self.dims = dims
1106
+ # if use_conv:
1107
+ # self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
1108
+
1109
+ # def forward(self, x):
1110
+ # assert x.shape[1] == self.channels
1111
+ # if self.dims == 3:
1112
+ # x = F.interpolate(
1113
+ # x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
1114
+ # )
1115
+ # else:
1116
+ # x = F.interpolate(x, scale_factor=2, mode="nearest")
1117
+ # if self.use_conv:
1118
+ # x = self.conv(x)
1119
+ # return x
1120
+
1121
+
1122
+ # class Downsample(nn.Module):
1123
+ # """
1124
+ # A downsampling layer with an optional convolution.
1125
+
1126
+ # :param channels: channels in the inputs and outputs.
1127
+ # :param use_conv: a bool determining if a convolution is applied.
1128
+ # :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
1129
+ # downsampling occurs in the inner-two dimensions.
1130
+ # """
1131
+
1132
+ # def __init__(self, channels, use_conv, dims=2, out_channels=None):
1133
+ # super().__init__()
1134
+ # self.channels = channels
1135
+ # self.out_channels = out_channels or channels
1136
+ # self.use_conv = use_conv
1137
+ # self.dims = dims
1138
+ # stride = 2 if dims != 3 else (1, 2, 2)
1139
+ # if use_conv:
1140
+ # self.op = conv_nd(
1141
+ # dims, self.channels, self.out_channels, 3, stride=stride, padding=1
1142
+ # )
1143
+ # else:
1144
+ # assert self.channels == self.out_channels
1145
+ # self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
1146
+
1147
+ # def forward(self, x):
1148
+ # assert x.shape[1] == self.channels
1149
+ # return self.op(x)
1150
+
1151
+
1152
+ # class ResBlock(TimestepBlock):
1153
+ # """
1154
+ # A residual block that can optionally change the number of channels.
1155
+
1156
+ # :param channels: the number of input channels.
1157
+ # :param emb_channels: the number of timestep embedding channels.
1158
+ # :param dropout: the rate of dropout.
1159
+ # :param out_channels: if specified, the number of out channels.
1160
+ # :param use_conv: if True and out_channels is specified, use a spatial
1161
+ # convolution instead of a smaller 1x1 convolution to change the
1162
+ # channels in the skip connection.
1163
+ # :param dims: determines if the signal is 1D, 2D, or 3D.
1164
+ # :param use_checkpoint: if True, use gradient checkpointing on this module.
1165
+ # :param up: if True, use this block for upsampling.
1166
+ # :param down: if True, use this block for downsampling.
1167
+ # """
1168
+
1169
+ # def __init__(
1170
+ # self,
1171
+ # channels,
1172
+ # emb_channels,
1173
+ # dropout,
1174
+ # out_channels=None,
1175
+ # use_conv=False,
1176
+ # use_scale_shift_norm=False,
1177
+ # dims=2,
1178
+ # use_checkpoint=False,
1179
+ # up=False,
1180
+ # down=False,
1181
+ # ):
1182
+ # super().__init__()
1183
+ # self.channels = channels
1184
+ # self.emb_channels = emb_channels
1185
+ # self.dropout = dropout
1186
+ # self.out_channels = out_channels or channels
1187
+ # self.use_conv = use_conv
1188
+ # self.use_checkpoint = use_checkpoint
1189
+ # self.use_scale_shift_norm = use_scale_shift_norm
1190
+
1191
+ # self.in_layers = nn.Sequential(
1192
+ # normalization(channels),
1193
+ # nn.SiLU(),
1194
+ # conv_nd(dims, channels, self.out_channels, 3, padding=1),
1195
+ # )
1196
+
1197
+ # self.updown = up or down
1198
+
1199
+ # if up:
1200
+ # self.h_upd = Upsample(channels, False, dims)
1201
+ # self.x_upd = Upsample(channels, False, dims)
1202
+ # elif down:
1203
+ # self.h_upd = Downsample(channels, False, dims)
1204
+ # self.x_upd = Downsample(channels, False, dims)
1205
+ # else:
1206
+ # self.h_upd = self.x_upd = nn.Identity()
1207
+
1208
+ # self.emb_layers = nn.Sequential(
1209
+ # nn.SiLU(),
1210
+ # linear(
1211
+ # emb_channels,
1212
+ # 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
1213
+ # ),
1214
+ # )
1215
+ # self.out_layers = nn.Sequential(
1216
+ # normalization(self.out_channels),
1217
+ # nn.SiLU(),
1218
+ # nn.Dropout(p=dropout),
1219
+ # zero_module(
1220
+ # conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
1221
+ # ),
1222
+ # )
1223
+
1224
+ # if self.out_channels == channels:
1225
+ # self.skip_connection = nn.Identity()
1226
+ # elif use_conv:
1227
+ # self.skip_connection = conv_nd(
1228
+ # dims, channels, self.out_channels, 3, padding=1
1229
+ # )
1230
+ # else:
1231
+ # self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
1232
+
1233
+ # def forward(self, x, emb):
1234
+ # """
1235
+ # Apply the block to a Tensor, conditioned on a timestep embedding.
1236
+
1237
+ # :param x: an [N x C x ...] Tensor of features.
1238
+ # :param emb: an [N x emb_channels] Tensor of timestep embeddings.
1239
+ # :return: an [N x C x ...] Tensor of outputs.
1240
+ # """
1241
+ # return checkpoint(
1242
+ # self._forward, (x, emb), self.parameters(), self.use_checkpoint
1243
+ # )
1244
+
1245
+ # def _forward(self, x, emb):
1246
+ # if self.updown:
1247
+ # in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
1248
+ # h = in_rest(x)
1249
+ # h = self.h_upd(h)
1250
+ # x = self.x_upd(x)
1251
+ # h = in_conv(h)
1252
+ # else:
1253
+ # h = self.in_layers(x)
1254
+ # emb_out = self.emb_layers(emb).type(h.dtype)
1255
+ # while len(emb_out.shape) < len(h.shape):
1256
+ # emb_out = emb_out[..., None]
1257
+ # if self.use_scale_shift_norm:
1258
+ # out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
1259
+ # scale, shift = th.chunk(emb_out, 2, dim=1)
1260
+ # h = out_norm(h) * (1 + scale) + shift
1261
+ # h = out_rest(h)
1262
+ # else:
1263
+ # h = h + emb_out
1264
+ # h = self.out_layers(h)
1265
+ # return self.skip_connection(x) + h
1266
+
1267
+
1268
+ # class AttentionBlock(nn.Module):
1269
+ # """
1270
+ # An attention block that allows spatial positions to attend to each other.
1271
+
1272
+ # Originally ported from here, but adapted to the N-d case.
1273
+ # https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
1274
+ # """
1275
+
1276
+ # def __init__(
1277
+ # self,
1278
+ # channels,
1279
+ # num_heads=1,
1280
+ # num_head_channels=-1,
1281
+ # use_checkpoint=False,
1282
+ # use_new_attention_order=False,
1283
+ # ):
1284
+ # super().__init__()
1285
+ # self.channels = channels
1286
+ # if num_head_channels == -1:
1287
+ # self.num_heads = num_heads
1288
+ # else:
1289
+ # assert (
1290
+ # channels % num_head_channels == 0
1291
+ # ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
1292
+ # self.num_heads = channels // num_head_channels
1293
+ # self.use_checkpoint = use_checkpoint
1294
+ # self.norm = normalization(channels)
1295
+ # self.qkv = conv_nd(1, channels, channels * 3, 1)
1296
+ # if use_new_attention_order:
1297
+ # # split qkv before split heads
1298
+ # self.attention = QKVAttention(self.num_heads)
1299
+ # else:
1300
+ # # split heads before split qkv
1301
+ # self.attention = QKVAttentionLegacy(self.num_heads)
1302
+
1303
+ # self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
1304
+
1305
+ # def forward(self, x):
1306
+ # return checkpoint(self._forward, (x,), self.parameters(), True)
1307
+
1308
+ # def _forward(self, x):
1309
+ # b, c, *spatial = x.shape
1310
+ # x = x.reshape(b, c, -1)
1311
+ # qkv = self.qkv(self.norm(x))
1312
+ # h = self.attention(qkv)
1313
+ # h = self.proj_out(h)
1314
+ # return (x + h).reshape(b, c, *spatial)
1315
+
1316
+
1317
+ # def count_flops_attn(model, _x, y):
1318
+ # """
1319
+ # A counter for the `thop` package to count the operations in an
1320
+ # attention operation.
1321
+ # Meant to be used like:
1322
+ # macs, params = thop.profile(
1323
+ # model,
1324
+ # inputs=(inputs, timestamps),
1325
+ # custom_ops={QKVAttention: QKVAttention.count_flops},
1326
+ # )
1327
+ # """
1328
+ # b, c, *spatial = y[0].shape
1329
+ # num_spatial = int(np.prod(spatial))
1330
+ # # We perform two matmuls with the same number of ops.
1331
+ # # The first computes the weight matrix, the second computes
1332
+ # # the combination of the value vectors.
1333
+ # matmul_ops = 2 * b * (num_spatial ** 2) * c
1334
+ # model.total_ops += th.DoubleTensor([matmul_ops])
1335
+
1336
+
1337
+ # class QKVAttentionLegacy(nn.Module):
1338
+ # """
1339
+ # A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
1340
+ # """
1341
+
1342
+ # def __init__(self, n_heads):
1343
+ # super().__init__()
1344
+ # self.n_heads = n_heads
1345
+
1346
+ # def forward(self, qkv):
1347
+ # """
1348
+ # Apply QKV attention.
1349
+
1350
+ # :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
1351
+ # :return: an [N x (H * C) x T] tensor after attention.
1352
+ # """
1353
+ # bs, width, length = qkv.shape
1354
+ # assert width % (3 * self.n_heads) == 0
1355
+ # ch = width // (3 * self.n_heads)
1356
+ # q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
1357
+ # scale = 1 / math.sqrt(math.sqrt(ch))
1358
+ # weight = th.einsum(
1359
+ # "bct,bcs->bts", q * scale, k * scale
1360
+ # ) # More stable with f16 than dividing afterwards
1361
+ # weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
1362
+ # a = th.einsum("bts,bcs->bct", weight, v)
1363
+ # return a.reshape(bs, -1, length)
1364
+
1365
+ # @staticmethod
1366
+ # def count_flops(model, _x, y):
1367
+ # return count_flops_attn(model, _x, y)
1368
+
1369
+
1370
+ # class QKVAttention(nn.Module):
1371
+ # """
1372
+ # A module which performs QKV attention and splits in a different order.
1373
+ # """
1374
+
1375
+ # def __init__(self, n_heads):
1376
+ # super().__init__()
1377
+ # self.n_heads = n_heads
1378
+
1379
+ # def forward(self, qkv):
1380
+ # """
1381
+ # Apply QKV attention.
1382
+
1383
+ # :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
1384
+ # :return: an [N x (H * C) x T] tensor after attention.
1385
+ # """
1386
+ # bs, width, length = qkv.shape
1387
+ # assert width % (3 * self.n_heads) == 0
1388
+ # ch = width // (3 * self.n_heads)
1389
+ # q, k, v = qkv.chunk(3, dim=1)
1390
+ # scale = 1 / math.sqrt(math.sqrt(ch))
1391
+ # weight = th.einsum(
1392
+ # "bct,bcs->bts",
1393
+ # (q * scale).view(bs * self.n_heads, ch, length),
1394
+ # (k * scale).view(bs * self.n_heads, ch, length),
1395
+ # ) # More stable with f16 than dividing afterwards
1396
+ # weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
1397
+ # a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
1398
+ # return a.reshape(bs, -1, length)
1399
+
1400
+ # @staticmethod
1401
+ # def count_flops(model, _x, y):
1402
+ # return count_flops_attn(model, _x, y)
1403
+
1404
+
1405
+ # class UNetModel(nn.Module):
1406
+ # """
1407
+ # The full UNet model with attention and timestep embedding.
1408
+
1409
+ # :param in_channels: channels in the input Tensor.
1410
+ # :param model_channels: base channel count for the model.
1411
+ # :param out_channels: channels in the output Tensor.
1412
+ # :param num_res_blocks: number of residual blocks per downsample.
1413
+ # :param attention_resolutions: a collection of downsample rates at which
1414
+ # attention will take place. May be a set, list, or tuple.
1415
+ # For example, if this contains 4, then at 4x downsampling, attention
1416
+ # will be used.
1417
+ # :param dropout: the dropout probability.
1418
+ # :param channel_mult: channel multiplier for each level of the UNet.
1419
+ # :param conv_resample: if True, use learned convolutions for upsampling and
1420
+ # downsampling.
1421
+ # :param dims: determines if the signal is 1D, 2D, or 3D.
1422
+ # :param num_classes: if specified (as an int), then this model will be
1423
+ # class-conditional with `num_classes` classes.
1424
+ # :param use_checkpoint: use gradient checkpointing to reduce memory usage.
1425
+ # :param num_heads: the number of attention heads in each attention layer.
1426
+ # :param num_heads_channels: if specified, ignore num_heads and instead use
1427
+ # a fixed channel width per attention head.
1428
+ # :param num_heads_upsample: works with num_heads to set a different number
1429
+ # of heads for upsampling. Deprecated.
1430
+ # :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
1431
+ # :param resblock_updown: use residual blocks for up/downsampling.
1432
+ # :param use_new_attention_order: use a different attention pattern for potentially
1433
+ # increased efficiency.
1434
+ # """
1435
+
1436
+ # def __init__(
1437
+ # self,
1438
+ # image_size,
1439
+ # in_channels,
1440
+ # model_channels,
1441
+ # out_channels,
1442
+ # num_res_blocks,
1443
+ # attention_resolutions,
1444
+ # dropout=0,
1445
+ # channel_mult=(1, 2, 4, 8),
1446
+ # conv_resample=True,
1447
+ # dims=2,
1448
+ # num_classes=None,
1449
+ # use_checkpoint=False,
1450
+ # use_fp16=False,
1451
+ # num_heads=1,
1452
+ # num_head_channels=-1,
1453
+ # num_heads_upsample=-1,
1454
+ # use_scale_shift_norm=False,
1455
+ # resblock_updown=False,
1456
+ # use_new_attention_order=False,
1457
+ # ):
1458
+ # super().__init__()
1459
+
1460
+ # if num_heads_upsample == -1:
1461
+ # num_heads_upsample = num_heads
1462
+
1463
+ # self.image_size = image_size
1464
+ # self.in_channels = in_channels
1465
+ # self.model_channels = model_channels
1466
+ # self.out_channels = out_channels
1467
+ # self.num_res_blocks = num_res_blocks
1468
+ # self.attention_resolutions = attention_resolutions
1469
+ # self.dropout = dropout
1470
+ # self.channel_mult = channel_mult
1471
+ # self.conv_resample = conv_resample
1472
+ # self.num_classes = num_classes
1473
+ # self.use_checkpoint = use_checkpoint
1474
+ # self.dtype = th.float16 if use_fp16 else th.float32
1475
+ # self.num_heads = num_heads
1476
+ # self.num_head_channels = num_head_channels
1477
+ # self.num_heads_upsample = num_heads_upsample
1478
+
1479
+ # time_embed_dim = model_channels * 4
1480
+ # self.time_embed = nn.Sequential(
1481
+ # linear(model_channels, time_embed_dim),
1482
+ # nn.SiLU(),
1483
+ # linear(time_embed_dim, time_embed_dim),
1484
+ # )
1485
+
1486
+ # if self.num_classes is not None:
1487
+ # self.label_emb = nn.Embedding(num_classes, time_embed_dim)
1488
+
1489
+ # ch = input_ch = int(channel_mult[0] * model_channels)
1490
+ # self.input_blocks = nn.ModuleList(
1491
+ # [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
1492
+ # )
1493
+ # self._feature_size = ch
1494
+ # input_block_chans = [ch]
1495
+ # ds = 1
1496
+ # for level, mult in enumerate(channel_mult):
1497
+ # for _ in range(num_res_blocks):
1498
+ # layers = [
1499
+ # ResBlock(
1500
+ # ch,
1501
+ # time_embed_dim,
1502
+ # dropout,
1503
+ # out_channels=int(mult * model_channels),
1504
+ # dims=dims,
1505
+ # use_checkpoint=use_checkpoint,
1506
+ # use_scale_shift_norm=use_scale_shift_norm,
1507
+ # )
1508
+ # ]
1509
+ # ch = int(mult * model_channels)
1510
+ # if ds in attention_resolutions:
1511
+ # layers.append(
1512
+ # AttentionBlock(
1513
+ # ch,
1514
+ # use_checkpoint=use_checkpoint,
1515
+ # num_heads=num_heads,
1516
+ # num_head_channels=num_head_channels,
1517
+ # use_new_attention_order=use_new_attention_order,
1518
+ # )
1519
+ # )
1520
+ # self.input_blocks.append(TimestepEmbedSequential(*layers))
1521
+ # self._feature_size += ch
1522
+ # input_block_chans.append(ch)
1523
+ # if level != len(channel_mult) - 1:
1524
+ # out_ch = ch
1525
+ # self.input_blocks.append(
1526
+ # TimestepEmbedSequential(
1527
+ # ResBlock(
1528
+ # ch,
1529
+ # time_embed_dim,
1530
+ # dropout,
1531
+ # out_channels=out_ch,
1532
+ # dims=dims,
1533
+ # use_checkpoint=use_checkpoint,
1534
+ # use_scale_shift_norm=use_scale_shift_norm,
1535
+ # down=True,
1536
+ # )
1537
+ # if resblock_updown
1538
+ # else Downsample(
1539
+ # ch, conv_resample, dims=dims, out_channels=out_ch
1540
+ # )
1541
+ # )
1542
+ # )
1543
+ # ch = out_ch
1544
+ # input_block_chans.append(ch)
1545
+ # ds *= 2
1546
+ # self._feature_size += ch
1547
+
1548
+ # self.middle_block = TimestepEmbedSequential(
1549
+ # ResBlock(
1550
+ # ch,
1551
+ # time_embed_dim,
1552
+ # dropout,
1553
+ # dims=dims,
1554
+ # use_checkpoint=use_checkpoint,
1555
+ # use_scale_shift_norm=use_scale_shift_norm,
1556
+ # ),
1557
+ # AttentionBlock(
1558
+ # ch,
1559
+ # use_checkpoint=use_checkpoint,
1560
+ # num_heads=num_heads,
1561
+ # num_head_channels=num_head_channels,
1562
+ # use_new_attention_order=use_new_attention_order,
1563
+ # ),
1564
+ # ResBlock(
1565
+ # ch,
1566
+ # time_embed_dim,
1567
+ # dropout,
1568
+ # dims=dims,
1569
+ # use_checkpoint=use_checkpoint,
1570
+ # use_scale_shift_norm=use_scale_shift_norm,
1571
+ # ),
1572
+ # )
1573
+ # self._feature_size += ch
1574
+
1575
+ # self.output_blocks = nn.ModuleList([])
1576
+ # for level, mult in list(enumerate(channel_mult))[::-1]:
1577
+ # for i in range(num_res_blocks + 1):
1578
+ # ich = input_block_chans.pop()
1579
+ # layers = [
1580
+ # ResBlock(
1581
+ # ch + ich,
1582
+ # time_embed_dim,
1583
+ # dropout,
1584
+ # out_channels=int(model_channels * mult),
1585
+ # dims=dims,
1586
+ # use_checkpoint=use_checkpoint,
1587
+ # use_scale_shift_norm=use_scale_shift_norm,
1588
+ # )
1589
+ # ]
1590
+ # ch = int(model_channels * mult)
1591
+ # if ds in attention_resolutions:
1592
+ # layers.append(
1593
+ # AttentionBlock(
1594
+ # ch,
1595
+ # use_checkpoint=use_checkpoint,
1596
+ # num_heads=num_heads_upsample,
1597
+ # num_head_channels=num_head_channels,
1598
+ # use_new_attention_order=use_new_attention_order,
1599
+ # )
1600
+ # )
1601
+ # if level and i == num_res_blocks:
1602
+ # out_ch = ch
1603
+ # layers.append(
1604
+ # ResBlock(
1605
+ # ch,
1606
+ # time_embed_dim,
1607
+ # dropout,
1608
+ # out_channels=out_ch,
1609
+ # dims=dims,
1610
+ # use_checkpoint=use_checkpoint,
1611
+ # use_scale_shift_norm=use_scale_shift_norm,
1612
+ # up=True,
1613
+ # )
1614
+ # if resblock_updown
1615
+ # else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
1616
+ # )
1617
+ # ds //= 2
1618
+ # self.output_blocks.append(TimestepEmbedSequential(*layers))
1619
+ # self._feature_size += ch
1620
+
1621
+ # self.out = nn.Sequential(
1622
+ # normalization(ch),
1623
+ # nn.SiLU(),
1624
+ # zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
1625
+ # )
1626
+
1627
+ # def convert_to_fp16(self):
1628
+ # """
1629
+ # Convert the torso of the model to float16.
1630
+ # """
1631
+ # self.input_blocks.apply(convert_module_to_f16)
1632
+ # self.middle_block.apply(convert_module_to_f16)
1633
+ # self.output_blocks.apply(convert_module_to_f16)
1634
+
1635
+ # def convert_to_fp32(self):
1636
+ # """
1637
+ # Convert the torso of the model to float32.
1638
+ # """
1639
+ # self.input_blocks.apply(convert_module_to_f32)
1640
+ # self.middle_block.apply(convert_module_to_f32)
1641
+ # self.output_blocks.apply(convert_module_to_f32)
1642
+
1643
+ # def forward(self, x, timesteps,y=None, **kwargs):
1644
+ # """
1645
+ # Apply the model to an input batch.
1646
+
1647
+ # :param x: an [N x C x ...] Tensor of inputs.
1648
+ # :param timesteps: a 1-D batch of timesteps.
1649
+ # :param y: an [N] Tensor of labels, if class-conditional.
1650
+ # :return: an [N x C x ...] Tensor of outputs.
1651
+ # """
1652
+ # # y=None
1653
+ # # assert (y is not None) == (
1654
+ # # self.num_classes is not None
1655
+ # # ), "must specify y if and only if the model is class-conditional"
1656
+
1657
+ # hs = []
1658
+ # emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
1659
+
1660
+ # # if self.num_classes is not None:
1661
+ # # assert y.shape == (x.shape[0],)
1662
+ # # emb = emb + self.label_emb(y)
1663
+
1664
+ # h = x.type(self.dtype)
1665
+ # for module in self.input_blocks:
1666
+ # h = module(h, emb)
1667
+ # hs.append(h)
1668
+ # h = self.middle_block(h, emb)
1669
+ # for module in self.output_blocks:
1670
+ # h = th.cat([h, hs.pop()], dim=1)
1671
+ # h = module(h, emb)
1672
+ # h = h.type(x.dtype)
1673
+ # return self.out(h)
1674
+
1675
+
1676
+ # class SuperResModel(UNetModel):
1677
+ # """
1678
+ # A UNetModel that performs super-resolution.
1679
+
1680
+ # Expects an extra kwarg `low_res` to condition on a low-resolution image.
1681
+ # """
1682
+
1683
+ # def __init__(self, image_size, in_channels, *args, **kwargs):
1684
+ # super().__init__(image_size, in_channels * 2, *args, **kwargs)
1685
+
1686
+ # def forward(self, x, timesteps, low_res=None, **kwargs):
1687
+ # _, _, new_height, new_width = x.shape
1688
+ # # upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
1689
+ # # for _ in kwargs:
1690
+ # # print(_)
1691
+ # # stop
1692
+ # # upsampled = kwargs["SR"]
1693
+ # # x = th.cat([x, upsampled], dim=1)
1694
+ # return super().forward(x, timesteps, **kwargs)
1695
+
1696
+
1697
+ # class EncoderUNetModel(nn.Module):
1698
+ # """
1699
+ # The half UNet model with attention and timestep embedding.
1700
+
1701
+ # For usage, see UNet.
1702
+ # """
1703
+
1704
+ # def __init__(
1705
+ # self,
1706
+ # image_size,
1707
+ # in_channels,
1708
+ # model_channels,
1709
+ # out_channels,
1710
+ # num_res_blocks,
1711
+ # attention_resolutions,
1712
+ # dropout=0,
1713
+ # channel_mult=(1, 2, 4, 8),
1714
+ # conv_resample=True,
1715
+ # dims=2,
1716
+ # use_checkpoint=False,
1717
+ # use_fp16=False,
1718
+ # num_heads=1,
1719
+ # num_head_channels=-1,
1720
+ # num_heads_upsample=-1,
1721
+ # use_scale_shift_norm=False,
1722
+ # resblock_updown=False,
1723
+ # use_new_attention_order=False,
1724
+ # pool="adaptive",
1725
+ # ):
1726
+ # super().__init__()
1727
+
1728
+ # if num_heads_upsample == -1:
1729
+ # num_heads_upsample = num_heads
1730
+
1731
+ # self.in_channels = in_channels
1732
+ # self.model_channels = model_channels
1733
+ # self.out_channels = out_channels
1734
+ # self.num_res_blocks = num_res_blocks
1735
+ # self.attention_resolutions = attention_resolutions
1736
+ # self.dropout = dropout
1737
+ # self.channel_mult = channel_mult
1738
+ # self.conv_resample = conv_resample
1739
+ # self.use_checkpoint = use_checkpoint
1740
+ # self.dtype = th.float16 if use_fp16 else th.float32
1741
+ # self.num_heads = num_heads
1742
+ # self.num_head_channels = num_head_channels
1743
+ # self.num_heads_upsample = num_heads_upsample
1744
+
1745
+ # time_embed_dim = model_channels * 4
1746
+ # self.time_embed = nn.Sequential(
1747
+ # linear(model_channels, time_embed_dim),
1748
+ # nn.SiLU(),
1749
+ # linear(time_embed_dim, time_embed_dim),
1750
+ # )
1751
+
1752
+ # ch = int(channel_mult[0] * model_channels)
1753
+ # self.input_blocks = nn.ModuleList(
1754
+ # [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
1755
+ # )
1756
+ # self._feature_size = ch
1757
+ # input_block_chans = [ch]
1758
+ # ds = 1
1759
+ # for level, mult in enumerate(channel_mult):
1760
+ # for _ in range(num_res_blocks):
1761
+ # layers = [
1762
+ # ResBlock(
1763
+ # ch,
1764
+ # time_embed_dim,
1765
+ # dropout,
1766
+ # out_channels=int(mult * model_channels),
1767
+ # dims=dims,
1768
+ # use_checkpoint=use_checkpoint,
1769
+ # use_scale_shift_norm=use_scale_shift_norm,
1770
+ # )
1771
+ # ]
1772
+ # ch = int(mult * model_channels)
1773
+ # if ds in attention_resolutions:
1774
+ # layers.append(
1775
+ # AttentionBlock(
1776
+ # ch,
1777
+ # use_checkpoint=use_checkpoint,
1778
+ # num_heads=num_heads,
1779
+ # num_head_channels=num_head_channels,
1780
+ # use_new_attention_order=use_new_attention_order,
1781
+ # )
1782
+ # )
1783
+ # self.input_blocks.append(TimestepEmbedSequential(*layers))
1784
+ # self._feature_size += ch
1785
+ # input_block_chans.append(ch)
1786
+ # if level != len(channel_mult) - 1:
1787
+ # out_ch = ch
1788
+ # self.input_blocks.append(
1789
+ # TimestepEmbedSequential(
1790
+ # ResBlock(
1791
+ # ch,
1792
+ # time_embed_dim,
1793
+ # dropout,
1794
+ # out_channels=out_ch,
1795
+ # dims=dims,
1796
+ # use_checkpoint=use_checkpoint,
1797
+ # use_scale_shift_norm=use_scale_shift_norm,
1798
+ # down=True,
1799
+ # )
1800
+ # if resblock_updown
1801
+ # else Downsample(
1802
+ # ch, conv_resample, dims=dims, out_channels=out_ch
1803
+ # )
1804
+ # )
1805
+ # )
1806
+ # ch = out_ch
1807
+ # input_block_chans.append(ch)
1808
+ # ds *= 2
1809
+ # self._feature_size += ch
1810
+
1811
+ # self.middle_block = TimestepEmbedSequential(
1812
+ # ResBlock(
1813
+ # ch,
1814
+ # time_embed_dim,
1815
+ # dropout,
1816
+ # dims=dims,
1817
+ # use_checkpoint=use_checkpoint,
1818
+ # use_scale_shift_norm=use_scale_shift_norm,
1819
+ # ),
1820
+ # AttentionBlock(
1821
+ # ch,
1822
+ # use_checkpoint=use_checkpoint,
1823
+ # num_heads=num_heads,
1824
+ # num_head_channels=num_head_channels,
1825
+ # use_new_attention_order=use_new_attention_order,
1826
+ # ),
1827
+ # ResBlock(
1828
+ # ch,
1829
+ # time_embed_dim,
1830
+ # dropout,
1831
+ # dims=dims,
1832
+ # use_checkpoint=use_checkpoint,
1833
+ # use_scale_shift_norm=use_scale_shift_norm,
1834
+ # ),
1835
+ # )
1836
+ # self._feature_size += ch
1837
+ # self.pool = pool
1838
+ # if pool == "adaptive":
1839
+ # self.out = nn.Sequential(
1840
+ # normalization(ch),
1841
+ # nn.SiLU(),
1842
+ # nn.AdaptiveAvgPool2d((1, 1)),
1843
+ # zero_module(conv_nd(dims, ch, out_channels, 1)),
1844
+ # nn.Flatten(),
1845
+ # )
1846
+ # elif pool == "attention":
1847
+ # assert num_head_channels != -1
1848
+ # self.out = nn.Sequential(
1849
+ # normalization(ch),
1850
+ # nn.SiLU(),
1851
+ # AttentionPool2d(
1852
+ # (image_size // ds), ch, num_head_channels, out_channels
1853
+ # ),
1854
+ # )
1855
+ # elif pool == "spatial":
1856
+ # self.out = nn.Sequential(
1857
+ # nn.Linear(self._feature_size, 2048),
1858
+ # nn.ReLU(),
1859
+ # nn.Linear(2048, self.out_channels),
1860
+ # )
1861
+ # elif pool == "spatial_v2":
1862
+ # self.out = nn.Sequential(
1863
+ # nn.Linear(self._feature_size, 2048),
1864
+ # normalization(2048),
1865
+ # nn.SiLU(),
1866
+ # nn.Linear(2048, self.out_channels),
1867
+ # )
1868
+ # else:
1869
+ # raise NotImplementedError(f"Unexpected {pool} pooling")
1870
+
1871
+ # def convert_to_fp16(self):
1872
+ # """
1873
+ # Convert the torso of the model to float16.
1874
+ # """
1875
+ # self.input_blocks.apply(convert_module_to_f16)
1876
+ # self.middle_block.apply(convert_module_to_f16)
1877
+
1878
+ # def convert_to_fp32(self):
1879
+ # """
1880
+ # Convert the torso of the model to float32.
1881
+ # """
1882
+ # self.input_blocks.apply(convert_module_to_f32)
1883
+ # self.middle_block.apply(convert_module_to_f32)
1884
+
1885
+ # def forward(self, x, timesteps):
1886
+ # """
1887
+ # Apply the model to an input batch.
1888
+
1889
+ # :param x: an [N x C x ...] Tensor of inputs.
1890
+ # :param timesteps: a 1-D batch of timesteps.goo
1891
+ # :return: an [N x K] Tensor of outputs.
1892
+ # """
1893
+ # emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
1894
+
1895
+ # results = []
1896
+ # h = x.type(self.dtype)
1897
+ # for module in self.input_blocks:
1898
+ # h = module(h, emb)
1899
+ # if self.pool.startswith("spatial"):
1900
+ # results.append(h.type(x.dtype).mean(dim=(2, 3)))
1901
+ # h = self.middle_block(h, emb)
1902
+ # if self.pool.startswith("spatial"):
1903
+ # results.append(h.type(x.dtype).mean(dim=(2, 3)))
1904
+ # h = th.cat(results, axis=-1)
1905
+ # return self.out(h)
1906
+ # else:
1907
+ # h = h.type(x.dtype)
1908
+ # return self.out(h)
guided_diffusion/unet2.py ADDED
@@ -0,0 +1,1181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+
3
+ import math
4
+
5
+ import numpy as np
6
+ import torch as th
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from .fp16_util import convert_module_to_f16, convert_module_to_f32
11
+ from .nn import (
12
+ checkpoint,
13
+ conv_nd,
14
+ linear,
15
+ avg_pool_nd,
16
+ zero_module,
17
+ normalization,
18
+ timestep_embedding,
19
+ )
20
+
21
+
22
+ # from models.submodules import *
23
+ import torchvision.models
24
+
25
+ class VGG19(nn.Module):
26
+ def __init__(self):
27
+ super(VGG19, self).__init__()
28
+ '''
29
+ use vgg19 conv1_2, conv2_2, conv3_3 feature, before relu layer
30
+ '''
31
+ self.feature_list = [2, 7, 14]
32
+ vgg19 = torchvision.models.vgg19(pretrained=True)
33
+
34
+ self.model = th.nn.Sequential(*list(vgg19.features.children())[:self.feature_list[-1]+1])
35
+ # self.model.apply(convert_module_to_f16)
36
+
37
+ def forward(self, x , emb):
38
+ # x = (x-0.5)/0.5
39
+ features = []
40
+ for i, layer in enumerate(list(self.model)):
41
+ # print(layer,i)
42
+ x = layer(x)
43
+ if i in self.feature_list:
44
+ features.append(x)
45
+ # print(x.shape)
46
+ return features
47
+
48
+
49
+
50
+ class AttentionPool2d(nn.Module):
51
+ """
52
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ spacial_dim: int,
58
+ embed_dim: int,
59
+ num_heads_channels: int,
60
+ output_dim: int = None,
61
+ ):
62
+ super().__init__()
63
+ self.positional_embedding = nn.Parameter(
64
+ th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
65
+ )
66
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
67
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
68
+ self.num_heads = embed_dim // num_heads_channels
69
+ self.attention = QKVAttention(self.num_heads)
70
+
71
+ def forward(self, x):
72
+ b, c, *_spatial = x.shape
73
+ x = x.reshape(b, c, -1) # NC(HW)
74
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
75
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
76
+ x = self.qkv_proj(x)
77
+ x = self.attention(x)
78
+ x = self.c_proj(x)
79
+ return x[:, :, 0]
80
+
81
+
82
+ class TimestepBlock(nn.Module):
83
+ """
84
+ Any module where forward() takes timestep embeddings as a second argument.
85
+ """
86
+
87
+ @abstractmethod
88
+ def forward(self, x, emb):
89
+ """
90
+ Apply the module to `x` given `emb` timestep embeddings.
91
+ """
92
+
93
+
94
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
95
+ """
96
+ A sequential module that passes timestep embeddings to the children that
97
+ support it as an extra input.
98
+ """
99
+
100
+ def forward(self, x, emb,zsem):
101
+ for layer in self:
102
+ if isinstance(layer, TimestepBlock):
103
+ x = layer(x, emb,zsem)
104
+ else:
105
+ x = layer(x)
106
+ return x
107
+
108
+
109
+ class TimestepEmbedSequential1(nn.Sequential, TimestepBlock):
110
+ """
111
+ A sequential module that passes timestep embeddings to the children that
112
+ support it as an extra input.
113
+ """
114
+
115
+ def forward(self, x, emb):
116
+ for layer in self:
117
+ if isinstance(layer, TimestepBlock):
118
+ x = layer(x, emb)
119
+ else:
120
+ x = layer(x)
121
+ return x
122
+
123
+ class Upsample(nn.Module):
124
+ """
125
+ An upsampling layer with an optional convolution.
126
+
127
+ :param channels: channels in the inputs and outputs.
128
+ :param use_conv: a bool determining if a convolution is applied.
129
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
130
+ upsampling occurs in the inner-two dimensions.
131
+ """
132
+
133
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
134
+ super().__init__()
135
+ self.channels = channels
136
+ self.out_channels = out_channels or channels
137
+ self.use_conv = use_conv
138
+ self.dims = dims
139
+ if use_conv:
140
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
141
+
142
+ def forward(self, x):
143
+ assert x.shape[1] == self.channels
144
+ if self.dims == 3:
145
+ x = F.interpolate(
146
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
147
+ )
148
+ else:
149
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
150
+ if self.use_conv:
151
+ x = self.conv(x)
152
+ return x
153
+
154
+
155
+ class Downsample(nn.Module):
156
+ """
157
+ A downsampling layer with an optional convolution.
158
+
159
+ :param channels: channels in the inputs and outputs.
160
+ :param use_conv: a bool determining if a convolution is applied.
161
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
162
+ downsampling occurs in the inner-two dimensions.
163
+ """
164
+
165
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
166
+ super().__init__()
167
+ self.channels = channels
168
+ self.out_channels = out_channels or channels
169
+ self.use_conv = use_conv
170
+ self.dims = dims
171
+ stride = 2 if dims != 3 else (1, 2, 2)
172
+ if use_conv:
173
+ self.op = conv_nd(
174
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=1
175
+ )
176
+ else:
177
+ assert self.channels == self.out_channels
178
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
179
+
180
+ def forward(self, x):
181
+ assert x.shape[1] == self.channels
182
+ return self.op(x)
183
+
184
+
185
+ class ResBlock(TimestepBlock):
186
+ """
187
+ A residual block that can optionally change the number of channels.
188
+
189
+ :param channels: the number of input channels.
190
+ :param emb_channels: the number of timestep embedding channels.
191
+ :param dropout: the rate of dropout.
192
+ :param out_channels: if specified, the number of out channels.
193
+ :param use_conv: if True and out_channels is specified, use a spatial
194
+ convolution instead of a smaller 1x1 convolution to change the
195
+ channels in the skip connection.
196
+ :param dims: determines if the signal is 1D, 2D, or 3D.
197
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
198
+ :param up: if True, use this block for upsampling.
199
+ :param down: if True, use this block for downsampling.
200
+ """
201
+
202
+ def __init__(
203
+ self,
204
+ channels,
205
+ emb_channels,
206
+ dropout,
207
+ out_channels=None,
208
+ use_conv=False,
209
+ use_scale_shift_norm=False,
210
+ dims=2,
211
+ use_checkpoint=False,
212
+ up=False,
213
+ down=False,
214
+ ):
215
+ super().__init__()
216
+ self.channels = channels
217
+ self.emb_channels = emb_channels
218
+ self.dropout = dropout
219
+ self.out_channels = out_channels or channels
220
+ self.use_conv = use_conv
221
+ self.use_checkpoint = use_checkpoint
222
+ self.use_scale_shift_norm = use_scale_shift_norm
223
+
224
+ self.in_layers = nn.Sequential(
225
+ normalization(channels),
226
+ nn.SiLU(),
227
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
228
+ )
229
+
230
+ self.updown = up or down
231
+
232
+ if up:
233
+ self.h_upd = Upsample(channels, False, dims)
234
+ self.x_upd = Upsample(channels, False, dims)
235
+ elif down:
236
+ self.h_upd = Downsample(channels, False, dims)
237
+ self.x_upd = Downsample(channels, False, dims)
238
+ else:
239
+ self.h_upd = self.x_upd = nn.Identity()
240
+
241
+ self.emb_layers = nn.Sequential(
242
+ nn.SiLU(),
243
+ linear(
244
+ emb_channels,
245
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
246
+ ),
247
+ )
248
+ self.out_layers = nn.Sequential(
249
+ normalization(self.out_channels),
250
+ nn.SiLU(),
251
+ nn.Dropout(p=dropout),
252
+ zero_module(
253
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
254
+ ),
255
+ )
256
+ self.sem_layers = nn.Sequential(
257
+ nn.SiLU(),
258
+ linear(
259
+ 512,
260
+ self.out_channels ,
261
+ ),
262
+ )
263
+ if self.out_channels == channels:
264
+ self.skip_connection = nn.Identity()
265
+ elif use_conv:
266
+ self.skip_connection = conv_nd(
267
+ dims, channels, self.out_channels, 3, padding=1
268
+ )
269
+ else:
270
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
271
+
272
+ def forward(self, x, emb,sem):
273
+ """
274
+ Apply the block to a Tensor, conditioned on a timestep embedding.
275
+
276
+ :param x: an [N x C x ...] Tensor of features.
277
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
278
+ :return: an [N x C x ...] Tensor of outputs.
279
+ """
280
+ return checkpoint(
281
+ self._forward, (x, emb, sem), self.parameters(), self.use_checkpoint
282
+ )
283
+
284
+ def _forward(self, x, emb,zsem):
285
+ if self.updown:
286
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
287
+ h = in_rest(x)
288
+ h = self.h_upd(h)
289
+ x = self.x_upd(x)
290
+ h = in_conv(h)
291
+ else:
292
+ h = self.in_layers(x)
293
+ emb_out = self.emb_layers(emb).type(h.dtype)
294
+ # print(zsem.shape)
295
+ sem_out = self.sem_layers(zsem).type(h.dtype)
296
+
297
+ while len(emb_out.shape) < len(h.shape):
298
+ emb_out = emb_out[..., None]
299
+ while len(sem_out.shape) < len(h.shape):
300
+ sem_out = sem_out[..., None]
301
+ if self.use_scale_shift_norm:
302
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
303
+ scale, shift = th.chunk(emb_out, 2, dim=1)
304
+ h = out_norm(h) * (1 + scale) + shift
305
+ # print(h.shape,sem_out.shape,scale.shape)
306
+ h=h*sem_out
307
+ h = out_rest(h)
308
+ else:
309
+ h = h + emb_out
310
+ h = self.out_layers(h)
311
+ return self.skip_connection(x) + h
312
+
313
+ class ResBlock1(TimestepBlock):
314
+ """
315
+ A residual block that can optionally change the number of channels.
316
+
317
+ :param channels: the number of input channels.
318
+ :param emb_channels: the number of timestep embedding channels.
319
+ :param dropout: the rate of dropout.
320
+ :param out_channels: if specified, the number of out channels.
321
+ :param use_conv: if True and out_channels is specified, use a spatial
322
+ convolution instead of a smaller 1x1 convolution to change the
323
+ channels in the skip connection.
324
+ :param dims: determines if the signal is 1D, 2D, or 3D.
325
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
326
+ :param up: if True, use this block for upsampling.
327
+ :param down: if True, use this block for downsampling.
328
+ """
329
+
330
+ def __init__(
331
+ self,
332
+ channels,
333
+ emb_channels,
334
+ dropout,
335
+ out_channels=None,
336
+ use_conv=False,
337
+ use_scale_shift_norm=False,
338
+ dims=2,
339
+ use_checkpoint=False,
340
+ up=False,
341
+ down=False,
342
+ ):
343
+ super().__init__()
344
+ self.channels = channels
345
+ self.emb_channels = emb_channels
346
+ self.dropout = dropout
347
+ self.out_channels = out_channels or channels
348
+ self.use_conv = use_conv
349
+ self.use_checkpoint = use_checkpoint
350
+ self.use_scale_shift_norm = use_scale_shift_norm
351
+
352
+ self.in_layers = nn.Sequential(
353
+ normalization(channels),
354
+ nn.SiLU(),
355
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
356
+ )
357
+
358
+ self.updown = up or down
359
+
360
+ if up:
361
+ self.h_upd = Upsample(channels, False, dims)
362
+ self.x_upd = Upsample(channels, False, dims)
363
+ elif down:
364
+ self.h_upd = Downsample(channels, False, dims)
365
+ self.x_upd = Downsample(channels, False, dims)
366
+ else:
367
+ self.h_upd = self.x_upd = nn.Identity()
368
+
369
+ self.emb_layers = nn.Sequential(
370
+ nn.SiLU(),
371
+ linear(
372
+ emb_channels,
373
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
374
+ ),
375
+ )
376
+ self.out_layers = nn.Sequential(
377
+ normalization(self.out_channels),
378
+ nn.SiLU(),
379
+ nn.Dropout(p=dropout),
380
+ zero_module(
381
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
382
+ ),
383
+ )
384
+ # self.sem_layers = nn.Sequential(
385
+ # nn.SiLU(),
386
+ # linear(
387
+ # emb_channels,
388
+ # self.out_channels if use_scale_shift_norm else self.out_channels,
389
+ # ),
390
+ # )
391
+ if self.out_channels == channels:
392
+ self.skip_connection = nn.Identity()
393
+ elif use_conv:
394
+ self.skip_connection = conv_nd(
395
+ dims, channels, self.out_channels, 3, padding=1
396
+ )
397
+ else:
398
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
399
+
400
+ def forward(self, x, emb):
401
+ """
402
+ Apply the block to a Tensor, conditioned on a timestep embedding.
403
+
404
+ :param x: an [N x C x ...] Tensor of features.
405
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
406
+ :return: an [N x C x ...] Tensor of outputs.
407
+ """
408
+ return checkpoint(
409
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
410
+ )
411
+
412
+ def _forward(self, x, emb):
413
+ if self.updown:
414
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
415
+ h = in_rest(x)
416
+ h = self.h_upd(h)
417
+ x = self.x_upd(x)
418
+ h = in_conv(h)
419
+ else:
420
+ h = self.in_layers(x)
421
+ emb_out = self.emb_layers(emb).type(h.dtype)
422
+ # sem_out = self.sem_layers(zsem).type(h.dtype)
423
+
424
+ while len(emb_out.shape) < len(h.shape):
425
+ emb_out = emb_out[..., None]
426
+ if self.use_scale_shift_norm:
427
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
428
+ scale, shift = th.chunk(emb_out, 2, dim=1)
429
+ h = out_norm(h) * (1 + scale) + shift
430
+ # h=h*sem_out
431
+ h = out_rest(h)
432
+ else:
433
+ h = h + emb_out
434
+ h = self.out_layers(h)
435
+ return self.skip_connection(x) + h
436
+
437
+ class AttentionBlock(nn.Module):
438
+ """
439
+ An attention block that allows spatial positions to attend to each other.
440
+
441
+ Originally ported from here, but adapted to the N-d case.
442
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
443
+ """
444
+
445
+ def __init__(
446
+ self,
447
+ channels,
448
+ num_heads=1,
449
+ num_head_channels=-1,
450
+ use_checkpoint=False,
451
+ use_new_attention_order=False,
452
+ ):
453
+ super().__init__()
454
+ self.channels = channels
455
+ if num_head_channels == -1:
456
+ self.num_heads = num_heads
457
+ else:
458
+ assert (
459
+ channels % num_head_channels == 0
460
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
461
+ self.num_heads = channels // num_head_channels
462
+ self.use_checkpoint = use_checkpoint
463
+ self.norm = normalization(channels)
464
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
465
+ if use_new_attention_order:
466
+ # split qkv before split heads
467
+ self.attention = QKVAttention(self.num_heads)
468
+ else:
469
+ # split heads before split qkv
470
+ self.attention = QKVAttentionLegacy(self.num_heads)
471
+
472
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
473
+
474
+ def forward(self, x):
475
+ return checkpoint(self._forward, (x,), self.parameters(), True)
476
+
477
+ def _forward(self, x):
478
+ b, c, *spatial = x.shape
479
+ x = x.reshape(b, c, -1)
480
+ qkv = self.qkv(self.norm(x))
481
+ h = self.attention(qkv)
482
+ h = self.proj_out(h)
483
+ return (x + h).reshape(b, c, *spatial)
484
+
485
+
486
+ def count_flops_attn(model, _x, y):
487
+ """
488
+ A counter for the `thop` package to count the operations in an
489
+ attention operation.
490
+ Meant to be used like:
491
+ macs, params = thop.profile(
492
+ model,
493
+ inputs=(inputs, timestamps),
494
+ custom_ops={QKVAttention: QKVAttention.count_flops},
495
+ )
496
+ """
497
+ b, c, *spatial = y[0].shape
498
+ num_spatial = int(np.prod(spatial))
499
+ # We perform two matmuls with the same number of ops.
500
+ # The first computes the weight matrix, the second computes
501
+ # the combination of the value vectors.
502
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
503
+ model.total_ops += th.DoubleTensor([matmul_ops])
504
+
505
+
506
+ class QKVAttentionLegacy(nn.Module):
507
+ """
508
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
509
+ """
510
+
511
+ def __init__(self, n_heads):
512
+ super().__init__()
513
+ self.n_heads = n_heads
514
+
515
+ def forward(self, qkv):
516
+ """
517
+ Apply QKV attention.
518
+
519
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
520
+ :return: an [N x (H * C) x T] tensor after attention.
521
+ """
522
+ bs, width, length = qkv.shape
523
+ assert width % (3 * self.n_heads) == 0
524
+ ch = width // (3 * self.n_heads)
525
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
526
+ scale = 1 / math.sqrt(math.sqrt(ch))
527
+ weight = th.einsum(
528
+ "bct,bcs->bts", q * scale, k * scale
529
+ ) # More stable with f16 than dividing afterwards
530
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
531
+ a = th.einsum("bts,bcs->bct", weight, v)
532
+ return a.reshape(bs, -1, length)
533
+
534
+ @staticmethod
535
+ def count_flops(model, _x, y):
536
+ return count_flops_attn(model, _x, y)
537
+
538
+
539
+ class QKVAttention(nn.Module):
540
+ """
541
+ A module which performs QKV attention and splits in a different order.
542
+ """
543
+
544
+ def __init__(self, n_heads):
545
+ super().__init__()
546
+ self.n_heads = n_heads
547
+
548
+ def forward(self, qkv):
549
+ """
550
+ Apply QKV attention.
551
+
552
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
553
+ :return: an [N x (H * C) x T] tensor after attention.
554
+ """
555
+ bs, width, length = qkv.shape
556
+ assert width % (3 * self.n_heads) == 0
557
+ ch = width // (3 * self.n_heads)
558
+ q, k, v = qkv.chunk(3, dim=1)
559
+ scale = 1 / math.sqrt(math.sqrt(ch))
560
+ weight = th.einsum(
561
+ "bct,bcs->bts",
562
+ (q * scale).view(bs * self.n_heads, ch, length),
563
+ (k * scale).view(bs * self.n_heads, ch, length),
564
+ ) # More stable with f16 than dividing afterwards
565
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
566
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
567
+ return a.reshape(bs, -1, length)
568
+
569
+ @staticmethod
570
+ def count_flops(model, _x, y):
571
+ return count_flops_attn(model, _x, y)
572
+
573
+
574
+ class UNetModel(nn.Module):
575
+ """
576
+ The full UNet model with attention and timestep embedding.
577
+ :param in_channels: channels in the input Tensor.
578
+ :param model_channels: base channel count for the model.
579
+ :param out_channels: channels in the output Tensor.
580
+ :param num_res_blocks: number of residual blocks per downsample.
581
+ :param attention_resolutions: a collection of downsample rates at which
582
+ attention will take place. May be a set, list, or tuple.
583
+ For example, if this contains 4, then at 4x downsampling, attention
584
+ will be used.
585
+ :param dropout: the dropout probability.
586
+ :param channel_mult: channel multiplier for each level of the UNet.
587
+ :param conv_resample: if True, use learned convolutions for upsampling and
588
+ downsampling.
589
+ :param dims: determines if the signal is 1D, 2D, or 3D.
590
+ :param num_classes: if specified (as an int), then this model will be
591
+ class-conditional with `num_classes` classes.
592
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
593
+ :param num_heads: the number of attention heads in each attention layer.
594
+ :param num_heads_channels: if specified, ignore num_heads and instead use
595
+ a fixed channel width per attention head.
596
+ :param num_heads_upsample: works with num_heads to set a different number
597
+ of heads for upsampling. Deprecated.
598
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
599
+ :param resblock_updown: use residual blocks for up/downsampling.
600
+ :param use_new_attention_order: use a different attention pattern for potentially
601
+ increased efficiency.
602
+ """
603
+
604
+ def __init__(
605
+ self,
606
+ image_size,
607
+ in_channels,
608
+ model_channels,
609
+ out_channels,
610
+ num_res_blocks,
611
+ attention_resolutions,
612
+ dropout=0,
613
+ channel_mult=(1, 2, 4, 8),
614
+ conv_resample=True,
615
+ dims=2,
616
+ num_classes=None,
617
+ use_checkpoint=False,
618
+ use_fp16=False,
619
+ num_heads=1,
620
+ num_head_channels=-1,
621
+ num_heads_upsample=-1,
622
+ use_scale_shift_norm=False,
623
+ resblock_updown=False,
624
+ use_new_attention_order=False,
625
+ ):
626
+ super().__init__()
627
+
628
+ if num_heads_upsample == -1:
629
+ num_heads_upsample = num_heads
630
+ in_channels=6
631
+ self.image_size = image_size
632
+ self.in_channels = in_channels
633
+ self.model_channels = model_channels
634
+ self.out_channels = out_channels
635
+ self.num_res_blocks = num_res_blocks
636
+ self.attention_resolutions = attention_resolutions
637
+ self.dropout = dropout
638
+ self.channel_mult = channel_mult
639
+ self.conv_resample = conv_resample
640
+ self.num_classes = num_classes
641
+ self.use_checkpoint = use_checkpoint
642
+ self.dtype = th.float16 if use_fp16 else th.float32
643
+ self.num_heads = num_heads
644
+ self.num_head_channels = num_head_channels
645
+ self.num_heads_upsample = num_heads_upsample
646
+ time_embed_dim = model_channels * 4
647
+ self.time_embed = nn.Sequential(
648
+ linear(model_channels, time_embed_dim),
649
+ nn.SiLU(),
650
+ linear(time_embed_dim, time_embed_dim),
651
+ )
652
+ self.input_encoder = EncoderUNetModel(
653
+ image_size,
654
+ 3,
655
+ model_channels,
656
+ out_channels,
657
+ num_res_blocks,
658
+ attention_resolutions,
659
+ dropout=0,
660
+ channel_mult=(1, 2, 4, 8),
661
+ conv_resample=True,
662
+ dims=2,
663
+ use_checkpoint=False,
664
+ use_fp16=False,
665
+ num_heads=1,
666
+ num_head_channels=-1,
667
+ num_heads_upsample=-1,
668
+ use_scale_shift_norm=False,
669
+ resblock_updown=False,
670
+ use_new_attention_order=False,
671
+ pool="spatial",
672
+ )
673
+ if self.num_classes is not None:
674
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
675
+
676
+ ch = input_ch = int(channel_mult[0] * model_channels)
677
+ # print(channel_mult,in_channels)
678
+ # in_channels=6
679
+ # print(in_channels)
680
+ self.input_transform_1 = conv_nd(2, 6, 3, 3, padding=1)
681
+ self.input_blocks = nn.ModuleList(
682
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
683
+ )
684
+ self._feature_size = ch
685
+ input_block_chans = [ch]
686
+ ds = 1
687
+ blah=0
688
+ for level, mult in enumerate(channel_mult):
689
+ for _ in range(num_res_blocks):
690
+ # print(level,mult,int(mult * model_channels))
691
+
692
+ layers = [
693
+ ResBlock(
694
+ ch,
695
+ time_embed_dim,
696
+ dropout,
697
+ out_channels=int(mult * model_channels),
698
+ dims=dims,
699
+ use_checkpoint=use_checkpoint,
700
+ use_scale_shift_norm=use_scale_shift_norm,
701
+ )
702
+ ]
703
+ ch = int(mult * model_channels)
704
+ if ds in attention_resolutions:
705
+ layers.append(
706
+ AttentionBlock(
707
+ ch,
708
+ use_checkpoint=use_checkpoint,
709
+ num_heads=num_heads,
710
+ num_head_channels=num_head_channels,
711
+ use_new_attention_order=use_new_attention_order,
712
+ )
713
+ )
714
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
715
+ self._feature_size += ch
716
+ input_block_chans.append(ch)
717
+ if level != len(channel_mult) - 1:
718
+ out_ch = ch
719
+ blah=blah+1
720
+ # if(blah==1):
721
+ # ch1=ch+64
722
+ # elif(blah==2):
723
+ # ch1=ch+128
724
+ # elif(blah==3):
725
+ # ch1=ch+256
726
+ # else:
727
+ # ch1=ch
728
+ ch1=ch
729
+ # print(resblock_updown)
730
+ self.input_blocks.append(
731
+ TimestepEmbedSequential(
732
+ ResBlock(
733
+ ch1,
734
+ time_embed_dim,
735
+ dropout,
736
+ out_channels=out_ch,
737
+ dims=dims,
738
+ use_checkpoint=use_checkpoint,
739
+ use_scale_shift_norm=use_scale_shift_norm,
740
+ down=True,
741
+ )
742
+ if resblock_updown
743
+ else Downsample(
744
+ ch, conv_resample, dims=dims, out_channels=out_ch
745
+ )
746
+ )
747
+ )
748
+ ch = out_ch
749
+ input_block_chans.append(ch)
750
+ ds *= 2
751
+ self._feature_size += ch
752
+ # print(input_block_chans)
753
+ self.middle_block = TimestepEmbedSequential(
754
+ ResBlock(
755
+ ch,
756
+ time_embed_dim,
757
+ dropout,
758
+ dims=dims,
759
+ use_checkpoint=use_checkpoint,
760
+ use_scale_shift_norm=use_scale_shift_norm,
761
+ ),
762
+ AttentionBlock(
763
+ ch,
764
+ use_checkpoint=use_checkpoint,
765
+ num_heads=num_heads,
766
+ num_head_channels=num_head_channels,
767
+ use_new_attention_order=use_new_attention_order,
768
+ ),
769
+ ResBlock(
770
+ ch,
771
+ time_embed_dim,
772
+ dropout,
773
+ dims=dims,
774
+ use_checkpoint=use_checkpoint,
775
+ use_scale_shift_norm=use_scale_shift_norm,
776
+ ),
777
+ )
778
+ self._feature_size += ch
779
+
780
+ self.output_blocks = nn.ModuleList([])
781
+ for level, mult in list(enumerate(channel_mult))[::-1]:
782
+ for i in range(num_res_blocks + 1):
783
+ ich = input_block_chans.pop()
784
+ layers = [
785
+ ResBlock(
786
+ ch + ich,
787
+ time_embed_dim,
788
+ dropout,
789
+ out_channels=int(model_channels * mult),
790
+ dims=dims,
791
+ use_checkpoint=use_checkpoint,
792
+ use_scale_shift_norm=use_scale_shift_norm,
793
+ )
794
+ ]
795
+ ch = int(model_channels * mult)
796
+ if ds in attention_resolutions:
797
+ layers.append(
798
+ AttentionBlock(
799
+ ch,
800
+ use_checkpoint=use_checkpoint,
801
+ num_heads=num_heads_upsample,
802
+ num_head_channels=num_head_channels,
803
+ use_new_attention_order=use_new_attention_order,
804
+ )
805
+ )
806
+ if level and i == num_res_blocks:
807
+ out_ch = ch
808
+ layers.append(
809
+ ResBlock(
810
+ ch,
811
+ time_embed_dim,
812
+ dropout,
813
+ out_channels=out_ch,
814
+ dims=dims,
815
+ use_checkpoint=use_checkpoint,
816
+ use_scale_shift_norm=use_scale_shift_norm,
817
+ up=True,
818
+ )
819
+ if resblock_updown
820
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
821
+ )
822
+ ds //= 2
823
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
824
+ self._feature_size += ch
825
+
826
+ self.vgg=VGG19()
827
+ self.conv_convert1 = ResBlock(
828
+ 320,
829
+ time_embed_dim,
830
+ dropout,
831
+ out_channels=256,
832
+ dims=dims,
833
+ use_checkpoint=use_checkpoint,
834
+ use_scale_shift_norm=use_scale_shift_norm,
835
+ )
836
+ self.conv_convert2 = ResBlock(
837
+ 384,
838
+ time_embed_dim,
839
+ dropout,
840
+ out_channels=256,
841
+ dims=dims,
842
+ use_checkpoint=use_checkpoint,
843
+ use_scale_shift_norm=use_scale_shift_norm,
844
+ )
845
+ self.conv_convert3 = ResBlock(
846
+ 768,
847
+ time_embed_dim,
848
+ dropout,
849
+ out_channels=512,
850
+ dims=dims,
851
+ use_checkpoint=use_checkpoint,
852
+ use_scale_shift_norm=use_scale_shift_norm,
853
+ )
854
+
855
+ self.out = nn.Sequential(
856
+ normalization(ch),
857
+ nn.SiLU(),
858
+ zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
859
+ )
860
+ # print(input_ch,out_channels)
861
+ def convert_to_fp16(self):
862
+ """
863
+ Convert the torso of the model to float16.
864
+ """
865
+ self.vgg.apply(convert_module_to_f16)
866
+ self.input_blocks.apply(convert_module_to_f16)
867
+ self.middle_block.apply(convert_module_to_f16)
868
+ self.output_blocks.apply(convert_module_to_f16)
869
+ self.conv_convert1.apply(convert_module_to_f16)
870
+ self.conv_convert2.apply(convert_module_to_f16)
871
+ self.conv_convert3.apply(convert_module_to_f16)
872
+ self.input_transform_1.apply(convert_module_to_f16)
873
+ self.input_encoder.convert_to_fp16()
874
+
875
+
876
+ def convert_to_fp32(self):
877
+ """
878
+ Convert the torso of the model to float32.
879
+ """
880
+ self.vgg.apply(convert_module_to_f32)
881
+
882
+ self.input_blocks.apply(convert_module_to_f32)
883
+ self.middle_block.apply(convert_module_to_f32)
884
+ self.output_blocks.apply(convert_module_to_f32)
885
+
886
+ def forward(self, x, timesteps, low_res ,high_res, y=None,**kwargs):
887
+ """
888
+ Apply the model to an input batch.
889
+
890
+ :param x: an [N x C x ...] Tensor of inputs.
891
+ :param timesteps: a 1-D batch of timesteps.
892
+ :param y: an [N] Tensor of labels, if class-conditional.
893
+ :return: an [N x C x ...] Tensor of outputs.
894
+ """
895
+
896
+ hs = []
897
+ # x1 = th.cat([x,high_res],1).type(self.dtype)
898
+ # x1 = self.input_transform_1(x.type(self.dtype))
899
+ x1=x
900
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
901
+ input1=low_res
902
+ # vgg_feats = self.vgg(input1.type(self.dtype), emb)
903
+ # print(x.shape)
904
+ # print(emb.shape)
905
+ # vgg_feats=vgg_feats.type(self.dtype)
906
+ # print(vgg_feats[0].shape)
907
+ # print(emb.shape)
908
+ h = x1.type(self.dtype)
909
+ zsem= self.input_encoder(input1, timesteps)
910
+ for i , module in enumerate(self.input_blocks):
911
+ # print(i,module,h.shape)
912
+
913
+ # if(i==3):
914
+ # # print()
915
+ # h= th.cat([h,vgg_feats[0]],1)
916
+ # h = self.conv_convert1(h,emb)
917
+ # if(i==6):
918
+
919
+ # h= th.cat([h,vgg_feats[1]],1)
920
+ # h = self.conv_convert2(h,emb)
921
+
922
+ # elif(i==9):
923
+ # h= th.cat([h,vgg_feats[2]],1)
924
+ # h = self.conv_convert3(h,emb)
925
+ # print(h.shape)
926
+ # print(h.shape,emb.shape)
927
+ h = module(h, emb,zsem)
928
+
929
+ hs.append(h)
930
+ # print(h.shape)
931
+ h = self.middle_block(h, emb,zsem)
932
+ # stop
933
+ for module in self.output_blocks:
934
+ h = th.cat([h, hs.pop()], dim=1)
935
+ h = module(h, emb,zsem)
936
+ h = h.type(x.dtype)
937
+ out=self.out(h)
938
+ return out
939
+
940
+
941
+ class SuperResModel(UNetModel):
942
+ """
943
+ A UNetModel that performs super-resolution.
944
+
945
+ Expects an extra kwarg `low_res` to condition on a low-resolution image.
946
+ """
947
+
948
+ def __init__(self, image_size, in_channels, *args, **kwargs):
949
+ super().__init__(image_size, in_channels * 2, *args, **kwargs)
950
+
951
+ def forward(self, x, timesteps, low_res=None, **kwargs):
952
+ _, _, new_height, new_width = x.shape
953
+ low_res = kwargs['SR']
954
+ # upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
955
+ # print(x.shape,low_res.shape)
956
+ high_res= kwargs['SR']
957
+ x = th.cat([x, low_res], dim=1)
958
+ return super().forward(x, timesteps,low_res,high_res, **kwargs)
959
+
960
+
961
+ class EncoderUNetModel(nn.Module):
962
+ """
963
+ The half UNet model with attention and timestep embedding.
964
+
965
+ For usage, see UNet.
966
+ """
967
+
968
+ def __init__(
969
+ self,
970
+ image_size,
971
+ in_channels,
972
+ model_channels,
973
+ out_channels,
974
+ num_res_blocks,
975
+ attention_resolutions,
976
+ dropout=0,
977
+ channel_mult=(1, 2, 4, 8),
978
+ conv_resample=True,
979
+ dims=2,
980
+ use_checkpoint=False,
981
+ use_fp16=False,
982
+ num_heads=1,
983
+ num_head_channels=-1,
984
+ num_heads_upsample=-1,
985
+ use_scale_shift_norm=False,
986
+ resblock_updown=False,
987
+ use_new_attention_order=False,
988
+ pool="adaptive",
989
+ ):
990
+ super().__init__()
991
+
992
+ if num_heads_upsample == -1:
993
+ num_heads_upsample = num_heads
994
+
995
+ self.in_channels = in_channels
996
+ self.model_channels = model_channels
997
+ self.out_channels = out_channels
998
+ self.num_res_blocks = num_res_blocks
999
+ self.attention_resolutions = attention_resolutions
1000
+ self.dropout = dropout
1001
+ self.channel_mult = channel_mult
1002
+ self.conv_resample = conv_resample
1003
+ self.use_checkpoint = use_checkpoint
1004
+ self.dtype = th.float16 if use_fp16 else th.float32
1005
+ self.num_heads = num_heads
1006
+ self.num_head_channels = num_head_channels
1007
+ self.num_heads_upsample = num_heads_upsample
1008
+
1009
+ time_embed_dim = model_channels * 4
1010
+ self.time_embed = nn.Sequential(
1011
+ linear(model_channels, time_embed_dim),
1012
+ nn.SiLU(),
1013
+ linear(time_embed_dim, time_embed_dim),
1014
+ )
1015
+
1016
+ ch = int(channel_mult[0] * model_channels)
1017
+ self.input_blocks = nn.ModuleList(
1018
+ [TimestepEmbedSequential1(conv_nd(dims, in_channels, ch, 3, padding=1))]
1019
+ )
1020
+ self._feature_size = ch
1021
+ input_block_chans = [ch]
1022
+ ds = 1
1023
+ for level, mult in enumerate(channel_mult):
1024
+ for _ in range(num_res_blocks):
1025
+ layers = [
1026
+ ResBlock1(
1027
+ ch,
1028
+ time_embed_dim,
1029
+ dropout,
1030
+ out_channels=int(mult * model_channels),
1031
+ dims=dims,
1032
+ use_checkpoint=use_checkpoint,
1033
+ use_scale_shift_norm=use_scale_shift_norm,
1034
+ )
1035
+ ]
1036
+ ch = int(mult * model_channels)
1037
+ if ds in attention_resolutions:
1038
+ layers.append(
1039
+ AttentionBlock(
1040
+ ch,
1041
+ use_checkpoint=use_checkpoint,
1042
+ num_heads=num_heads,
1043
+ num_head_channels=num_head_channels,
1044
+ use_new_attention_order=use_new_attention_order,
1045
+ )
1046
+ )
1047
+ self.input_blocks.append(TimestepEmbedSequential1(*layers))
1048
+ self._feature_size += ch
1049
+ input_block_chans.append(ch)
1050
+ if level != len(channel_mult) - 1:
1051
+ out_ch = ch
1052
+ self.input_blocks.append(
1053
+ TimestepEmbedSequential1(
1054
+ ResBlock1(
1055
+ ch,
1056
+ time_embed_dim,
1057
+ dropout,
1058
+ out_channels=out_ch,
1059
+ dims=dims,
1060
+ use_checkpoint=use_checkpoint,
1061
+ use_scale_shift_norm=use_scale_shift_norm,
1062
+ down=True,
1063
+ )
1064
+ if resblock_updown
1065
+ else Downsample(
1066
+ ch, conv_resample, dims=dims, out_channels=out_ch
1067
+ )
1068
+ )
1069
+ )
1070
+ ch = out_ch
1071
+ input_block_chans.append(ch)
1072
+ ds *= 2
1073
+ self._feature_size += ch
1074
+
1075
+ self.middle_block = TimestepEmbedSequential1(
1076
+ ResBlock1(
1077
+ ch,
1078
+ time_embed_dim,
1079
+ dropout,
1080
+ dims=dims,
1081
+ use_checkpoint=use_checkpoint,
1082
+ use_scale_shift_norm=use_scale_shift_norm,
1083
+ ),
1084
+ AttentionBlock(
1085
+ ch,
1086
+ use_checkpoint=use_checkpoint,
1087
+ num_heads=num_heads,
1088
+ num_head_channels=num_head_channels,
1089
+ use_new_attention_order=use_new_attention_order,
1090
+ ),
1091
+ ResBlock1(
1092
+ ch,
1093
+ time_embed_dim,
1094
+ dropout,
1095
+ dims=dims,
1096
+ use_checkpoint=use_checkpoint,
1097
+ use_scale_shift_norm=use_scale_shift_norm,
1098
+ ),
1099
+ )
1100
+ self._feature_size += ch
1101
+ self.pool = pool
1102
+ if pool == "adaptive":
1103
+ self.out = nn.Sequential(
1104
+ normalization(ch),
1105
+ nn.SiLU(),
1106
+ nn.AdaptiveAvgPool2d((1, 1)),
1107
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
1108
+ nn.Flatten(),
1109
+ )
1110
+ elif pool == "attention":
1111
+ assert num_head_channels != -1
1112
+ self.out = nn.Sequential(
1113
+ normalization(ch),
1114
+ nn.SiLU(),
1115
+ AttentionPool2d(
1116
+ (image_size // ds), ch, num_head_channels, out_channels
1117
+ ),
1118
+ )
1119
+ elif pool == "spatial":
1120
+ self.out = nn.Sequential(
1121
+ nn.Linear(self._feature_size, 2048),
1122
+ nn.ReLU(),
1123
+ nn.Linear(2048, 512),
1124
+ )
1125
+ elif pool == "spatial_v2":
1126
+ self.out = nn.Sequential(
1127
+ nn.Linear(self._feature_size, 2048),
1128
+ normalization(2048),
1129
+ nn.SiLU(),
1130
+ nn.Linear(2048, self.out_channels),
1131
+ )
1132
+ else:
1133
+ raise NotImplementedError(f"Unexpected {pool} pooling")
1134
+
1135
+ def convert_to_fp16(self):
1136
+ """
1137
+ Convert the torso of the model to float16.
1138
+ """
1139
+ self.input_blocks.apply(convert_module_to_f16)
1140
+ self.middle_block.apply(convert_module_to_f16)
1141
+ self.out.apply(convert_module_to_f16)
1142
+
1143
+ def convert_to_fp32(self):
1144
+ """
1145
+ Convert the torso of the model to float32.
1146
+ """
1147
+ self.input_blocks.apply(convert_module_to_f32)
1148
+ self.middle_block.apply(convert_module_to_f32)
1149
+
1150
+ def forward(self, x, timesteps):
1151
+ """
1152
+ Apply the model to an input batch.
1153
+
1154
+ :param x: an [N x C x ...] Tensor of inputs.
1155
+ :param timesteps: a 1-D batch of timesteps.
1156
+ :return: an [N x K] Tensor of outputs.
1157
+ """
1158
+ hs=[]
1159
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
1160
+
1161
+ results = []
1162
+ h = x.type(th.cuda.HalfTensor)
1163
+
1164
+ for module in self.input_blocks:
1165
+ h = module(h, emb)
1166
+ hs.append(h)
1167
+ if self.pool.startswith("spatial"):
1168
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
1169
+ h = self.middle_block(h, emb)
1170
+ hs.append(h)
1171
+
1172
+ if self.pool.startswith("spatial"):
1173
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
1174
+ h = th.cat(results, axis=-1)
1175
+ # print("hi")
1176
+ h = h.type(x.dtype)
1177
+
1178
+ return self.out(h)
1179
+ else:
1180
+ h = h.type(x.dtype)
1181
+ return hs
scripts/sarddpm_test.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SAR-DDPM Inference on real SAR images.
3
+ """
4
+
5
+ import argparse
6
+ import torch
7
+ import os
8
+ import cv2
9
+ import numpy as np
10
+
11
+ import torch.nn.functional as F
12
+
13
+ from guided_diffusion import dist_util, logger
14
+ from guided_diffusion.image_datasets import load_data
15
+ from guided_diffusion.resample import create_named_schedule_sampler
16
+ from guided_diffusion.script_util import (
17
+ sr_model_and_diffusion_defaults,
18
+ sr_create_model_and_diffusion,
19
+ args_to_dict,
20
+ add_dict_to_argparser,
21
+ )
22
+ from guided_diffusion.train_util import TrainLoop
23
+ from torch.utils.data import DataLoader
24
+ from torch.optim import AdamW
25
+
26
+ from valdata import ValData, ValDataNew, ValDataNewReal
27
+ from skimage.metrics import peak_signal_noise_ratio as psnr
28
+ from skimage.metrics import structural_similarity as ssim
29
+
30
+
31
+
32
+ val_dir = 'path_to_validation_data/'
33
+ base_path = 'path_to_save_results/'
34
+ resume_checkpoint_clean = './weights/sar_ddpm.pt'
35
+
36
+
37
+
38
+
39
+ def main():
40
+ args = create_argparser().parse_args()
41
+
42
+ print(args)
43
+
44
+
45
+ model_clean, diffusion = sr_create_model_and_diffusion(
46
+ **args_to_dict(args, sr_model_and_diffusion_defaults().keys())
47
+ )
48
+
49
+
50
+ print(torch.device('cuda'))
51
+
52
+ schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
53
+
54
+
55
+ val_data = DataLoader(ValDataNewReal(dataset_path=val_dir), batch_size=1, shuffle=False, num_workers=1) #load_superres_dataval()
56
+
57
+ device0 = torch.device("cuda:0")
58
+
59
+ model_clean.load_state_dict(torch.load(resume_checkpoint_clean, map_location="cuda:0"))
60
+
61
+
62
+ model_clean.to(device0)
63
+
64
+
65
+
66
+
67
+ params = list(model_clean.parameters())
68
+
69
+ print('model clean device:')
70
+ print(next(model_clean.parameters()).device)
71
+
72
+
73
+
74
+ with torch.no_grad():
75
+ number = 0
76
+
77
+
78
+ for batch_id1, data_var in enumerate(val_data):
79
+ number = number+1
80
+ clean_batch, model_kwargs1 = data_var
81
+
82
+ single_img = model_kwargs1['SR'].to(dist_util.dev())
83
+
84
+ count = 0
85
+ [t1,t2,max_r,max_c] = single_img.size()
86
+
87
+ N =9
88
+
89
+ val_inputv = single_img.clone()
90
+
91
+ for row in range(0,max_r,100):
92
+ for col in range(0,max_c,100):
93
+
94
+
95
+ val_inputv[:,:,:row,:col] = single_img[:,:,max_r-row:,max_c-col:]
96
+ val_inputv[:,:,row:,col:] = single_img[:,:,:max_r-row,:max_c-col]
97
+ val_inputv[:,:,row:,:col] = single_img[:,:,:max_r-row,max_c-col:]
98
+ val_inputv[:,:,:row,col:] = single_img[:,:,max_r-row:,:max_c-col]
99
+
100
+ model_kwargs = {}
101
+ for k, v in model_kwargs1.items():
102
+ if('Index' in k):
103
+ img_name=v
104
+ elif('SR' in k):
105
+ model_kwargs[k] = val_inputv.to(dist_util.dev())
106
+ else:
107
+ model_kwargs[k]= v.to(dist_util.dev())
108
+
109
+
110
+
111
+ sample = diffusion.p_sample_loop(
112
+ model_clean,
113
+ (clean_batch.shape[0], 3, 256,256),
114
+ clip_denoised=True,
115
+ model_kwargs=model_kwargs,
116
+ )
117
+
118
+
119
+
120
+ if count==0:
121
+ sample_new = (1.0/N)*sample
122
+ else :
123
+ sample_new[:,:,max_r-row:,max_c-col:] = sample_new[:,:,max_r-row:,max_c-col:] + (1.0/N)*sample[:,:,:row,:col]
124
+ sample_new[:,:,:max_r-row,:max_c-col] = sample_new[:,:,:max_r-row,:max_c-col] + (1.0/N)*sample[:,:,row:,col:]
125
+ sample_new[:,:,:max_r-row,max_c-col:] = sample_new[:,:,:max_r-row,max_c-col:] + (1.0/N)*sample[:,:,row:,:col]
126
+ sample_new[:,:,max_r-row:,:max_c-col] = sample_new[:,:,max_r-row:,:max_c-col] + (1.0/N)*sample[:,:,:row,col:]
127
+
128
+ count += 1
129
+
130
+ sample_new = ((sample_new + 1) * 127.5)
131
+ sample_new = sample_new.clamp(0, 255).to(torch.uint8)
132
+ sample_new = sample_new.permute(0, 2, 3, 1)
133
+ sample_new = sample_new.contiguous().cpu().numpy()
134
+ sample_new = sample_new[0][:,:,::-1]
135
+
136
+ sample_new = cv2.cvtColor(sample_new, cv2.COLOR_BGR2GRAY)
137
+ print(img_name[0])
138
+ cv2.imwrite(base_path+'pred_'+img_name[0],sample_new)
139
+
140
+
141
+
142
+
143
+
144
+
145
+ def create_argparser():
146
+ defaults = dict(
147
+ data_dir= val_dir,
148
+ schedule_sampler="uniform",
149
+ lr=1e-4,
150
+ weight_decay=0.0,
151
+ lr_anneal_steps=0,
152
+ batch_size=2,
153
+ microbatch=1,
154
+ ema_rate="0.9999",
155
+ log_interval=100,
156
+ save_interval=200,
157
+ use_fp16=False,
158
+ fp16_scale_growth=1e-3,
159
+ )
160
+ defaults.update(sr_model_and_diffusion_defaults())
161
+ parser = argparse.ArgumentParser()
162
+ add_dict_to_argparser(parser, defaults)
163
+ return parser
164
+
165
+ if __name__ == "__main__":
166
+ main()
scripts/sarddpm_train.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train SAR-DDPM model.
3
+ """
4
+
5
+ import argparse
6
+
7
+ import torch.nn.functional as F
8
+
9
+ from guided_diffusion import dist_util, logger
10
+ from guided_diffusion.image_datasets import load_data
11
+ from guided_diffusion.resample import create_named_schedule_sampler
12
+ from guided_diffusion.script_util import (
13
+ sr_model_and_diffusion_defaults,
14
+ sr_create_model_and_diffusion,
15
+ args_to_dict,
16
+ add_dict_to_argparser,
17
+ )
18
+ from guided_diffusion.train_util import TrainLoop
19
+ from torch.utils.data import DataLoader
20
+ # from train_dataset import TrainData
21
+ from valdata import ValData, ValDataNew
22
+
23
+ train_dir = 'path_to_training_data/'
24
+
25
+ val_dir = 'path_to_validation_data/'
26
+
27
+ pretrained_weight_path = "./weights/64_256_upsampler.pt"
28
+
29
+
30
+ def main():
31
+ args = create_argparser().parse_args()
32
+
33
+ dist_util.setup_dist()
34
+ logger.configure()
35
+
36
+ logger.log("creating model...")
37
+ model, diffusion = sr_create_model_and_diffusion(
38
+ **args_to_dict(args, sr_model_and_diffusion_defaults().keys())
39
+ )
40
+ model.to(dist_util.dev())
41
+ schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
42
+
43
+ logger.log("creating data loader...")
44
+
45
+
46
+
47
+ val_data = DataLoader(ValDataNew(dataset_path=val_dir), batch_size=1, shuffle=False, num_workers=1)
48
+
49
+
50
+ print(args)
51
+ data = load_sar_data(
52
+ args.data_dir,
53
+ train_dir,
54
+ args.batch_size,
55
+ large_size=256,
56
+ small_size=256,
57
+ class_cond=False,
58
+ )
59
+
60
+ logger.log("training...")
61
+ TrainLoop(
62
+ model=model,
63
+ diffusion=diffusion,
64
+ data=data,
65
+ val_dat=val_data,
66
+ batch_size=args.batch_size,
67
+ microbatch=args.microbatch,
68
+ lr=args.lr,
69
+ ema_rate=args.ema_rate,
70
+ log_interval=args.log_interval,
71
+ save_interval=args.save_interval,
72
+ resume_checkpoint=args.resume_checkpoint,
73
+ args = args,
74
+ use_fp16=args.use_fp16,
75
+ fp16_scale_growth=args.fp16_scale_growth,
76
+ schedule_sampler=schedule_sampler,
77
+ weight_decay=args.weight_decay,
78
+ lr_anneal_steps=args.lr_anneal_steps,
79
+ ).run_loop()
80
+
81
+
82
+ def load_sar_data(data_dir,gt_dirs, batch_size, large_size, small_size, class_cond=False):
83
+ data = load_data(
84
+ data_dir=data_dir,
85
+ gt_dir=gt_dirs,
86
+ batch_size=batch_size,
87
+ image_size=large_size,
88
+ class_cond=False,
89
+ )
90
+ for large_batch, model_kwargs in data:
91
+ yield large_batch, model_kwargs
92
+
93
+
94
+ def create_argparser():
95
+ defaults = dict(
96
+ data_dir = train_dir,
97
+ schedule_sampler="uniform",
98
+ lr=1e-4,
99
+ # lr=5e-5,
100
+ weight_decay=0.0,
101
+ lr_anneal_steps=0,
102
+ batch_size=2,
103
+ microbatch=1,
104
+ ema_rate="0.9999",
105
+ log_interval=1000,
106
+ save_interval=10,
107
+ resume_checkpoint=pretrained_weight_path,
108
+ use_fp16=False,
109
+ fp16_scale_growth=1e-3,
110
+ )
111
+ defaults.update(sr_model_and_diffusion_defaults())
112
+ parser = argparse.ArgumentParser()
113
+ add_dict_to_argparser(parser, defaults)
114
+ return parser
115
+
116
+ if __name__ == "__main__":
117
+ main()
scripts/valdata.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+ from PIL import Image
3
+ from random import randrange
4
+ from torchvision.transforms import Compose, ToTensor, Normalize
5
+ import re
6
+ from PIL import ImageFile
7
+ from os import path
8
+ import numpy as np
9
+ import torch
10
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
11
+ import os
12
+ # --- Training dataset --- #
13
+ import torch as th
14
+ import cv2
15
+ import math
16
+ import random
17
+ seed = np.random.RandomState(112311)
18
+
19
+ class ValData(data.Dataset):
20
+ def __init__(self, dataset_path, crop_size=[256,256]):
21
+ super().__init__()
22
+ # train_list = train_data_dir + train_filename
23
+ # with open(train_list) as f:
24
+ # contents = f.readlines()
25
+ # input_names = [i.strip() for i in contents]
26
+ # gt_names = [i.strip().replace('input','gt') for i in input_names]
27
+ # self.train_data_dir = '/media/labuser/cb8bb1ad-451a-4aa4-870c-2d3eeafe2525/ECCV_2022/diffusion_ema_rain_imagenet/rain_sub1/'
28
+ # self.train_data_dir = '/media/labuser/cb8bb1ad-451a-4aa4-870c-2d3eeafe2525/ICIP_Turbulence_files/Tubfaces89/300M/tubimages/'
29
+ # self.train_data_dir = "/media/malsha/47a8802b-e0b7-47a8-8a4d-1649cc3ad408/sar_optical/optical/"
30
+
31
+
32
+ self.noisy_path = os.path.join(dataset_path, 'noisy')
33
+ # self.noisy_path = dataset_path
34
+ # self.clean_path = dataset_path
35
+ self.clean_path = os.path.join(dataset_path, 'clean')
36
+ self.images_list = os.listdir(self.noisy_path)
37
+
38
+
39
+ self.crop_size = crop_size
40
+
41
+ def __len__(self):
42
+ return len(os.listdir(self.noisy_path))
43
+
44
+ def __getitem__(self, idx):
45
+ image_filename = self.images_list[idx]
46
+
47
+ noisy_im = cv2.imread(os.path.join(self.noisy_path, image_filename))
48
+ clean_im = cv2.imread(os.path.join(self.clean_path, image_filename))
49
+
50
+ arr1=np.array(clean_im)
51
+ arr2=np.array(noisy_im)
52
+ arr3 = arr1+ 1e-9
53
+ arr3 = np.divide(arr2,arr3)
54
+
55
+
56
+ arr1 = cv2.resize(arr1, (256,256), interpolation=cv2.INTER_LINEAR)
57
+ arr2= cv2.resize(arr2, (256,256), interpolation=cv2.INTER_LINEAR)
58
+ arr3= cv2.resize(arr3, (256,256), interpolation=cv2.INTER_LINEAR)
59
+
60
+ ## for grayscale images
61
+ # arr1 = arr1[..., np.newaxis]
62
+ # arr2 = arr2[..., np.newaxis]
63
+ # arr3 = arr3[..., np.newaxis]
64
+
65
+ # arr3 = np.square(arr3)
66
+
67
+ # # for log data
68
+ # arr1 = (arr1.astype(np.float32) + 1 )/256.0
69
+ # arr2 = (arr2.astype(np.float32) + 1 )/256.0
70
+ # arr1 = np.log(np.absolute(arr1))
71
+ # arr2 = np.log(np.absolute(arr2))
72
+ # # arr1 = arr1.astype(np.float32) / (0.5*np.log(256.0)) - 1
73
+ # # arr2 = arr2.astype(np.float32) / (0.5*np.log(256.0)) - 1
74
+ # arr1 = 2*(arr1.astype(np.float32) + np.log(256.0))/ np.log(256.0) - 1
75
+ # arr2 = 2*(arr2.astype(np.float32) + np.log(256.0))/ np.log(256.0) - 1
76
+
77
+
78
+ # ## correct normalization for log
79
+
80
+ # arr1 = (arr1.astype(np.float32))/255.0
81
+ # arr2 = (arr2.astype(np.float32))/255.0
82
+ # arr1 = arr1*(math.exp(1)-math.exp(-1)) + math.exp(-1)
83
+ # arr2 = arr2*(math.exp(1)-math.exp(-1)) + math.exp(-1)
84
+ # arr1 = np.log(arr1)
85
+ # arr2 = np.log(arr2)
86
+ # arr1 = arr1.astype(np.float32)
87
+ # arr2 = arr2.astype(np.float32)
88
+
89
+
90
+ arr1 = arr1.astype(np.float32) / 127.5 - 1
91
+ arr2 = arr2.astype(np.float32) / 127.5 - 1
92
+ # arr3 = arr3.astype(np.float32) / 127.5 - 1
93
+ # arr3 = arr3.astype(np.float32)
94
+
95
+ arr2 = np.transpose(arr2, [2, 0, 1])
96
+ arr1 = np.transpose(arr1, [2, 0, 1])
97
+ arr3 = np.transpose(arr3, [2, 0, 1])
98
+
99
+ # return arr3, {'SR': arr2, 'HR': arr1 , 'Index': image_filename}
100
+ return arr1, {'SR': arr2, 'HR': arr1 , 'Index': image_filename}
101
+ # return arr2, {'SR': arr2, 'HR': arr2 , 'Index': image_filename}
102
+
103
+ # return arr1, {'noise': arr2, 'Index': image_filename}
104
+
105
+
106
+ class ValDataNew(data.Dataset):
107
+ def __init__(self, dataset_path, crop_size=[256,256]):
108
+ super().__init__()
109
+ # train_list = train_data_dir + train_filename
110
+ # with open(train_list) as f:
111
+ # contents = f.readlines()
112
+ # input_names = [i.strip() for i in contents]
113
+ # gt_names = [i.strip().replace('input','gt') for i in input_names]
114
+ # self.train_data_dir = '/media/labuser/cb8bb1ad-451a-4aa4-870c-2d3eeafe2525/ECCV_2022/diffusion_ema_rain_imagenet/rain_sub1/'
115
+ # self.train_data_dir = '/media/labuser/cb8bb1ad-451a-4aa4-870c-2d3eeafe2525/ICIP_Turbulence_files/Tubfaces89/300M/tubimages/'
116
+ # self.train_data_dir = "/media/malsha/47a8802b-e0b7-47a8-8a4d-1649cc3ad408/sar_optical/optical/"
117
+
118
+
119
+ # self.noisy_path = os.path.join(dataset_path, 'noisy')
120
+ self.noisy_path = dataset_path
121
+ self.clean_path = dataset_path
122
+ # self.clean_path = os.path.join(dataset_path, 'clean')
123
+ self.images_list = os.listdir(self.noisy_path)
124
+
125
+
126
+ self.crop_size = crop_size
127
+
128
+ def __len__(self):
129
+ return len(os.listdir(self.noisy_path))
130
+
131
+ def __getitem__(self, idx):
132
+ image_filename = self.images_list[idx]
133
+
134
+ pil_image = cv2.imread(os.path.join(self.noisy_path, image_filename)) ## Clean image
135
+
136
+ pil_image = cv2.cvtColor(pil_image, cv2.COLOR_BGR2GRAY)
137
+ pil_image = np.repeat(pil_image[:,:,np.newaxis],3, axis=2)
138
+ # print(pil_image.shape)
139
+
140
+
141
+ im1 = ((np.float32(pil_image)+1.0)/256.0)**2
142
+ gamma_noise = seed.gamma(size=im1.shape, shape=1.0, scale=1.0).astype(im1.dtype)
143
+ syn_sar = np.sqrt(im1 * gamma_noise)
144
+ pil_image1 = syn_sar * 256-1 ## Noisy image
145
+
146
+
147
+
148
+ arr1=np.array(pil_image)
149
+ arr2=np.array(pil_image1)
150
+
151
+
152
+
153
+ arr1 = cv2.resize(arr1, (256,256), interpolation=cv2.INTER_LINEAR)
154
+ arr2= cv2.resize(arr2, (256,256), interpolation=cv2.INTER_LINEAR)
155
+
156
+
157
+
158
+ arr1 = arr1.astype(np.float32) / 127.5 - 1
159
+ arr2 = arr2.astype(np.float32) / 127.5 - 1
160
+
161
+
162
+ arr2 = np.transpose(arr2, [2, 0, 1])
163
+ arr1 = np.transpose(arr1, [2, 0, 1])
164
+
165
+
166
+
167
+ return arr1, {'SR': arr2, 'HR': arr1 , 'Index': image_filename}
168
+
169
+
170
+
171
+ class ValDataNewReal(data.Dataset):
172
+ def __init__(self, dataset_path, crop_size=[256,256]):
173
+ super().__init__()
174
+ # train_list = train_data_dir + train_filename
175
+ # with open(train_list) as f:
176
+ # contents = f.readlines()
177
+ # input_names = [i.strip() for i in contents]
178
+ # gt_names = [i.strip().replace('input','gt') for i in input_names]
179
+ # self.train_data_dir = '/media/labuser/cb8bb1ad-451a-4aa4-870c-2d3eeafe2525/ECCV_2022/diffusion_ema_rain_imagenet/rain_sub1/'
180
+ # self.train_data_dir = '/media/labuser/cb8bb1ad-451a-4aa4-870c-2d3eeafe2525/ICIP_Turbulence_files/Tubfaces89/300M/tubimages/'
181
+ # self.train_data_dir = "/media/malsha/47a8802b-e0b7-47a8-8a4d-1649cc3ad408/sar_optical/optical/"
182
+
183
+
184
+ # self.noisy_path = os.path.join(dataset_path, 'noisy')
185
+ self.noisy_path = dataset_path
186
+ self.clean_path = dataset_path
187
+ # self.clean_path = os.path.join(dataset_path, 'clean')
188
+ self.images_list = os.listdir(self.noisy_path)
189
+
190
+
191
+ self.crop_size = crop_size
192
+
193
+ def __len__(self):
194
+ return len(os.listdir(self.noisy_path))
195
+
196
+ def __getitem__(self, idx):
197
+ image_filename = self.images_list[idx]
198
+
199
+ pil_image = cv2.imread(os.path.join(self.noisy_path, image_filename),0) ## SAR image
200
+
201
+ # pil_image = cv2.cvtColor(pil_image, cv2.COLOR_BGR2GRAY)
202
+ pil_image = np.repeat(pil_image[:,:,np.newaxis],3, axis=2)
203
+ # print(pil_image.shape)
204
+
205
+
206
+ # im1 = ((np.float32(pil_image)+1.0)/256.0)**2
207
+ # gamma_noise = seed.gamma(size=im1.shape, shape=1.0, scale=1.0).astype(im1.dtype)
208
+ # syn_sar = np.sqrt(im1 * gamma_noise)
209
+ # pil_image1 = syn_sar * 256-1 ## Noisy image
210
+
211
+ # pil_image = np.repeat(pil_image[:,:,np.newaxis],3, axis=2)
212
+ # pil_image1 = np.repeat(pil_image1[:,:,np.newaxis],3, axis=2)
213
+
214
+
215
+
216
+
217
+
218
+ arr1=np.array(pil_image)
219
+ arr2=np.array(pil_image)
220
+ arr3 = arr1 + 1e-9
221
+ # print(arr3.dtype)
222
+ arr3 = np.divide(arr2,arr3)
223
+
224
+
225
+ arr1 = cv2.resize(arr1, (256,256), interpolation=cv2.INTER_LINEAR)
226
+ arr2= cv2.resize(arr2, (256,256), interpolation=cv2.INTER_LINEAR)
227
+ arr3= cv2.resize(arr3, (256,256), interpolation=cv2.INTER_LINEAR)
228
+
229
+ ## for grayscale images
230
+ # arr1 = arr1[..., np.newaxis]
231
+ # arr2 = arr2[..., np.newaxis]
232
+ # arr3 = arr3[..., np.newaxis]
233
+
234
+ # arr3 = np.square(arr3)
235
+
236
+ # # for log data
237
+ # arr1 = (arr1.astype(np.float32) + 1 )/256.0
238
+ # arr2 = (arr2.astype(np.float32) + 1 )/256.0
239
+ # arr1 = np.log(np.absolute(arr1))
240
+ # arr2 = np.log(np.absolute(arr2))
241
+ # # arr1 = arr1.astype(np.float32) / (0.5*np.log(256.0)) - 1
242
+ # # arr2 = arr2.astype(np.float32) / (0.5*np.log(256.0)) - 1
243
+ # arr1 = 2*(arr1.astype(np.float32) + np.log(256.0))/ np.log(256.0) - 1
244
+ # arr2 = 2*(arr2.astype(np.float32) + np.log(256.0))/ np.log(256.0) - 1
245
+
246
+
247
+ # ## correct normalization for log
248
+
249
+ # arr1 = (arr1.astype(np.float32))/255.0
250
+ # arr2 = (arr2.astype(np.float32))/255.0
251
+ # arr1 = arr1*(math.exp(1)-math.exp(-1)) + math.exp(-1)
252
+ # arr2 = arr2*(math.exp(1)-math.exp(-1)) + math.exp(-1)
253
+ # arr1 = np.log(arr1)
254
+ # arr2 = np.log(arr2)
255
+ # arr1 = arr1.astype(np.float32)
256
+ # arr2 = arr2.astype(np.float32)
257
+
258
+
259
+ arr1 = arr1.astype(np.float32) / 127.5 - 1
260
+ arr2 = arr2.astype(np.float32) / 127.5 - 1
261
+ # arr3 = arr3.astype(np.float32) / 127.5 - 1
262
+ # arr3 = arr3.astype(np.float32)
263
+
264
+ arr2 = np.transpose(arr2, [2, 0, 1])
265
+ arr1 = np.transpose(arr1, [2, 0, 1])
266
+ arr3 = np.transpose(arr3, [2, 0, 1])
267
+
268
+ # return arr3, {'SR': arr2, 'HR': arr1 , 'Index': image_filename}
269
+ return arr1, {'SR': arr2, 'HR': arr1 , 'Index': image_filename}
270
+
271
+
272
+
273
+
setup.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from setuptools import setup
2
+
3
+ setup(
4
+ name="guided-diffusion",
5
+ py_modules=["guided_diffusion"],
6
+ install_requires=["blobfile>=1.0.5", "torch", "tqdm"],
7
+ )