dineshsai07 commited on
Commit
0ccacae
·
verified ·
1 Parent(s): 46a8d8a

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. results/versatile_diffusion/subj01/roi/4.png +3 -0
  3. versatile_diffusion/lib/__pycache__/log_service.cpython-38.pyc +0 -0
  4. versatile_diffusion/lib/data_factory/__pycache__/__init__.cpython-310.pyc +0 -0
  5. versatile_diffusion/lib/data_factory/__pycache__/__init__.cpython-38.pyc +0 -0
  6. versatile_diffusion/lib/data_factory/common/__init__.py +6 -0
  7. versatile_diffusion/lib/data_factory/common/__pycache__/__init__.cpython-310.pyc +0 -0
  8. versatile_diffusion/lib/data_factory/common/__pycache__/__init__.cpython-38.pyc +0 -0
  9. versatile_diffusion/lib/data_factory/common/__pycache__/ds_base.cpython-310.pyc +0 -0
  10. versatile_diffusion/lib/data_factory/common/__pycache__/ds_base.cpython-38.pyc +0 -0
  11. versatile_diffusion/lib/data_factory/common/__pycache__/ds_estimator.cpython-310.pyc +0 -0
  12. versatile_diffusion/lib/data_factory/common/__pycache__/ds_estimator.cpython-38.pyc +0 -0
  13. versatile_diffusion/lib/data_factory/common/__pycache__/ds_formatter.cpython-310.pyc +0 -0
  14. versatile_diffusion/lib/data_factory/common/__pycache__/ds_formatter.cpython-38.pyc +0 -0
  15. versatile_diffusion/lib/data_factory/common/__pycache__/ds_loader.cpython-310.pyc +0 -0
  16. versatile_diffusion/lib/data_factory/common/__pycache__/ds_loader.cpython-38.pyc +0 -0
  17. versatile_diffusion/lib/data_factory/common/__pycache__/ds_sampler.cpython-310.pyc +0 -0
  18. versatile_diffusion/lib/data_factory/common/__pycache__/ds_sampler.cpython-38.pyc +0 -0
  19. versatile_diffusion/lib/data_factory/common/__pycache__/ds_transform.cpython-310.pyc +0 -0
  20. versatile_diffusion/lib/data_factory/common/__pycache__/ds_transform.cpython-38.pyc +0 -0
  21. versatile_diffusion/lib/data_factory/common/ds_base.py +280 -0
  22. versatile_diffusion/lib/data_factory/common/ds_estimator.py +85 -0
  23. versatile_diffusion/lib/data_factory/common/ds_formatter.py +39 -0
  24. versatile_diffusion/lib/data_factory/common/ds_loader.py +97 -0
  25. versatile_diffusion/lib/data_factory/common/ds_sampler.py +273 -0
  26. versatile_diffusion/lib/data_factory/common/ds_transform.py +178 -0
  27. versatile_diffusion/lib/data_factory/ds_laion2b_webdataset.py +221 -0
  28. versatile_diffusion/lib/evaluator/__init__.py +1 -0
  29. versatile_diffusion/lib/evaluator/__pycache__/__init__.cpython-310.pyc +0 -0
  30. versatile_diffusion/lib/evaluator/__pycache__/__init__.cpython-38.pyc +0 -0
  31. versatile_diffusion/lib/evaluator/__pycache__/eva_base.cpython-310.pyc +0 -0
  32. versatile_diffusion/lib/evaluator/__pycache__/eva_base.cpython-38.pyc +0 -0
  33. versatile_diffusion/lib/evaluator/eva_base.py +293 -0
  34. versatile_diffusion/lib/evaluator/eva_null.py +26 -0
  35. versatile_diffusion/lib/experiments/__init__.py +0 -0
  36. versatile_diffusion/lib/experiments/__pycache__/__init__.cpython-310.pyc +0 -0
  37. versatile_diffusion/lib/experiments/__pycache__/__init__.cpython-38.pyc +0 -0
  38. versatile_diffusion/lib/experiments/__pycache__/sd_default.cpython-310.pyc +0 -0
  39. versatile_diffusion/lib/experiments/__pycache__/sd_default.cpython-38.pyc +0 -0
  40. versatile_diffusion/lib/experiments/sd_default.py +441 -0
  41. versatile_diffusion/lib/experiments/vd_default.py +549 -0
  42. versatile_diffusion/lib/model_zoo/__init__.py +4 -0
  43. versatile_diffusion/lib/model_zoo/__pycache__/__init__.cpython-310.pyc +0 -0
  44. versatile_diffusion/lib/model_zoo/__pycache__/__init__.cpython-38.pyc +0 -0
  45. versatile_diffusion/lib/model_zoo/__pycache__/attention.cpython-310.pyc +0 -0
  46. versatile_diffusion/lib/model_zoo/__pycache__/attention.cpython-38.pyc +0 -0
  47. versatile_diffusion/lib/model_zoo/__pycache__/autoencoder.cpython-310.pyc +0 -0
  48. versatile_diffusion/lib/model_zoo/__pycache__/autoencoder.cpython-38.pyc +0 -0
  49. versatile_diffusion/lib/model_zoo/__pycache__/clip.cpython-310.pyc +0 -0
  50. versatile_diffusion/lib/model_zoo/__pycache__/clip.cpython-38.pyc +0 -0
.gitattributes CHANGED
@@ -2984,3 +2984,4 @@ results/versatile_diffusion/subj01/97.png filter=lfs diff=lfs merge=lfs -text
2984
  results/versatile_diffusion/subj01/roi/12.png filter=lfs diff=lfs merge=lfs -text
2985
  results/versatile_diffusion/subj01/roi/2.png filter=lfs diff=lfs merge=lfs -text
2986
  results/versatile_diffusion/subj01/roi/3.png filter=lfs diff=lfs merge=lfs -text
 
 
2984
  results/versatile_diffusion/subj01/roi/12.png filter=lfs diff=lfs merge=lfs -text
2985
  results/versatile_diffusion/subj01/roi/2.png filter=lfs diff=lfs merge=lfs -text
2986
  results/versatile_diffusion/subj01/roi/3.png filter=lfs diff=lfs merge=lfs -text
2987
+ results/versatile_diffusion/subj01/roi/4.png filter=lfs diff=lfs merge=lfs -text
results/versatile_diffusion/subj01/roi/4.png ADDED

Git LFS Details

  • SHA256: ec73857ae27b4acd80809ddda20ac07a2f986075639bb301e7a9fdfe4fa0367f
  • Pointer size: 131 Bytes
  • Size of remote file: 176 kB
versatile_diffusion/lib/__pycache__/log_service.cpython-38.pyc ADDED
Binary file (4.96 kB). View file
 
versatile_diffusion/lib/data_factory/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (535 Bytes). View file
 
versatile_diffusion/lib/data_factory/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (495 Bytes). View file
 
versatile_diffusion/lib/data_factory/common/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .ds_base import ds_base, collate, register as regdataset
2
+ from .ds_loader import pre_loader_checkings, register as regloader
3
+ from .ds_transform import TBase, have, register as regtrans
4
+ from .ds_estimator import register as regestmat
5
+ from .ds_formatter import register as regformat
6
+ from .ds_sampler import register as regsampler
versatile_diffusion/lib/data_factory/common/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (551 Bytes). View file
 
versatile_diffusion/lib/data_factory/common/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (511 Bytes). View file
 
versatile_diffusion/lib/data_factory/common/__pycache__/ds_base.cpython-310.pyc ADDED
Binary file (7.95 kB). View file
 
versatile_diffusion/lib/data_factory/common/__pycache__/ds_base.cpython-38.pyc ADDED
Binary file (7.92 kB). View file
 
versatile_diffusion/lib/data_factory/common/__pycache__/ds_estimator.cpython-310.pyc ADDED
Binary file (3.46 kB). View file
 
versatile_diffusion/lib/data_factory/common/__pycache__/ds_estimator.cpython-38.pyc ADDED
Binary file (3.4 kB). View file
 
versatile_diffusion/lib/data_factory/common/__pycache__/ds_formatter.cpython-310.pyc ADDED
Binary file (1.68 kB). View file
 
versatile_diffusion/lib/data_factory/common/__pycache__/ds_formatter.cpython-38.pyc ADDED
Binary file (1.63 kB). View file
 
versatile_diffusion/lib/data_factory/common/__pycache__/ds_loader.cpython-310.pyc ADDED
Binary file (3.27 kB). View file
 
versatile_diffusion/lib/data_factory/common/__pycache__/ds_loader.cpython-38.pyc ADDED
Binary file (3.25 kB). View file
 
versatile_diffusion/lib/data_factory/common/__pycache__/ds_sampler.cpython-310.pyc ADDED
Binary file (8.97 kB). View file
 
versatile_diffusion/lib/data_factory/common/__pycache__/ds_sampler.cpython-38.pyc ADDED
Binary file (8.92 kB). View file
 
versatile_diffusion/lib/data_factory/common/__pycache__/ds_transform.cpython-310.pyc ADDED
Binary file (5.43 kB). View file
 
versatile_diffusion/lib/data_factory/common/__pycache__/ds_transform.cpython-38.pyc ADDED
Binary file (5.41 kB). View file
 
versatile_diffusion/lib/data_factory/common/ds_base.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import numpy as np
4
+ import numpy.random as npr
5
+ import torch
6
+ import torch.distributed as dist
7
+ import torchvision
8
+ import copy
9
+ import itertools
10
+
11
+ from ... import sync
12
+ from ...cfg_holder import cfg_unique_holder as cfguh
13
+ from ...log_service import print_log
14
+
15
+ import torch.distributed as dist
16
+ from multiprocessing import shared_memory
17
+
18
+ # import multiprocessing
19
+ # if hasattr(multiprocessing, "shared_memory"):
20
+ # from multiprocessing import shared_memory
21
+ # else:
22
+ # # workaround for single gpu inference on colab
23
+ # shared_memory = None
24
+
25
+ import pickle
26
+ import hashlib
27
+ import random
28
+
29
+ class ds_base(torch.utils.data.Dataset):
30
+ def __init__(self,
31
+ cfg,
32
+ loader = None,
33
+ estimator = None,
34
+ transforms = None,
35
+ formatter = None):
36
+
37
+ self.cfg = cfg
38
+ self.load_info = None
39
+ self.init_load_info()
40
+ self.loader = loader
41
+ self.transforms = transforms
42
+ self.formatter = formatter
43
+
44
+ if self.load_info is not None:
45
+ load_info_order_by = getattr(self.cfg, 'load_info_order_by', 'default')
46
+ if load_info_order_by == 'default':
47
+ self.load_info = sorted(self.load_info, key=lambda x:x['unique_id'])
48
+ else:
49
+ try:
50
+ load_info_order_by, reverse = load_info_order_by.split('|')
51
+ reverse = reverse == 'reverse'
52
+ except:
53
+ reverse = False
54
+ self.load_info = sorted(
55
+ self.load_info, key=lambda x:x[load_info_order_by], reverse=reverse)
56
+
57
+ load_info_add_idx = getattr(self.cfg, 'load_info_add_idx', True)
58
+ if (self.load_info is not None) and load_info_add_idx:
59
+ for idx, info in enumerate(self.load_info):
60
+ info['idx'] = idx
61
+
62
+ if estimator is not None:
63
+ self.load_info = estimator(self.load_info)
64
+
65
+ self.try_sample = getattr(self.cfg, 'try_sample', None)
66
+ if self.try_sample is not None:
67
+ try:
68
+ start, end = self.try_sample
69
+ except:
70
+ start, end = 0, self.try_sample
71
+ self.load_info = self.load_info[start:end]
72
+
73
+ self.repeat = getattr(self.cfg, 'repeat', 1)
74
+
75
+ pick = getattr(self.cfg, 'pick', None)
76
+ if pick is not None:
77
+ self.load_info = [i for i in self.load_info if i['filename'] in pick]
78
+
79
+ #########
80
+ # cache #
81
+ #########
82
+
83
+ self.cache_sm = getattr(self.cfg, 'cache_sm', False)
84
+ self.cache_cnt = 0
85
+ if self.cache_sm:
86
+ self.cache_pct = getattr(self.cfg, 'cache_pct', 0)
87
+ cache_unique_id = sync.nodewise_sync().random_sync_id()
88
+ self.cache_unique_id = hashlib.sha256(pickle.dumps(cache_unique_id)).hexdigest()
89
+ self.__cache__(self.cache_pct)
90
+
91
+ #######
92
+ # log #
93
+ #######
94
+
95
+ if self.load_info is not None:
96
+ console_info = '{}: '.format(self.__class__.__name__)
97
+ console_info += 'total {} unique images, '.format(len(self.load_info))
98
+ console_info += 'total {} unique sample. Cached {}. Repeat {} times.'.format(
99
+ len(self.load_info), self.cache_cnt, self.repeat)
100
+ else:
101
+ console_info = '{}: load_info not ready.'.format(self.__class__.__name__)
102
+ print_log(console_info)
103
+
104
+ def init_load_info(self):
105
+ # implement by sub class
106
+ pass
107
+
108
+ def __len__(self):
109
+ return len(self.load_info)*self.repeat
110
+
111
+ def __cache__(self, pct):
112
+ if pct == 0:
113
+ self.cache_cnt = 0
114
+ return
115
+ self.cache_cnt = int(len(self.load_info)*pct)
116
+ if not self.cache_sm:
117
+ for i in range(self.cache_cnt):
118
+ self.load_info[i] = self.loader(self.load_info[i])
119
+ return
120
+
121
+ for i in range(self.cache_cnt):
122
+ shm_name = str(self.load_info[i]['unique_id']) + '_' + self.cache_unique_id
123
+ if i % self.local_world_size == self.local_rank:
124
+ data = pickle.dumps(self.loader(self.load_info[i]))
125
+ datan = len(data)
126
+ # self.print_smname_to_file(shm_name)
127
+ shm = shared_memory.SharedMemory(
128
+ name=shm_name, create=True, size=datan)
129
+ shm.buf[0:datan] = data[0:datan]
130
+ shm.close()
131
+ self.load_info[i] = shm_name
132
+ else:
133
+ self.load_info[i] = shm_name
134
+ dist.barrier()
135
+
136
+ def __getitem__(self, idx):
137
+ idx = idx%len(self.load_info)
138
+ # element = copy.deepcopy(self.load_info[idx])
139
+
140
+ # 0730 try shared memory
141
+ element = copy.deepcopy(self.load_info[idx])
142
+ if isinstance(element, str):
143
+ shm = shared_memory.SharedMemory(name=element)
144
+ element = pickle.loads(shm.buf)
145
+ shm.close()
146
+ else:
147
+ element = copy.deepcopy(element)
148
+ element['load_info_ptr'] = self.load_info
149
+
150
+ if idx >= self.cache_cnt:
151
+ element = self.loader(element)
152
+ if self.transforms is not None:
153
+ element = self.transforms(element)
154
+ if self.formatter is not None:
155
+ return self.formatter(element)
156
+ else:
157
+ return element
158
+
159
+ # 0730 try shared memory
160
+ def __del__(self):
161
+ # Clean the shared memory
162
+ for infoi in self.load_info:
163
+ if isinstance(infoi, str) and (self.local_rank==0):
164
+ shm = shared_memory.SharedMemory(name=infoi)
165
+ shm.close()
166
+ shm.unlink()
167
+
168
+ def print_smname_to_file(self, smname):
169
+ try:
170
+ log_file = cfguh().cfg.train.log_file
171
+ except:
172
+ try:
173
+ log_file = cfguh().cfg.eval.log_file
174
+ except:
175
+ raise ValueError
176
+ # a trick to use the log_file path
177
+ sm_file = log_file.replace('.log', '.smname')
178
+ with open(sm_file, 'a') as f:
179
+ f.write(smname + '\n')
180
+
181
+ def singleton(class_):
182
+ instances = {}
183
+ def getinstance(*args, **kwargs):
184
+ if class_ not in instances:
185
+ instances[class_] = class_(*args, **kwargs)
186
+ return instances[class_]
187
+ return getinstance
188
+
189
+ from .ds_loader import get_loader
190
+ from .ds_transform import get_transform
191
+ from .ds_estimator import get_estimator
192
+ from .ds_formatter import get_formatter
193
+
194
+ @singleton
195
+ class get_dataset(object):
196
+ def __init__(self):
197
+ self.dataset = {}
198
+
199
+ def register(self, ds):
200
+ self.dataset[ds.__name__] = ds
201
+
202
+ def __call__(self, cfg):
203
+ if cfg is None:
204
+ return None
205
+ t = cfg.type
206
+ if t is None:
207
+ return None
208
+ elif t in ['laion2b', 'laion2b_dummy',
209
+ 'laion2b_webdataset',
210
+ 'laion2b_webdataset_sdofficial', ]:
211
+ from .. import ds_laion2b
212
+ elif t in ['coyo', 'coyo_dummy',
213
+ 'coyo_webdataset', ]:
214
+ from .. import ds_coyo_webdataset
215
+ elif t in ['laionart', 'laionart_dummy',
216
+ 'laionart_webdataset', ]:
217
+ from .. import ds_laionart
218
+ elif t in ['celeba']:
219
+ from .. import ds_celeba
220
+ elif t in ['div2k']:
221
+ from .. import ds_div2k
222
+ elif t in ['pafc']:
223
+ from .. import ds_pafc
224
+ elif t in ['coco_caption']:
225
+ from .. import ds_coco
226
+ else:
227
+ raise ValueError
228
+
229
+ loader = get_loader() (cfg.get('loader' , None))
230
+ transform = get_transform()(cfg.get('transform', None))
231
+ estimator = get_estimator()(cfg.get('estimator', None))
232
+ formatter = get_formatter()(cfg.get('formatter', None))
233
+
234
+ return self.dataset[t](
235
+ cfg, loader, estimator,
236
+ transform, formatter)
237
+
238
+ def register():
239
+ def wrapper(class_):
240
+ get_dataset().register(class_)
241
+ return class_
242
+ return wrapper
243
+
244
+ # some other helpers
245
+
246
+ class collate(object):
247
+ """
248
+ Modified from torch.utils.data._utils.collate
249
+ It handle list different from the default.
250
+ List collate just by append each other.
251
+ """
252
+ def __init__(self):
253
+ self.default_collate = \
254
+ torch.utils.data._utils.collate.default_collate
255
+
256
+ def __call__(self, batch):
257
+ """
258
+ Args:
259
+ batch: [data, data] -or- [(data1, data2, ...), (data1, data2, ...)]
260
+ This function will not be used as induction function
261
+ """
262
+ elem = batch[0]
263
+ if not (elem, (tuple, list)):
264
+ return self.default_collate(batch)
265
+
266
+ rv = []
267
+ # transposed
268
+ for i in zip(*batch):
269
+ if isinstance(i[0], list):
270
+ if len(i[0]) != 1:
271
+ raise ValueError
272
+ try:
273
+ i = [[self.default_collate(ii).squeeze(0)] for ii in i]
274
+ except:
275
+ pass
276
+ rvi = list(itertools.chain.from_iterable(i))
277
+ rv.append(rvi) # list concat
278
+ else:
279
+ rv.append(self.default_collate(i))
280
+ return rv
versatile_diffusion/lib/data_factory/common/ds_estimator.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import numpy as np
3
+ import numpy.random as npr
4
+ import PIL
5
+ import cv2
6
+
7
+ import torch
8
+ import torchvision
9
+ import xml.etree.ElementTree as ET
10
+ import json
11
+ import copy
12
+ import math
13
+
14
+ def singleton(class_):
15
+ instances = {}
16
+ def getinstance(*args, **kwargs):
17
+ if class_ not in instances:
18
+ instances[class_] = class_(*args, **kwargs)
19
+ return instances[class_]
20
+ return getinstance
21
+
22
+ @singleton
23
+ class get_estimator(object):
24
+ def __init__(self):
25
+ self.estimator = {}
26
+
27
+ def register(self, estimf):
28
+ self.estimator[estimf.__name__] = estimf
29
+
30
+ def __call__(self, cfg):
31
+ if cfg is None:
32
+ return None
33
+ t = cfg.type
34
+ return self.estimator[t](**cfg.args)
35
+
36
+ def register():
37
+ def wrapper(class_):
38
+ get_estimator().register(class_)
39
+ return class_
40
+ return wrapper
41
+
42
+ @register()
43
+ class PickFileEstimator(object):
44
+ """
45
+ This is an estimator that filter load_info
46
+ using the provided filelist
47
+ """
48
+ def __init__(self,
49
+ filelist = None,
50
+ repeat_n = 1):
51
+ """
52
+ Args:
53
+ filelist: a list of string gives the name of images
54
+ we would like to visualize, evaluate or train.
55
+ repeat_n: int, times these images will be repeated
56
+ """
57
+ self.filelist = filelist
58
+ self.repeat_n = repeat_n
59
+
60
+ def __call__(self, load_info):
61
+ load_info_new = []
62
+ for info in load_info:
63
+ if os.path.basename(info['image_path']).split('.')[0] in self.filelist:
64
+ load_info_new.append(info)
65
+ return load_info_new * self.repeat_n
66
+
67
+ @register()
68
+ class PickIndexEstimator(object):
69
+ """
70
+ This is an estimator that filter load_info
71
+ using the provided indices
72
+ """
73
+ def __init__(self,
74
+ indexlist = None,
75
+ **kwargs):
76
+ """
77
+ Args:
78
+ indexlist: [] of int.
79
+ the indices to be filtered out.
80
+ """
81
+ self.indexlist = indexlist
82
+
83
+ def __call__(self, load_info):
84
+ load_info_new = [load_info[i] for i in self.indexlist]
85
+ return load_info_new
versatile_diffusion/lib/data_factory/common/ds_formatter.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import numpy as np
4
+ import numpy.random as npr
5
+ import torch
6
+ import cv2
7
+ import scipy.ndimage
8
+ from PIL import Image
9
+ import copy
10
+ import gc
11
+ import itertools
12
+
13
+ def singleton(class_):
14
+ instances = {}
15
+ def getinstance(*args, **kwargs):
16
+ if class_ not in instances:
17
+ instances[class_] = class_(*args, **kwargs)
18
+ return instances[class_]
19
+ return getinstance
20
+
21
+ @singleton
22
+ class get_formatter(object):
23
+ def __init__(self):
24
+ self.formatter = {}
25
+
26
+ def register(self, formatf):
27
+ self.formatter[formatf.__name__] = formatf
28
+
29
+ def __call__(self, cfg):
30
+ if cfg is None:
31
+ return None
32
+ t = cfg.type
33
+ return self.formatter[t](**cfg.args)
34
+
35
+ def register():
36
+ def wrapper(class_):
37
+ get_formatter().register(class_)
38
+ return class_
39
+ return wrapper
versatile_diffusion/lib/data_factory/common/ds_loader.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import numpy as np
3
+ import numpy.random as npr
4
+ import PIL
5
+ import cv2
6
+
7
+ import torch
8
+ import torchvision
9
+ import xml.etree.ElementTree as ET
10
+ import json
11
+ import copy
12
+
13
+ from ...cfg_holder import cfg_unique_holder as cfguh
14
+
15
+ def singleton(class_):
16
+ instances = {}
17
+ def getinstance(*args, **kwargs):
18
+ if class_ not in instances:
19
+ instances[class_] = class_(*args, **kwargs)
20
+ return instances[class_]
21
+ return getinstance
22
+
23
+ @singleton
24
+ class get_loader(object):
25
+ def __init__(self):
26
+ self.loader = {}
27
+
28
+ def register(self, loadf):
29
+ self.loader[loadf.__name__] = loadf
30
+
31
+ def __call__(self, cfg):
32
+ if cfg is None:
33
+ return None
34
+ if isinstance(cfg, list):
35
+ loader = []
36
+ for ci in cfg:
37
+ t = ci.type
38
+ loader.append(self.loader[t](**ci.args))
39
+ return compose(loader)
40
+ t = cfg.type
41
+ return self.loader[t](**cfg.args)
42
+
43
+ class compose(object):
44
+ def __init__(self, loaders):
45
+ self.loaders = loaders
46
+
47
+ def __call__(self, element):
48
+ for l in self.loaders:
49
+ element = l(element)
50
+ return element
51
+
52
+ def __getitem__(self, idx):
53
+ return self.loaders[idx]
54
+
55
+ def register():
56
+ def wrapper(class_):
57
+ get_loader().register(class_)
58
+ return class_
59
+ return wrapper
60
+
61
+ def pre_loader_checkings(ltype):
62
+ lpath = ltype+'_path'
63
+ # cache feature added on 20201021
64
+ lcache = ltype+'_cache'
65
+ def wrapper(func):
66
+ def inner(self, element):
67
+ if lcache in element:
68
+ # cache feature added on 20201021
69
+ data = element[lcache]
70
+ else:
71
+ if ltype in element:
72
+ raise ValueError
73
+ if lpath not in element:
74
+ raise ValueError
75
+
76
+ if element[lpath] is None:
77
+ data = None
78
+ else:
79
+ data = func(self, element[lpath], element)
80
+ element[ltype] = data
81
+
82
+ if ltype == 'image':
83
+ if isinstance(data, np.ndarray):
84
+ imsize = data.shape[-2:]
85
+ elif isinstance(data, PIL.Image.Image):
86
+ imsize = data.size[::-1]
87
+ elif isinstance(data, torch.Tensor):
88
+ imsize = [data.size(-2), data.size(-1)]
89
+ elif data is None:
90
+ imsize = None
91
+ else:
92
+ raise ValueError
93
+ element['imsize'] = imsize
94
+ element['imsize_current'] = copy.deepcopy(imsize)
95
+ return element
96
+ return inner
97
+ return wrapper
versatile_diffusion/lib/data_factory/common/ds_sampler.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tokenize import group
2
+ import torch
3
+ import numpy as np
4
+ import numpy.random as npr
5
+ import torch.distributed as dist
6
+ import math
7
+
8
+ from ...log_service import print_log
9
+ from ... import sync
10
+
11
+ def singleton(class_):
12
+ instances = {}
13
+ def getinstance(*args, **kwargs):
14
+ if class_ not in instances:
15
+ instances[class_] = class_(*args, **kwargs)
16
+ return instances[class_]
17
+ return getinstance
18
+
19
+ @singleton
20
+ class get_sampler(object):
21
+ def __init__(self):
22
+ self.sampler = {}
23
+
24
+ def register(self, sampler):
25
+ self.sampler[sampler.__name__] = sampler
26
+
27
+ def __call__(self, dataset, cfg):
28
+ if cfg == 'default_train':
29
+ return GlobalDistributedSampler(dataset, shuffle=True, extend=False)
30
+ elif cfg == 'default_eval':
31
+ return GlobalDistributedSampler(dataset, shuffle=False, extend=True)
32
+ else:
33
+ t = cfg.type
34
+ return self.sampler[t](dataset=dataset, **cfg.args)
35
+
36
+ def register():
37
+ def wrapper(class_):
38
+ get_sampler().register(class_)
39
+ return class_
40
+ return wrapper
41
+
42
+ ######################
43
+ # DistributedSampler #
44
+ ######################
45
+
46
+ @register()
47
+ class GlobalDistributedSampler(torch.utils.data.Sampler):
48
+ """
49
+ This is a distributed sampler that sync accross gpus and nodes.
50
+ """
51
+ def __init__(self,
52
+ dataset,
53
+ shuffle=True,
54
+ extend=False,):
55
+ """
56
+ Arguments:
57
+ dataset: Dataset used for sampling.
58
+ shuffle: If true, sampler will shuffle the indices
59
+ extend: If true, sampler will extend the indices that can be even distributed by ranks
60
+ otherwise sampler will truncate the indices to make it even.
61
+ """
62
+ self.ddp = sync.is_ddp()
63
+ self.rank = sync.get_rank('global')
64
+ self.world_size = sync.get_world_size('global')
65
+ self.dataset = dataset
66
+ self.shuffle = shuffle
67
+ self.extend = extend
68
+
69
+ num_samples = len(dataset) // self.world_size
70
+ if extend and (len(dataset)%self.world_size != 0):
71
+ num_samples+=1
72
+ self.num_samples = num_samples
73
+ self.total_size = num_samples * self.world_size
74
+
75
+ def __iter__(self):
76
+ indices = self.get_sync_order()
77
+ if self.extend:
78
+ # extend using the front indices
79
+ indices = indices+indices[0:self.total_size-len(indices)]
80
+ else:
81
+ # truncate
82
+ indices = indices[0:self.total_size]
83
+ # subsample
84
+ indices = indices[self.rank : len(indices) : self.world_size]
85
+ return iter(indices)
86
+
87
+ def __len__(self):
88
+ return self.num_samples
89
+
90
+ def get_sync_order(self):
91
+ if self.shuffle:
92
+ indices = torch.randperm(len(self.dataset)).to(self.rank)
93
+ if self.ddp:
94
+ dist.broadcast(indices, src=0)
95
+ indices = indices.to('cpu').tolist()
96
+ else:
97
+ indices = list(range(len(self.dataset)))
98
+ print_log('Sampler : {}'.format(str(indices[0:5])) )
99
+ return indices
100
+
101
+ @register()
102
+ class LocalDistributedSampler(GlobalDistributedSampler):
103
+ """
104
+ This is a distributed sampler that sync across gpus within the nodes.
105
+ But not sync across nodes.
106
+ """
107
+ def __init__(self,
108
+ dataset,
109
+ shuffle=True,
110
+ extend=False,):
111
+ super().__init__(dataset, shuffle, extend)
112
+ self.rank = sync.get_rank('local')
113
+ self.world_size = sync.get_world_size('local')
114
+
115
+ def get_sync_order(self):
116
+ if self.shuffle:
117
+ if self.rank == 0:
118
+ indices = list(npr.permutation(len(self.dataset)))
119
+ sync.nodewise_sync().broadcast_r0(indices)
120
+ else:
121
+ indices = sync.nodewise_sync().broadcast_r0(None)
122
+ else:
123
+ indices = list(range(len(self.dataset)))
124
+ print_log('Sampler : {}'.format(str(indices[0:5])) )
125
+ return indices
126
+
127
+ ############################
128
+ # random sample with group #
129
+ ############################
130
+ # Deprecated
131
+
132
+ @register()
133
+ class GroupSampler(torch.utils.data.Sampler):
134
+ """
135
+ This is a new DistributedSampler that sample all index according to group.
136
+ i.e.
137
+ if group_size=3, num_replicas=2, train mode:
138
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
139
+ ==> (group) [0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10]
140
+ ==> (distribute) process0: [3, 4, 5], (leftover [6, 7, 8, 9, 10])
141
+ process1: [0, 1, 2]
142
+ ==> (group leftover) process0: [3, 4, 5], (leftover [6, 7], [8, 9], 10)
143
+ process1: [0, 1, 2]
144
+ ==> (distribute) process0: [3, 4, 5], [6, 7] (remove 10)
145
+ process1: [0, 1, 2], [8, 9]
146
+
147
+ it will avoid_batchsize=1:
148
+ 0, 1, 2, 3, 4, 5, 6, 7, 8,
149
+ ==> (group) [0, 1, 2], [3, 4, 5], [6, 7, 8]
150
+ ==> (distribute) process0: [3, 4, 5], (leftover [6, 7, 8])
151
+ process1: [0, 1, 2]
152
+ ==> (group leftover) process0: [3, 4, 5], (leftover [6], [7], [8])
153
+ process1: [0, 1, 2]
154
+ ==> (distribute) process0: [3, 4, 5], (remove 6, 7, 8) (because distribute make batchsize 1)
155
+ process1: [0, 1, 2]
156
+
157
+ if group_size=3, num_replicas=2, eval mode:
158
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
159
+ ==> (extend) 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10
160
+ ==> (group) [0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 10]
161
+ ==> (distribute) process0: [0, 1, 2], [6, 7, 8],
162
+ process1: [3, 4, 5], [9, 10, 10]
163
+ """
164
+
165
+ def __init__(self,
166
+ dataset,
167
+ group_size,
168
+ num_replicas=None,
169
+ rank=None,
170
+ mode='train',):
171
+ if num_replicas is None:
172
+ if not dist.is_available():
173
+ raise ValueError
174
+ num_replicas = dist.get_world_size()
175
+ if rank is None:
176
+ if not dist.is_available():
177
+ raise ValueError
178
+ rank = dist.get_rank()
179
+
180
+ self.dataset = dataset
181
+ self.len_dataset = len(dataset)
182
+ self.group_size = group_size
183
+ self.num_replicas = num_replicas
184
+ self.rank = rank
185
+ self.mode = mode
186
+ len_dataset = self.len_dataset
187
+
188
+ if (len_dataset % num_replicas != 0) and (mode == 'train'):
189
+ # drop the non_aligned
190
+ aligned_indices = np.arange(len_dataset)[:-(len_dataset % num_replicas)]
191
+ aligned_len_dataset = aligned_indices.shape[0]
192
+ elif (len_dataset % num_replicas != 0) and (mode == 'eval'):
193
+ extend = np.array([len_dataset-1 for _ in range(num_replicas - len_dataset % num_replicas)])
194
+ aligned_indices = np.concatenate([range(len_dataset), extend])
195
+ aligned_len_dataset = aligned_indices.shape[0]
196
+ else:
197
+ aligned_indices = np.arange(len_dataset)
198
+ aligned_len_dataset = len_dataset
199
+
200
+ num_even_distributed_groups = aligned_len_dataset // (group_size * num_replicas)
201
+ num_even = num_even_distributed_groups * group_size * num_replicas
202
+
203
+ self.regular_groups = aligned_indices[0:num_even].reshape(-1, group_size)
204
+ self.leftover_groups = aligned_indices[num_even:].reshape(num_replicas, -1)
205
+
206
+ if self.leftover_groups.size == 0:
207
+ self.leftover_groups = None
208
+ elif (self.leftover_groups.shape[-1]==1) and (mode == 'train'):
209
+ # avoid bs=1
210
+ self.leftover_groups = None
211
+
212
+ # a urly way to modify dataset.load_info according to the grouping
213
+ for groupi in self.regular_groups:
214
+ for idx in groupi:
215
+ idx_lowerbd = groupi[0]
216
+ idx_upperbd = groupi[-1]
217
+ idx_reference = (idx_lowerbd+idx_upperbd)//2
218
+ dataset.load_info[idx]['ref_size'] = dataset.load_info[idx_reference]['image_size']
219
+ if self.leftover_groups is not None:
220
+ for groupi in self.leftover_groups:
221
+ for idx in groupi:
222
+ idx_lowerbd = groupi[0]
223
+ idx_upperbd = groupi[-1]
224
+ idx_reference = (idx_lowerbd+idx_upperbd)//2
225
+ dataset.load_info[idx]['ref_size'] = dataset.load_info[idx_reference]['image_size']
226
+
227
+ def concat(self, nparrays, axis=0):
228
+ # a helper for save concaternation
229
+ nparrays = [i for i in nparrays if i.size > 0]
230
+ return np.concatenate(nparrays, axis=axis)
231
+
232
+ def __iter__(self):
233
+ indices = self.get_sync_order()
234
+ return iter(indices)
235
+
236
+ def __len__(self):
237
+ return self.num_samples
238
+
239
+ def get_sync_order(self):
240
+ # g = torch.Generator()
241
+ # g.manual_seed(self.epoch)
242
+
243
+ mode = self.mode
244
+ rank = self.rank
245
+ num_replicas = self.num_replicas
246
+ group_size = self.group_size
247
+ num_groups = len(self.regular_groups)
248
+
249
+ if mode == 'train':
250
+ g_indices = torch.randperm(num_groups).to(rank)
251
+ dist.broadcast(g_indices, src=0)
252
+ g_indices = g_indices.to('cpu').tolist()
253
+ num_groups_per_rank = num_groups // num_replicas
254
+ groups = self.regular_groups[g_indices][num_groups_per_rank*rank : num_groups_per_rank*(rank+1)]
255
+ indices = groups.flatten()
256
+
257
+ if self.leftover_groups is not None:
258
+ leftg_indices = torch.randperm(len(self.leftover_groups)).to(rank)
259
+ dist.broadcast(leftg_indices, src=0)
260
+ leftg_indices = leftg_indices.to('cpu').tolist()
261
+ last = self.leftover_groups[leftg_indices][rank]
262
+ indices = np.concatenate([indices, last], axis=0)
263
+ elif mode == 'eval':
264
+ groups = self.regular_groups.reshape(-1, num_replicas, group_size)[:, rank, :]
265
+ indices = groups.flatten()
266
+ if self.leftover_groups is not None:
267
+ last = self.leftover_groups[rank]
268
+ indices = np.concatenate([indices, last], axis=0)
269
+ else:
270
+ raise ValueError
271
+
272
+ print_log('Sampler RANK {} : {}'.format(rank, str(indices[0:group_size+1])))
273
+ return indices
versatile_diffusion/lib/data_factory/common/ds_transform.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import numpy as np
3
+ import numpy.random as npr
4
+ import PIL
5
+ import cv2
6
+
7
+ import torch
8
+ import torchvision
9
+ import xml.etree.ElementTree as ET
10
+ import json
11
+ import copy
12
+ import math
13
+
14
+ def singleton(class_):
15
+ instances = {}
16
+ def getinstance(*args, **kwargs):
17
+ if class_ not in instances:
18
+ instances[class_] = class_(*args, **kwargs)
19
+ return instances[class_]
20
+ return getinstance
21
+
22
+ @singleton
23
+ class get_transform(object):
24
+ def __init__(self):
25
+ self.transform = {}
26
+
27
+ def register(self, transf):
28
+ self.transform[transf.__name__] = transf
29
+
30
+ def __call__(self, cfg):
31
+ if cfg is None:
32
+ return None
33
+ if isinstance(cfg, list):
34
+ loader = []
35
+ for ci in cfg:
36
+ t = ci.type
37
+ loader.append(self.transform[t](**ci.args))
38
+ return compose(loader)
39
+ t = cfg.type
40
+ return self.transform[t](**cfg.args)
41
+
42
+ def register():
43
+ def wrapper(class_):
44
+ get_transform().register(class_)
45
+ return class_
46
+ return wrapper
47
+
48
+ def have(must=[], may=[]):
49
+ """
50
+ The nextgen decorator that have two list of
51
+ input tells what category the transform
52
+ will operate on.
53
+ Args:
54
+ must: [] of str,
55
+ the names of the items that must be included
56
+ inside the element.
57
+ If element[name] exist: do the transform
58
+ If element[name] is None: raise Exception.
59
+ If element[name] not exist: raise Exception.
60
+ may: [] of str,
61
+ the names of the items that may be contained
62
+ inside the element for transform.
63
+ If element[name] exist: do the transform
64
+ If element[name] is None: ignore it.
65
+ If element[name] not exist: ignore it.
66
+ """
67
+ def route(self, item, e, d):
68
+ """
69
+ Route the element to a proper function
70
+ for calculation.
71
+ Args:
72
+ self: object,
73
+ the transform functor.
74
+ item: str,
75
+ the item name of the data.
76
+ e: {},
77
+ the element
78
+ d: nparray, tensor or PIL.Image,
79
+ the data to transform.
80
+ """
81
+ if isinstance(d, np.ndarray):
82
+ dtype = 'nparray'
83
+ elif isinstance(d, torch.Tensor):
84
+ dtype = 'tensor'
85
+ elif isinstance(d, PIL.Image.Image):
86
+ dtype = 'pilimage'
87
+ else:
88
+ raise ValueError
89
+
90
+ # find function by order
91
+ f = None
92
+ for attrname in [
93
+ 'exec_{}_{}'.format(item, dtype),
94
+ 'exec_{}'.format(item),
95
+ 'exec_{}'.format(dtype),
96
+ 'exec']:
97
+ f = getattr(self, attrname, None)
98
+ if f is not None:
99
+ break
100
+ d, e = f(d, e)
101
+ e[item] = d
102
+ return e
103
+
104
+ def wrapper(func):
105
+ def inner(self, e):
106
+ e['imsize_previous'] = e['imsize_current']
107
+ imsize_tag_cnt = 0
108
+ imsize_tag = 'imsize_before_' + self.__class__.__name__
109
+ while True:
110
+ if imsize_tag_cnt != 0:
111
+ tag = imsize_tag + str(imsize_tag_cnt)
112
+ else:
113
+ tag = imsize_tag
114
+ if not tag in e:
115
+ e[tag] = e['imsize_current']
116
+ break
117
+ imsize_tag_cnt += 1
118
+
119
+ e = func(self, e)
120
+ # must transform list
121
+ for item in must:
122
+ try:
123
+ d = e[item]
124
+ except:
125
+ raise ValueError
126
+ if d is None:
127
+ raise ValueError
128
+ e = route(self, item, e, d)
129
+ # may transform list
130
+ for item in may:
131
+ try:
132
+ d = e[item]
133
+ except:
134
+ d = None
135
+ if d is not None:
136
+ e = route(self, item, e, d)
137
+ return e
138
+ return inner
139
+ return wrapper
140
+
141
+ class compose(object):
142
+ def __init__(self, transforms):
143
+ self.transforms = transforms
144
+
145
+ def __call__(self, element):
146
+ for t in self.transforms:
147
+ element = t(element)
148
+ return element
149
+
150
+ class TBase(object):
151
+ def __init__(self):
152
+ pass
153
+
154
+ def exec(self, data, element):
155
+ raise ValueError
156
+
157
+ def rand(self,
158
+ uid,
159
+ tag,
160
+ rand_f,
161
+ *args,
162
+ **kwargs):
163
+ """
164
+ Args:
165
+ uid: string element['unique_id']
166
+ tag: string tells the tag uses when tracking the random number.
167
+ Or the tag to restore the tracked random number.
168
+ rand_f: the random function use to generate random number.
169
+ **kwargs: the argument for the given random function.
170
+ """
171
+ # if rnduh().hdata is not None:
172
+ # return rnduh().get_history(uid, self.__class__.__name__, tag)
173
+ # if rnduh().record_path is None:
174
+ # return rand_f(*args, **kwargs)
175
+ # the special mode to create the random file.
176
+ d = rand_f(*args, **kwargs)
177
+ # rnduh().record(uid, self.__class__.__name__, tag, d)
178
+ return d
versatile_diffusion/lib/data_factory/ds_laion2b_webdataset.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import numpy as np
4
+ import numpy.random as npr
5
+ import torch
6
+ import torch.distributed as dist
7
+ import torchvision.transforms as tvtrans
8
+ import PIL.Image
9
+ PIL.Image.MAX_IMAGE_PIXELS = None
10
+ import math
11
+ import json
12
+ import copy
13
+ import pickle
14
+ from multiprocessing import shared_memory
15
+ import time
16
+ from .common import *
17
+ from ..log_service import print_log
18
+
19
+ from lib import visual_service as vis
20
+ from .. import sync
21
+
22
+ import webdataset as wds
23
+
24
+ ###################################################
25
+ # this is a special ds that use webdataset mainly #
26
+ ###################################################
27
+
28
+ @regdataset()
29
+ class laion2b_dummy(ds_base):
30
+ def init_load_info(self):
31
+ self.load_info = []
32
+
33
+ @regdataset()
34
+ class laion2b_webdataset(ds_base):
35
+ def init_load_info(self):
36
+ self.load_info = []
37
+
38
+ def make_loader(self, batch_size, num_workers, train=True):
39
+ cfg = self.cfg
40
+ self.root_dir = cfg.root_dir
41
+
42
+ interpolation_mode = tvtrans.InterpolationMode.BICUBIC
43
+ if train:
44
+ trans = [
45
+ tvtrans.Resize(cfg.scale, interpolation=interpolation_mode),
46
+ tvtrans.RandomCrop(cfg.scale),
47
+ tvtrans.ToTensor(),]
48
+ else:
49
+ trans = [
50
+ tvtrans.Resize(cfg.scale, interpolation=interpolation_mode),
51
+ tvtrans.CenterCrop(cfg.scale),
52
+ tvtrans.ToTensor(),]
53
+
54
+ trans = tvtrans.Compose(trans)
55
+
56
+ trans_dict = {'jpg': trans}
57
+ postprocess = customized_postprocess
58
+
59
+ shuffle = cfg.get('shuffle', 10000)
60
+ shardshuffle = shuffle > 0
61
+ node_world_size = sync.get_world_size('node')
62
+ nodesplitter = wds.shardlists.split_by_node \
63
+ if node_world_size==1 else wds.shardlists.single_node_only
64
+
65
+ tars = [osp.join(self.root_dir, 'data', i) for i in os.listdir(osp.join(self.root_dir, 'data'))
66
+ if osp.splitext(i)[1]=='.tar']
67
+ tars = sorted(tars)
68
+
69
+ dset = wds.WebDataset(
70
+ tars,
71
+ nodesplitter=nodesplitter,
72
+ shardshuffle=shardshuffle,
73
+ handler=wds.warn_and_continue).repeat().shuffle(shuffle)
74
+
75
+ print_log(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.')
76
+ self.min_size = cfg.get('min_size', None)
77
+ self.max_pwatermark = cfg.get('max_pwatermark', None)
78
+ dset = (dset
79
+ .select(self.filter_keys)
80
+ .decode('pil', handler=wds.warn_and_continue)
81
+ .select(self.filter_size)
82
+ .map_dict(**trans_dict, handler=wds.warn_and_continue))
83
+
84
+ if postprocess is not None:
85
+ dset = dset.map(postprocess)
86
+
87
+ dset.batched(batch_size, partial=False)
88
+
89
+ loader = wds.WebLoader(
90
+ dset,
91
+ batch_size=None,
92
+ shuffle=False,
93
+ num_workers=num_workers, )
94
+ return loader
95
+
96
+ def filter_size(self, x):
97
+ try:
98
+ valid = True
99
+ if self.min_size is not None and self.min_size > 1:
100
+ try:
101
+ valid = valid and x['json']['original_width'] >= self.min_size and \
102
+ x['json']['original_height'] >= self.min_size
103
+ except Exception:
104
+ valid = False
105
+ if self.max_pwatermark is not None and self.max_pwatermark < 1.0:
106
+ try:
107
+ valid = valid and x['json']['pwatermark'] <= self.max_pwatermark
108
+ except Exception:
109
+ valid = False
110
+ return valid
111
+ except Exception:
112
+ return False
113
+
114
+ def filter_keys(self, x):
115
+ try:
116
+ return ("jpg" in x) and ("txt" in x)
117
+ except Exception:
118
+ return False
119
+
120
+ def train_dataloader(self):
121
+ return self.make_loader(self.train)
122
+
123
+ def val_dataloader(self):
124
+ return self.make_loader(self.validation, train=False)
125
+
126
+ def test_dataloader(self):
127
+ return self.make_loader(self.test, train=False)
128
+
129
+ def customized_postprocess(element):
130
+ return element['jpg']*2-1, element['txt'], element['__key__']
131
+
132
+ def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
133
+ keys = set.intersection(*[set(sample.keys()) for sample in samples])
134
+ batched = {key: [] for key in keys}
135
+
136
+ for s in samples:
137
+ [batched[key].append(s[key]) for key in batched]
138
+
139
+ result = {}
140
+ for key in batched:
141
+ if isinstance(batched[key][0], (int, float)):
142
+ if combine_scalars:
143
+ result[key] = np.array(list(batched[key]))
144
+ elif isinstance(batched[key][0], torch.Tensor):
145
+ if combine_tensors:
146
+ result[key] = torch.stack(list(batched[key]))
147
+ elif isinstance(batched[key][0], np.ndarray):
148
+ if combine_tensors:
149
+ result[key] = np.array(list(batched[key]))
150
+ else:
151
+ result[key] = list(batched[key])
152
+ return result
153
+
154
+ ###################
155
+ # for sd official #
156
+ ###################
157
+
158
+ def customized_postprocess_sdofficial(element):
159
+ return {
160
+ 'jpg': element['jpg']*2-1,
161
+ 'txt': element['txt'], }
162
+
163
+ @regdataset()
164
+ class laion2b_webdataset_sdofficial(laion2b_webdataset):
165
+ def make_loader(self, batch_size, num_workers, train=True):
166
+ cfg = self.cfg
167
+ self.root_dir = cfg.root_dir
168
+
169
+ interpolation_mode = tvtrans.InterpolationMode.BICUBIC
170
+ if train:
171
+ trans = [
172
+ tvtrans.Resize(cfg.scale, interpolation=interpolation_mode),
173
+ tvtrans.RandomCrop(cfg.scale),
174
+ tvtrans.ToTensor(),]
175
+ else:
176
+ trans = [
177
+ tvtrans.Resize(cfg.scale, interpolation=interpolation_mode),
178
+ tvtrans.CenterCrop(cfg.scale),
179
+ tvtrans.ToTensor(),]
180
+
181
+ trans = tvtrans.Compose(trans)
182
+
183
+ trans_dict = {'jpg': trans}
184
+ postprocess = customized_postprocess_sdofficial
185
+
186
+ shuffle = 10000
187
+ shardshuffle = shuffle > 0
188
+ node_world_size = 1
189
+ nodesplitter = wds.shardlists.split_by_node \
190
+ if node_world_size==1 else wds.shardlists.single_node_only
191
+
192
+ tars = [osp.join(self.root_dir, 'data', i) for i in os.listdir(osp.join(self.root_dir, 'data'))
193
+ if osp.splitext(i)[1]=='.tar']
194
+ tars = sorted(tars)
195
+
196
+ dset = wds.WebDataset(
197
+ tars,
198
+ nodesplitter=nodesplitter,
199
+ shardshuffle=shardshuffle,
200
+ handler=wds.warn_and_continue).repeat().shuffle(shuffle)
201
+
202
+ print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.')
203
+ self.min_size = cfg.get('min_size', None)
204
+ self.max_pwatermark = cfg.get('max_pwatermark', None)
205
+ dset = (dset
206
+ .select(self.filter_keys)
207
+ .decode('pil', handler=wds.warn_and_continue)
208
+ .select(self.filter_size)
209
+ .map_dict(**trans_dict, handler=wds.warn_and_continue))
210
+
211
+ if postprocess is not None:
212
+ dset = dset.map(postprocess)
213
+
214
+ dset.batched(batch_size, partial=False, collation_fn=dict_collation_fn)
215
+
216
+ loader = wds.WebLoader(
217
+ dset,
218
+ batch_size=None,
219
+ shuffle=False,
220
+ num_workers=num_workers, )
221
+ return loader
versatile_diffusion/lib/evaluator/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .eva_base import get_evaluator
versatile_diffusion/lib/evaluator/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (235 Bytes). View file
 
versatile_diffusion/lib/evaluator/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (195 Bytes). View file
 
versatile_diffusion/lib/evaluator/__pycache__/eva_base.cpython-310.pyc ADDED
Binary file (8.67 kB). View file
 
versatile_diffusion/lib/evaluator/__pycache__/eva_base.cpython-38.pyc ADDED
Binary file (8.82 kB). View file
 
versatile_diffusion/lib/evaluator/eva_base.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+
4
+ import os
5
+ import os.path as osp
6
+ import numpy as np
7
+ import cv2
8
+ import copy
9
+ import json
10
+
11
+ from ..log_service import print_log
12
+
13
+ def singleton(class_):
14
+ instances = {}
15
+ def getinstance(*args, **kwargs):
16
+ if class_ not in instances:
17
+ instances[class_] = class_(*args, **kwargs)
18
+ return instances[class_]
19
+ return getinstance
20
+
21
+ @singleton
22
+ class get_evaluator(object):
23
+ def __init__(self):
24
+ self.evaluator = {}
25
+
26
+ def register(self, evaf, name):
27
+ self.evaluator[name] = evaf
28
+
29
+ def __call__(self, pipeline_cfg=None):
30
+ if pipeline_cfg is None:
31
+ from . import eva_null
32
+ return self.evaluator['null']()
33
+
34
+ if not isinstance(pipeline_cfg, list):
35
+ t = pipeline_cfg.type
36
+ if t == 'miou':
37
+ from . import eva_miou
38
+ if t == 'psnr':
39
+ from . import eva_psnr
40
+ if t == 'ssim':
41
+ from . import eva_ssim
42
+ if t == 'lpips':
43
+ from . import eva_lpips
44
+ if t == 'fid':
45
+ from . import eva_fid
46
+ return self.evaluator[t](**pipeline_cfg.args)
47
+
48
+ evaluator = []
49
+ for ci in pipeline_cfg:
50
+ t = ci.type
51
+ if t == 'miou':
52
+ from . import eva_miou
53
+ if t == 'psnr':
54
+ from . import eva_psnr
55
+ if t == 'ssim':
56
+ from . import eva_ssim
57
+ if t == 'lpips':
58
+ from . import eva_lpips
59
+ if t == 'fid':
60
+ from . import eva_fid
61
+ evaluator.append(
62
+ self.evaluator[t](**ci.args))
63
+ if len(evaluator) == 0:
64
+ return None
65
+ else:
66
+ return compose(evaluator)
67
+
68
+ def register(name):
69
+ def wrapper(class_):
70
+ get_evaluator().register(class_, name)
71
+ return class_
72
+ return wrapper
73
+
74
+ class base_evaluator(object):
75
+ def __init__(self,
76
+ **args):
77
+ '''
78
+ Args:
79
+ sample_n, int,
80
+ the total number of sample. used in
81
+ distributed sync
82
+ '''
83
+ if not dist.is_available():
84
+ raise ValueError
85
+ self.world_size = dist.get_world_size()
86
+ self.rank = dist.get_rank()
87
+ self.sample_n = None
88
+ self.final = {}
89
+
90
+ def sync(self, data):
91
+ """
92
+ Args:
93
+ data: any,
94
+ the data needs to be broadcasted
95
+ """
96
+ if data is None:
97
+ return None
98
+
99
+ if isinstance(data, tuple):
100
+ data = list(data)
101
+
102
+ if isinstance(data, list):
103
+ data_list = []
104
+ for datai in data:
105
+ data_list.append(self.sync(datai))
106
+ data = [[*i] for i in zip(*data_list)]
107
+ return data
108
+
109
+ data = [
110
+ self.sync_(data, ranki)
111
+ for ranki in range(self.world_size)
112
+ ]
113
+ return data
114
+
115
+ def sync_(self, data, rank):
116
+
117
+ t = type(data)
118
+ is_broadcast = rank == self.rank
119
+
120
+ if t is np.ndarray:
121
+ dtrans = data
122
+ dt = data.dtype
123
+ if dt in [
124
+ int,
125
+ np.bool,
126
+ np.uint8,
127
+ np.int8,
128
+ np.int16,
129
+ np.int32,
130
+ np.int64,]:
131
+ dtt = torch.int64
132
+ elif dt in [
133
+ float,
134
+ np.float16,
135
+ np.float32,
136
+ np.float64,]:
137
+ dtt = torch.float64
138
+
139
+ elif t is str:
140
+ dtrans = np.array(
141
+ [ord(c) for c in data],
142
+ dtype = np.int64
143
+ )
144
+ dt = np.int64
145
+ dtt = torch.int64
146
+ else:
147
+ raise ValueError
148
+
149
+ if is_broadcast:
150
+ n = len(dtrans.shape)
151
+ n = torch.tensor(n).long()
152
+
153
+ n = n.to(self.rank)
154
+ dist.broadcast(n, src=rank)
155
+
156
+ n = list(dtrans.shape)
157
+ n = torch.tensor(n).long()
158
+ n = n.to(self.rank)
159
+ dist.broadcast(n, src=rank)
160
+
161
+ n = torch.tensor(dtrans, dtype=dtt)
162
+ n = n.to(self.rank)
163
+ dist.broadcast(n, src=rank)
164
+ return data
165
+
166
+ n = torch.tensor(0).long()
167
+ n = n.to(self.rank)
168
+ dist.broadcast(n, src=rank)
169
+ n = n.item()
170
+
171
+ n = torch.zeros(n).long()
172
+ n = n.to(self.rank)
173
+ dist.broadcast(n, src=rank)
174
+ n = list(n.to('cpu').numpy())
175
+
176
+ n = torch.zeros(n, dtype=dtt)
177
+ n = n.to(self.rank)
178
+ dist.broadcast(n, src=rank)
179
+ n = n.to('cpu').numpy().astype(dt)
180
+
181
+ if t is np.ndarray:
182
+ return n
183
+ elif t is str:
184
+ n = ''.join([chr(c) for c in n])
185
+ return n
186
+
187
+ def zipzap_arrange(self, data):
188
+ '''
189
+ Order the data so it range like this:
190
+ input [[0, 2, 4, 6], [1, 3, 5, 7]] -> output [0, 1, 2, 3, 4, 5, ...]
191
+ '''
192
+ if isinstance(data[0], list):
193
+ data_new = []
194
+ maxlen = max([len(i) for i in data])
195
+ totlen = sum([len(i) for i in data])
196
+ cnt = 0
197
+ for idx in range(maxlen):
198
+ for datai in data:
199
+ data_new += [datai[idx]]
200
+ cnt += 1
201
+ if cnt >= totlen:
202
+ break
203
+ return data_new
204
+
205
+ elif isinstance(data[0], np.ndarray):
206
+ maxlen = max([i.shape[0] for i in data])
207
+ totlen = sum([i.shape[0] for i in data])
208
+ datai_shape = data[0].shape[1:]
209
+ data = [
210
+ np.concatenate(datai, np.zeros(maxlen-datai.shape[0], *datai_shape), axis=0)
211
+ if datai.shape[0] < maxlen else datai
212
+ for datai in data
213
+ ] # even the array
214
+ data = np.stack(data, axis=1).reshape(-1, *datai_shape)
215
+ data = data[:totlen]
216
+ return data
217
+
218
+ else:
219
+ raise NotImplementedError
220
+
221
+ def add_batch(self, **args):
222
+ raise NotImplementedError
223
+
224
+ def set_sample_n(self, sample_n):
225
+ self.sample_n = sample_n
226
+
227
+ def compute(self):
228
+ raise NotImplementedError
229
+
230
+ # Function needed in training to judge which
231
+ # evaluated number is better
232
+ def isbetter(self, old, new):
233
+ return new>old
234
+
235
+ def one_line_summary(self):
236
+ print_log('Evaluator display')
237
+
238
+ def save(self, path):
239
+ if not osp.exists(path):
240
+ os.makedirs(path)
241
+ ofile = osp.join(path, 'result.json')
242
+ with open(ofile, 'w') as f:
243
+ json.dump(self.final, f, indent=4)
244
+
245
+ def clear_data(self):
246
+ raise NotImplementedError
247
+
248
+ class compose(object):
249
+ def __init__(self, pipeline):
250
+ self.pipeline = pipeline
251
+ self.sample_n = None
252
+ self.final = {}
253
+
254
+ def add_batch(self, *args, **kwargs):
255
+ for pi in self.pipeline:
256
+ pi.add_batch(*args, **kwargs)
257
+
258
+ def set_sample_n(self, sample_n):
259
+ self.sample_n = sample_n
260
+ for pi in self.pipeline:
261
+ pi.set_sample_n(sample_n)
262
+
263
+ def compute(self):
264
+ rv = {}
265
+ for pi in self.pipeline:
266
+ rv[pi.symbol] = pi.compute()
267
+ self.final[pi.symbol] = pi.final
268
+ return rv
269
+
270
+ def isbetter(self, old, new):
271
+ check = 0
272
+ for pi in self.pipeline:
273
+ if pi.isbetter(old, new):
274
+ check+=1
275
+ if check/len(self.pipeline)>0.5:
276
+ return True
277
+ else:
278
+ return False
279
+
280
+ def one_line_summary(self):
281
+ for pi in self.pipeline:
282
+ pi.one_line_summary()
283
+
284
+ def save(self, path):
285
+ if not osp.exists(path):
286
+ os.makedirs(path)
287
+ ofile = osp.join(path, 'result.json')
288
+ with open(ofile, 'w') as f:
289
+ json.dump(self.final, f, indent=4)
290
+
291
+ def clear_data(self):
292
+ for pi in self.pipeline:
293
+ pi.clear_data()
versatile_diffusion/lib/evaluator/eva_null.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import lpips
4
+
5
+ from .. import nputils
6
+ from ..log_service import print_log
7
+
8
+ from .eva_base import base_evaluator, register
9
+
10
+ @register('null')
11
+ class null_evaluator(base_evaluator):
12
+ def __init__(self, **dummy):
13
+ super().__init__()
14
+
15
+ def add_batch(self,
16
+ **dummy):
17
+ pass
18
+
19
+ def compute(self):
20
+ return None
21
+
22
+ def one_line_summary(self):
23
+ print_log('Evaluator null')
24
+
25
+ def clear_data(self):
26
+ pass
versatile_diffusion/lib/experiments/__init__.py ADDED
File without changes
versatile_diffusion/lib/experiments/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (191 Bytes). View file
 
versatile_diffusion/lib/experiments/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (151 Bytes). View file
 
versatile_diffusion/lib/experiments/__pycache__/sd_default.cpython-310.pyc ADDED
Binary file (13.5 kB). View file
 
versatile_diffusion/lib/experiments/__pycache__/sd_default.cpython-38.pyc ADDED
Binary file (13.5 kB). View file
 
versatile_diffusion/lib/experiments/sd_default.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ from torchvision import transforms as tvtrans
4
+ import os
5
+ import os.path as osp
6
+ import time
7
+ import timeit
8
+ import copy
9
+ import json
10
+ import pickle
11
+ import PIL.Image
12
+ import numpy as np
13
+ from datetime import datetime
14
+ from easydict import EasyDict as edict
15
+ from collections import OrderedDict
16
+
17
+ from lib.cfg_holder import cfg_unique_holder as cfguh
18
+ from lib.data_factory import get_dataset, get_sampler, collate
19
+ from lib.model_zoo import \
20
+ get_model, get_optimizer, get_scheduler
21
+ from lib.log_service import print_log
22
+
23
+ from ..utils import train as train_base
24
+ from ..utils import eval as eval_base
25
+ from ..utils import train_stage as tsbase
26
+ from ..utils import eval_stage as esbase
27
+ from .. import sync
28
+
29
+ ###############
30
+ # some helper #
31
+ ###############
32
+
33
+ def atomic_save(cfg, net, opt, step, path):
34
+ if isinstance(net, (torch.nn.DataParallel,
35
+ torch.nn.parallel.DistributedDataParallel)):
36
+ netm = net.module
37
+ else:
38
+ netm = net
39
+ sd = netm.state_dict()
40
+ slimmed_sd = [(ki, vi) for ki, vi in sd.items()
41
+ if ki.find('first_stage_model')!=0 and ki.find('cond_stage_model')!=0]
42
+
43
+ checkpoint = {
44
+ "config" : cfg,
45
+ "state_dict" : OrderedDict(slimmed_sd),
46
+ "step" : step}
47
+ if opt is not None:
48
+ checkpoint['optimizer_states'] = opt.state_dict()
49
+ import io
50
+ import fsspec
51
+ bytesbuffer = io.BytesIO()
52
+ torch.save(checkpoint, bytesbuffer)
53
+ with fsspec.open(path, "wb") as f:
54
+ f.write(bytesbuffer.getvalue())
55
+
56
+ def load_state_dict(net, cfg):
57
+ pretrained_pth_full = cfg.get('pretrained_pth_full' , None)
58
+ pretrained_ckpt_full = cfg.get('pretrained_ckpt_full', None)
59
+ pretrained_pth = cfg.get('pretrained_pth' , None)
60
+ pretrained_ckpt = cfg.get('pretrained_ckpt' , None)
61
+ pretrained_pth_dm = cfg.get('pretrained_pth_dm' , None)
62
+ pretrained_pth_ema = cfg.get('pretrained_pth_ema' , None)
63
+ strict_sd = cfg.get('strict_sd', False)
64
+ errmsg = "Overlapped model state_dict! This is undesired behavior!"
65
+
66
+ if pretrained_pth_full is not None or pretrained_ckpt_full is not None:
67
+ assert (pretrained_pth is None) and \
68
+ (pretrained_ckpt is None) and \
69
+ (pretrained_pth_dm is None) and \
70
+ (pretrained_pth_ema is None), errmsg
71
+ if pretrained_pth_full is not None:
72
+ target_file = pretrained_pth_full
73
+ sd = torch.load(target_file, map_location='cpu')
74
+ assert pretrained_ckpt is None, errmsg
75
+ else:
76
+ target_file = pretrained_ckpt_full
77
+ sd = torch.load(target_file, map_location='cpu')['state_dict']
78
+ print_log('Load full model from [{}] strict [{}].'.format(
79
+ target_file, strict_sd))
80
+ net.load_state_dict(sd, strict=strict_sd)
81
+
82
+ if pretrained_pth is not None or pretrained_ckpt is not None:
83
+ assert (pretrained_ckpt_full is None) and \
84
+ (pretrained_pth_full is None) and \
85
+ (pretrained_pth_dm is None) and \
86
+ (pretrained_pth_ema is None), errmsg
87
+ if pretrained_pth is not None:
88
+ target_file = pretrained_pth
89
+ sd = torch.load(target_file, map_location='cpu')
90
+ assert pretrained_ckpt is None, errmsg
91
+ else:
92
+ target_file = pretrained_ckpt
93
+ sd = torch.load(target_file, map_location='cpu')['state_dict']
94
+ print_log('Load model from [{}] strict [{}].'.format(
95
+ target_file, strict_sd))
96
+ sd_extra = [(ki, vi) for ki, vi in net.state_dict().items() \
97
+ if ki.find('first_stage_model')==0 or ki.find('cond_stage_model')==0]
98
+ sd.update(OrderedDict(sd_extra))
99
+ net.load_state_dict(sd, strict=strict_sd)
100
+
101
+ if pretrained_pth_dm is not None:
102
+ assert (pretrained_ckpt_full is None) and \
103
+ (pretrained_pth_full is None) and \
104
+ (pretrained_pth is None) and \
105
+ (pretrained_ckpt is None), errmsg
106
+ print_log('Load diffusion model from [{}] strict [{}].'.format(
107
+ pretrained_pth_dm, strict_sd))
108
+ sd = torch.load(pretrained_pth_dm, map_location='cpu')
109
+ net.model.diffusion_model.load_state_dict(sd, strict=strict_sd)
110
+
111
+ if pretrained_pth_ema is not None:
112
+ assert (pretrained_ckpt_full is None) and \
113
+ (pretrained_pth_full is None) and \
114
+ (pretrained_pth is None) and \
115
+ (pretrained_ckpt is None), errmsg
116
+ print_log('Load unet ema model from [{}] strict [{}].'.format(
117
+ pretrained_pth_ema, strict_sd))
118
+ sd = torch.load(pretrained_pth_ema, map_location='cpu')
119
+ net.model_ema.load_state_dict(sd, strict=strict_sd)
120
+
121
+ def auto_merge_imlist(imlist, max=64):
122
+ imlist = imlist[0:max]
123
+ h, w = imlist[0].shape[0:2]
124
+ num_images = len(imlist)
125
+ num_row = int(np.sqrt(num_images))
126
+ num_col = num_images//num_row + 1 if num_images%num_row!=0 else num_images//num_row
127
+ canvas = np.zeros([num_row*h, num_col*w, 3], dtype=np.uint8)
128
+ for idx, im in enumerate(imlist):
129
+ hi = (idx // num_col) * h
130
+ wi = (idx % num_col) * w
131
+ canvas[hi:hi+h, wi:wi+w, :] = im
132
+ return canvas
133
+
134
+ def latent2im(net, latent):
135
+ single_input = len(latent.shape) == 3
136
+ if single_input:
137
+ latent = latent[None]
138
+ im = net.decode_image(latent.to(net.device))
139
+ im = torch.clamp((im+1.0)/2.0, min=0.0, max=1.0)
140
+ im = [tvtrans.ToPILImage()(i) for i in im]
141
+ if single_input:
142
+ im = im[0]
143
+ return im
144
+
145
+ def im2latent(net, im):
146
+ single_input = not isinstance(im, list)
147
+ if single_input:
148
+ im = [im]
149
+ im = torch.stack([tvtrans.ToTensor()(i) for i in im], dim=0)
150
+ im = (im*2-1).to(net.device)
151
+ z = net.encode_image(im)
152
+ if single_input:
153
+ z = z[0]
154
+ return z
155
+
156
+ class color_adjust(object):
157
+ def __init__(self, ref_from, ref_to):
158
+ x0, m0, std0 = self.get_data_and_stat(ref_from)
159
+ x1, m1, std1 = self.get_data_and_stat(ref_to)
160
+ self.ref_from_stat = (m0, std0)
161
+ self.ref_to_stat = (m1, std1)
162
+ self.ref_from = self.preprocess(x0).reshape(-1, 3)
163
+ self.ref_to = x1.reshape(-1, 3)
164
+
165
+ def get_data_and_stat(self, x):
166
+ if isinstance(x, str):
167
+ x = np.array(PIL.Image.open(x))
168
+ elif isinstance(x, PIL.Image.Image):
169
+ x = np.array(x)
170
+ elif isinstance(x, torch.Tensor):
171
+ x = torch.clamp(x, min=0.0, max=1.0)
172
+ x = np.array(tvtrans.ToPILImage()(x))
173
+ elif isinstance(x, np.ndarray):
174
+ pass
175
+ else:
176
+ raise ValueError
177
+ x = x.astype(float)
178
+ m = np.reshape(x, (-1, 3)).mean(0)
179
+ s = np.reshape(x, (-1, 3)).std(0)
180
+ return x, m, s
181
+
182
+ def preprocess(self, x):
183
+ m0, s0 = self.ref_from_stat
184
+ m1, s1 = self.ref_to_stat
185
+ y = ((x-m0)/s0)*s1 + m1
186
+ return y
187
+
188
+ def __call__(self, xin, keep=0, simple=False):
189
+ xin, _, _ = self.get_data_and_stat(xin)
190
+ x = self.preprocess(xin)
191
+ if simple:
192
+ y = (x*(1-keep) + xin*keep)
193
+ y = np.clip(y, 0, 255).astype(np.uint8)
194
+ return y
195
+
196
+ h, w = x.shape[:2]
197
+ x = x.reshape(-1, 3)
198
+ y = []
199
+ for chi in range(3):
200
+ yi = self.pdf_transfer_1d(self.ref_from[:, chi], self.ref_to[:, chi], x[:, chi])
201
+ y.append(yi)
202
+
203
+ y = np.stack(y, axis=1)
204
+ y = y.reshape(h, w, 3)
205
+ y = (y.astype(float)*(1-keep) + xin.astype(float)*keep)
206
+ y = np.clip(y, 0, 255).astype(np.uint8)
207
+ return y
208
+
209
+ def pdf_transfer_1d(self, arr_fo, arr_to, arr_in, n=600):
210
+ arr = np.concatenate((arr_fo, arr_to))
211
+ min_v = arr.min() - 1e-6
212
+ max_v = arr.max() + 1e-6
213
+ min_vto = arr_to.min() - 1e-6
214
+ max_vto = arr_to.max() + 1e-6
215
+ xs = np.array(
216
+ [min_v + (max_v - min_v) * i / n for i in range(n + 1)])
217
+ hist_fo, _ = np.histogram(arr_fo, xs)
218
+ hist_to, _ = np.histogram(arr_to, xs)
219
+ xs = xs[:-1]
220
+ # compute probability distribution
221
+ cum_fo = np.cumsum(hist_fo)
222
+ cum_to = np.cumsum(hist_to)
223
+ d_fo = cum_fo / cum_fo[-1]
224
+ d_to = cum_to / cum_to[-1]
225
+ # transfer
226
+ t_d = np.interp(d_fo, d_to, xs)
227
+ t_d[d_fo <= d_to[ 0]] = min_vto
228
+ t_d[d_fo >= d_to[-1]] = max_vto
229
+ arr_out = np.interp(arr_in, xs, t_d)
230
+ return arr_out
231
+
232
+ ########
233
+ # main #
234
+ ########
235
+
236
+ class eval(eval_base):
237
+ def prepare_model(self):
238
+ cfg = cfguh().cfg
239
+ net = get_model()(cfg.model)
240
+ if cfg.env.cuda:
241
+ net.to(self.local_rank)
242
+ load_state_dict(net, cfg.eval) #<--- added
243
+ net = torch.nn.parallel.DistributedDataParallel(
244
+ net, device_ids=[self.local_rank],
245
+ find_unused_parameters=True)
246
+ net.eval()
247
+ return {'net' : net,}
248
+
249
+ class eval_stage(esbase):
250
+ """
251
+ This is eval stage that can check comprehensive results
252
+ """
253
+ def __init__(self):
254
+ from ..model_zoo.ddim import DDIMSampler
255
+ self.sampler = DDIMSampler
256
+
257
+ def get_net(self, paras):
258
+ return paras['net']
259
+
260
+ def get_image_path(self):
261
+ if 'train' in cfguh().cfg:
262
+ log_dir = cfguh().cfg.train.log_dir
263
+ else:
264
+ log_dir = cfguh().cfg.eval.log_dir
265
+ return os.path.join(log_dir, "udemo")
266
+
267
+ @torch.no_grad()
268
+ def sample(self, net, sampler, prompt, output_dim, scale, n_samples, ddim_steps, ddim_eta):
269
+ h, w = output_dim
270
+ uc = None
271
+ if scale != 1.0:
272
+ uc = net.get_learned_conditioning(n_samples * [""])
273
+ c = net.get_learned_conditioning(n_samples * [prompt])
274
+ shape = [4, h//8, w//8]
275
+ rv = sampler.sample(
276
+ S=ddim_steps,
277
+ conditioning=c,
278
+ batch_size=n_samples,
279
+ shape=shape,
280
+ verbose=False,
281
+ unconditional_guidance_scale=scale,
282
+ unconditional_conditioning=uc,
283
+ eta=ddim_eta)
284
+ return rv
285
+
286
+ def save_images(self, pil_list, name, path, suffix=''):
287
+ canvas = auto_merge_imlist([np.array(i) for i in pil_list])
288
+ image_name = '{}{}.png'.format(name, suffix)
289
+ PIL.Image.fromarray(canvas).save(osp.join(path, image_name))
290
+
291
+ def __call__(self, **paras):
292
+ cfg = cfguh().cfg
293
+ cfgv = cfg.eval
294
+
295
+ net = paras['net']
296
+ eval_cnt = paras.get('eval_cnt', None)
297
+ fix_seed = cfgv.get('fix_seed', False)
298
+
299
+ LRANK = sync.get_rank('local')
300
+ LWSIZE = sync.get_world_size('local')
301
+
302
+ image_path = self.get_image_path()
303
+ self.create_dir(image_path)
304
+ eval_cnt = paras.get('eval_cnt', None)
305
+ suffix='' if eval_cnt is None else '_itern'+str(eval_cnt)
306
+
307
+ if isinstance(net, (torch.nn.DataParallel,
308
+ torch.nn.parallel.DistributedDataParallel)):
309
+ netm = net.module
310
+ else:
311
+ netm = net
312
+
313
+ with_ema = getattr(netm, 'model_ema', None) is not None
314
+ sampler = self.sampler(netm)
315
+ setattr(netm, 'device', LRANK) # Trick
316
+
317
+ replicate = cfgv.get('replicate', 1)
318
+ conditioning = cfgv.conditioning * replicate
319
+ conditioning_local = conditioning[LRANK : len(conditioning) : LWSIZE]
320
+ seed_increment = [i for i in range(len(conditioning))][LRANK : len(conditioning) : LWSIZE]
321
+
322
+ for prompti, seedi in zip(conditioning_local, seed_increment):
323
+ if prompti == 'SKIP':
324
+ continue
325
+ draw_filename = prompti.strip().replace(' ', '-')
326
+ if fix_seed:
327
+ np.random.seed(cfg.env.rnd_seed + seedi)
328
+ torch.manual_seed(cfg.env.rnd_seed + seedi + 100)
329
+ suffixi = suffix + "_seed{}".format(cfg.env.rnd_seed + seedi + 100)
330
+ else:
331
+ suffixi = suffix
332
+
333
+ if with_ema:
334
+ with netm.ema_scope():
335
+ x, _ = self.sample(netm, sampler, prompti, **cfgv.sample)
336
+ else:
337
+ x, _ = self.sample(netm, sampler, prompti, **cfgv.sample)
338
+
339
+ demo_image = latent2im(netm, x)
340
+ self.save_images(demo_image, draw_filename, image_path, suffix=suffixi)
341
+
342
+ if eval_cnt is not None:
343
+ print_log('Demo printed for {}'.format(eval_cnt))
344
+ return {}
345
+
346
+ ##################
347
+ # eval variation #
348
+ ##################
349
+
350
+ class eval_stage_variation(eval_stage):
351
+ @torch.no_grad()
352
+ def sample(self, net, sampler, visual_hint, output_dim, scale, n_samples, ddim_steps, ddim_eta):
353
+ h, w = output_dim
354
+ vh = tvtrans.ToTensor()(PIL.Image.open(visual_hint))[None].to(net.device)
355
+ c = net.get_learned_conditioning(vh)
356
+ c = c.repeat(n_samples, 1, 1)
357
+ uc = None
358
+ if scale != 1.0:
359
+ dummy = torch.zeros_like(vh)
360
+ uc = net.get_learned_conditioning(dummy)
361
+ uc = uc.repeat(n_samples, 1, 1)
362
+
363
+ shape = [4, h//8, w//8]
364
+ rv = sampler.sample(
365
+ S=ddim_steps,
366
+ conditioning=c,
367
+ batch_size=n_samples,
368
+ shape=shape,
369
+ verbose=False,
370
+ unconditional_guidance_scale=scale,
371
+ unconditional_conditioning=uc,
372
+ eta=ddim_eta)
373
+ return rv
374
+
375
+ def __call__(self, **paras):
376
+ cfg = cfguh().cfg
377
+ cfgv = cfg.eval
378
+
379
+ net = paras['net']
380
+ eval_cnt = paras.get('eval_cnt', None)
381
+ fix_seed = cfgv.get('fix_seed', False)
382
+
383
+ LRANK = sync.get_rank('local')
384
+ LWSIZE = sync.get_world_size('local')
385
+
386
+ image_path = self.get_image_path()
387
+ self.create_dir(image_path)
388
+ eval_cnt = paras.get('eval_cnt', None)
389
+ suffix='' if eval_cnt is None else '_'+str(eval_cnt)
390
+
391
+ if isinstance(net, (torch.nn.DataParallel,
392
+ torch.nn.parallel.DistributedDataParallel)):
393
+ netm = net.module
394
+ else:
395
+ netm = net
396
+
397
+ with_ema = getattr(netm, 'model_ema', None) is not None
398
+ sampler = self.sampler(netm)
399
+ setattr(netm, 'device', LRANK) # Trick
400
+
401
+ color_adj = cfguh().cfg.eval.get('color_adj', False)
402
+ color_adj_keep_ratio = cfguh().cfg.eval.get('color_adj_keep_ratio', 0.5)
403
+ color_adj_simple = cfguh().cfg.eval.get('color_adj_simple', True)
404
+
405
+ replicate = cfgv.get('replicate', 1)
406
+ conditioning = cfgv.conditioning * replicate
407
+ conditioning_local = conditioning[LRANK : len(conditioning) : LWSIZE]
408
+ seed_increment = [i for i in range(len(conditioning))][LRANK : len(conditioning) : LWSIZE]
409
+
410
+ for ci, seedi in zip(conditioning_local, seed_increment):
411
+ if ci == 'SKIP':
412
+ continue
413
+
414
+ draw_filename = osp.splitext(osp.basename(ci))[0]
415
+
416
+ if fix_seed:
417
+ np.random.seed(cfg.env.rnd_seed + seedi)
418
+ torch.manual_seed(cfg.env.rnd_seed + seedi + 100)
419
+ suffixi = suffix + "_seed{}".format(cfg.env.rnd_seed + seedi + 100)
420
+ else:
421
+ suffixi = suffix
422
+
423
+ if with_ema:
424
+ with netm.ema_scope():
425
+ x, _ = self.sample(netm, sampler, ci, **cfgv.sample)
426
+ else:
427
+ x, _ = self.sample(netm, sampler, ci, **cfgv.sample)
428
+
429
+ demo_image = latent2im(netm, x)
430
+ if color_adj:
431
+ x_adj = []
432
+ for demoi in demo_image:
433
+ color_adj_f = color_adjust(ref_from=demoi, ref_to=ci)
434
+ xi_adj = color_adj_f(demoi, keep=color_adj_keep_ratio, simple=color_adj_simple)
435
+ x_adj.append(xi_adj)
436
+ demo_image = x_adj
437
+ self.save_images(demo_image, draw_filename, image_path, suffix=suffixi)
438
+
439
+ if eval_cnt is not None:
440
+ print_log('Demo printed for {}'.format(eval_cnt))
441
+ return {}
versatile_diffusion/lib/experiments/vd_default.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ from torchvision import transforms as tvtrans
4
+ import os
5
+ import os.path as osp
6
+ import time
7
+ import timeit
8
+ import copy
9
+ import json
10
+ import pickle
11
+ import PIL.Image
12
+ import numpy as np
13
+ from datetime import datetime
14
+ from easydict import EasyDict as edict
15
+ from collections import OrderedDict
16
+
17
+ from lib.cfg_holder import cfg_unique_holder as cfguh
18
+ from lib.data_factory import get_dataset, get_sampler, collate
19
+ from lib.model_zoo import \
20
+ get_model, get_optimizer, get_scheduler
21
+ from lib.log_service import print_log
22
+
23
+ from ..utils import train as train_base
24
+ from ..utils import eval as eval_base
25
+ from ..utils import train_stage as tsbase
26
+ from ..utils import eval_stage as esbase
27
+ from .. import sync
28
+
29
+ from .sd_default import auto_merge_imlist, latent2im, color_adjust
30
+ from .sd_default import eval as eval_base
31
+ from .sd_default import eval_stage as eval_stage_base
32
+
33
+ ###############
34
+ # some helper #
35
+ ###############
36
+
37
+ def atomic_save(cfg, net, opt, step, path):
38
+ if isinstance(net, (torch.nn.DataParallel,
39
+ torch.nn.parallel.DistributedDataParallel)):
40
+ netm = net.module
41
+ else:
42
+ netm = net
43
+ sd = netm.state_dict()
44
+ slimmed_sd = [(ki, vi) for ki, vi in sd.items()
45
+ if ki.find('autokl')!=0 and ki.find('optimus')!=0 and ki.find('clip')!=0]
46
+
47
+ checkpoint = {
48
+ "config" : cfg,
49
+ "state_dict" : OrderedDict(slimmed_sd),
50
+ "step" : step}
51
+ if opt is not None:
52
+ checkpoint['optimizer_states'] = opt.state_dict()
53
+ import io
54
+ import fsspec
55
+ bytesbuffer = io.BytesIO()
56
+ torch.save(checkpoint, bytesbuffer)
57
+ with fsspec.open(path, "wb") as f:
58
+ f.write(bytesbuffer.getvalue())
59
+
60
+ def load_state_dict(net, cfg):
61
+ pretrained_pth_full = cfg.get('pretrained_pth_full' , None)
62
+ pretrained_ckpt_full = cfg.get('pretrained_ckpt_full', None)
63
+ pretrained_pth = cfg.get('pretrained_pth' , None)
64
+ pretrained_ckpt = cfg.get('pretrained_ckpt' , None)
65
+ pretrained_pth_dm = cfg.get('pretrained_pth_dm' , None)
66
+ pretrained_pth_ema = cfg.get('pretrained_pth_ema' , None)
67
+ strict_sd = cfg.get('strict_sd', False)
68
+ errmsg = "Overlapped model state_dict! This is undesired behavior!"
69
+
70
+ if pretrained_pth_full is not None or pretrained_ckpt_full is not None:
71
+ assert (pretrained_pth is None) and \
72
+ (pretrained_ckpt is None) and \
73
+ (pretrained_pth_dm is None) and \
74
+ (pretrained_pth_ema is None), errmsg
75
+ if pretrained_pth_full is not None:
76
+ target_file = pretrained_pth_full
77
+ sd = torch.load(target_file, map_location='cpu')
78
+ assert pretrained_ckpt is None, errmsg
79
+ else:
80
+ target_file = pretrained_ckpt_full
81
+ sd = torch.load(target_file, map_location='cpu')['state_dict']
82
+ print_log('Load full model from [{}] strict [{}].'.format(
83
+ target_file, strict_sd))
84
+ net.load_state_dict(sd, strict=strict_sd)
85
+
86
+ if pretrained_pth is not None or pretrained_ckpt is not None:
87
+ assert (pretrained_ckpt_full is None) and \
88
+ (pretrained_pth_full is None) and \
89
+ (pretrained_pth_dm is None) and \
90
+ (pretrained_pth_ema is None), errmsg
91
+ if pretrained_pth is not None:
92
+ target_file = pretrained_pth
93
+ sd = torch.load(target_file, map_location='cpu')
94
+ assert pretrained_ckpt is None, errmsg
95
+ else:
96
+ target_file = pretrained_ckpt
97
+ sd = torch.load(target_file, map_location='cpu')['state_dict']
98
+ print_log('Load model from [{}] strict [{}].'.format(
99
+ target_file, strict_sd))
100
+ sd_extra = [(ki, vi) for ki, vi in net.state_dict().items() \
101
+ if ki.find('autokl')==0 or ki.find('optimus')==0 or ki.find('clip')==0]
102
+ sd.update(OrderedDict(sd_extra))
103
+ net.load_state_dict(sd, strict=strict_sd)
104
+
105
+ if pretrained_pth_dm is not None:
106
+ assert (pretrained_ckpt_full is None) and \
107
+ (pretrained_pth_full is None) and \
108
+ (pretrained_pth is None) and \
109
+ (pretrained_ckpt is None), errmsg
110
+ print_log('Load diffusion model from [{}] strict [{}].'.format(
111
+ pretrained_pth_dm, strict_sd))
112
+ sd = torch.load(pretrained_pth_dm, map_location='cpu')
113
+ net.model.diffusion_model.load_state_dict(sd, strict=strict_sd)
114
+
115
+ if pretrained_pth_ema is not None:
116
+ assert (pretrained_ckpt_full is None) and \
117
+ (pretrained_pth_full is None) and \
118
+ (pretrained_pth is None) and \
119
+ (pretrained_ckpt is None), errmsg
120
+ print_log('Load unet ema model from [{}] strict [{}].'.format(
121
+ pretrained_pth_ema, strict_sd))
122
+ sd = torch.load(pretrained_pth_ema, map_location='cpu')
123
+ net.model_ema.load_state_dict(sd, strict=strict_sd)
124
+
125
+ ###################
126
+ # official stages #
127
+ ###################
128
+
129
+ class eval(eval_base):
130
+ pass
131
+
132
+ class eval_stage(eval_stage_base):
133
+ """
134
+ Evaluation of both prompt and vision
135
+ """
136
+ def __init__(self):
137
+ from ..model_zoo.ddim_vd import DDIMSampler_VD
138
+ self.sampler = DDIMSampler_VD
139
+
140
+ @torch.no_grad()
141
+ def sample(
142
+ self, net, sampler, context, otype, ctype, image_output_dim, text_latent_dim,
143
+ scale, n_samples, ddim_steps, ddim_eta):
144
+ if ctype == 'prompt':
145
+ c = net.clip_encode_text(n_samples * [context])
146
+ uc = None
147
+ if scale != 1.0:
148
+ uc = net.clip_encode_text(n_samples * [""])
149
+ elif ctype == 'vision':
150
+ context = context[None].repeat(n_samples, 1, 1, 1)
151
+ c = net.clip_encode_vision(context)
152
+ uc = None
153
+ if scale != 1.0:
154
+ dummy = torch.zeros_like(context)
155
+ uc = net.clip_encode_vision(dummy)
156
+
157
+ if otype == 'image':
158
+ h, w = image_output_dim
159
+ shape = [n_samples, 4, h//8, w//8]
160
+ rv = sampler.sample(
161
+ steps=ddim_steps,
162
+ shape=shape,
163
+ conditioning=c,
164
+ unconditional_guidance_scale=scale,
165
+ unconditional_conditioning=uc,
166
+ xtype=otype, ctype=ctype,
167
+ eta=ddim_eta,
168
+ verbose=False,)
169
+ elif otype == 'text':
170
+ n = text_latent_dim
171
+ shape = [n_samples, n]
172
+ rv = sampler.sample(
173
+ steps=ddim_steps,
174
+ shape=shape,
175
+ conditioning=c,
176
+ unconditional_guidance_scale=scale,
177
+ unconditional_conditioning=uc,
178
+ xtype=otype, ctype=ctype,
179
+ eta=ddim_eta,
180
+ verbose=False,)
181
+
182
+ return rv
183
+
184
+ def decode_and_save(
185
+ self, netm, z, xtype, ctype, path, name, suffix,
186
+ color_adj=False, color_adj_to=None):
187
+ if xtype == 'image':
188
+ x = netm.autokl_decode(z)
189
+ name = 't2i_'+name if ctype == 'prompt' else 'v2i_'+name
190
+ if color_adj and (ctype=='vision'):
191
+ keep_ratio = cfguh().cfg.eval.get('color_adj_keep_ratio', 0.5)
192
+ simple = cfguh().cfg.eval.get('color_adj_simple', True)
193
+ x_adj = []
194
+ for xi in x:
195
+ color_adj_f = color_adjust(ref_from=(xi+1)/2, ref_to=color_adj_to)
196
+ xi_adj = color_adj_f((xi+1)/2, keep=keep_ratio, simple=simple)
197
+ x_adj.append(xi_adj)
198
+ x = x_adj
199
+ self.save_images(x, name, path, suffix=suffix)
200
+ elif xtype == 'text':
201
+ prompt_temperature = cfguh().cfg.eval.get('prompt_temperature', 1.0)
202
+ x = netm.optimus_decode(z, temperature=prompt_temperature)
203
+ name = 't2t_'+name if ctype == 'prompt' else 'v2t_'+name
204
+ prompt_merge_same_adj_word = cfguh().cfg.eval.get('prompt_merge_same_adj_word', False)
205
+ if prompt_merge_same_adj_word:
206
+ xnew = []
207
+ for xi in x:
208
+ xi_split = xi.split()
209
+ xinew = []
210
+ for idxi, wi in enumerate(xi_split):
211
+ if idxi!=0 and wi==xi_split[idxi-1]:
212
+ continue
213
+ xinew.append(wi)
214
+ xnew.append(' '.join(xinew))
215
+ x = xnew
216
+ self.save_text(x, name, path, suffix=suffix)
217
+
218
+ def save_images(self, x, name, path, suffix=''):
219
+ if isinstance(x, torch.Tensor):
220
+ single_input = len(x.shape) == 3
221
+ if single_input:
222
+ x = x[None]
223
+ x = torch.clamp((x+1.0)/2.0, min=0.0, max=1.0)
224
+ x = [tvtrans.ToPILImage()(xi) for xi in x]
225
+ xlist = [np.array(xi) for xi in x]
226
+ elif isinstance(x, list):
227
+ xlist = x
228
+ canvas = auto_merge_imlist(xlist)
229
+ image_name = '{}{}.png'.format(name, suffix)
230
+ PIL.Image.fromarray(canvas).save(osp.join(path, image_name))
231
+
232
+ def save_text(self, x, name, path, suffix=''):
233
+ file_name = '{}{}.txt'.format(name, suffix)
234
+ with open(osp.join(path, file_name) ,'w') as f:
235
+ for xi in x:
236
+ f.write(xi+'\n')
237
+
238
+ def __call__(self, **paras):
239
+ cfg = cfguh().cfg
240
+ cfgv = cfg.eval
241
+
242
+ net = self.get_net(paras)
243
+ eval_cnt = paras.get('eval_cnt', None)
244
+ fix_seed = cfgv.get('fix_seed', False)
245
+
246
+ LRANK = sync.get_rank('local')
247
+ LWSIZE = sync.get_world_size('local')
248
+
249
+ output_path = self.get_image_path()
250
+ self.create_dir(output_path)
251
+ eval_cnt = paras.get('eval_cnt', None)
252
+ suffix='' if eval_cnt is None else '_'+str(eval_cnt)
253
+
254
+ if isinstance(net, (torch.nn.DataParallel,
255
+ torch.nn.parallel.DistributedDataParallel)):
256
+ netm = net.module
257
+ else:
258
+ netm = net
259
+
260
+ with_ema = getattr(netm, 'model_ema', None) is not None
261
+ sampler = self.sampler(netm)
262
+ setattr(netm, 'device', LRANK) # Trick
263
+
264
+ color_adj = cfguh().cfg.eval.get('color_adj', False)
265
+
266
+ replicate = cfgv.get('replicate', 1)
267
+ conditioning = cfgv.conditioning * replicate
268
+ conditioning_local = conditioning[LRANK : len(conditioning) : LWSIZE]
269
+ seed_increment = [i for i in range(len(conditioning))][LRANK : len(conditioning) : LWSIZE]
270
+
271
+ for conditioningi, seedi in zip(conditioning_local, seed_increment):
272
+ if conditioningi == 'SKIP':
273
+ continue
274
+
275
+ ci, otypei = conditioningi
276
+
277
+ if osp.isfile(ci):
278
+ # is vision
279
+ output_name = osp.splitext(osp.basename(ci))[0]
280
+ ci = tvtrans.ToTensor()(PIL.Image.open(ci))
281
+ ci = ci*2 - 1
282
+ ctypei = 'vision'
283
+ else:
284
+ # is prompt
285
+ output_name = ci.strip().replace(' ', '-')
286
+ ctypei = 'prompt'
287
+
288
+ if fix_seed:
289
+ np.random.seed(cfg.env.rnd_seed + seedi)
290
+ torch.manual_seed(cfg.env.rnd_seed + seedi + 100)
291
+ suffixi = suffix + "_seed{}".format(cfg.env.rnd_seed + seedi + 100)
292
+ else:
293
+ suffixi = suffix
294
+
295
+ if with_ema:
296
+ with netm.ema_scope():
297
+ z, _ = self.sample(netm, sampler, ci, otypei, ctypei, **cfgv.sample)
298
+ else:
299
+ z, _ = self.sample(netm, sampler, ci, otypei, ctypei, **cfgv.sample)
300
+
301
+ self.decode_and_save(
302
+ netm, z, otypei, ctypei, output_path, output_name, suffixi,
303
+ color_adj=color_adj, color_adj_to=conditioningi[0],)
304
+
305
+ if eval_cnt is not None:
306
+ print_log('Demo printed for {}'.format(eval_cnt))
307
+ return {}
308
+
309
+ ################
310
+ # basic stages #
311
+ ################
312
+
313
+ class eval_stage_basic(eval_stage_base):
314
+ @torch.no_grad()
315
+ def sample(self, net, sampler, visual_hint, output_dim, scale, n_samples, ddim_steps, ddim_eta):
316
+ h, w = output_dim
317
+ vh = PIL.Image.open(visual_hint)
318
+ c = net.clip_encode_vision(n_samples * [vh])
319
+ uc = None
320
+ if scale != 1.0:
321
+ dummy = torch.zeros_like(tvtrans.ToTensor()(vh))
322
+ uc = net.clip_encode_vision(n_samples * [dummy])
323
+
324
+ shape = [4, h//8, w//8]
325
+ rv = sampler.sample(
326
+ S=ddim_steps,
327
+ conditioning=c,
328
+ batch_size=n_samples,
329
+ shape=shape,
330
+ verbose=False,
331
+ unconditional_guidance_scale=scale,
332
+ unconditional_conditioning=uc,
333
+ eta=ddim_eta)
334
+ return rv
335
+
336
+ def __call__(self, **paras):
337
+ cfg = cfguh().cfg
338
+ cfgv = cfg.eval
339
+
340
+ net = paras['net']
341
+ eval_cnt = paras.get('eval_cnt', None)
342
+ fix_seed = cfgv.get('fix_seed', False)
343
+
344
+ LRANK = sync.get_rank('local')
345
+ LWSIZE = sync.get_world_size('local')
346
+
347
+ image_path = self.get_image_path()
348
+ self.create_dir(image_path)
349
+ eval_cnt = paras.get('eval_cnt', None)
350
+ suffix='' if eval_cnt is None else '_'+str(eval_cnt)
351
+
352
+ if isinstance(net, (torch.nn.DataParallel,
353
+ torch.nn.parallel.DistributedDataParallel)):
354
+ netm = net.module
355
+ else:
356
+ netm = net
357
+
358
+ with_ema = getattr(netm, 'model_ema', None) is not None
359
+ sampler = self.sampler(netm)
360
+ setattr(netm, 'device', LRANK) # Trick
361
+
362
+ color_adj = cfguh().cfg.eval.get('color_adj', False)
363
+ color_adj_keep_ratio = cfguh().cfg.eval.get('color_adj_keep_ratio', 0.5)
364
+ color_adj_simple = cfguh().cfg.eval.get('color_adj_simple', True)
365
+
366
+ replicate = cfgv.get('replicate', 1)
367
+ conditioning = cfgv.conditioning * replicate
368
+ conditioning_local = conditioning[LRANK : len(conditioning) : LWSIZE]
369
+ seed_increment = [i for i in range(len(conditioning))][LRANK : len(conditioning) : LWSIZE]
370
+
371
+ for ci, seedi in zip(conditioning_local, seed_increment):
372
+ if ci == 'SKIP':
373
+ continue
374
+ draw_filename = osp.splitext(osp.basename(ci))[0]
375
+ if fix_seed:
376
+ np.random.seed(cfg.env.rnd_seed + seedi)
377
+ torch.manual_seed(cfg.env.rnd_seed + seedi + 100)
378
+ suffixi = suffix + "_seed{}".format(cfg.env.rnd_seed + seedi + 100)
379
+ else:
380
+ suffixi = suffix
381
+
382
+ if with_ema:
383
+ with netm.ema_scope():
384
+ x, _ = self.sample(netm, sampler, ci, **cfgv.sample)
385
+ else:
386
+ x, _ = self.sample(netm, sampler, ci, **cfgv.sample)
387
+
388
+ demo_image = latent2im(netm, x)
389
+ if color_adj:
390
+ x_adj = []
391
+ for demoi in demo_image:
392
+ color_adj_f = color_adjust(ref_from=demoi, ref_to=ci)
393
+ xi_adj = color_adj_f(demoi, keep=color_adj_keep_ratio, simple=color_adj_simple)
394
+ x_adj.append(xi_adj)
395
+ demo_image = x_adj
396
+ self.save_images(demo_image, draw_filename, image_path, suffix=suffixi)
397
+
398
+ if eval_cnt is not None:
399
+ print_log('Demo printed for {}'.format(eval_cnt))
400
+ return {}
401
+
402
+ #######################
403
+ # dual context stages #
404
+ #######################
405
+
406
+ class eval_stage_dc(eval_stage_base):
407
+ def __init__(self):
408
+ from ..model_zoo.ddim_dualcontext import DDIMSampler_DualContext
409
+ self.sampler = DDIMSampler_DualContext
410
+
411
+ @torch.no_grad()
412
+ def sample(
413
+ self, net, sampler, conditioning, output_dim,
414
+ scale, n_samples, ddim_steps, ddim_eta):
415
+ ctype, cvalue =conditioning
416
+ if ctype == 'prompt':
417
+ return self.sample_text(
418
+ net, sampler, cvalue, output_dim,
419
+ scale, n_samples, ddim_steps, ddim_eta)
420
+ elif ctype == 'vision':
421
+ return self.sample_vision(
422
+ net, sampler, cvalue, output_dim,
423
+ scale, n_samples, ddim_steps, ddim_eta)
424
+ else:
425
+ raise ValueError
426
+
427
+ @torch.no_grad()
428
+ def sample_text(
429
+ self, net, sampler, prompt, output_dim,
430
+ scale, n_samples, ddim_steps, ddim_eta):
431
+ h, w = output_dim
432
+ uc = None
433
+ if scale != 1.0:
434
+ uc = net.clip_encode_text(n_samples * [""])
435
+ c = net.clip_encode_text(n_samples * [prompt])
436
+ shape = [n_samples, 4, h//8, w//8]
437
+ rv = sampler.sample_text(
438
+ steps=ddim_steps,
439
+ shape=shape,
440
+ conditioning=c,
441
+ unconditional_guidance_scale=scale,
442
+ unconditional_conditioning=uc,
443
+ eta=ddim_eta,
444
+ verbose=False,)
445
+ return rv
446
+
447
+ @torch.no_grad()
448
+ def sample_vision(
449
+ self, net, sampler, visual_hint, output_dim,
450
+ scale, n_samples, ddim_steps, ddim_eta):
451
+ h, w = output_dim
452
+ if len(visual_hint.shape) == 3:
453
+ visual_hint=visual_hint[None].repeat(n_samples, 1, 1, 1)
454
+ else:
455
+ raise ValueError
456
+
457
+ c = net.clip_encode_vision(visual_hint)
458
+ uc = None
459
+ if scale != 1.0:
460
+ visual_hint_blank = torch.zeros_like(visual_hint)
461
+ uc = net.clip_encode_vision(visual_hint_blank)
462
+
463
+ shape = [n_samples, 4, h//8, w//8]
464
+ rv = sampler.sample_vision(
465
+ steps=ddim_steps,
466
+ shape=shape,
467
+ conditioning=c,
468
+ unconditional_guidance_scale=scale,
469
+ unconditional_conditioning=uc,
470
+ eta=ddim_eta,
471
+ verbose=False,)
472
+ return rv
473
+
474
+ def __call__(self, **paras):
475
+ cfg = cfguh().cfg
476
+ cfgv = cfg.eval
477
+
478
+ net = self.get_net(paras)
479
+ eval_cnt = paras.get('eval_cnt', None)
480
+ fix_seed = cfgv.get('fix_seed', False)
481
+
482
+ LRANK = sync.get_rank('local')
483
+ LWSIZE = sync.get_world_size('local')
484
+
485
+ image_path = self.get_image_path()
486
+ self.create_dir(image_path)
487
+ suffix='' if eval_cnt is None else '_'+str(eval_cnt)
488
+
489
+ if isinstance(net, (torch.nn.DataParallel,
490
+ torch.nn.parallel.DistributedDataParallel)):
491
+ netm = net.module
492
+ else:
493
+ netm = net
494
+
495
+ with_ema = getattr(netm, 'model_ema', None) is not None
496
+ sampler = self.sampler(netm)
497
+ setattr(netm, 'device', LRANK) # Trick
498
+
499
+ color_adj = cfguh().cfg.eval.get('color_adj', False)
500
+ color_adj_keep_ratio = cfguh().cfg.eval.get('color_adj_keep_ratio', 0.5)
501
+ color_adj_simple = cfguh().cfg.eval.get('color_adj_simple', True)
502
+
503
+ replicate = cfgv.get('replicate', 1)
504
+ conditioning = cfgv.conditioning * replicate
505
+ conditioning_local = conditioning[LRANK : len(conditioning) : LWSIZE]
506
+ seed_increment = [i for i in range(len(conditioning))][LRANK : len(conditioning) : LWSIZE]
507
+
508
+ for ci, seedi in zip(conditioning_local, seed_increment):
509
+ if ci == 'SKIP':
510
+ continue
511
+
512
+ if osp.isfile(ci):
513
+ # is vision
514
+ draw_filename = 'v2i_' + osp.splitext(osp.basename(ci))[0]
515
+ ci = tvtrans.ToTensor()(PIL.Image.open(ci))
516
+ ci = ci*2 - 1
517
+ ci = ('vision', ci)
518
+ else:
519
+ # is prompt
520
+ draw_filename = 't2i_' + ci.strip().replace(' ', '-')
521
+ ci = ('prompt', ci)
522
+
523
+ if fix_seed:
524
+ np.random.seed(cfg.env.rnd_seed + seedi)
525
+ torch.manual_seed(cfg.env.rnd_seed + seedi + 100)
526
+ suffixi = suffix + "_seed{}".format(cfg.env.rnd_seed + seedi + 100)
527
+ else:
528
+ suffixi = suffix
529
+
530
+ if with_ema:
531
+ with netm.ema_scope():
532
+ x, _ = self.sample(netm, sampler, ci, **cfgv.sample)
533
+ else:
534
+ x, _ = self.sample(netm, sampler, ci, **cfgv.sample)
535
+
536
+ demo_image = latent2im(netm, x)
537
+ if color_adj and ci[0] == 'vision':
538
+ x_adj = []
539
+ for demoi in demo_image:
540
+ color_adj_f = color_adjust(ref_from=demoi, ref_to=ci[1])
541
+ xi_adj = color_adj_f(demoi, keep=color_adj_keep_ratio, simple=color_adj_simple)
542
+ x_adj.append(xi_adj)
543
+ demo_image = x_adj
544
+ self.save_images(demo_image, draw_filename, image_path, suffix=suffixi)
545
+
546
+ if eval_cnt is not None:
547
+ print_log('Demo printed for {}'.format(eval_cnt))
548
+ return {}
549
+
versatile_diffusion/lib/model_zoo/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .common.get_model import get_model
2
+ from .common.get_optimizer import get_optimizer
3
+ from .common.get_scheduler import get_scheduler
4
+ from .common.utils import get_unit
versatile_diffusion/lib/model_zoo/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (400 Bytes). View file
 
versatile_diffusion/lib/model_zoo/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (360 Bytes). View file
 
versatile_diffusion/lib/model_zoo/__pycache__/attention.cpython-310.pyc ADDED
Binary file (12.5 kB). View file
 
versatile_diffusion/lib/model_zoo/__pycache__/attention.cpython-38.pyc ADDED
Binary file (12.6 kB). View file
 
versatile_diffusion/lib/model_zoo/__pycache__/autoencoder.cpython-310.pyc ADDED
Binary file (5.63 kB). View file
 
versatile_diffusion/lib/model_zoo/__pycache__/autoencoder.cpython-38.pyc ADDED
Binary file (5.61 kB). View file
 
versatile_diffusion/lib/model_zoo/__pycache__/clip.cpython-310.pyc ADDED
Binary file (7.26 kB). View file
 
versatile_diffusion/lib/model_zoo/__pycache__/clip.cpython-38.pyc ADDED
Binary file (7.22 kB). View file