Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +1 -0
- results/versatile_diffusion/subj01/roi/4.png +3 -0
- versatile_diffusion/lib/__pycache__/log_service.cpython-38.pyc +0 -0
- versatile_diffusion/lib/data_factory/__pycache__/__init__.cpython-310.pyc +0 -0
- versatile_diffusion/lib/data_factory/__pycache__/__init__.cpython-38.pyc +0 -0
- versatile_diffusion/lib/data_factory/common/__init__.py +6 -0
- versatile_diffusion/lib/data_factory/common/__pycache__/__init__.cpython-310.pyc +0 -0
- versatile_diffusion/lib/data_factory/common/__pycache__/__init__.cpython-38.pyc +0 -0
- versatile_diffusion/lib/data_factory/common/__pycache__/ds_base.cpython-310.pyc +0 -0
- versatile_diffusion/lib/data_factory/common/__pycache__/ds_base.cpython-38.pyc +0 -0
- versatile_diffusion/lib/data_factory/common/__pycache__/ds_estimator.cpython-310.pyc +0 -0
- versatile_diffusion/lib/data_factory/common/__pycache__/ds_estimator.cpython-38.pyc +0 -0
- versatile_diffusion/lib/data_factory/common/__pycache__/ds_formatter.cpython-310.pyc +0 -0
- versatile_diffusion/lib/data_factory/common/__pycache__/ds_formatter.cpython-38.pyc +0 -0
- versatile_diffusion/lib/data_factory/common/__pycache__/ds_loader.cpython-310.pyc +0 -0
- versatile_diffusion/lib/data_factory/common/__pycache__/ds_loader.cpython-38.pyc +0 -0
- versatile_diffusion/lib/data_factory/common/__pycache__/ds_sampler.cpython-310.pyc +0 -0
- versatile_diffusion/lib/data_factory/common/__pycache__/ds_sampler.cpython-38.pyc +0 -0
- versatile_diffusion/lib/data_factory/common/__pycache__/ds_transform.cpython-310.pyc +0 -0
- versatile_diffusion/lib/data_factory/common/__pycache__/ds_transform.cpython-38.pyc +0 -0
- versatile_diffusion/lib/data_factory/common/ds_base.py +280 -0
- versatile_diffusion/lib/data_factory/common/ds_estimator.py +85 -0
- versatile_diffusion/lib/data_factory/common/ds_formatter.py +39 -0
- versatile_diffusion/lib/data_factory/common/ds_loader.py +97 -0
- versatile_diffusion/lib/data_factory/common/ds_sampler.py +273 -0
- versatile_diffusion/lib/data_factory/common/ds_transform.py +178 -0
- versatile_diffusion/lib/data_factory/ds_laion2b_webdataset.py +221 -0
- versatile_diffusion/lib/evaluator/__init__.py +1 -0
- versatile_diffusion/lib/evaluator/__pycache__/__init__.cpython-310.pyc +0 -0
- versatile_diffusion/lib/evaluator/__pycache__/__init__.cpython-38.pyc +0 -0
- versatile_diffusion/lib/evaluator/__pycache__/eva_base.cpython-310.pyc +0 -0
- versatile_diffusion/lib/evaluator/__pycache__/eva_base.cpython-38.pyc +0 -0
- versatile_diffusion/lib/evaluator/eva_base.py +293 -0
- versatile_diffusion/lib/evaluator/eva_null.py +26 -0
- versatile_diffusion/lib/experiments/__init__.py +0 -0
- versatile_diffusion/lib/experiments/__pycache__/__init__.cpython-310.pyc +0 -0
- versatile_diffusion/lib/experiments/__pycache__/__init__.cpython-38.pyc +0 -0
- versatile_diffusion/lib/experiments/__pycache__/sd_default.cpython-310.pyc +0 -0
- versatile_diffusion/lib/experiments/__pycache__/sd_default.cpython-38.pyc +0 -0
- versatile_diffusion/lib/experiments/sd_default.py +441 -0
- versatile_diffusion/lib/experiments/vd_default.py +549 -0
- versatile_diffusion/lib/model_zoo/__init__.py +4 -0
- versatile_diffusion/lib/model_zoo/__pycache__/__init__.cpython-310.pyc +0 -0
- versatile_diffusion/lib/model_zoo/__pycache__/__init__.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/__pycache__/attention.cpython-310.pyc +0 -0
- versatile_diffusion/lib/model_zoo/__pycache__/attention.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/__pycache__/autoencoder.cpython-310.pyc +0 -0
- versatile_diffusion/lib/model_zoo/__pycache__/autoencoder.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/__pycache__/clip.cpython-310.pyc +0 -0
- versatile_diffusion/lib/model_zoo/__pycache__/clip.cpython-38.pyc +0 -0
.gitattributes
CHANGED
|
@@ -2984,3 +2984,4 @@ results/versatile_diffusion/subj01/97.png filter=lfs diff=lfs merge=lfs -text
|
|
| 2984 |
results/versatile_diffusion/subj01/roi/12.png filter=lfs diff=lfs merge=lfs -text
|
| 2985 |
results/versatile_diffusion/subj01/roi/2.png filter=lfs diff=lfs merge=lfs -text
|
| 2986 |
results/versatile_diffusion/subj01/roi/3.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 2984 |
results/versatile_diffusion/subj01/roi/12.png filter=lfs diff=lfs merge=lfs -text
|
| 2985 |
results/versatile_diffusion/subj01/roi/2.png filter=lfs diff=lfs merge=lfs -text
|
| 2986 |
results/versatile_diffusion/subj01/roi/3.png filter=lfs diff=lfs merge=lfs -text
|
| 2987 |
+
results/versatile_diffusion/subj01/roi/4.png filter=lfs diff=lfs merge=lfs -text
|
results/versatile_diffusion/subj01/roi/4.png
ADDED
|
Git LFS Details
|
versatile_diffusion/lib/__pycache__/log_service.cpython-38.pyc
ADDED
|
Binary file (4.96 kB). View file
|
|
|
versatile_diffusion/lib/data_factory/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (535 Bytes). View file
|
|
|
versatile_diffusion/lib/data_factory/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (495 Bytes). View file
|
|
|
versatile_diffusion/lib/data_factory/common/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .ds_base import ds_base, collate, register as regdataset
|
| 2 |
+
from .ds_loader import pre_loader_checkings, register as regloader
|
| 3 |
+
from .ds_transform import TBase, have, register as regtrans
|
| 4 |
+
from .ds_estimator import register as regestmat
|
| 5 |
+
from .ds_formatter import register as regformat
|
| 6 |
+
from .ds_sampler import register as regsampler
|
versatile_diffusion/lib/data_factory/common/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (551 Bytes). View file
|
|
|
versatile_diffusion/lib/data_factory/common/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (511 Bytes). View file
|
|
|
versatile_diffusion/lib/data_factory/common/__pycache__/ds_base.cpython-310.pyc
ADDED
|
Binary file (7.95 kB). View file
|
|
|
versatile_diffusion/lib/data_factory/common/__pycache__/ds_base.cpython-38.pyc
ADDED
|
Binary file (7.92 kB). View file
|
|
|
versatile_diffusion/lib/data_factory/common/__pycache__/ds_estimator.cpython-310.pyc
ADDED
|
Binary file (3.46 kB). View file
|
|
|
versatile_diffusion/lib/data_factory/common/__pycache__/ds_estimator.cpython-38.pyc
ADDED
|
Binary file (3.4 kB). View file
|
|
|
versatile_diffusion/lib/data_factory/common/__pycache__/ds_formatter.cpython-310.pyc
ADDED
|
Binary file (1.68 kB). View file
|
|
|
versatile_diffusion/lib/data_factory/common/__pycache__/ds_formatter.cpython-38.pyc
ADDED
|
Binary file (1.63 kB). View file
|
|
|
versatile_diffusion/lib/data_factory/common/__pycache__/ds_loader.cpython-310.pyc
ADDED
|
Binary file (3.27 kB). View file
|
|
|
versatile_diffusion/lib/data_factory/common/__pycache__/ds_loader.cpython-38.pyc
ADDED
|
Binary file (3.25 kB). View file
|
|
|
versatile_diffusion/lib/data_factory/common/__pycache__/ds_sampler.cpython-310.pyc
ADDED
|
Binary file (8.97 kB). View file
|
|
|
versatile_diffusion/lib/data_factory/common/__pycache__/ds_sampler.cpython-38.pyc
ADDED
|
Binary file (8.92 kB). View file
|
|
|
versatile_diffusion/lib/data_factory/common/__pycache__/ds_transform.cpython-310.pyc
ADDED
|
Binary file (5.43 kB). View file
|
|
|
versatile_diffusion/lib/data_factory/common/__pycache__/ds_transform.cpython-38.pyc
ADDED
|
Binary file (5.41 kB). View file
|
|
|
versatile_diffusion/lib/data_factory/common/ds_base.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import os.path as osp
|
| 3 |
+
import numpy as np
|
| 4 |
+
import numpy.random as npr
|
| 5 |
+
import torch
|
| 6 |
+
import torch.distributed as dist
|
| 7 |
+
import torchvision
|
| 8 |
+
import copy
|
| 9 |
+
import itertools
|
| 10 |
+
|
| 11 |
+
from ... import sync
|
| 12 |
+
from ...cfg_holder import cfg_unique_holder as cfguh
|
| 13 |
+
from ...log_service import print_log
|
| 14 |
+
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
from multiprocessing import shared_memory
|
| 17 |
+
|
| 18 |
+
# import multiprocessing
|
| 19 |
+
# if hasattr(multiprocessing, "shared_memory"):
|
| 20 |
+
# from multiprocessing import shared_memory
|
| 21 |
+
# else:
|
| 22 |
+
# # workaround for single gpu inference on colab
|
| 23 |
+
# shared_memory = None
|
| 24 |
+
|
| 25 |
+
import pickle
|
| 26 |
+
import hashlib
|
| 27 |
+
import random
|
| 28 |
+
|
| 29 |
+
class ds_base(torch.utils.data.Dataset):
|
| 30 |
+
def __init__(self,
|
| 31 |
+
cfg,
|
| 32 |
+
loader = None,
|
| 33 |
+
estimator = None,
|
| 34 |
+
transforms = None,
|
| 35 |
+
formatter = None):
|
| 36 |
+
|
| 37 |
+
self.cfg = cfg
|
| 38 |
+
self.load_info = None
|
| 39 |
+
self.init_load_info()
|
| 40 |
+
self.loader = loader
|
| 41 |
+
self.transforms = transforms
|
| 42 |
+
self.formatter = formatter
|
| 43 |
+
|
| 44 |
+
if self.load_info is not None:
|
| 45 |
+
load_info_order_by = getattr(self.cfg, 'load_info_order_by', 'default')
|
| 46 |
+
if load_info_order_by == 'default':
|
| 47 |
+
self.load_info = sorted(self.load_info, key=lambda x:x['unique_id'])
|
| 48 |
+
else:
|
| 49 |
+
try:
|
| 50 |
+
load_info_order_by, reverse = load_info_order_by.split('|')
|
| 51 |
+
reverse = reverse == 'reverse'
|
| 52 |
+
except:
|
| 53 |
+
reverse = False
|
| 54 |
+
self.load_info = sorted(
|
| 55 |
+
self.load_info, key=lambda x:x[load_info_order_by], reverse=reverse)
|
| 56 |
+
|
| 57 |
+
load_info_add_idx = getattr(self.cfg, 'load_info_add_idx', True)
|
| 58 |
+
if (self.load_info is not None) and load_info_add_idx:
|
| 59 |
+
for idx, info in enumerate(self.load_info):
|
| 60 |
+
info['idx'] = idx
|
| 61 |
+
|
| 62 |
+
if estimator is not None:
|
| 63 |
+
self.load_info = estimator(self.load_info)
|
| 64 |
+
|
| 65 |
+
self.try_sample = getattr(self.cfg, 'try_sample', None)
|
| 66 |
+
if self.try_sample is not None:
|
| 67 |
+
try:
|
| 68 |
+
start, end = self.try_sample
|
| 69 |
+
except:
|
| 70 |
+
start, end = 0, self.try_sample
|
| 71 |
+
self.load_info = self.load_info[start:end]
|
| 72 |
+
|
| 73 |
+
self.repeat = getattr(self.cfg, 'repeat', 1)
|
| 74 |
+
|
| 75 |
+
pick = getattr(self.cfg, 'pick', None)
|
| 76 |
+
if pick is not None:
|
| 77 |
+
self.load_info = [i for i in self.load_info if i['filename'] in pick]
|
| 78 |
+
|
| 79 |
+
#########
|
| 80 |
+
# cache #
|
| 81 |
+
#########
|
| 82 |
+
|
| 83 |
+
self.cache_sm = getattr(self.cfg, 'cache_sm', False)
|
| 84 |
+
self.cache_cnt = 0
|
| 85 |
+
if self.cache_sm:
|
| 86 |
+
self.cache_pct = getattr(self.cfg, 'cache_pct', 0)
|
| 87 |
+
cache_unique_id = sync.nodewise_sync().random_sync_id()
|
| 88 |
+
self.cache_unique_id = hashlib.sha256(pickle.dumps(cache_unique_id)).hexdigest()
|
| 89 |
+
self.__cache__(self.cache_pct)
|
| 90 |
+
|
| 91 |
+
#######
|
| 92 |
+
# log #
|
| 93 |
+
#######
|
| 94 |
+
|
| 95 |
+
if self.load_info is not None:
|
| 96 |
+
console_info = '{}: '.format(self.__class__.__name__)
|
| 97 |
+
console_info += 'total {} unique images, '.format(len(self.load_info))
|
| 98 |
+
console_info += 'total {} unique sample. Cached {}. Repeat {} times.'.format(
|
| 99 |
+
len(self.load_info), self.cache_cnt, self.repeat)
|
| 100 |
+
else:
|
| 101 |
+
console_info = '{}: load_info not ready.'.format(self.__class__.__name__)
|
| 102 |
+
print_log(console_info)
|
| 103 |
+
|
| 104 |
+
def init_load_info(self):
|
| 105 |
+
# implement by sub class
|
| 106 |
+
pass
|
| 107 |
+
|
| 108 |
+
def __len__(self):
|
| 109 |
+
return len(self.load_info)*self.repeat
|
| 110 |
+
|
| 111 |
+
def __cache__(self, pct):
|
| 112 |
+
if pct == 0:
|
| 113 |
+
self.cache_cnt = 0
|
| 114 |
+
return
|
| 115 |
+
self.cache_cnt = int(len(self.load_info)*pct)
|
| 116 |
+
if not self.cache_sm:
|
| 117 |
+
for i in range(self.cache_cnt):
|
| 118 |
+
self.load_info[i] = self.loader(self.load_info[i])
|
| 119 |
+
return
|
| 120 |
+
|
| 121 |
+
for i in range(self.cache_cnt):
|
| 122 |
+
shm_name = str(self.load_info[i]['unique_id']) + '_' + self.cache_unique_id
|
| 123 |
+
if i % self.local_world_size == self.local_rank:
|
| 124 |
+
data = pickle.dumps(self.loader(self.load_info[i]))
|
| 125 |
+
datan = len(data)
|
| 126 |
+
# self.print_smname_to_file(shm_name)
|
| 127 |
+
shm = shared_memory.SharedMemory(
|
| 128 |
+
name=shm_name, create=True, size=datan)
|
| 129 |
+
shm.buf[0:datan] = data[0:datan]
|
| 130 |
+
shm.close()
|
| 131 |
+
self.load_info[i] = shm_name
|
| 132 |
+
else:
|
| 133 |
+
self.load_info[i] = shm_name
|
| 134 |
+
dist.barrier()
|
| 135 |
+
|
| 136 |
+
def __getitem__(self, idx):
|
| 137 |
+
idx = idx%len(self.load_info)
|
| 138 |
+
# element = copy.deepcopy(self.load_info[idx])
|
| 139 |
+
|
| 140 |
+
# 0730 try shared memory
|
| 141 |
+
element = copy.deepcopy(self.load_info[idx])
|
| 142 |
+
if isinstance(element, str):
|
| 143 |
+
shm = shared_memory.SharedMemory(name=element)
|
| 144 |
+
element = pickle.loads(shm.buf)
|
| 145 |
+
shm.close()
|
| 146 |
+
else:
|
| 147 |
+
element = copy.deepcopy(element)
|
| 148 |
+
element['load_info_ptr'] = self.load_info
|
| 149 |
+
|
| 150 |
+
if idx >= self.cache_cnt:
|
| 151 |
+
element = self.loader(element)
|
| 152 |
+
if self.transforms is not None:
|
| 153 |
+
element = self.transforms(element)
|
| 154 |
+
if self.formatter is not None:
|
| 155 |
+
return self.formatter(element)
|
| 156 |
+
else:
|
| 157 |
+
return element
|
| 158 |
+
|
| 159 |
+
# 0730 try shared memory
|
| 160 |
+
def __del__(self):
|
| 161 |
+
# Clean the shared memory
|
| 162 |
+
for infoi in self.load_info:
|
| 163 |
+
if isinstance(infoi, str) and (self.local_rank==0):
|
| 164 |
+
shm = shared_memory.SharedMemory(name=infoi)
|
| 165 |
+
shm.close()
|
| 166 |
+
shm.unlink()
|
| 167 |
+
|
| 168 |
+
def print_smname_to_file(self, smname):
|
| 169 |
+
try:
|
| 170 |
+
log_file = cfguh().cfg.train.log_file
|
| 171 |
+
except:
|
| 172 |
+
try:
|
| 173 |
+
log_file = cfguh().cfg.eval.log_file
|
| 174 |
+
except:
|
| 175 |
+
raise ValueError
|
| 176 |
+
# a trick to use the log_file path
|
| 177 |
+
sm_file = log_file.replace('.log', '.smname')
|
| 178 |
+
with open(sm_file, 'a') as f:
|
| 179 |
+
f.write(smname + '\n')
|
| 180 |
+
|
| 181 |
+
def singleton(class_):
|
| 182 |
+
instances = {}
|
| 183 |
+
def getinstance(*args, **kwargs):
|
| 184 |
+
if class_ not in instances:
|
| 185 |
+
instances[class_] = class_(*args, **kwargs)
|
| 186 |
+
return instances[class_]
|
| 187 |
+
return getinstance
|
| 188 |
+
|
| 189 |
+
from .ds_loader import get_loader
|
| 190 |
+
from .ds_transform import get_transform
|
| 191 |
+
from .ds_estimator import get_estimator
|
| 192 |
+
from .ds_formatter import get_formatter
|
| 193 |
+
|
| 194 |
+
@singleton
|
| 195 |
+
class get_dataset(object):
|
| 196 |
+
def __init__(self):
|
| 197 |
+
self.dataset = {}
|
| 198 |
+
|
| 199 |
+
def register(self, ds):
|
| 200 |
+
self.dataset[ds.__name__] = ds
|
| 201 |
+
|
| 202 |
+
def __call__(self, cfg):
|
| 203 |
+
if cfg is None:
|
| 204 |
+
return None
|
| 205 |
+
t = cfg.type
|
| 206 |
+
if t is None:
|
| 207 |
+
return None
|
| 208 |
+
elif t in ['laion2b', 'laion2b_dummy',
|
| 209 |
+
'laion2b_webdataset',
|
| 210 |
+
'laion2b_webdataset_sdofficial', ]:
|
| 211 |
+
from .. import ds_laion2b
|
| 212 |
+
elif t in ['coyo', 'coyo_dummy',
|
| 213 |
+
'coyo_webdataset', ]:
|
| 214 |
+
from .. import ds_coyo_webdataset
|
| 215 |
+
elif t in ['laionart', 'laionart_dummy',
|
| 216 |
+
'laionart_webdataset', ]:
|
| 217 |
+
from .. import ds_laionart
|
| 218 |
+
elif t in ['celeba']:
|
| 219 |
+
from .. import ds_celeba
|
| 220 |
+
elif t in ['div2k']:
|
| 221 |
+
from .. import ds_div2k
|
| 222 |
+
elif t in ['pafc']:
|
| 223 |
+
from .. import ds_pafc
|
| 224 |
+
elif t in ['coco_caption']:
|
| 225 |
+
from .. import ds_coco
|
| 226 |
+
else:
|
| 227 |
+
raise ValueError
|
| 228 |
+
|
| 229 |
+
loader = get_loader() (cfg.get('loader' , None))
|
| 230 |
+
transform = get_transform()(cfg.get('transform', None))
|
| 231 |
+
estimator = get_estimator()(cfg.get('estimator', None))
|
| 232 |
+
formatter = get_formatter()(cfg.get('formatter', None))
|
| 233 |
+
|
| 234 |
+
return self.dataset[t](
|
| 235 |
+
cfg, loader, estimator,
|
| 236 |
+
transform, formatter)
|
| 237 |
+
|
| 238 |
+
def register():
|
| 239 |
+
def wrapper(class_):
|
| 240 |
+
get_dataset().register(class_)
|
| 241 |
+
return class_
|
| 242 |
+
return wrapper
|
| 243 |
+
|
| 244 |
+
# some other helpers
|
| 245 |
+
|
| 246 |
+
class collate(object):
|
| 247 |
+
"""
|
| 248 |
+
Modified from torch.utils.data._utils.collate
|
| 249 |
+
It handle list different from the default.
|
| 250 |
+
List collate just by append each other.
|
| 251 |
+
"""
|
| 252 |
+
def __init__(self):
|
| 253 |
+
self.default_collate = \
|
| 254 |
+
torch.utils.data._utils.collate.default_collate
|
| 255 |
+
|
| 256 |
+
def __call__(self, batch):
|
| 257 |
+
"""
|
| 258 |
+
Args:
|
| 259 |
+
batch: [data, data] -or- [(data1, data2, ...), (data1, data2, ...)]
|
| 260 |
+
This function will not be used as induction function
|
| 261 |
+
"""
|
| 262 |
+
elem = batch[0]
|
| 263 |
+
if not (elem, (tuple, list)):
|
| 264 |
+
return self.default_collate(batch)
|
| 265 |
+
|
| 266 |
+
rv = []
|
| 267 |
+
# transposed
|
| 268 |
+
for i in zip(*batch):
|
| 269 |
+
if isinstance(i[0], list):
|
| 270 |
+
if len(i[0]) != 1:
|
| 271 |
+
raise ValueError
|
| 272 |
+
try:
|
| 273 |
+
i = [[self.default_collate(ii).squeeze(0)] for ii in i]
|
| 274 |
+
except:
|
| 275 |
+
pass
|
| 276 |
+
rvi = list(itertools.chain.from_iterable(i))
|
| 277 |
+
rv.append(rvi) # list concat
|
| 278 |
+
else:
|
| 279 |
+
rv.append(self.default_collate(i))
|
| 280 |
+
return rv
|
versatile_diffusion/lib/data_factory/common/ds_estimator.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import numpy as np
|
| 3 |
+
import numpy.random as npr
|
| 4 |
+
import PIL
|
| 5 |
+
import cv2
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torchvision
|
| 9 |
+
import xml.etree.ElementTree as ET
|
| 10 |
+
import json
|
| 11 |
+
import copy
|
| 12 |
+
import math
|
| 13 |
+
|
| 14 |
+
def singleton(class_):
|
| 15 |
+
instances = {}
|
| 16 |
+
def getinstance(*args, **kwargs):
|
| 17 |
+
if class_ not in instances:
|
| 18 |
+
instances[class_] = class_(*args, **kwargs)
|
| 19 |
+
return instances[class_]
|
| 20 |
+
return getinstance
|
| 21 |
+
|
| 22 |
+
@singleton
|
| 23 |
+
class get_estimator(object):
|
| 24 |
+
def __init__(self):
|
| 25 |
+
self.estimator = {}
|
| 26 |
+
|
| 27 |
+
def register(self, estimf):
|
| 28 |
+
self.estimator[estimf.__name__] = estimf
|
| 29 |
+
|
| 30 |
+
def __call__(self, cfg):
|
| 31 |
+
if cfg is None:
|
| 32 |
+
return None
|
| 33 |
+
t = cfg.type
|
| 34 |
+
return self.estimator[t](**cfg.args)
|
| 35 |
+
|
| 36 |
+
def register():
|
| 37 |
+
def wrapper(class_):
|
| 38 |
+
get_estimator().register(class_)
|
| 39 |
+
return class_
|
| 40 |
+
return wrapper
|
| 41 |
+
|
| 42 |
+
@register()
|
| 43 |
+
class PickFileEstimator(object):
|
| 44 |
+
"""
|
| 45 |
+
This is an estimator that filter load_info
|
| 46 |
+
using the provided filelist
|
| 47 |
+
"""
|
| 48 |
+
def __init__(self,
|
| 49 |
+
filelist = None,
|
| 50 |
+
repeat_n = 1):
|
| 51 |
+
"""
|
| 52 |
+
Args:
|
| 53 |
+
filelist: a list of string gives the name of images
|
| 54 |
+
we would like to visualize, evaluate or train.
|
| 55 |
+
repeat_n: int, times these images will be repeated
|
| 56 |
+
"""
|
| 57 |
+
self.filelist = filelist
|
| 58 |
+
self.repeat_n = repeat_n
|
| 59 |
+
|
| 60 |
+
def __call__(self, load_info):
|
| 61 |
+
load_info_new = []
|
| 62 |
+
for info in load_info:
|
| 63 |
+
if os.path.basename(info['image_path']).split('.')[0] in self.filelist:
|
| 64 |
+
load_info_new.append(info)
|
| 65 |
+
return load_info_new * self.repeat_n
|
| 66 |
+
|
| 67 |
+
@register()
|
| 68 |
+
class PickIndexEstimator(object):
|
| 69 |
+
"""
|
| 70 |
+
This is an estimator that filter load_info
|
| 71 |
+
using the provided indices
|
| 72 |
+
"""
|
| 73 |
+
def __init__(self,
|
| 74 |
+
indexlist = None,
|
| 75 |
+
**kwargs):
|
| 76 |
+
"""
|
| 77 |
+
Args:
|
| 78 |
+
indexlist: [] of int.
|
| 79 |
+
the indices to be filtered out.
|
| 80 |
+
"""
|
| 81 |
+
self.indexlist = indexlist
|
| 82 |
+
|
| 83 |
+
def __call__(self, load_info):
|
| 84 |
+
load_info_new = [load_info[i] for i in self.indexlist]
|
| 85 |
+
return load_info_new
|
versatile_diffusion/lib/data_factory/common/ds_formatter.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import os.path as osp
|
| 3 |
+
import numpy as np
|
| 4 |
+
import numpy.random as npr
|
| 5 |
+
import torch
|
| 6 |
+
import cv2
|
| 7 |
+
import scipy.ndimage
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import copy
|
| 10 |
+
import gc
|
| 11 |
+
import itertools
|
| 12 |
+
|
| 13 |
+
def singleton(class_):
|
| 14 |
+
instances = {}
|
| 15 |
+
def getinstance(*args, **kwargs):
|
| 16 |
+
if class_ not in instances:
|
| 17 |
+
instances[class_] = class_(*args, **kwargs)
|
| 18 |
+
return instances[class_]
|
| 19 |
+
return getinstance
|
| 20 |
+
|
| 21 |
+
@singleton
|
| 22 |
+
class get_formatter(object):
|
| 23 |
+
def __init__(self):
|
| 24 |
+
self.formatter = {}
|
| 25 |
+
|
| 26 |
+
def register(self, formatf):
|
| 27 |
+
self.formatter[formatf.__name__] = formatf
|
| 28 |
+
|
| 29 |
+
def __call__(self, cfg):
|
| 30 |
+
if cfg is None:
|
| 31 |
+
return None
|
| 32 |
+
t = cfg.type
|
| 33 |
+
return self.formatter[t](**cfg.args)
|
| 34 |
+
|
| 35 |
+
def register():
|
| 36 |
+
def wrapper(class_):
|
| 37 |
+
get_formatter().register(class_)
|
| 38 |
+
return class_
|
| 39 |
+
return wrapper
|
versatile_diffusion/lib/data_factory/common/ds_loader.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import numpy as np
|
| 3 |
+
import numpy.random as npr
|
| 4 |
+
import PIL
|
| 5 |
+
import cv2
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torchvision
|
| 9 |
+
import xml.etree.ElementTree as ET
|
| 10 |
+
import json
|
| 11 |
+
import copy
|
| 12 |
+
|
| 13 |
+
from ...cfg_holder import cfg_unique_holder as cfguh
|
| 14 |
+
|
| 15 |
+
def singleton(class_):
|
| 16 |
+
instances = {}
|
| 17 |
+
def getinstance(*args, **kwargs):
|
| 18 |
+
if class_ not in instances:
|
| 19 |
+
instances[class_] = class_(*args, **kwargs)
|
| 20 |
+
return instances[class_]
|
| 21 |
+
return getinstance
|
| 22 |
+
|
| 23 |
+
@singleton
|
| 24 |
+
class get_loader(object):
|
| 25 |
+
def __init__(self):
|
| 26 |
+
self.loader = {}
|
| 27 |
+
|
| 28 |
+
def register(self, loadf):
|
| 29 |
+
self.loader[loadf.__name__] = loadf
|
| 30 |
+
|
| 31 |
+
def __call__(self, cfg):
|
| 32 |
+
if cfg is None:
|
| 33 |
+
return None
|
| 34 |
+
if isinstance(cfg, list):
|
| 35 |
+
loader = []
|
| 36 |
+
for ci in cfg:
|
| 37 |
+
t = ci.type
|
| 38 |
+
loader.append(self.loader[t](**ci.args))
|
| 39 |
+
return compose(loader)
|
| 40 |
+
t = cfg.type
|
| 41 |
+
return self.loader[t](**cfg.args)
|
| 42 |
+
|
| 43 |
+
class compose(object):
|
| 44 |
+
def __init__(self, loaders):
|
| 45 |
+
self.loaders = loaders
|
| 46 |
+
|
| 47 |
+
def __call__(self, element):
|
| 48 |
+
for l in self.loaders:
|
| 49 |
+
element = l(element)
|
| 50 |
+
return element
|
| 51 |
+
|
| 52 |
+
def __getitem__(self, idx):
|
| 53 |
+
return self.loaders[idx]
|
| 54 |
+
|
| 55 |
+
def register():
|
| 56 |
+
def wrapper(class_):
|
| 57 |
+
get_loader().register(class_)
|
| 58 |
+
return class_
|
| 59 |
+
return wrapper
|
| 60 |
+
|
| 61 |
+
def pre_loader_checkings(ltype):
|
| 62 |
+
lpath = ltype+'_path'
|
| 63 |
+
# cache feature added on 20201021
|
| 64 |
+
lcache = ltype+'_cache'
|
| 65 |
+
def wrapper(func):
|
| 66 |
+
def inner(self, element):
|
| 67 |
+
if lcache in element:
|
| 68 |
+
# cache feature added on 20201021
|
| 69 |
+
data = element[lcache]
|
| 70 |
+
else:
|
| 71 |
+
if ltype in element:
|
| 72 |
+
raise ValueError
|
| 73 |
+
if lpath not in element:
|
| 74 |
+
raise ValueError
|
| 75 |
+
|
| 76 |
+
if element[lpath] is None:
|
| 77 |
+
data = None
|
| 78 |
+
else:
|
| 79 |
+
data = func(self, element[lpath], element)
|
| 80 |
+
element[ltype] = data
|
| 81 |
+
|
| 82 |
+
if ltype == 'image':
|
| 83 |
+
if isinstance(data, np.ndarray):
|
| 84 |
+
imsize = data.shape[-2:]
|
| 85 |
+
elif isinstance(data, PIL.Image.Image):
|
| 86 |
+
imsize = data.size[::-1]
|
| 87 |
+
elif isinstance(data, torch.Tensor):
|
| 88 |
+
imsize = [data.size(-2), data.size(-1)]
|
| 89 |
+
elif data is None:
|
| 90 |
+
imsize = None
|
| 91 |
+
else:
|
| 92 |
+
raise ValueError
|
| 93 |
+
element['imsize'] = imsize
|
| 94 |
+
element['imsize_current'] = copy.deepcopy(imsize)
|
| 95 |
+
return element
|
| 96 |
+
return inner
|
| 97 |
+
return wrapper
|
versatile_diffusion/lib/data_factory/common/ds_sampler.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tokenize import group
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import numpy.random as npr
|
| 5 |
+
import torch.distributed as dist
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
from ...log_service import print_log
|
| 9 |
+
from ... import sync
|
| 10 |
+
|
| 11 |
+
def singleton(class_):
|
| 12 |
+
instances = {}
|
| 13 |
+
def getinstance(*args, **kwargs):
|
| 14 |
+
if class_ not in instances:
|
| 15 |
+
instances[class_] = class_(*args, **kwargs)
|
| 16 |
+
return instances[class_]
|
| 17 |
+
return getinstance
|
| 18 |
+
|
| 19 |
+
@singleton
|
| 20 |
+
class get_sampler(object):
|
| 21 |
+
def __init__(self):
|
| 22 |
+
self.sampler = {}
|
| 23 |
+
|
| 24 |
+
def register(self, sampler):
|
| 25 |
+
self.sampler[sampler.__name__] = sampler
|
| 26 |
+
|
| 27 |
+
def __call__(self, dataset, cfg):
|
| 28 |
+
if cfg == 'default_train':
|
| 29 |
+
return GlobalDistributedSampler(dataset, shuffle=True, extend=False)
|
| 30 |
+
elif cfg == 'default_eval':
|
| 31 |
+
return GlobalDistributedSampler(dataset, shuffle=False, extend=True)
|
| 32 |
+
else:
|
| 33 |
+
t = cfg.type
|
| 34 |
+
return self.sampler[t](dataset=dataset, **cfg.args)
|
| 35 |
+
|
| 36 |
+
def register():
|
| 37 |
+
def wrapper(class_):
|
| 38 |
+
get_sampler().register(class_)
|
| 39 |
+
return class_
|
| 40 |
+
return wrapper
|
| 41 |
+
|
| 42 |
+
######################
|
| 43 |
+
# DistributedSampler #
|
| 44 |
+
######################
|
| 45 |
+
|
| 46 |
+
@register()
|
| 47 |
+
class GlobalDistributedSampler(torch.utils.data.Sampler):
|
| 48 |
+
"""
|
| 49 |
+
This is a distributed sampler that sync accross gpus and nodes.
|
| 50 |
+
"""
|
| 51 |
+
def __init__(self,
|
| 52 |
+
dataset,
|
| 53 |
+
shuffle=True,
|
| 54 |
+
extend=False,):
|
| 55 |
+
"""
|
| 56 |
+
Arguments:
|
| 57 |
+
dataset: Dataset used for sampling.
|
| 58 |
+
shuffle: If true, sampler will shuffle the indices
|
| 59 |
+
extend: If true, sampler will extend the indices that can be even distributed by ranks
|
| 60 |
+
otherwise sampler will truncate the indices to make it even.
|
| 61 |
+
"""
|
| 62 |
+
self.ddp = sync.is_ddp()
|
| 63 |
+
self.rank = sync.get_rank('global')
|
| 64 |
+
self.world_size = sync.get_world_size('global')
|
| 65 |
+
self.dataset = dataset
|
| 66 |
+
self.shuffle = shuffle
|
| 67 |
+
self.extend = extend
|
| 68 |
+
|
| 69 |
+
num_samples = len(dataset) // self.world_size
|
| 70 |
+
if extend and (len(dataset)%self.world_size != 0):
|
| 71 |
+
num_samples+=1
|
| 72 |
+
self.num_samples = num_samples
|
| 73 |
+
self.total_size = num_samples * self.world_size
|
| 74 |
+
|
| 75 |
+
def __iter__(self):
|
| 76 |
+
indices = self.get_sync_order()
|
| 77 |
+
if self.extend:
|
| 78 |
+
# extend using the front indices
|
| 79 |
+
indices = indices+indices[0:self.total_size-len(indices)]
|
| 80 |
+
else:
|
| 81 |
+
# truncate
|
| 82 |
+
indices = indices[0:self.total_size]
|
| 83 |
+
# subsample
|
| 84 |
+
indices = indices[self.rank : len(indices) : self.world_size]
|
| 85 |
+
return iter(indices)
|
| 86 |
+
|
| 87 |
+
def __len__(self):
|
| 88 |
+
return self.num_samples
|
| 89 |
+
|
| 90 |
+
def get_sync_order(self):
|
| 91 |
+
if self.shuffle:
|
| 92 |
+
indices = torch.randperm(len(self.dataset)).to(self.rank)
|
| 93 |
+
if self.ddp:
|
| 94 |
+
dist.broadcast(indices, src=0)
|
| 95 |
+
indices = indices.to('cpu').tolist()
|
| 96 |
+
else:
|
| 97 |
+
indices = list(range(len(self.dataset)))
|
| 98 |
+
print_log('Sampler : {}'.format(str(indices[0:5])) )
|
| 99 |
+
return indices
|
| 100 |
+
|
| 101 |
+
@register()
|
| 102 |
+
class LocalDistributedSampler(GlobalDistributedSampler):
|
| 103 |
+
"""
|
| 104 |
+
This is a distributed sampler that sync across gpus within the nodes.
|
| 105 |
+
But not sync across nodes.
|
| 106 |
+
"""
|
| 107 |
+
def __init__(self,
|
| 108 |
+
dataset,
|
| 109 |
+
shuffle=True,
|
| 110 |
+
extend=False,):
|
| 111 |
+
super().__init__(dataset, shuffle, extend)
|
| 112 |
+
self.rank = sync.get_rank('local')
|
| 113 |
+
self.world_size = sync.get_world_size('local')
|
| 114 |
+
|
| 115 |
+
def get_sync_order(self):
|
| 116 |
+
if self.shuffle:
|
| 117 |
+
if self.rank == 0:
|
| 118 |
+
indices = list(npr.permutation(len(self.dataset)))
|
| 119 |
+
sync.nodewise_sync().broadcast_r0(indices)
|
| 120 |
+
else:
|
| 121 |
+
indices = sync.nodewise_sync().broadcast_r0(None)
|
| 122 |
+
else:
|
| 123 |
+
indices = list(range(len(self.dataset)))
|
| 124 |
+
print_log('Sampler : {}'.format(str(indices[0:5])) )
|
| 125 |
+
return indices
|
| 126 |
+
|
| 127 |
+
############################
|
| 128 |
+
# random sample with group #
|
| 129 |
+
############################
|
| 130 |
+
# Deprecated
|
| 131 |
+
|
| 132 |
+
@register()
|
| 133 |
+
class GroupSampler(torch.utils.data.Sampler):
|
| 134 |
+
"""
|
| 135 |
+
This is a new DistributedSampler that sample all index according to group.
|
| 136 |
+
i.e.
|
| 137 |
+
if group_size=3, num_replicas=2, train mode:
|
| 138 |
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
|
| 139 |
+
==> (group) [0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10]
|
| 140 |
+
==> (distribute) process0: [3, 4, 5], (leftover [6, 7, 8, 9, 10])
|
| 141 |
+
process1: [0, 1, 2]
|
| 142 |
+
==> (group leftover) process0: [3, 4, 5], (leftover [6, 7], [8, 9], 10)
|
| 143 |
+
process1: [0, 1, 2]
|
| 144 |
+
==> (distribute) process0: [3, 4, 5], [6, 7] (remove 10)
|
| 145 |
+
process1: [0, 1, 2], [8, 9]
|
| 146 |
+
|
| 147 |
+
it will avoid_batchsize=1:
|
| 148 |
+
0, 1, 2, 3, 4, 5, 6, 7, 8,
|
| 149 |
+
==> (group) [0, 1, 2], [3, 4, 5], [6, 7, 8]
|
| 150 |
+
==> (distribute) process0: [3, 4, 5], (leftover [6, 7, 8])
|
| 151 |
+
process1: [0, 1, 2]
|
| 152 |
+
==> (group leftover) process0: [3, 4, 5], (leftover [6], [7], [8])
|
| 153 |
+
process1: [0, 1, 2]
|
| 154 |
+
==> (distribute) process0: [3, 4, 5], (remove 6, 7, 8) (because distribute make batchsize 1)
|
| 155 |
+
process1: [0, 1, 2]
|
| 156 |
+
|
| 157 |
+
if group_size=3, num_replicas=2, eval mode:
|
| 158 |
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
|
| 159 |
+
==> (extend) 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10
|
| 160 |
+
==> (group) [0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 10]
|
| 161 |
+
==> (distribute) process0: [0, 1, 2], [6, 7, 8],
|
| 162 |
+
process1: [3, 4, 5], [9, 10, 10]
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
def __init__(self,
|
| 166 |
+
dataset,
|
| 167 |
+
group_size,
|
| 168 |
+
num_replicas=None,
|
| 169 |
+
rank=None,
|
| 170 |
+
mode='train',):
|
| 171 |
+
if num_replicas is None:
|
| 172 |
+
if not dist.is_available():
|
| 173 |
+
raise ValueError
|
| 174 |
+
num_replicas = dist.get_world_size()
|
| 175 |
+
if rank is None:
|
| 176 |
+
if not dist.is_available():
|
| 177 |
+
raise ValueError
|
| 178 |
+
rank = dist.get_rank()
|
| 179 |
+
|
| 180 |
+
self.dataset = dataset
|
| 181 |
+
self.len_dataset = len(dataset)
|
| 182 |
+
self.group_size = group_size
|
| 183 |
+
self.num_replicas = num_replicas
|
| 184 |
+
self.rank = rank
|
| 185 |
+
self.mode = mode
|
| 186 |
+
len_dataset = self.len_dataset
|
| 187 |
+
|
| 188 |
+
if (len_dataset % num_replicas != 0) and (mode == 'train'):
|
| 189 |
+
# drop the non_aligned
|
| 190 |
+
aligned_indices = np.arange(len_dataset)[:-(len_dataset % num_replicas)]
|
| 191 |
+
aligned_len_dataset = aligned_indices.shape[0]
|
| 192 |
+
elif (len_dataset % num_replicas != 0) and (mode == 'eval'):
|
| 193 |
+
extend = np.array([len_dataset-1 for _ in range(num_replicas - len_dataset % num_replicas)])
|
| 194 |
+
aligned_indices = np.concatenate([range(len_dataset), extend])
|
| 195 |
+
aligned_len_dataset = aligned_indices.shape[0]
|
| 196 |
+
else:
|
| 197 |
+
aligned_indices = np.arange(len_dataset)
|
| 198 |
+
aligned_len_dataset = len_dataset
|
| 199 |
+
|
| 200 |
+
num_even_distributed_groups = aligned_len_dataset // (group_size * num_replicas)
|
| 201 |
+
num_even = num_even_distributed_groups * group_size * num_replicas
|
| 202 |
+
|
| 203 |
+
self.regular_groups = aligned_indices[0:num_even].reshape(-1, group_size)
|
| 204 |
+
self.leftover_groups = aligned_indices[num_even:].reshape(num_replicas, -1)
|
| 205 |
+
|
| 206 |
+
if self.leftover_groups.size == 0:
|
| 207 |
+
self.leftover_groups = None
|
| 208 |
+
elif (self.leftover_groups.shape[-1]==1) and (mode == 'train'):
|
| 209 |
+
# avoid bs=1
|
| 210 |
+
self.leftover_groups = None
|
| 211 |
+
|
| 212 |
+
# a urly way to modify dataset.load_info according to the grouping
|
| 213 |
+
for groupi in self.regular_groups:
|
| 214 |
+
for idx in groupi:
|
| 215 |
+
idx_lowerbd = groupi[0]
|
| 216 |
+
idx_upperbd = groupi[-1]
|
| 217 |
+
idx_reference = (idx_lowerbd+idx_upperbd)//2
|
| 218 |
+
dataset.load_info[idx]['ref_size'] = dataset.load_info[idx_reference]['image_size']
|
| 219 |
+
if self.leftover_groups is not None:
|
| 220 |
+
for groupi in self.leftover_groups:
|
| 221 |
+
for idx in groupi:
|
| 222 |
+
idx_lowerbd = groupi[0]
|
| 223 |
+
idx_upperbd = groupi[-1]
|
| 224 |
+
idx_reference = (idx_lowerbd+idx_upperbd)//2
|
| 225 |
+
dataset.load_info[idx]['ref_size'] = dataset.load_info[idx_reference]['image_size']
|
| 226 |
+
|
| 227 |
+
def concat(self, nparrays, axis=0):
|
| 228 |
+
# a helper for save concaternation
|
| 229 |
+
nparrays = [i for i in nparrays if i.size > 0]
|
| 230 |
+
return np.concatenate(nparrays, axis=axis)
|
| 231 |
+
|
| 232 |
+
def __iter__(self):
|
| 233 |
+
indices = self.get_sync_order()
|
| 234 |
+
return iter(indices)
|
| 235 |
+
|
| 236 |
+
def __len__(self):
|
| 237 |
+
return self.num_samples
|
| 238 |
+
|
| 239 |
+
def get_sync_order(self):
|
| 240 |
+
# g = torch.Generator()
|
| 241 |
+
# g.manual_seed(self.epoch)
|
| 242 |
+
|
| 243 |
+
mode = self.mode
|
| 244 |
+
rank = self.rank
|
| 245 |
+
num_replicas = self.num_replicas
|
| 246 |
+
group_size = self.group_size
|
| 247 |
+
num_groups = len(self.regular_groups)
|
| 248 |
+
|
| 249 |
+
if mode == 'train':
|
| 250 |
+
g_indices = torch.randperm(num_groups).to(rank)
|
| 251 |
+
dist.broadcast(g_indices, src=0)
|
| 252 |
+
g_indices = g_indices.to('cpu').tolist()
|
| 253 |
+
num_groups_per_rank = num_groups // num_replicas
|
| 254 |
+
groups = self.regular_groups[g_indices][num_groups_per_rank*rank : num_groups_per_rank*(rank+1)]
|
| 255 |
+
indices = groups.flatten()
|
| 256 |
+
|
| 257 |
+
if self.leftover_groups is not None:
|
| 258 |
+
leftg_indices = torch.randperm(len(self.leftover_groups)).to(rank)
|
| 259 |
+
dist.broadcast(leftg_indices, src=0)
|
| 260 |
+
leftg_indices = leftg_indices.to('cpu').tolist()
|
| 261 |
+
last = self.leftover_groups[leftg_indices][rank]
|
| 262 |
+
indices = np.concatenate([indices, last], axis=0)
|
| 263 |
+
elif mode == 'eval':
|
| 264 |
+
groups = self.regular_groups.reshape(-1, num_replicas, group_size)[:, rank, :]
|
| 265 |
+
indices = groups.flatten()
|
| 266 |
+
if self.leftover_groups is not None:
|
| 267 |
+
last = self.leftover_groups[rank]
|
| 268 |
+
indices = np.concatenate([indices, last], axis=0)
|
| 269 |
+
else:
|
| 270 |
+
raise ValueError
|
| 271 |
+
|
| 272 |
+
print_log('Sampler RANK {} : {}'.format(rank, str(indices[0:group_size+1])))
|
| 273 |
+
return indices
|
versatile_diffusion/lib/data_factory/common/ds_transform.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import numpy as np
|
| 3 |
+
import numpy.random as npr
|
| 4 |
+
import PIL
|
| 5 |
+
import cv2
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torchvision
|
| 9 |
+
import xml.etree.ElementTree as ET
|
| 10 |
+
import json
|
| 11 |
+
import copy
|
| 12 |
+
import math
|
| 13 |
+
|
| 14 |
+
def singleton(class_):
|
| 15 |
+
instances = {}
|
| 16 |
+
def getinstance(*args, **kwargs):
|
| 17 |
+
if class_ not in instances:
|
| 18 |
+
instances[class_] = class_(*args, **kwargs)
|
| 19 |
+
return instances[class_]
|
| 20 |
+
return getinstance
|
| 21 |
+
|
| 22 |
+
@singleton
|
| 23 |
+
class get_transform(object):
|
| 24 |
+
def __init__(self):
|
| 25 |
+
self.transform = {}
|
| 26 |
+
|
| 27 |
+
def register(self, transf):
|
| 28 |
+
self.transform[transf.__name__] = transf
|
| 29 |
+
|
| 30 |
+
def __call__(self, cfg):
|
| 31 |
+
if cfg is None:
|
| 32 |
+
return None
|
| 33 |
+
if isinstance(cfg, list):
|
| 34 |
+
loader = []
|
| 35 |
+
for ci in cfg:
|
| 36 |
+
t = ci.type
|
| 37 |
+
loader.append(self.transform[t](**ci.args))
|
| 38 |
+
return compose(loader)
|
| 39 |
+
t = cfg.type
|
| 40 |
+
return self.transform[t](**cfg.args)
|
| 41 |
+
|
| 42 |
+
def register():
|
| 43 |
+
def wrapper(class_):
|
| 44 |
+
get_transform().register(class_)
|
| 45 |
+
return class_
|
| 46 |
+
return wrapper
|
| 47 |
+
|
| 48 |
+
def have(must=[], may=[]):
|
| 49 |
+
"""
|
| 50 |
+
The nextgen decorator that have two list of
|
| 51 |
+
input tells what category the transform
|
| 52 |
+
will operate on.
|
| 53 |
+
Args:
|
| 54 |
+
must: [] of str,
|
| 55 |
+
the names of the items that must be included
|
| 56 |
+
inside the element.
|
| 57 |
+
If element[name] exist: do the transform
|
| 58 |
+
If element[name] is None: raise Exception.
|
| 59 |
+
If element[name] not exist: raise Exception.
|
| 60 |
+
may: [] of str,
|
| 61 |
+
the names of the items that may be contained
|
| 62 |
+
inside the element for transform.
|
| 63 |
+
If element[name] exist: do the transform
|
| 64 |
+
If element[name] is None: ignore it.
|
| 65 |
+
If element[name] not exist: ignore it.
|
| 66 |
+
"""
|
| 67 |
+
def route(self, item, e, d):
|
| 68 |
+
"""
|
| 69 |
+
Route the element to a proper function
|
| 70 |
+
for calculation.
|
| 71 |
+
Args:
|
| 72 |
+
self: object,
|
| 73 |
+
the transform functor.
|
| 74 |
+
item: str,
|
| 75 |
+
the item name of the data.
|
| 76 |
+
e: {},
|
| 77 |
+
the element
|
| 78 |
+
d: nparray, tensor or PIL.Image,
|
| 79 |
+
the data to transform.
|
| 80 |
+
"""
|
| 81 |
+
if isinstance(d, np.ndarray):
|
| 82 |
+
dtype = 'nparray'
|
| 83 |
+
elif isinstance(d, torch.Tensor):
|
| 84 |
+
dtype = 'tensor'
|
| 85 |
+
elif isinstance(d, PIL.Image.Image):
|
| 86 |
+
dtype = 'pilimage'
|
| 87 |
+
else:
|
| 88 |
+
raise ValueError
|
| 89 |
+
|
| 90 |
+
# find function by order
|
| 91 |
+
f = None
|
| 92 |
+
for attrname in [
|
| 93 |
+
'exec_{}_{}'.format(item, dtype),
|
| 94 |
+
'exec_{}'.format(item),
|
| 95 |
+
'exec_{}'.format(dtype),
|
| 96 |
+
'exec']:
|
| 97 |
+
f = getattr(self, attrname, None)
|
| 98 |
+
if f is not None:
|
| 99 |
+
break
|
| 100 |
+
d, e = f(d, e)
|
| 101 |
+
e[item] = d
|
| 102 |
+
return e
|
| 103 |
+
|
| 104 |
+
def wrapper(func):
|
| 105 |
+
def inner(self, e):
|
| 106 |
+
e['imsize_previous'] = e['imsize_current']
|
| 107 |
+
imsize_tag_cnt = 0
|
| 108 |
+
imsize_tag = 'imsize_before_' + self.__class__.__name__
|
| 109 |
+
while True:
|
| 110 |
+
if imsize_tag_cnt != 0:
|
| 111 |
+
tag = imsize_tag + str(imsize_tag_cnt)
|
| 112 |
+
else:
|
| 113 |
+
tag = imsize_tag
|
| 114 |
+
if not tag in e:
|
| 115 |
+
e[tag] = e['imsize_current']
|
| 116 |
+
break
|
| 117 |
+
imsize_tag_cnt += 1
|
| 118 |
+
|
| 119 |
+
e = func(self, e)
|
| 120 |
+
# must transform list
|
| 121 |
+
for item in must:
|
| 122 |
+
try:
|
| 123 |
+
d = e[item]
|
| 124 |
+
except:
|
| 125 |
+
raise ValueError
|
| 126 |
+
if d is None:
|
| 127 |
+
raise ValueError
|
| 128 |
+
e = route(self, item, e, d)
|
| 129 |
+
# may transform list
|
| 130 |
+
for item in may:
|
| 131 |
+
try:
|
| 132 |
+
d = e[item]
|
| 133 |
+
except:
|
| 134 |
+
d = None
|
| 135 |
+
if d is not None:
|
| 136 |
+
e = route(self, item, e, d)
|
| 137 |
+
return e
|
| 138 |
+
return inner
|
| 139 |
+
return wrapper
|
| 140 |
+
|
| 141 |
+
class compose(object):
|
| 142 |
+
def __init__(self, transforms):
|
| 143 |
+
self.transforms = transforms
|
| 144 |
+
|
| 145 |
+
def __call__(self, element):
|
| 146 |
+
for t in self.transforms:
|
| 147 |
+
element = t(element)
|
| 148 |
+
return element
|
| 149 |
+
|
| 150 |
+
class TBase(object):
|
| 151 |
+
def __init__(self):
|
| 152 |
+
pass
|
| 153 |
+
|
| 154 |
+
def exec(self, data, element):
|
| 155 |
+
raise ValueError
|
| 156 |
+
|
| 157 |
+
def rand(self,
|
| 158 |
+
uid,
|
| 159 |
+
tag,
|
| 160 |
+
rand_f,
|
| 161 |
+
*args,
|
| 162 |
+
**kwargs):
|
| 163 |
+
"""
|
| 164 |
+
Args:
|
| 165 |
+
uid: string element['unique_id']
|
| 166 |
+
tag: string tells the tag uses when tracking the random number.
|
| 167 |
+
Or the tag to restore the tracked random number.
|
| 168 |
+
rand_f: the random function use to generate random number.
|
| 169 |
+
**kwargs: the argument for the given random function.
|
| 170 |
+
"""
|
| 171 |
+
# if rnduh().hdata is not None:
|
| 172 |
+
# return rnduh().get_history(uid, self.__class__.__name__, tag)
|
| 173 |
+
# if rnduh().record_path is None:
|
| 174 |
+
# return rand_f(*args, **kwargs)
|
| 175 |
+
# the special mode to create the random file.
|
| 176 |
+
d = rand_f(*args, **kwargs)
|
| 177 |
+
# rnduh().record(uid, self.__class__.__name__, tag, d)
|
| 178 |
+
return d
|
versatile_diffusion/lib/data_factory/ds_laion2b_webdataset.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import os.path as osp
|
| 3 |
+
import numpy as np
|
| 4 |
+
import numpy.random as npr
|
| 5 |
+
import torch
|
| 6 |
+
import torch.distributed as dist
|
| 7 |
+
import torchvision.transforms as tvtrans
|
| 8 |
+
import PIL.Image
|
| 9 |
+
PIL.Image.MAX_IMAGE_PIXELS = None
|
| 10 |
+
import math
|
| 11 |
+
import json
|
| 12 |
+
import copy
|
| 13 |
+
import pickle
|
| 14 |
+
from multiprocessing import shared_memory
|
| 15 |
+
import time
|
| 16 |
+
from .common import *
|
| 17 |
+
from ..log_service import print_log
|
| 18 |
+
|
| 19 |
+
from lib import visual_service as vis
|
| 20 |
+
from .. import sync
|
| 21 |
+
|
| 22 |
+
import webdataset as wds
|
| 23 |
+
|
| 24 |
+
###################################################
|
| 25 |
+
# this is a special ds that use webdataset mainly #
|
| 26 |
+
###################################################
|
| 27 |
+
|
| 28 |
+
@regdataset()
|
| 29 |
+
class laion2b_dummy(ds_base):
|
| 30 |
+
def init_load_info(self):
|
| 31 |
+
self.load_info = []
|
| 32 |
+
|
| 33 |
+
@regdataset()
|
| 34 |
+
class laion2b_webdataset(ds_base):
|
| 35 |
+
def init_load_info(self):
|
| 36 |
+
self.load_info = []
|
| 37 |
+
|
| 38 |
+
def make_loader(self, batch_size, num_workers, train=True):
|
| 39 |
+
cfg = self.cfg
|
| 40 |
+
self.root_dir = cfg.root_dir
|
| 41 |
+
|
| 42 |
+
interpolation_mode = tvtrans.InterpolationMode.BICUBIC
|
| 43 |
+
if train:
|
| 44 |
+
trans = [
|
| 45 |
+
tvtrans.Resize(cfg.scale, interpolation=interpolation_mode),
|
| 46 |
+
tvtrans.RandomCrop(cfg.scale),
|
| 47 |
+
tvtrans.ToTensor(),]
|
| 48 |
+
else:
|
| 49 |
+
trans = [
|
| 50 |
+
tvtrans.Resize(cfg.scale, interpolation=interpolation_mode),
|
| 51 |
+
tvtrans.CenterCrop(cfg.scale),
|
| 52 |
+
tvtrans.ToTensor(),]
|
| 53 |
+
|
| 54 |
+
trans = tvtrans.Compose(trans)
|
| 55 |
+
|
| 56 |
+
trans_dict = {'jpg': trans}
|
| 57 |
+
postprocess = customized_postprocess
|
| 58 |
+
|
| 59 |
+
shuffle = cfg.get('shuffle', 10000)
|
| 60 |
+
shardshuffle = shuffle > 0
|
| 61 |
+
node_world_size = sync.get_world_size('node')
|
| 62 |
+
nodesplitter = wds.shardlists.split_by_node \
|
| 63 |
+
if node_world_size==1 else wds.shardlists.single_node_only
|
| 64 |
+
|
| 65 |
+
tars = [osp.join(self.root_dir, 'data', i) for i in os.listdir(osp.join(self.root_dir, 'data'))
|
| 66 |
+
if osp.splitext(i)[1]=='.tar']
|
| 67 |
+
tars = sorted(tars)
|
| 68 |
+
|
| 69 |
+
dset = wds.WebDataset(
|
| 70 |
+
tars,
|
| 71 |
+
nodesplitter=nodesplitter,
|
| 72 |
+
shardshuffle=shardshuffle,
|
| 73 |
+
handler=wds.warn_and_continue).repeat().shuffle(shuffle)
|
| 74 |
+
|
| 75 |
+
print_log(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.')
|
| 76 |
+
self.min_size = cfg.get('min_size', None)
|
| 77 |
+
self.max_pwatermark = cfg.get('max_pwatermark', None)
|
| 78 |
+
dset = (dset
|
| 79 |
+
.select(self.filter_keys)
|
| 80 |
+
.decode('pil', handler=wds.warn_and_continue)
|
| 81 |
+
.select(self.filter_size)
|
| 82 |
+
.map_dict(**trans_dict, handler=wds.warn_and_continue))
|
| 83 |
+
|
| 84 |
+
if postprocess is not None:
|
| 85 |
+
dset = dset.map(postprocess)
|
| 86 |
+
|
| 87 |
+
dset.batched(batch_size, partial=False)
|
| 88 |
+
|
| 89 |
+
loader = wds.WebLoader(
|
| 90 |
+
dset,
|
| 91 |
+
batch_size=None,
|
| 92 |
+
shuffle=False,
|
| 93 |
+
num_workers=num_workers, )
|
| 94 |
+
return loader
|
| 95 |
+
|
| 96 |
+
def filter_size(self, x):
|
| 97 |
+
try:
|
| 98 |
+
valid = True
|
| 99 |
+
if self.min_size is not None and self.min_size > 1:
|
| 100 |
+
try:
|
| 101 |
+
valid = valid and x['json']['original_width'] >= self.min_size and \
|
| 102 |
+
x['json']['original_height'] >= self.min_size
|
| 103 |
+
except Exception:
|
| 104 |
+
valid = False
|
| 105 |
+
if self.max_pwatermark is not None and self.max_pwatermark < 1.0:
|
| 106 |
+
try:
|
| 107 |
+
valid = valid and x['json']['pwatermark'] <= self.max_pwatermark
|
| 108 |
+
except Exception:
|
| 109 |
+
valid = False
|
| 110 |
+
return valid
|
| 111 |
+
except Exception:
|
| 112 |
+
return False
|
| 113 |
+
|
| 114 |
+
def filter_keys(self, x):
|
| 115 |
+
try:
|
| 116 |
+
return ("jpg" in x) and ("txt" in x)
|
| 117 |
+
except Exception:
|
| 118 |
+
return False
|
| 119 |
+
|
| 120 |
+
def train_dataloader(self):
|
| 121 |
+
return self.make_loader(self.train)
|
| 122 |
+
|
| 123 |
+
def val_dataloader(self):
|
| 124 |
+
return self.make_loader(self.validation, train=False)
|
| 125 |
+
|
| 126 |
+
def test_dataloader(self):
|
| 127 |
+
return self.make_loader(self.test, train=False)
|
| 128 |
+
|
| 129 |
+
def customized_postprocess(element):
|
| 130 |
+
return element['jpg']*2-1, element['txt'], element['__key__']
|
| 131 |
+
|
| 132 |
+
def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
|
| 133 |
+
keys = set.intersection(*[set(sample.keys()) for sample in samples])
|
| 134 |
+
batched = {key: [] for key in keys}
|
| 135 |
+
|
| 136 |
+
for s in samples:
|
| 137 |
+
[batched[key].append(s[key]) for key in batched]
|
| 138 |
+
|
| 139 |
+
result = {}
|
| 140 |
+
for key in batched:
|
| 141 |
+
if isinstance(batched[key][0], (int, float)):
|
| 142 |
+
if combine_scalars:
|
| 143 |
+
result[key] = np.array(list(batched[key]))
|
| 144 |
+
elif isinstance(batched[key][0], torch.Tensor):
|
| 145 |
+
if combine_tensors:
|
| 146 |
+
result[key] = torch.stack(list(batched[key]))
|
| 147 |
+
elif isinstance(batched[key][0], np.ndarray):
|
| 148 |
+
if combine_tensors:
|
| 149 |
+
result[key] = np.array(list(batched[key]))
|
| 150 |
+
else:
|
| 151 |
+
result[key] = list(batched[key])
|
| 152 |
+
return result
|
| 153 |
+
|
| 154 |
+
###################
|
| 155 |
+
# for sd official #
|
| 156 |
+
###################
|
| 157 |
+
|
| 158 |
+
def customized_postprocess_sdofficial(element):
|
| 159 |
+
return {
|
| 160 |
+
'jpg': element['jpg']*2-1,
|
| 161 |
+
'txt': element['txt'], }
|
| 162 |
+
|
| 163 |
+
@regdataset()
|
| 164 |
+
class laion2b_webdataset_sdofficial(laion2b_webdataset):
|
| 165 |
+
def make_loader(self, batch_size, num_workers, train=True):
|
| 166 |
+
cfg = self.cfg
|
| 167 |
+
self.root_dir = cfg.root_dir
|
| 168 |
+
|
| 169 |
+
interpolation_mode = tvtrans.InterpolationMode.BICUBIC
|
| 170 |
+
if train:
|
| 171 |
+
trans = [
|
| 172 |
+
tvtrans.Resize(cfg.scale, interpolation=interpolation_mode),
|
| 173 |
+
tvtrans.RandomCrop(cfg.scale),
|
| 174 |
+
tvtrans.ToTensor(),]
|
| 175 |
+
else:
|
| 176 |
+
trans = [
|
| 177 |
+
tvtrans.Resize(cfg.scale, interpolation=interpolation_mode),
|
| 178 |
+
tvtrans.CenterCrop(cfg.scale),
|
| 179 |
+
tvtrans.ToTensor(),]
|
| 180 |
+
|
| 181 |
+
trans = tvtrans.Compose(trans)
|
| 182 |
+
|
| 183 |
+
trans_dict = {'jpg': trans}
|
| 184 |
+
postprocess = customized_postprocess_sdofficial
|
| 185 |
+
|
| 186 |
+
shuffle = 10000
|
| 187 |
+
shardshuffle = shuffle > 0
|
| 188 |
+
node_world_size = 1
|
| 189 |
+
nodesplitter = wds.shardlists.split_by_node \
|
| 190 |
+
if node_world_size==1 else wds.shardlists.single_node_only
|
| 191 |
+
|
| 192 |
+
tars = [osp.join(self.root_dir, 'data', i) for i in os.listdir(osp.join(self.root_dir, 'data'))
|
| 193 |
+
if osp.splitext(i)[1]=='.tar']
|
| 194 |
+
tars = sorted(tars)
|
| 195 |
+
|
| 196 |
+
dset = wds.WebDataset(
|
| 197 |
+
tars,
|
| 198 |
+
nodesplitter=nodesplitter,
|
| 199 |
+
shardshuffle=shardshuffle,
|
| 200 |
+
handler=wds.warn_and_continue).repeat().shuffle(shuffle)
|
| 201 |
+
|
| 202 |
+
print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.')
|
| 203 |
+
self.min_size = cfg.get('min_size', None)
|
| 204 |
+
self.max_pwatermark = cfg.get('max_pwatermark', None)
|
| 205 |
+
dset = (dset
|
| 206 |
+
.select(self.filter_keys)
|
| 207 |
+
.decode('pil', handler=wds.warn_and_continue)
|
| 208 |
+
.select(self.filter_size)
|
| 209 |
+
.map_dict(**trans_dict, handler=wds.warn_and_continue))
|
| 210 |
+
|
| 211 |
+
if postprocess is not None:
|
| 212 |
+
dset = dset.map(postprocess)
|
| 213 |
+
|
| 214 |
+
dset.batched(batch_size, partial=False, collation_fn=dict_collation_fn)
|
| 215 |
+
|
| 216 |
+
loader = wds.WebLoader(
|
| 217 |
+
dset,
|
| 218 |
+
batch_size=None,
|
| 219 |
+
shuffle=False,
|
| 220 |
+
num_workers=num_workers, )
|
| 221 |
+
return loader
|
versatile_diffusion/lib/evaluator/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .eva_base import get_evaluator
|
versatile_diffusion/lib/evaluator/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (235 Bytes). View file
|
|
|
versatile_diffusion/lib/evaluator/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (195 Bytes). View file
|
|
|
versatile_diffusion/lib/evaluator/__pycache__/eva_base.cpython-310.pyc
ADDED
|
Binary file (8.67 kB). View file
|
|
|
versatile_diffusion/lib/evaluator/__pycache__/eva_base.cpython-38.pyc
ADDED
|
Binary file (8.82 kB). View file
|
|
|
versatile_diffusion/lib/evaluator/eva_base.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.distributed as dist
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import os.path as osp
|
| 6 |
+
import numpy as np
|
| 7 |
+
import cv2
|
| 8 |
+
import copy
|
| 9 |
+
import json
|
| 10 |
+
|
| 11 |
+
from ..log_service import print_log
|
| 12 |
+
|
| 13 |
+
def singleton(class_):
|
| 14 |
+
instances = {}
|
| 15 |
+
def getinstance(*args, **kwargs):
|
| 16 |
+
if class_ not in instances:
|
| 17 |
+
instances[class_] = class_(*args, **kwargs)
|
| 18 |
+
return instances[class_]
|
| 19 |
+
return getinstance
|
| 20 |
+
|
| 21 |
+
@singleton
|
| 22 |
+
class get_evaluator(object):
|
| 23 |
+
def __init__(self):
|
| 24 |
+
self.evaluator = {}
|
| 25 |
+
|
| 26 |
+
def register(self, evaf, name):
|
| 27 |
+
self.evaluator[name] = evaf
|
| 28 |
+
|
| 29 |
+
def __call__(self, pipeline_cfg=None):
|
| 30 |
+
if pipeline_cfg is None:
|
| 31 |
+
from . import eva_null
|
| 32 |
+
return self.evaluator['null']()
|
| 33 |
+
|
| 34 |
+
if not isinstance(pipeline_cfg, list):
|
| 35 |
+
t = pipeline_cfg.type
|
| 36 |
+
if t == 'miou':
|
| 37 |
+
from . import eva_miou
|
| 38 |
+
if t == 'psnr':
|
| 39 |
+
from . import eva_psnr
|
| 40 |
+
if t == 'ssim':
|
| 41 |
+
from . import eva_ssim
|
| 42 |
+
if t == 'lpips':
|
| 43 |
+
from . import eva_lpips
|
| 44 |
+
if t == 'fid':
|
| 45 |
+
from . import eva_fid
|
| 46 |
+
return self.evaluator[t](**pipeline_cfg.args)
|
| 47 |
+
|
| 48 |
+
evaluator = []
|
| 49 |
+
for ci in pipeline_cfg:
|
| 50 |
+
t = ci.type
|
| 51 |
+
if t == 'miou':
|
| 52 |
+
from . import eva_miou
|
| 53 |
+
if t == 'psnr':
|
| 54 |
+
from . import eva_psnr
|
| 55 |
+
if t == 'ssim':
|
| 56 |
+
from . import eva_ssim
|
| 57 |
+
if t == 'lpips':
|
| 58 |
+
from . import eva_lpips
|
| 59 |
+
if t == 'fid':
|
| 60 |
+
from . import eva_fid
|
| 61 |
+
evaluator.append(
|
| 62 |
+
self.evaluator[t](**ci.args))
|
| 63 |
+
if len(evaluator) == 0:
|
| 64 |
+
return None
|
| 65 |
+
else:
|
| 66 |
+
return compose(evaluator)
|
| 67 |
+
|
| 68 |
+
def register(name):
|
| 69 |
+
def wrapper(class_):
|
| 70 |
+
get_evaluator().register(class_, name)
|
| 71 |
+
return class_
|
| 72 |
+
return wrapper
|
| 73 |
+
|
| 74 |
+
class base_evaluator(object):
|
| 75 |
+
def __init__(self,
|
| 76 |
+
**args):
|
| 77 |
+
'''
|
| 78 |
+
Args:
|
| 79 |
+
sample_n, int,
|
| 80 |
+
the total number of sample. used in
|
| 81 |
+
distributed sync
|
| 82 |
+
'''
|
| 83 |
+
if not dist.is_available():
|
| 84 |
+
raise ValueError
|
| 85 |
+
self.world_size = dist.get_world_size()
|
| 86 |
+
self.rank = dist.get_rank()
|
| 87 |
+
self.sample_n = None
|
| 88 |
+
self.final = {}
|
| 89 |
+
|
| 90 |
+
def sync(self, data):
|
| 91 |
+
"""
|
| 92 |
+
Args:
|
| 93 |
+
data: any,
|
| 94 |
+
the data needs to be broadcasted
|
| 95 |
+
"""
|
| 96 |
+
if data is None:
|
| 97 |
+
return None
|
| 98 |
+
|
| 99 |
+
if isinstance(data, tuple):
|
| 100 |
+
data = list(data)
|
| 101 |
+
|
| 102 |
+
if isinstance(data, list):
|
| 103 |
+
data_list = []
|
| 104 |
+
for datai in data:
|
| 105 |
+
data_list.append(self.sync(datai))
|
| 106 |
+
data = [[*i] for i in zip(*data_list)]
|
| 107 |
+
return data
|
| 108 |
+
|
| 109 |
+
data = [
|
| 110 |
+
self.sync_(data, ranki)
|
| 111 |
+
for ranki in range(self.world_size)
|
| 112 |
+
]
|
| 113 |
+
return data
|
| 114 |
+
|
| 115 |
+
def sync_(self, data, rank):
|
| 116 |
+
|
| 117 |
+
t = type(data)
|
| 118 |
+
is_broadcast = rank == self.rank
|
| 119 |
+
|
| 120 |
+
if t is np.ndarray:
|
| 121 |
+
dtrans = data
|
| 122 |
+
dt = data.dtype
|
| 123 |
+
if dt in [
|
| 124 |
+
int,
|
| 125 |
+
np.bool,
|
| 126 |
+
np.uint8,
|
| 127 |
+
np.int8,
|
| 128 |
+
np.int16,
|
| 129 |
+
np.int32,
|
| 130 |
+
np.int64,]:
|
| 131 |
+
dtt = torch.int64
|
| 132 |
+
elif dt in [
|
| 133 |
+
float,
|
| 134 |
+
np.float16,
|
| 135 |
+
np.float32,
|
| 136 |
+
np.float64,]:
|
| 137 |
+
dtt = torch.float64
|
| 138 |
+
|
| 139 |
+
elif t is str:
|
| 140 |
+
dtrans = np.array(
|
| 141 |
+
[ord(c) for c in data],
|
| 142 |
+
dtype = np.int64
|
| 143 |
+
)
|
| 144 |
+
dt = np.int64
|
| 145 |
+
dtt = torch.int64
|
| 146 |
+
else:
|
| 147 |
+
raise ValueError
|
| 148 |
+
|
| 149 |
+
if is_broadcast:
|
| 150 |
+
n = len(dtrans.shape)
|
| 151 |
+
n = torch.tensor(n).long()
|
| 152 |
+
|
| 153 |
+
n = n.to(self.rank)
|
| 154 |
+
dist.broadcast(n, src=rank)
|
| 155 |
+
|
| 156 |
+
n = list(dtrans.shape)
|
| 157 |
+
n = torch.tensor(n).long()
|
| 158 |
+
n = n.to(self.rank)
|
| 159 |
+
dist.broadcast(n, src=rank)
|
| 160 |
+
|
| 161 |
+
n = torch.tensor(dtrans, dtype=dtt)
|
| 162 |
+
n = n.to(self.rank)
|
| 163 |
+
dist.broadcast(n, src=rank)
|
| 164 |
+
return data
|
| 165 |
+
|
| 166 |
+
n = torch.tensor(0).long()
|
| 167 |
+
n = n.to(self.rank)
|
| 168 |
+
dist.broadcast(n, src=rank)
|
| 169 |
+
n = n.item()
|
| 170 |
+
|
| 171 |
+
n = torch.zeros(n).long()
|
| 172 |
+
n = n.to(self.rank)
|
| 173 |
+
dist.broadcast(n, src=rank)
|
| 174 |
+
n = list(n.to('cpu').numpy())
|
| 175 |
+
|
| 176 |
+
n = torch.zeros(n, dtype=dtt)
|
| 177 |
+
n = n.to(self.rank)
|
| 178 |
+
dist.broadcast(n, src=rank)
|
| 179 |
+
n = n.to('cpu').numpy().astype(dt)
|
| 180 |
+
|
| 181 |
+
if t is np.ndarray:
|
| 182 |
+
return n
|
| 183 |
+
elif t is str:
|
| 184 |
+
n = ''.join([chr(c) for c in n])
|
| 185 |
+
return n
|
| 186 |
+
|
| 187 |
+
def zipzap_arrange(self, data):
|
| 188 |
+
'''
|
| 189 |
+
Order the data so it range like this:
|
| 190 |
+
input [[0, 2, 4, 6], [1, 3, 5, 7]] -> output [0, 1, 2, 3, 4, 5, ...]
|
| 191 |
+
'''
|
| 192 |
+
if isinstance(data[0], list):
|
| 193 |
+
data_new = []
|
| 194 |
+
maxlen = max([len(i) for i in data])
|
| 195 |
+
totlen = sum([len(i) for i in data])
|
| 196 |
+
cnt = 0
|
| 197 |
+
for idx in range(maxlen):
|
| 198 |
+
for datai in data:
|
| 199 |
+
data_new += [datai[idx]]
|
| 200 |
+
cnt += 1
|
| 201 |
+
if cnt >= totlen:
|
| 202 |
+
break
|
| 203 |
+
return data_new
|
| 204 |
+
|
| 205 |
+
elif isinstance(data[0], np.ndarray):
|
| 206 |
+
maxlen = max([i.shape[0] for i in data])
|
| 207 |
+
totlen = sum([i.shape[0] for i in data])
|
| 208 |
+
datai_shape = data[0].shape[1:]
|
| 209 |
+
data = [
|
| 210 |
+
np.concatenate(datai, np.zeros(maxlen-datai.shape[0], *datai_shape), axis=0)
|
| 211 |
+
if datai.shape[0] < maxlen else datai
|
| 212 |
+
for datai in data
|
| 213 |
+
] # even the array
|
| 214 |
+
data = np.stack(data, axis=1).reshape(-1, *datai_shape)
|
| 215 |
+
data = data[:totlen]
|
| 216 |
+
return data
|
| 217 |
+
|
| 218 |
+
else:
|
| 219 |
+
raise NotImplementedError
|
| 220 |
+
|
| 221 |
+
def add_batch(self, **args):
|
| 222 |
+
raise NotImplementedError
|
| 223 |
+
|
| 224 |
+
def set_sample_n(self, sample_n):
|
| 225 |
+
self.sample_n = sample_n
|
| 226 |
+
|
| 227 |
+
def compute(self):
|
| 228 |
+
raise NotImplementedError
|
| 229 |
+
|
| 230 |
+
# Function needed in training to judge which
|
| 231 |
+
# evaluated number is better
|
| 232 |
+
def isbetter(self, old, new):
|
| 233 |
+
return new>old
|
| 234 |
+
|
| 235 |
+
def one_line_summary(self):
|
| 236 |
+
print_log('Evaluator display')
|
| 237 |
+
|
| 238 |
+
def save(self, path):
|
| 239 |
+
if not osp.exists(path):
|
| 240 |
+
os.makedirs(path)
|
| 241 |
+
ofile = osp.join(path, 'result.json')
|
| 242 |
+
with open(ofile, 'w') as f:
|
| 243 |
+
json.dump(self.final, f, indent=4)
|
| 244 |
+
|
| 245 |
+
def clear_data(self):
|
| 246 |
+
raise NotImplementedError
|
| 247 |
+
|
| 248 |
+
class compose(object):
|
| 249 |
+
def __init__(self, pipeline):
|
| 250 |
+
self.pipeline = pipeline
|
| 251 |
+
self.sample_n = None
|
| 252 |
+
self.final = {}
|
| 253 |
+
|
| 254 |
+
def add_batch(self, *args, **kwargs):
|
| 255 |
+
for pi in self.pipeline:
|
| 256 |
+
pi.add_batch(*args, **kwargs)
|
| 257 |
+
|
| 258 |
+
def set_sample_n(self, sample_n):
|
| 259 |
+
self.sample_n = sample_n
|
| 260 |
+
for pi in self.pipeline:
|
| 261 |
+
pi.set_sample_n(sample_n)
|
| 262 |
+
|
| 263 |
+
def compute(self):
|
| 264 |
+
rv = {}
|
| 265 |
+
for pi in self.pipeline:
|
| 266 |
+
rv[pi.symbol] = pi.compute()
|
| 267 |
+
self.final[pi.symbol] = pi.final
|
| 268 |
+
return rv
|
| 269 |
+
|
| 270 |
+
def isbetter(self, old, new):
|
| 271 |
+
check = 0
|
| 272 |
+
for pi in self.pipeline:
|
| 273 |
+
if pi.isbetter(old, new):
|
| 274 |
+
check+=1
|
| 275 |
+
if check/len(self.pipeline)>0.5:
|
| 276 |
+
return True
|
| 277 |
+
else:
|
| 278 |
+
return False
|
| 279 |
+
|
| 280 |
+
def one_line_summary(self):
|
| 281 |
+
for pi in self.pipeline:
|
| 282 |
+
pi.one_line_summary()
|
| 283 |
+
|
| 284 |
+
def save(self, path):
|
| 285 |
+
if not osp.exists(path):
|
| 286 |
+
os.makedirs(path)
|
| 287 |
+
ofile = osp.join(path, 'result.json')
|
| 288 |
+
with open(ofile, 'w') as f:
|
| 289 |
+
json.dump(self.final, f, indent=4)
|
| 290 |
+
|
| 291 |
+
def clear_data(self):
|
| 292 |
+
for pi in self.pipeline:
|
| 293 |
+
pi.clear_data()
|
versatile_diffusion/lib/evaluator/eva_null.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import lpips
|
| 4 |
+
|
| 5 |
+
from .. import nputils
|
| 6 |
+
from ..log_service import print_log
|
| 7 |
+
|
| 8 |
+
from .eva_base import base_evaluator, register
|
| 9 |
+
|
| 10 |
+
@register('null')
|
| 11 |
+
class null_evaluator(base_evaluator):
|
| 12 |
+
def __init__(self, **dummy):
|
| 13 |
+
super().__init__()
|
| 14 |
+
|
| 15 |
+
def add_batch(self,
|
| 16 |
+
**dummy):
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
def compute(self):
|
| 20 |
+
return None
|
| 21 |
+
|
| 22 |
+
def one_line_summary(self):
|
| 23 |
+
print_log('Evaluator null')
|
| 24 |
+
|
| 25 |
+
def clear_data(self):
|
| 26 |
+
pass
|
versatile_diffusion/lib/experiments/__init__.py
ADDED
|
File without changes
|
versatile_diffusion/lib/experiments/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (191 Bytes). View file
|
|
|
versatile_diffusion/lib/experiments/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (151 Bytes). View file
|
|
|
versatile_diffusion/lib/experiments/__pycache__/sd_default.cpython-310.pyc
ADDED
|
Binary file (13.5 kB). View file
|
|
|
versatile_diffusion/lib/experiments/__pycache__/sd_default.cpython-38.pyc
ADDED
|
Binary file (13.5 kB). View file
|
|
|
versatile_diffusion/lib/experiments/sd_default.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.distributed as dist
|
| 3 |
+
from torchvision import transforms as tvtrans
|
| 4 |
+
import os
|
| 5 |
+
import os.path as osp
|
| 6 |
+
import time
|
| 7 |
+
import timeit
|
| 8 |
+
import copy
|
| 9 |
+
import json
|
| 10 |
+
import pickle
|
| 11 |
+
import PIL.Image
|
| 12 |
+
import numpy as np
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
from easydict import EasyDict as edict
|
| 15 |
+
from collections import OrderedDict
|
| 16 |
+
|
| 17 |
+
from lib.cfg_holder import cfg_unique_holder as cfguh
|
| 18 |
+
from lib.data_factory import get_dataset, get_sampler, collate
|
| 19 |
+
from lib.model_zoo import \
|
| 20 |
+
get_model, get_optimizer, get_scheduler
|
| 21 |
+
from lib.log_service import print_log
|
| 22 |
+
|
| 23 |
+
from ..utils import train as train_base
|
| 24 |
+
from ..utils import eval as eval_base
|
| 25 |
+
from ..utils import train_stage as tsbase
|
| 26 |
+
from ..utils import eval_stage as esbase
|
| 27 |
+
from .. import sync
|
| 28 |
+
|
| 29 |
+
###############
|
| 30 |
+
# some helper #
|
| 31 |
+
###############
|
| 32 |
+
|
| 33 |
+
def atomic_save(cfg, net, opt, step, path):
|
| 34 |
+
if isinstance(net, (torch.nn.DataParallel,
|
| 35 |
+
torch.nn.parallel.DistributedDataParallel)):
|
| 36 |
+
netm = net.module
|
| 37 |
+
else:
|
| 38 |
+
netm = net
|
| 39 |
+
sd = netm.state_dict()
|
| 40 |
+
slimmed_sd = [(ki, vi) for ki, vi in sd.items()
|
| 41 |
+
if ki.find('first_stage_model')!=0 and ki.find('cond_stage_model')!=0]
|
| 42 |
+
|
| 43 |
+
checkpoint = {
|
| 44 |
+
"config" : cfg,
|
| 45 |
+
"state_dict" : OrderedDict(slimmed_sd),
|
| 46 |
+
"step" : step}
|
| 47 |
+
if opt is not None:
|
| 48 |
+
checkpoint['optimizer_states'] = opt.state_dict()
|
| 49 |
+
import io
|
| 50 |
+
import fsspec
|
| 51 |
+
bytesbuffer = io.BytesIO()
|
| 52 |
+
torch.save(checkpoint, bytesbuffer)
|
| 53 |
+
with fsspec.open(path, "wb") as f:
|
| 54 |
+
f.write(bytesbuffer.getvalue())
|
| 55 |
+
|
| 56 |
+
def load_state_dict(net, cfg):
|
| 57 |
+
pretrained_pth_full = cfg.get('pretrained_pth_full' , None)
|
| 58 |
+
pretrained_ckpt_full = cfg.get('pretrained_ckpt_full', None)
|
| 59 |
+
pretrained_pth = cfg.get('pretrained_pth' , None)
|
| 60 |
+
pretrained_ckpt = cfg.get('pretrained_ckpt' , None)
|
| 61 |
+
pretrained_pth_dm = cfg.get('pretrained_pth_dm' , None)
|
| 62 |
+
pretrained_pth_ema = cfg.get('pretrained_pth_ema' , None)
|
| 63 |
+
strict_sd = cfg.get('strict_sd', False)
|
| 64 |
+
errmsg = "Overlapped model state_dict! This is undesired behavior!"
|
| 65 |
+
|
| 66 |
+
if pretrained_pth_full is not None or pretrained_ckpt_full is not None:
|
| 67 |
+
assert (pretrained_pth is None) and \
|
| 68 |
+
(pretrained_ckpt is None) and \
|
| 69 |
+
(pretrained_pth_dm is None) and \
|
| 70 |
+
(pretrained_pth_ema is None), errmsg
|
| 71 |
+
if pretrained_pth_full is not None:
|
| 72 |
+
target_file = pretrained_pth_full
|
| 73 |
+
sd = torch.load(target_file, map_location='cpu')
|
| 74 |
+
assert pretrained_ckpt is None, errmsg
|
| 75 |
+
else:
|
| 76 |
+
target_file = pretrained_ckpt_full
|
| 77 |
+
sd = torch.load(target_file, map_location='cpu')['state_dict']
|
| 78 |
+
print_log('Load full model from [{}] strict [{}].'.format(
|
| 79 |
+
target_file, strict_sd))
|
| 80 |
+
net.load_state_dict(sd, strict=strict_sd)
|
| 81 |
+
|
| 82 |
+
if pretrained_pth is not None or pretrained_ckpt is not None:
|
| 83 |
+
assert (pretrained_ckpt_full is None) and \
|
| 84 |
+
(pretrained_pth_full is None) and \
|
| 85 |
+
(pretrained_pth_dm is None) and \
|
| 86 |
+
(pretrained_pth_ema is None), errmsg
|
| 87 |
+
if pretrained_pth is not None:
|
| 88 |
+
target_file = pretrained_pth
|
| 89 |
+
sd = torch.load(target_file, map_location='cpu')
|
| 90 |
+
assert pretrained_ckpt is None, errmsg
|
| 91 |
+
else:
|
| 92 |
+
target_file = pretrained_ckpt
|
| 93 |
+
sd = torch.load(target_file, map_location='cpu')['state_dict']
|
| 94 |
+
print_log('Load model from [{}] strict [{}].'.format(
|
| 95 |
+
target_file, strict_sd))
|
| 96 |
+
sd_extra = [(ki, vi) for ki, vi in net.state_dict().items() \
|
| 97 |
+
if ki.find('first_stage_model')==0 or ki.find('cond_stage_model')==0]
|
| 98 |
+
sd.update(OrderedDict(sd_extra))
|
| 99 |
+
net.load_state_dict(sd, strict=strict_sd)
|
| 100 |
+
|
| 101 |
+
if pretrained_pth_dm is not None:
|
| 102 |
+
assert (pretrained_ckpt_full is None) and \
|
| 103 |
+
(pretrained_pth_full is None) and \
|
| 104 |
+
(pretrained_pth is None) and \
|
| 105 |
+
(pretrained_ckpt is None), errmsg
|
| 106 |
+
print_log('Load diffusion model from [{}] strict [{}].'.format(
|
| 107 |
+
pretrained_pth_dm, strict_sd))
|
| 108 |
+
sd = torch.load(pretrained_pth_dm, map_location='cpu')
|
| 109 |
+
net.model.diffusion_model.load_state_dict(sd, strict=strict_sd)
|
| 110 |
+
|
| 111 |
+
if pretrained_pth_ema is not None:
|
| 112 |
+
assert (pretrained_ckpt_full is None) and \
|
| 113 |
+
(pretrained_pth_full is None) and \
|
| 114 |
+
(pretrained_pth is None) and \
|
| 115 |
+
(pretrained_ckpt is None), errmsg
|
| 116 |
+
print_log('Load unet ema model from [{}] strict [{}].'.format(
|
| 117 |
+
pretrained_pth_ema, strict_sd))
|
| 118 |
+
sd = torch.load(pretrained_pth_ema, map_location='cpu')
|
| 119 |
+
net.model_ema.load_state_dict(sd, strict=strict_sd)
|
| 120 |
+
|
| 121 |
+
def auto_merge_imlist(imlist, max=64):
|
| 122 |
+
imlist = imlist[0:max]
|
| 123 |
+
h, w = imlist[0].shape[0:2]
|
| 124 |
+
num_images = len(imlist)
|
| 125 |
+
num_row = int(np.sqrt(num_images))
|
| 126 |
+
num_col = num_images//num_row + 1 if num_images%num_row!=0 else num_images//num_row
|
| 127 |
+
canvas = np.zeros([num_row*h, num_col*w, 3], dtype=np.uint8)
|
| 128 |
+
for idx, im in enumerate(imlist):
|
| 129 |
+
hi = (idx // num_col) * h
|
| 130 |
+
wi = (idx % num_col) * w
|
| 131 |
+
canvas[hi:hi+h, wi:wi+w, :] = im
|
| 132 |
+
return canvas
|
| 133 |
+
|
| 134 |
+
def latent2im(net, latent):
|
| 135 |
+
single_input = len(latent.shape) == 3
|
| 136 |
+
if single_input:
|
| 137 |
+
latent = latent[None]
|
| 138 |
+
im = net.decode_image(latent.to(net.device))
|
| 139 |
+
im = torch.clamp((im+1.0)/2.0, min=0.0, max=1.0)
|
| 140 |
+
im = [tvtrans.ToPILImage()(i) for i in im]
|
| 141 |
+
if single_input:
|
| 142 |
+
im = im[0]
|
| 143 |
+
return im
|
| 144 |
+
|
| 145 |
+
def im2latent(net, im):
|
| 146 |
+
single_input = not isinstance(im, list)
|
| 147 |
+
if single_input:
|
| 148 |
+
im = [im]
|
| 149 |
+
im = torch.stack([tvtrans.ToTensor()(i) for i in im], dim=0)
|
| 150 |
+
im = (im*2-1).to(net.device)
|
| 151 |
+
z = net.encode_image(im)
|
| 152 |
+
if single_input:
|
| 153 |
+
z = z[0]
|
| 154 |
+
return z
|
| 155 |
+
|
| 156 |
+
class color_adjust(object):
|
| 157 |
+
def __init__(self, ref_from, ref_to):
|
| 158 |
+
x0, m0, std0 = self.get_data_and_stat(ref_from)
|
| 159 |
+
x1, m1, std1 = self.get_data_and_stat(ref_to)
|
| 160 |
+
self.ref_from_stat = (m0, std0)
|
| 161 |
+
self.ref_to_stat = (m1, std1)
|
| 162 |
+
self.ref_from = self.preprocess(x0).reshape(-1, 3)
|
| 163 |
+
self.ref_to = x1.reshape(-1, 3)
|
| 164 |
+
|
| 165 |
+
def get_data_and_stat(self, x):
|
| 166 |
+
if isinstance(x, str):
|
| 167 |
+
x = np.array(PIL.Image.open(x))
|
| 168 |
+
elif isinstance(x, PIL.Image.Image):
|
| 169 |
+
x = np.array(x)
|
| 170 |
+
elif isinstance(x, torch.Tensor):
|
| 171 |
+
x = torch.clamp(x, min=0.0, max=1.0)
|
| 172 |
+
x = np.array(tvtrans.ToPILImage()(x))
|
| 173 |
+
elif isinstance(x, np.ndarray):
|
| 174 |
+
pass
|
| 175 |
+
else:
|
| 176 |
+
raise ValueError
|
| 177 |
+
x = x.astype(float)
|
| 178 |
+
m = np.reshape(x, (-1, 3)).mean(0)
|
| 179 |
+
s = np.reshape(x, (-1, 3)).std(0)
|
| 180 |
+
return x, m, s
|
| 181 |
+
|
| 182 |
+
def preprocess(self, x):
|
| 183 |
+
m0, s0 = self.ref_from_stat
|
| 184 |
+
m1, s1 = self.ref_to_stat
|
| 185 |
+
y = ((x-m0)/s0)*s1 + m1
|
| 186 |
+
return y
|
| 187 |
+
|
| 188 |
+
def __call__(self, xin, keep=0, simple=False):
|
| 189 |
+
xin, _, _ = self.get_data_and_stat(xin)
|
| 190 |
+
x = self.preprocess(xin)
|
| 191 |
+
if simple:
|
| 192 |
+
y = (x*(1-keep) + xin*keep)
|
| 193 |
+
y = np.clip(y, 0, 255).astype(np.uint8)
|
| 194 |
+
return y
|
| 195 |
+
|
| 196 |
+
h, w = x.shape[:2]
|
| 197 |
+
x = x.reshape(-1, 3)
|
| 198 |
+
y = []
|
| 199 |
+
for chi in range(3):
|
| 200 |
+
yi = self.pdf_transfer_1d(self.ref_from[:, chi], self.ref_to[:, chi], x[:, chi])
|
| 201 |
+
y.append(yi)
|
| 202 |
+
|
| 203 |
+
y = np.stack(y, axis=1)
|
| 204 |
+
y = y.reshape(h, w, 3)
|
| 205 |
+
y = (y.astype(float)*(1-keep) + xin.astype(float)*keep)
|
| 206 |
+
y = np.clip(y, 0, 255).astype(np.uint8)
|
| 207 |
+
return y
|
| 208 |
+
|
| 209 |
+
def pdf_transfer_1d(self, arr_fo, arr_to, arr_in, n=600):
|
| 210 |
+
arr = np.concatenate((arr_fo, arr_to))
|
| 211 |
+
min_v = arr.min() - 1e-6
|
| 212 |
+
max_v = arr.max() + 1e-6
|
| 213 |
+
min_vto = arr_to.min() - 1e-6
|
| 214 |
+
max_vto = arr_to.max() + 1e-6
|
| 215 |
+
xs = np.array(
|
| 216 |
+
[min_v + (max_v - min_v) * i / n for i in range(n + 1)])
|
| 217 |
+
hist_fo, _ = np.histogram(arr_fo, xs)
|
| 218 |
+
hist_to, _ = np.histogram(arr_to, xs)
|
| 219 |
+
xs = xs[:-1]
|
| 220 |
+
# compute probability distribution
|
| 221 |
+
cum_fo = np.cumsum(hist_fo)
|
| 222 |
+
cum_to = np.cumsum(hist_to)
|
| 223 |
+
d_fo = cum_fo / cum_fo[-1]
|
| 224 |
+
d_to = cum_to / cum_to[-1]
|
| 225 |
+
# transfer
|
| 226 |
+
t_d = np.interp(d_fo, d_to, xs)
|
| 227 |
+
t_d[d_fo <= d_to[ 0]] = min_vto
|
| 228 |
+
t_d[d_fo >= d_to[-1]] = max_vto
|
| 229 |
+
arr_out = np.interp(arr_in, xs, t_d)
|
| 230 |
+
return arr_out
|
| 231 |
+
|
| 232 |
+
########
|
| 233 |
+
# main #
|
| 234 |
+
########
|
| 235 |
+
|
| 236 |
+
class eval(eval_base):
|
| 237 |
+
def prepare_model(self):
|
| 238 |
+
cfg = cfguh().cfg
|
| 239 |
+
net = get_model()(cfg.model)
|
| 240 |
+
if cfg.env.cuda:
|
| 241 |
+
net.to(self.local_rank)
|
| 242 |
+
load_state_dict(net, cfg.eval) #<--- added
|
| 243 |
+
net = torch.nn.parallel.DistributedDataParallel(
|
| 244 |
+
net, device_ids=[self.local_rank],
|
| 245 |
+
find_unused_parameters=True)
|
| 246 |
+
net.eval()
|
| 247 |
+
return {'net' : net,}
|
| 248 |
+
|
| 249 |
+
class eval_stage(esbase):
|
| 250 |
+
"""
|
| 251 |
+
This is eval stage that can check comprehensive results
|
| 252 |
+
"""
|
| 253 |
+
def __init__(self):
|
| 254 |
+
from ..model_zoo.ddim import DDIMSampler
|
| 255 |
+
self.sampler = DDIMSampler
|
| 256 |
+
|
| 257 |
+
def get_net(self, paras):
|
| 258 |
+
return paras['net']
|
| 259 |
+
|
| 260 |
+
def get_image_path(self):
|
| 261 |
+
if 'train' in cfguh().cfg:
|
| 262 |
+
log_dir = cfguh().cfg.train.log_dir
|
| 263 |
+
else:
|
| 264 |
+
log_dir = cfguh().cfg.eval.log_dir
|
| 265 |
+
return os.path.join(log_dir, "udemo")
|
| 266 |
+
|
| 267 |
+
@torch.no_grad()
|
| 268 |
+
def sample(self, net, sampler, prompt, output_dim, scale, n_samples, ddim_steps, ddim_eta):
|
| 269 |
+
h, w = output_dim
|
| 270 |
+
uc = None
|
| 271 |
+
if scale != 1.0:
|
| 272 |
+
uc = net.get_learned_conditioning(n_samples * [""])
|
| 273 |
+
c = net.get_learned_conditioning(n_samples * [prompt])
|
| 274 |
+
shape = [4, h//8, w//8]
|
| 275 |
+
rv = sampler.sample(
|
| 276 |
+
S=ddim_steps,
|
| 277 |
+
conditioning=c,
|
| 278 |
+
batch_size=n_samples,
|
| 279 |
+
shape=shape,
|
| 280 |
+
verbose=False,
|
| 281 |
+
unconditional_guidance_scale=scale,
|
| 282 |
+
unconditional_conditioning=uc,
|
| 283 |
+
eta=ddim_eta)
|
| 284 |
+
return rv
|
| 285 |
+
|
| 286 |
+
def save_images(self, pil_list, name, path, suffix=''):
|
| 287 |
+
canvas = auto_merge_imlist([np.array(i) for i in pil_list])
|
| 288 |
+
image_name = '{}{}.png'.format(name, suffix)
|
| 289 |
+
PIL.Image.fromarray(canvas).save(osp.join(path, image_name))
|
| 290 |
+
|
| 291 |
+
def __call__(self, **paras):
|
| 292 |
+
cfg = cfguh().cfg
|
| 293 |
+
cfgv = cfg.eval
|
| 294 |
+
|
| 295 |
+
net = paras['net']
|
| 296 |
+
eval_cnt = paras.get('eval_cnt', None)
|
| 297 |
+
fix_seed = cfgv.get('fix_seed', False)
|
| 298 |
+
|
| 299 |
+
LRANK = sync.get_rank('local')
|
| 300 |
+
LWSIZE = sync.get_world_size('local')
|
| 301 |
+
|
| 302 |
+
image_path = self.get_image_path()
|
| 303 |
+
self.create_dir(image_path)
|
| 304 |
+
eval_cnt = paras.get('eval_cnt', None)
|
| 305 |
+
suffix='' if eval_cnt is None else '_itern'+str(eval_cnt)
|
| 306 |
+
|
| 307 |
+
if isinstance(net, (torch.nn.DataParallel,
|
| 308 |
+
torch.nn.parallel.DistributedDataParallel)):
|
| 309 |
+
netm = net.module
|
| 310 |
+
else:
|
| 311 |
+
netm = net
|
| 312 |
+
|
| 313 |
+
with_ema = getattr(netm, 'model_ema', None) is not None
|
| 314 |
+
sampler = self.sampler(netm)
|
| 315 |
+
setattr(netm, 'device', LRANK) # Trick
|
| 316 |
+
|
| 317 |
+
replicate = cfgv.get('replicate', 1)
|
| 318 |
+
conditioning = cfgv.conditioning * replicate
|
| 319 |
+
conditioning_local = conditioning[LRANK : len(conditioning) : LWSIZE]
|
| 320 |
+
seed_increment = [i for i in range(len(conditioning))][LRANK : len(conditioning) : LWSIZE]
|
| 321 |
+
|
| 322 |
+
for prompti, seedi in zip(conditioning_local, seed_increment):
|
| 323 |
+
if prompti == 'SKIP':
|
| 324 |
+
continue
|
| 325 |
+
draw_filename = prompti.strip().replace(' ', '-')
|
| 326 |
+
if fix_seed:
|
| 327 |
+
np.random.seed(cfg.env.rnd_seed + seedi)
|
| 328 |
+
torch.manual_seed(cfg.env.rnd_seed + seedi + 100)
|
| 329 |
+
suffixi = suffix + "_seed{}".format(cfg.env.rnd_seed + seedi + 100)
|
| 330 |
+
else:
|
| 331 |
+
suffixi = suffix
|
| 332 |
+
|
| 333 |
+
if with_ema:
|
| 334 |
+
with netm.ema_scope():
|
| 335 |
+
x, _ = self.sample(netm, sampler, prompti, **cfgv.sample)
|
| 336 |
+
else:
|
| 337 |
+
x, _ = self.sample(netm, sampler, prompti, **cfgv.sample)
|
| 338 |
+
|
| 339 |
+
demo_image = latent2im(netm, x)
|
| 340 |
+
self.save_images(demo_image, draw_filename, image_path, suffix=suffixi)
|
| 341 |
+
|
| 342 |
+
if eval_cnt is not None:
|
| 343 |
+
print_log('Demo printed for {}'.format(eval_cnt))
|
| 344 |
+
return {}
|
| 345 |
+
|
| 346 |
+
##################
|
| 347 |
+
# eval variation #
|
| 348 |
+
##################
|
| 349 |
+
|
| 350 |
+
class eval_stage_variation(eval_stage):
|
| 351 |
+
@torch.no_grad()
|
| 352 |
+
def sample(self, net, sampler, visual_hint, output_dim, scale, n_samples, ddim_steps, ddim_eta):
|
| 353 |
+
h, w = output_dim
|
| 354 |
+
vh = tvtrans.ToTensor()(PIL.Image.open(visual_hint))[None].to(net.device)
|
| 355 |
+
c = net.get_learned_conditioning(vh)
|
| 356 |
+
c = c.repeat(n_samples, 1, 1)
|
| 357 |
+
uc = None
|
| 358 |
+
if scale != 1.0:
|
| 359 |
+
dummy = torch.zeros_like(vh)
|
| 360 |
+
uc = net.get_learned_conditioning(dummy)
|
| 361 |
+
uc = uc.repeat(n_samples, 1, 1)
|
| 362 |
+
|
| 363 |
+
shape = [4, h//8, w//8]
|
| 364 |
+
rv = sampler.sample(
|
| 365 |
+
S=ddim_steps,
|
| 366 |
+
conditioning=c,
|
| 367 |
+
batch_size=n_samples,
|
| 368 |
+
shape=shape,
|
| 369 |
+
verbose=False,
|
| 370 |
+
unconditional_guidance_scale=scale,
|
| 371 |
+
unconditional_conditioning=uc,
|
| 372 |
+
eta=ddim_eta)
|
| 373 |
+
return rv
|
| 374 |
+
|
| 375 |
+
def __call__(self, **paras):
|
| 376 |
+
cfg = cfguh().cfg
|
| 377 |
+
cfgv = cfg.eval
|
| 378 |
+
|
| 379 |
+
net = paras['net']
|
| 380 |
+
eval_cnt = paras.get('eval_cnt', None)
|
| 381 |
+
fix_seed = cfgv.get('fix_seed', False)
|
| 382 |
+
|
| 383 |
+
LRANK = sync.get_rank('local')
|
| 384 |
+
LWSIZE = sync.get_world_size('local')
|
| 385 |
+
|
| 386 |
+
image_path = self.get_image_path()
|
| 387 |
+
self.create_dir(image_path)
|
| 388 |
+
eval_cnt = paras.get('eval_cnt', None)
|
| 389 |
+
suffix='' if eval_cnt is None else '_'+str(eval_cnt)
|
| 390 |
+
|
| 391 |
+
if isinstance(net, (torch.nn.DataParallel,
|
| 392 |
+
torch.nn.parallel.DistributedDataParallel)):
|
| 393 |
+
netm = net.module
|
| 394 |
+
else:
|
| 395 |
+
netm = net
|
| 396 |
+
|
| 397 |
+
with_ema = getattr(netm, 'model_ema', None) is not None
|
| 398 |
+
sampler = self.sampler(netm)
|
| 399 |
+
setattr(netm, 'device', LRANK) # Trick
|
| 400 |
+
|
| 401 |
+
color_adj = cfguh().cfg.eval.get('color_adj', False)
|
| 402 |
+
color_adj_keep_ratio = cfguh().cfg.eval.get('color_adj_keep_ratio', 0.5)
|
| 403 |
+
color_adj_simple = cfguh().cfg.eval.get('color_adj_simple', True)
|
| 404 |
+
|
| 405 |
+
replicate = cfgv.get('replicate', 1)
|
| 406 |
+
conditioning = cfgv.conditioning * replicate
|
| 407 |
+
conditioning_local = conditioning[LRANK : len(conditioning) : LWSIZE]
|
| 408 |
+
seed_increment = [i for i in range(len(conditioning))][LRANK : len(conditioning) : LWSIZE]
|
| 409 |
+
|
| 410 |
+
for ci, seedi in zip(conditioning_local, seed_increment):
|
| 411 |
+
if ci == 'SKIP':
|
| 412 |
+
continue
|
| 413 |
+
|
| 414 |
+
draw_filename = osp.splitext(osp.basename(ci))[0]
|
| 415 |
+
|
| 416 |
+
if fix_seed:
|
| 417 |
+
np.random.seed(cfg.env.rnd_seed + seedi)
|
| 418 |
+
torch.manual_seed(cfg.env.rnd_seed + seedi + 100)
|
| 419 |
+
suffixi = suffix + "_seed{}".format(cfg.env.rnd_seed + seedi + 100)
|
| 420 |
+
else:
|
| 421 |
+
suffixi = suffix
|
| 422 |
+
|
| 423 |
+
if with_ema:
|
| 424 |
+
with netm.ema_scope():
|
| 425 |
+
x, _ = self.sample(netm, sampler, ci, **cfgv.sample)
|
| 426 |
+
else:
|
| 427 |
+
x, _ = self.sample(netm, sampler, ci, **cfgv.sample)
|
| 428 |
+
|
| 429 |
+
demo_image = latent2im(netm, x)
|
| 430 |
+
if color_adj:
|
| 431 |
+
x_adj = []
|
| 432 |
+
for demoi in demo_image:
|
| 433 |
+
color_adj_f = color_adjust(ref_from=demoi, ref_to=ci)
|
| 434 |
+
xi_adj = color_adj_f(demoi, keep=color_adj_keep_ratio, simple=color_adj_simple)
|
| 435 |
+
x_adj.append(xi_adj)
|
| 436 |
+
demo_image = x_adj
|
| 437 |
+
self.save_images(demo_image, draw_filename, image_path, suffix=suffixi)
|
| 438 |
+
|
| 439 |
+
if eval_cnt is not None:
|
| 440 |
+
print_log('Demo printed for {}'.format(eval_cnt))
|
| 441 |
+
return {}
|
versatile_diffusion/lib/experiments/vd_default.py
ADDED
|
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.distributed as dist
|
| 3 |
+
from torchvision import transforms as tvtrans
|
| 4 |
+
import os
|
| 5 |
+
import os.path as osp
|
| 6 |
+
import time
|
| 7 |
+
import timeit
|
| 8 |
+
import copy
|
| 9 |
+
import json
|
| 10 |
+
import pickle
|
| 11 |
+
import PIL.Image
|
| 12 |
+
import numpy as np
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
from easydict import EasyDict as edict
|
| 15 |
+
from collections import OrderedDict
|
| 16 |
+
|
| 17 |
+
from lib.cfg_holder import cfg_unique_holder as cfguh
|
| 18 |
+
from lib.data_factory import get_dataset, get_sampler, collate
|
| 19 |
+
from lib.model_zoo import \
|
| 20 |
+
get_model, get_optimizer, get_scheduler
|
| 21 |
+
from lib.log_service import print_log
|
| 22 |
+
|
| 23 |
+
from ..utils import train as train_base
|
| 24 |
+
from ..utils import eval as eval_base
|
| 25 |
+
from ..utils import train_stage as tsbase
|
| 26 |
+
from ..utils import eval_stage as esbase
|
| 27 |
+
from .. import sync
|
| 28 |
+
|
| 29 |
+
from .sd_default import auto_merge_imlist, latent2im, color_adjust
|
| 30 |
+
from .sd_default import eval as eval_base
|
| 31 |
+
from .sd_default import eval_stage as eval_stage_base
|
| 32 |
+
|
| 33 |
+
###############
|
| 34 |
+
# some helper #
|
| 35 |
+
###############
|
| 36 |
+
|
| 37 |
+
def atomic_save(cfg, net, opt, step, path):
|
| 38 |
+
if isinstance(net, (torch.nn.DataParallel,
|
| 39 |
+
torch.nn.parallel.DistributedDataParallel)):
|
| 40 |
+
netm = net.module
|
| 41 |
+
else:
|
| 42 |
+
netm = net
|
| 43 |
+
sd = netm.state_dict()
|
| 44 |
+
slimmed_sd = [(ki, vi) for ki, vi in sd.items()
|
| 45 |
+
if ki.find('autokl')!=0 and ki.find('optimus')!=0 and ki.find('clip')!=0]
|
| 46 |
+
|
| 47 |
+
checkpoint = {
|
| 48 |
+
"config" : cfg,
|
| 49 |
+
"state_dict" : OrderedDict(slimmed_sd),
|
| 50 |
+
"step" : step}
|
| 51 |
+
if opt is not None:
|
| 52 |
+
checkpoint['optimizer_states'] = opt.state_dict()
|
| 53 |
+
import io
|
| 54 |
+
import fsspec
|
| 55 |
+
bytesbuffer = io.BytesIO()
|
| 56 |
+
torch.save(checkpoint, bytesbuffer)
|
| 57 |
+
with fsspec.open(path, "wb") as f:
|
| 58 |
+
f.write(bytesbuffer.getvalue())
|
| 59 |
+
|
| 60 |
+
def load_state_dict(net, cfg):
|
| 61 |
+
pretrained_pth_full = cfg.get('pretrained_pth_full' , None)
|
| 62 |
+
pretrained_ckpt_full = cfg.get('pretrained_ckpt_full', None)
|
| 63 |
+
pretrained_pth = cfg.get('pretrained_pth' , None)
|
| 64 |
+
pretrained_ckpt = cfg.get('pretrained_ckpt' , None)
|
| 65 |
+
pretrained_pth_dm = cfg.get('pretrained_pth_dm' , None)
|
| 66 |
+
pretrained_pth_ema = cfg.get('pretrained_pth_ema' , None)
|
| 67 |
+
strict_sd = cfg.get('strict_sd', False)
|
| 68 |
+
errmsg = "Overlapped model state_dict! This is undesired behavior!"
|
| 69 |
+
|
| 70 |
+
if pretrained_pth_full is not None or pretrained_ckpt_full is not None:
|
| 71 |
+
assert (pretrained_pth is None) and \
|
| 72 |
+
(pretrained_ckpt is None) and \
|
| 73 |
+
(pretrained_pth_dm is None) and \
|
| 74 |
+
(pretrained_pth_ema is None), errmsg
|
| 75 |
+
if pretrained_pth_full is not None:
|
| 76 |
+
target_file = pretrained_pth_full
|
| 77 |
+
sd = torch.load(target_file, map_location='cpu')
|
| 78 |
+
assert pretrained_ckpt is None, errmsg
|
| 79 |
+
else:
|
| 80 |
+
target_file = pretrained_ckpt_full
|
| 81 |
+
sd = torch.load(target_file, map_location='cpu')['state_dict']
|
| 82 |
+
print_log('Load full model from [{}] strict [{}].'.format(
|
| 83 |
+
target_file, strict_sd))
|
| 84 |
+
net.load_state_dict(sd, strict=strict_sd)
|
| 85 |
+
|
| 86 |
+
if pretrained_pth is not None or pretrained_ckpt is not None:
|
| 87 |
+
assert (pretrained_ckpt_full is None) and \
|
| 88 |
+
(pretrained_pth_full is None) and \
|
| 89 |
+
(pretrained_pth_dm is None) and \
|
| 90 |
+
(pretrained_pth_ema is None), errmsg
|
| 91 |
+
if pretrained_pth is not None:
|
| 92 |
+
target_file = pretrained_pth
|
| 93 |
+
sd = torch.load(target_file, map_location='cpu')
|
| 94 |
+
assert pretrained_ckpt is None, errmsg
|
| 95 |
+
else:
|
| 96 |
+
target_file = pretrained_ckpt
|
| 97 |
+
sd = torch.load(target_file, map_location='cpu')['state_dict']
|
| 98 |
+
print_log('Load model from [{}] strict [{}].'.format(
|
| 99 |
+
target_file, strict_sd))
|
| 100 |
+
sd_extra = [(ki, vi) for ki, vi in net.state_dict().items() \
|
| 101 |
+
if ki.find('autokl')==0 or ki.find('optimus')==0 or ki.find('clip')==0]
|
| 102 |
+
sd.update(OrderedDict(sd_extra))
|
| 103 |
+
net.load_state_dict(sd, strict=strict_sd)
|
| 104 |
+
|
| 105 |
+
if pretrained_pth_dm is not None:
|
| 106 |
+
assert (pretrained_ckpt_full is None) and \
|
| 107 |
+
(pretrained_pth_full is None) and \
|
| 108 |
+
(pretrained_pth is None) and \
|
| 109 |
+
(pretrained_ckpt is None), errmsg
|
| 110 |
+
print_log('Load diffusion model from [{}] strict [{}].'.format(
|
| 111 |
+
pretrained_pth_dm, strict_sd))
|
| 112 |
+
sd = torch.load(pretrained_pth_dm, map_location='cpu')
|
| 113 |
+
net.model.diffusion_model.load_state_dict(sd, strict=strict_sd)
|
| 114 |
+
|
| 115 |
+
if pretrained_pth_ema is not None:
|
| 116 |
+
assert (pretrained_ckpt_full is None) and \
|
| 117 |
+
(pretrained_pth_full is None) and \
|
| 118 |
+
(pretrained_pth is None) and \
|
| 119 |
+
(pretrained_ckpt is None), errmsg
|
| 120 |
+
print_log('Load unet ema model from [{}] strict [{}].'.format(
|
| 121 |
+
pretrained_pth_ema, strict_sd))
|
| 122 |
+
sd = torch.load(pretrained_pth_ema, map_location='cpu')
|
| 123 |
+
net.model_ema.load_state_dict(sd, strict=strict_sd)
|
| 124 |
+
|
| 125 |
+
###################
|
| 126 |
+
# official stages #
|
| 127 |
+
###################
|
| 128 |
+
|
| 129 |
+
class eval(eval_base):
|
| 130 |
+
pass
|
| 131 |
+
|
| 132 |
+
class eval_stage(eval_stage_base):
|
| 133 |
+
"""
|
| 134 |
+
Evaluation of both prompt and vision
|
| 135 |
+
"""
|
| 136 |
+
def __init__(self):
|
| 137 |
+
from ..model_zoo.ddim_vd import DDIMSampler_VD
|
| 138 |
+
self.sampler = DDIMSampler_VD
|
| 139 |
+
|
| 140 |
+
@torch.no_grad()
|
| 141 |
+
def sample(
|
| 142 |
+
self, net, sampler, context, otype, ctype, image_output_dim, text_latent_dim,
|
| 143 |
+
scale, n_samples, ddim_steps, ddim_eta):
|
| 144 |
+
if ctype == 'prompt':
|
| 145 |
+
c = net.clip_encode_text(n_samples * [context])
|
| 146 |
+
uc = None
|
| 147 |
+
if scale != 1.0:
|
| 148 |
+
uc = net.clip_encode_text(n_samples * [""])
|
| 149 |
+
elif ctype == 'vision':
|
| 150 |
+
context = context[None].repeat(n_samples, 1, 1, 1)
|
| 151 |
+
c = net.clip_encode_vision(context)
|
| 152 |
+
uc = None
|
| 153 |
+
if scale != 1.0:
|
| 154 |
+
dummy = torch.zeros_like(context)
|
| 155 |
+
uc = net.clip_encode_vision(dummy)
|
| 156 |
+
|
| 157 |
+
if otype == 'image':
|
| 158 |
+
h, w = image_output_dim
|
| 159 |
+
shape = [n_samples, 4, h//8, w//8]
|
| 160 |
+
rv = sampler.sample(
|
| 161 |
+
steps=ddim_steps,
|
| 162 |
+
shape=shape,
|
| 163 |
+
conditioning=c,
|
| 164 |
+
unconditional_guidance_scale=scale,
|
| 165 |
+
unconditional_conditioning=uc,
|
| 166 |
+
xtype=otype, ctype=ctype,
|
| 167 |
+
eta=ddim_eta,
|
| 168 |
+
verbose=False,)
|
| 169 |
+
elif otype == 'text':
|
| 170 |
+
n = text_latent_dim
|
| 171 |
+
shape = [n_samples, n]
|
| 172 |
+
rv = sampler.sample(
|
| 173 |
+
steps=ddim_steps,
|
| 174 |
+
shape=shape,
|
| 175 |
+
conditioning=c,
|
| 176 |
+
unconditional_guidance_scale=scale,
|
| 177 |
+
unconditional_conditioning=uc,
|
| 178 |
+
xtype=otype, ctype=ctype,
|
| 179 |
+
eta=ddim_eta,
|
| 180 |
+
verbose=False,)
|
| 181 |
+
|
| 182 |
+
return rv
|
| 183 |
+
|
| 184 |
+
def decode_and_save(
|
| 185 |
+
self, netm, z, xtype, ctype, path, name, suffix,
|
| 186 |
+
color_adj=False, color_adj_to=None):
|
| 187 |
+
if xtype == 'image':
|
| 188 |
+
x = netm.autokl_decode(z)
|
| 189 |
+
name = 't2i_'+name if ctype == 'prompt' else 'v2i_'+name
|
| 190 |
+
if color_adj and (ctype=='vision'):
|
| 191 |
+
keep_ratio = cfguh().cfg.eval.get('color_adj_keep_ratio', 0.5)
|
| 192 |
+
simple = cfguh().cfg.eval.get('color_adj_simple', True)
|
| 193 |
+
x_adj = []
|
| 194 |
+
for xi in x:
|
| 195 |
+
color_adj_f = color_adjust(ref_from=(xi+1)/2, ref_to=color_adj_to)
|
| 196 |
+
xi_adj = color_adj_f((xi+1)/2, keep=keep_ratio, simple=simple)
|
| 197 |
+
x_adj.append(xi_adj)
|
| 198 |
+
x = x_adj
|
| 199 |
+
self.save_images(x, name, path, suffix=suffix)
|
| 200 |
+
elif xtype == 'text':
|
| 201 |
+
prompt_temperature = cfguh().cfg.eval.get('prompt_temperature', 1.0)
|
| 202 |
+
x = netm.optimus_decode(z, temperature=prompt_temperature)
|
| 203 |
+
name = 't2t_'+name if ctype == 'prompt' else 'v2t_'+name
|
| 204 |
+
prompt_merge_same_adj_word = cfguh().cfg.eval.get('prompt_merge_same_adj_word', False)
|
| 205 |
+
if prompt_merge_same_adj_word:
|
| 206 |
+
xnew = []
|
| 207 |
+
for xi in x:
|
| 208 |
+
xi_split = xi.split()
|
| 209 |
+
xinew = []
|
| 210 |
+
for idxi, wi in enumerate(xi_split):
|
| 211 |
+
if idxi!=0 and wi==xi_split[idxi-1]:
|
| 212 |
+
continue
|
| 213 |
+
xinew.append(wi)
|
| 214 |
+
xnew.append(' '.join(xinew))
|
| 215 |
+
x = xnew
|
| 216 |
+
self.save_text(x, name, path, suffix=suffix)
|
| 217 |
+
|
| 218 |
+
def save_images(self, x, name, path, suffix=''):
|
| 219 |
+
if isinstance(x, torch.Tensor):
|
| 220 |
+
single_input = len(x.shape) == 3
|
| 221 |
+
if single_input:
|
| 222 |
+
x = x[None]
|
| 223 |
+
x = torch.clamp((x+1.0)/2.0, min=0.0, max=1.0)
|
| 224 |
+
x = [tvtrans.ToPILImage()(xi) for xi in x]
|
| 225 |
+
xlist = [np.array(xi) for xi in x]
|
| 226 |
+
elif isinstance(x, list):
|
| 227 |
+
xlist = x
|
| 228 |
+
canvas = auto_merge_imlist(xlist)
|
| 229 |
+
image_name = '{}{}.png'.format(name, suffix)
|
| 230 |
+
PIL.Image.fromarray(canvas).save(osp.join(path, image_name))
|
| 231 |
+
|
| 232 |
+
def save_text(self, x, name, path, suffix=''):
|
| 233 |
+
file_name = '{}{}.txt'.format(name, suffix)
|
| 234 |
+
with open(osp.join(path, file_name) ,'w') as f:
|
| 235 |
+
for xi in x:
|
| 236 |
+
f.write(xi+'\n')
|
| 237 |
+
|
| 238 |
+
def __call__(self, **paras):
|
| 239 |
+
cfg = cfguh().cfg
|
| 240 |
+
cfgv = cfg.eval
|
| 241 |
+
|
| 242 |
+
net = self.get_net(paras)
|
| 243 |
+
eval_cnt = paras.get('eval_cnt', None)
|
| 244 |
+
fix_seed = cfgv.get('fix_seed', False)
|
| 245 |
+
|
| 246 |
+
LRANK = sync.get_rank('local')
|
| 247 |
+
LWSIZE = sync.get_world_size('local')
|
| 248 |
+
|
| 249 |
+
output_path = self.get_image_path()
|
| 250 |
+
self.create_dir(output_path)
|
| 251 |
+
eval_cnt = paras.get('eval_cnt', None)
|
| 252 |
+
suffix='' if eval_cnt is None else '_'+str(eval_cnt)
|
| 253 |
+
|
| 254 |
+
if isinstance(net, (torch.nn.DataParallel,
|
| 255 |
+
torch.nn.parallel.DistributedDataParallel)):
|
| 256 |
+
netm = net.module
|
| 257 |
+
else:
|
| 258 |
+
netm = net
|
| 259 |
+
|
| 260 |
+
with_ema = getattr(netm, 'model_ema', None) is not None
|
| 261 |
+
sampler = self.sampler(netm)
|
| 262 |
+
setattr(netm, 'device', LRANK) # Trick
|
| 263 |
+
|
| 264 |
+
color_adj = cfguh().cfg.eval.get('color_adj', False)
|
| 265 |
+
|
| 266 |
+
replicate = cfgv.get('replicate', 1)
|
| 267 |
+
conditioning = cfgv.conditioning * replicate
|
| 268 |
+
conditioning_local = conditioning[LRANK : len(conditioning) : LWSIZE]
|
| 269 |
+
seed_increment = [i for i in range(len(conditioning))][LRANK : len(conditioning) : LWSIZE]
|
| 270 |
+
|
| 271 |
+
for conditioningi, seedi in zip(conditioning_local, seed_increment):
|
| 272 |
+
if conditioningi == 'SKIP':
|
| 273 |
+
continue
|
| 274 |
+
|
| 275 |
+
ci, otypei = conditioningi
|
| 276 |
+
|
| 277 |
+
if osp.isfile(ci):
|
| 278 |
+
# is vision
|
| 279 |
+
output_name = osp.splitext(osp.basename(ci))[0]
|
| 280 |
+
ci = tvtrans.ToTensor()(PIL.Image.open(ci))
|
| 281 |
+
ci = ci*2 - 1
|
| 282 |
+
ctypei = 'vision'
|
| 283 |
+
else:
|
| 284 |
+
# is prompt
|
| 285 |
+
output_name = ci.strip().replace(' ', '-')
|
| 286 |
+
ctypei = 'prompt'
|
| 287 |
+
|
| 288 |
+
if fix_seed:
|
| 289 |
+
np.random.seed(cfg.env.rnd_seed + seedi)
|
| 290 |
+
torch.manual_seed(cfg.env.rnd_seed + seedi + 100)
|
| 291 |
+
suffixi = suffix + "_seed{}".format(cfg.env.rnd_seed + seedi + 100)
|
| 292 |
+
else:
|
| 293 |
+
suffixi = suffix
|
| 294 |
+
|
| 295 |
+
if with_ema:
|
| 296 |
+
with netm.ema_scope():
|
| 297 |
+
z, _ = self.sample(netm, sampler, ci, otypei, ctypei, **cfgv.sample)
|
| 298 |
+
else:
|
| 299 |
+
z, _ = self.sample(netm, sampler, ci, otypei, ctypei, **cfgv.sample)
|
| 300 |
+
|
| 301 |
+
self.decode_and_save(
|
| 302 |
+
netm, z, otypei, ctypei, output_path, output_name, suffixi,
|
| 303 |
+
color_adj=color_adj, color_adj_to=conditioningi[0],)
|
| 304 |
+
|
| 305 |
+
if eval_cnt is not None:
|
| 306 |
+
print_log('Demo printed for {}'.format(eval_cnt))
|
| 307 |
+
return {}
|
| 308 |
+
|
| 309 |
+
################
|
| 310 |
+
# basic stages #
|
| 311 |
+
################
|
| 312 |
+
|
| 313 |
+
class eval_stage_basic(eval_stage_base):
|
| 314 |
+
@torch.no_grad()
|
| 315 |
+
def sample(self, net, sampler, visual_hint, output_dim, scale, n_samples, ddim_steps, ddim_eta):
|
| 316 |
+
h, w = output_dim
|
| 317 |
+
vh = PIL.Image.open(visual_hint)
|
| 318 |
+
c = net.clip_encode_vision(n_samples * [vh])
|
| 319 |
+
uc = None
|
| 320 |
+
if scale != 1.0:
|
| 321 |
+
dummy = torch.zeros_like(tvtrans.ToTensor()(vh))
|
| 322 |
+
uc = net.clip_encode_vision(n_samples * [dummy])
|
| 323 |
+
|
| 324 |
+
shape = [4, h//8, w//8]
|
| 325 |
+
rv = sampler.sample(
|
| 326 |
+
S=ddim_steps,
|
| 327 |
+
conditioning=c,
|
| 328 |
+
batch_size=n_samples,
|
| 329 |
+
shape=shape,
|
| 330 |
+
verbose=False,
|
| 331 |
+
unconditional_guidance_scale=scale,
|
| 332 |
+
unconditional_conditioning=uc,
|
| 333 |
+
eta=ddim_eta)
|
| 334 |
+
return rv
|
| 335 |
+
|
| 336 |
+
def __call__(self, **paras):
|
| 337 |
+
cfg = cfguh().cfg
|
| 338 |
+
cfgv = cfg.eval
|
| 339 |
+
|
| 340 |
+
net = paras['net']
|
| 341 |
+
eval_cnt = paras.get('eval_cnt', None)
|
| 342 |
+
fix_seed = cfgv.get('fix_seed', False)
|
| 343 |
+
|
| 344 |
+
LRANK = sync.get_rank('local')
|
| 345 |
+
LWSIZE = sync.get_world_size('local')
|
| 346 |
+
|
| 347 |
+
image_path = self.get_image_path()
|
| 348 |
+
self.create_dir(image_path)
|
| 349 |
+
eval_cnt = paras.get('eval_cnt', None)
|
| 350 |
+
suffix='' if eval_cnt is None else '_'+str(eval_cnt)
|
| 351 |
+
|
| 352 |
+
if isinstance(net, (torch.nn.DataParallel,
|
| 353 |
+
torch.nn.parallel.DistributedDataParallel)):
|
| 354 |
+
netm = net.module
|
| 355 |
+
else:
|
| 356 |
+
netm = net
|
| 357 |
+
|
| 358 |
+
with_ema = getattr(netm, 'model_ema', None) is not None
|
| 359 |
+
sampler = self.sampler(netm)
|
| 360 |
+
setattr(netm, 'device', LRANK) # Trick
|
| 361 |
+
|
| 362 |
+
color_adj = cfguh().cfg.eval.get('color_adj', False)
|
| 363 |
+
color_adj_keep_ratio = cfguh().cfg.eval.get('color_adj_keep_ratio', 0.5)
|
| 364 |
+
color_adj_simple = cfguh().cfg.eval.get('color_adj_simple', True)
|
| 365 |
+
|
| 366 |
+
replicate = cfgv.get('replicate', 1)
|
| 367 |
+
conditioning = cfgv.conditioning * replicate
|
| 368 |
+
conditioning_local = conditioning[LRANK : len(conditioning) : LWSIZE]
|
| 369 |
+
seed_increment = [i for i in range(len(conditioning))][LRANK : len(conditioning) : LWSIZE]
|
| 370 |
+
|
| 371 |
+
for ci, seedi in zip(conditioning_local, seed_increment):
|
| 372 |
+
if ci == 'SKIP':
|
| 373 |
+
continue
|
| 374 |
+
draw_filename = osp.splitext(osp.basename(ci))[0]
|
| 375 |
+
if fix_seed:
|
| 376 |
+
np.random.seed(cfg.env.rnd_seed + seedi)
|
| 377 |
+
torch.manual_seed(cfg.env.rnd_seed + seedi + 100)
|
| 378 |
+
suffixi = suffix + "_seed{}".format(cfg.env.rnd_seed + seedi + 100)
|
| 379 |
+
else:
|
| 380 |
+
suffixi = suffix
|
| 381 |
+
|
| 382 |
+
if with_ema:
|
| 383 |
+
with netm.ema_scope():
|
| 384 |
+
x, _ = self.sample(netm, sampler, ci, **cfgv.sample)
|
| 385 |
+
else:
|
| 386 |
+
x, _ = self.sample(netm, sampler, ci, **cfgv.sample)
|
| 387 |
+
|
| 388 |
+
demo_image = latent2im(netm, x)
|
| 389 |
+
if color_adj:
|
| 390 |
+
x_adj = []
|
| 391 |
+
for demoi in demo_image:
|
| 392 |
+
color_adj_f = color_adjust(ref_from=demoi, ref_to=ci)
|
| 393 |
+
xi_adj = color_adj_f(demoi, keep=color_adj_keep_ratio, simple=color_adj_simple)
|
| 394 |
+
x_adj.append(xi_adj)
|
| 395 |
+
demo_image = x_adj
|
| 396 |
+
self.save_images(demo_image, draw_filename, image_path, suffix=suffixi)
|
| 397 |
+
|
| 398 |
+
if eval_cnt is not None:
|
| 399 |
+
print_log('Demo printed for {}'.format(eval_cnt))
|
| 400 |
+
return {}
|
| 401 |
+
|
| 402 |
+
#######################
|
| 403 |
+
# dual context stages #
|
| 404 |
+
#######################
|
| 405 |
+
|
| 406 |
+
class eval_stage_dc(eval_stage_base):
|
| 407 |
+
def __init__(self):
|
| 408 |
+
from ..model_zoo.ddim_dualcontext import DDIMSampler_DualContext
|
| 409 |
+
self.sampler = DDIMSampler_DualContext
|
| 410 |
+
|
| 411 |
+
@torch.no_grad()
|
| 412 |
+
def sample(
|
| 413 |
+
self, net, sampler, conditioning, output_dim,
|
| 414 |
+
scale, n_samples, ddim_steps, ddim_eta):
|
| 415 |
+
ctype, cvalue =conditioning
|
| 416 |
+
if ctype == 'prompt':
|
| 417 |
+
return self.sample_text(
|
| 418 |
+
net, sampler, cvalue, output_dim,
|
| 419 |
+
scale, n_samples, ddim_steps, ddim_eta)
|
| 420 |
+
elif ctype == 'vision':
|
| 421 |
+
return self.sample_vision(
|
| 422 |
+
net, sampler, cvalue, output_dim,
|
| 423 |
+
scale, n_samples, ddim_steps, ddim_eta)
|
| 424 |
+
else:
|
| 425 |
+
raise ValueError
|
| 426 |
+
|
| 427 |
+
@torch.no_grad()
|
| 428 |
+
def sample_text(
|
| 429 |
+
self, net, sampler, prompt, output_dim,
|
| 430 |
+
scale, n_samples, ddim_steps, ddim_eta):
|
| 431 |
+
h, w = output_dim
|
| 432 |
+
uc = None
|
| 433 |
+
if scale != 1.0:
|
| 434 |
+
uc = net.clip_encode_text(n_samples * [""])
|
| 435 |
+
c = net.clip_encode_text(n_samples * [prompt])
|
| 436 |
+
shape = [n_samples, 4, h//8, w//8]
|
| 437 |
+
rv = sampler.sample_text(
|
| 438 |
+
steps=ddim_steps,
|
| 439 |
+
shape=shape,
|
| 440 |
+
conditioning=c,
|
| 441 |
+
unconditional_guidance_scale=scale,
|
| 442 |
+
unconditional_conditioning=uc,
|
| 443 |
+
eta=ddim_eta,
|
| 444 |
+
verbose=False,)
|
| 445 |
+
return rv
|
| 446 |
+
|
| 447 |
+
@torch.no_grad()
|
| 448 |
+
def sample_vision(
|
| 449 |
+
self, net, sampler, visual_hint, output_dim,
|
| 450 |
+
scale, n_samples, ddim_steps, ddim_eta):
|
| 451 |
+
h, w = output_dim
|
| 452 |
+
if len(visual_hint.shape) == 3:
|
| 453 |
+
visual_hint=visual_hint[None].repeat(n_samples, 1, 1, 1)
|
| 454 |
+
else:
|
| 455 |
+
raise ValueError
|
| 456 |
+
|
| 457 |
+
c = net.clip_encode_vision(visual_hint)
|
| 458 |
+
uc = None
|
| 459 |
+
if scale != 1.0:
|
| 460 |
+
visual_hint_blank = torch.zeros_like(visual_hint)
|
| 461 |
+
uc = net.clip_encode_vision(visual_hint_blank)
|
| 462 |
+
|
| 463 |
+
shape = [n_samples, 4, h//8, w//8]
|
| 464 |
+
rv = sampler.sample_vision(
|
| 465 |
+
steps=ddim_steps,
|
| 466 |
+
shape=shape,
|
| 467 |
+
conditioning=c,
|
| 468 |
+
unconditional_guidance_scale=scale,
|
| 469 |
+
unconditional_conditioning=uc,
|
| 470 |
+
eta=ddim_eta,
|
| 471 |
+
verbose=False,)
|
| 472 |
+
return rv
|
| 473 |
+
|
| 474 |
+
def __call__(self, **paras):
|
| 475 |
+
cfg = cfguh().cfg
|
| 476 |
+
cfgv = cfg.eval
|
| 477 |
+
|
| 478 |
+
net = self.get_net(paras)
|
| 479 |
+
eval_cnt = paras.get('eval_cnt', None)
|
| 480 |
+
fix_seed = cfgv.get('fix_seed', False)
|
| 481 |
+
|
| 482 |
+
LRANK = sync.get_rank('local')
|
| 483 |
+
LWSIZE = sync.get_world_size('local')
|
| 484 |
+
|
| 485 |
+
image_path = self.get_image_path()
|
| 486 |
+
self.create_dir(image_path)
|
| 487 |
+
suffix='' if eval_cnt is None else '_'+str(eval_cnt)
|
| 488 |
+
|
| 489 |
+
if isinstance(net, (torch.nn.DataParallel,
|
| 490 |
+
torch.nn.parallel.DistributedDataParallel)):
|
| 491 |
+
netm = net.module
|
| 492 |
+
else:
|
| 493 |
+
netm = net
|
| 494 |
+
|
| 495 |
+
with_ema = getattr(netm, 'model_ema', None) is not None
|
| 496 |
+
sampler = self.sampler(netm)
|
| 497 |
+
setattr(netm, 'device', LRANK) # Trick
|
| 498 |
+
|
| 499 |
+
color_adj = cfguh().cfg.eval.get('color_adj', False)
|
| 500 |
+
color_adj_keep_ratio = cfguh().cfg.eval.get('color_adj_keep_ratio', 0.5)
|
| 501 |
+
color_adj_simple = cfguh().cfg.eval.get('color_adj_simple', True)
|
| 502 |
+
|
| 503 |
+
replicate = cfgv.get('replicate', 1)
|
| 504 |
+
conditioning = cfgv.conditioning * replicate
|
| 505 |
+
conditioning_local = conditioning[LRANK : len(conditioning) : LWSIZE]
|
| 506 |
+
seed_increment = [i for i in range(len(conditioning))][LRANK : len(conditioning) : LWSIZE]
|
| 507 |
+
|
| 508 |
+
for ci, seedi in zip(conditioning_local, seed_increment):
|
| 509 |
+
if ci == 'SKIP':
|
| 510 |
+
continue
|
| 511 |
+
|
| 512 |
+
if osp.isfile(ci):
|
| 513 |
+
# is vision
|
| 514 |
+
draw_filename = 'v2i_' + osp.splitext(osp.basename(ci))[0]
|
| 515 |
+
ci = tvtrans.ToTensor()(PIL.Image.open(ci))
|
| 516 |
+
ci = ci*2 - 1
|
| 517 |
+
ci = ('vision', ci)
|
| 518 |
+
else:
|
| 519 |
+
# is prompt
|
| 520 |
+
draw_filename = 't2i_' + ci.strip().replace(' ', '-')
|
| 521 |
+
ci = ('prompt', ci)
|
| 522 |
+
|
| 523 |
+
if fix_seed:
|
| 524 |
+
np.random.seed(cfg.env.rnd_seed + seedi)
|
| 525 |
+
torch.manual_seed(cfg.env.rnd_seed + seedi + 100)
|
| 526 |
+
suffixi = suffix + "_seed{}".format(cfg.env.rnd_seed + seedi + 100)
|
| 527 |
+
else:
|
| 528 |
+
suffixi = suffix
|
| 529 |
+
|
| 530 |
+
if with_ema:
|
| 531 |
+
with netm.ema_scope():
|
| 532 |
+
x, _ = self.sample(netm, sampler, ci, **cfgv.sample)
|
| 533 |
+
else:
|
| 534 |
+
x, _ = self.sample(netm, sampler, ci, **cfgv.sample)
|
| 535 |
+
|
| 536 |
+
demo_image = latent2im(netm, x)
|
| 537 |
+
if color_adj and ci[0] == 'vision':
|
| 538 |
+
x_adj = []
|
| 539 |
+
for demoi in demo_image:
|
| 540 |
+
color_adj_f = color_adjust(ref_from=demoi, ref_to=ci[1])
|
| 541 |
+
xi_adj = color_adj_f(demoi, keep=color_adj_keep_ratio, simple=color_adj_simple)
|
| 542 |
+
x_adj.append(xi_adj)
|
| 543 |
+
demo_image = x_adj
|
| 544 |
+
self.save_images(demo_image, draw_filename, image_path, suffix=suffixi)
|
| 545 |
+
|
| 546 |
+
if eval_cnt is not None:
|
| 547 |
+
print_log('Demo printed for {}'.format(eval_cnt))
|
| 548 |
+
return {}
|
| 549 |
+
|
versatile_diffusion/lib/model_zoo/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .common.get_model import get_model
|
| 2 |
+
from .common.get_optimizer import get_optimizer
|
| 3 |
+
from .common.get_scheduler import get_scheduler
|
| 4 |
+
from .common.utils import get_unit
|
versatile_diffusion/lib/model_zoo/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (400 Bytes). View file
|
|
|
versatile_diffusion/lib/model_zoo/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (360 Bytes). View file
|
|
|
versatile_diffusion/lib/model_zoo/__pycache__/attention.cpython-310.pyc
ADDED
|
Binary file (12.5 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/__pycache__/attention.cpython-38.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/__pycache__/autoencoder.cpython-310.pyc
ADDED
|
Binary file (5.63 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/__pycache__/autoencoder.cpython-38.pyc
ADDED
|
Binary file (5.61 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/__pycache__/clip.cpython-310.pyc
ADDED
|
Binary file (7.26 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/__pycache__/clip.cpython-38.pyc
ADDED
|
Binary file (7.22 kB). View file
|
|
|