diff --git a/.gitattributes b/.gitattributes index a3f022eff87781e8ea8a27a031ceb983ea695449..434cb378bea80d1665ecf79708f93598a087ff7f 100644 --- a/.gitattributes +++ b/.gitattributes @@ -2984,3 +2984,4 @@ results/versatile_diffusion/subj01/97.png filter=lfs diff=lfs merge=lfs -text results/versatile_diffusion/subj01/roi/12.png filter=lfs diff=lfs merge=lfs -text results/versatile_diffusion/subj01/roi/2.png filter=lfs diff=lfs merge=lfs -text results/versatile_diffusion/subj01/roi/3.png filter=lfs diff=lfs merge=lfs -text +results/versatile_diffusion/subj01/roi/4.png filter=lfs diff=lfs merge=lfs -text diff --git a/results/versatile_diffusion/subj01/roi/4.png b/results/versatile_diffusion/subj01/roi/4.png new file mode 100644 index 0000000000000000000000000000000000000000..53318c688292863c076ea873287bcb55ab26b59d --- /dev/null +++ b/results/versatile_diffusion/subj01/roi/4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec73857ae27b4acd80809ddda20ac07a2f986075639bb301e7a9fdfe4fa0367f +size 175575 diff --git a/versatile_diffusion/lib/__pycache__/log_service.cpython-38.pyc b/versatile_diffusion/lib/__pycache__/log_service.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11560c899d910cedb2e6f051859fb8f3a5b815dc Binary files /dev/null and b/versatile_diffusion/lib/__pycache__/log_service.cpython-38.pyc differ diff --git a/versatile_diffusion/lib/data_factory/__pycache__/__init__.cpython-310.pyc b/versatile_diffusion/lib/data_factory/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73be1979ee7a87c3c833ecd4057256bb76770ee7 Binary files /dev/null and b/versatile_diffusion/lib/data_factory/__pycache__/__init__.cpython-310.pyc differ diff --git a/versatile_diffusion/lib/data_factory/__pycache__/__init__.cpython-38.pyc b/versatile_diffusion/lib/data_factory/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..750ae5db72bb6a6993b180d5f169a13358f066e8 Binary files /dev/null and b/versatile_diffusion/lib/data_factory/__pycache__/__init__.cpython-38.pyc differ diff --git a/versatile_diffusion/lib/data_factory/common/__init__.py b/versatile_diffusion/lib/data_factory/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b4a1c21f2f700f446defaa8ded685b1cbc0a909c --- /dev/null +++ b/versatile_diffusion/lib/data_factory/common/__init__.py @@ -0,0 +1,6 @@ +from .ds_base import ds_base, collate, register as regdataset +from .ds_loader import pre_loader_checkings, register as regloader +from .ds_transform import TBase, have, register as regtrans +from .ds_estimator import register as regestmat +from .ds_formatter import register as regformat +from .ds_sampler import register as regsampler diff --git a/versatile_diffusion/lib/data_factory/common/__pycache__/__init__.cpython-310.pyc b/versatile_diffusion/lib/data_factory/common/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38aea3ecd972fd4129eca76ef270e9d98b6a71f3 Binary files /dev/null and b/versatile_diffusion/lib/data_factory/common/__pycache__/__init__.cpython-310.pyc differ diff --git a/versatile_diffusion/lib/data_factory/common/__pycache__/__init__.cpython-38.pyc b/versatile_diffusion/lib/data_factory/common/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c24ee76c292730f56478dd07dd3267e672c2f7bb Binary files /dev/null and b/versatile_diffusion/lib/data_factory/common/__pycache__/__init__.cpython-38.pyc differ diff --git a/versatile_diffusion/lib/data_factory/common/__pycache__/ds_base.cpython-310.pyc b/versatile_diffusion/lib/data_factory/common/__pycache__/ds_base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3a7d25e74d2e71eec5fc7dd9bc1a91c89e896f9 Binary files /dev/null and b/versatile_diffusion/lib/data_factory/common/__pycache__/ds_base.cpython-310.pyc differ diff --git a/versatile_diffusion/lib/data_factory/common/__pycache__/ds_base.cpython-38.pyc b/versatile_diffusion/lib/data_factory/common/__pycache__/ds_base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6ba1ce89a5e326a03789c37e1eb74a8b9f8c772 Binary files /dev/null and b/versatile_diffusion/lib/data_factory/common/__pycache__/ds_base.cpython-38.pyc differ diff --git a/versatile_diffusion/lib/data_factory/common/__pycache__/ds_estimator.cpython-310.pyc b/versatile_diffusion/lib/data_factory/common/__pycache__/ds_estimator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3cd88389c08b9a9e681a55a3901302c68879847c Binary files /dev/null and b/versatile_diffusion/lib/data_factory/common/__pycache__/ds_estimator.cpython-310.pyc differ diff --git a/versatile_diffusion/lib/data_factory/common/__pycache__/ds_estimator.cpython-38.pyc b/versatile_diffusion/lib/data_factory/common/__pycache__/ds_estimator.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f38dc31cce63b93d65696b9241489fe70c2b9617 Binary files /dev/null and b/versatile_diffusion/lib/data_factory/common/__pycache__/ds_estimator.cpython-38.pyc differ diff --git a/versatile_diffusion/lib/data_factory/common/__pycache__/ds_formatter.cpython-310.pyc b/versatile_diffusion/lib/data_factory/common/__pycache__/ds_formatter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2ca0ba07402f929c45f9a2111e8dade60a03d95 Binary files /dev/null and b/versatile_diffusion/lib/data_factory/common/__pycache__/ds_formatter.cpython-310.pyc differ diff --git a/versatile_diffusion/lib/data_factory/common/__pycache__/ds_formatter.cpython-38.pyc b/versatile_diffusion/lib/data_factory/common/__pycache__/ds_formatter.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fea32ab1a9cd8ae8c94e8cb64b6a2f2139daeda Binary files /dev/null and b/versatile_diffusion/lib/data_factory/common/__pycache__/ds_formatter.cpython-38.pyc differ diff --git a/versatile_diffusion/lib/data_factory/common/__pycache__/ds_loader.cpython-310.pyc b/versatile_diffusion/lib/data_factory/common/__pycache__/ds_loader.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa74d9a88382d4a0a337888102b86bd06c3eff95 Binary files /dev/null and b/versatile_diffusion/lib/data_factory/common/__pycache__/ds_loader.cpython-310.pyc differ diff --git a/versatile_diffusion/lib/data_factory/common/__pycache__/ds_loader.cpython-38.pyc b/versatile_diffusion/lib/data_factory/common/__pycache__/ds_loader.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46674ae9e4b9a3ca756d565f17de4c4cbeb47970 Binary files /dev/null and b/versatile_diffusion/lib/data_factory/common/__pycache__/ds_loader.cpython-38.pyc differ diff --git a/versatile_diffusion/lib/data_factory/common/__pycache__/ds_sampler.cpython-310.pyc b/versatile_diffusion/lib/data_factory/common/__pycache__/ds_sampler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e41a9061268267c9b114e4070b00f1097755e046 Binary files /dev/null and b/versatile_diffusion/lib/data_factory/common/__pycache__/ds_sampler.cpython-310.pyc differ diff --git a/versatile_diffusion/lib/data_factory/common/__pycache__/ds_sampler.cpython-38.pyc b/versatile_diffusion/lib/data_factory/common/__pycache__/ds_sampler.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..191829883886d6388cb1ec4e93f7b2ddceed095f Binary files /dev/null and b/versatile_diffusion/lib/data_factory/common/__pycache__/ds_sampler.cpython-38.pyc differ diff --git a/versatile_diffusion/lib/data_factory/common/__pycache__/ds_transform.cpython-310.pyc b/versatile_diffusion/lib/data_factory/common/__pycache__/ds_transform.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ccea4b390a7a20e58aefbc31736dca7a9a58bb2 Binary files /dev/null and b/versatile_diffusion/lib/data_factory/common/__pycache__/ds_transform.cpython-310.pyc differ diff --git a/versatile_diffusion/lib/data_factory/common/__pycache__/ds_transform.cpython-38.pyc b/versatile_diffusion/lib/data_factory/common/__pycache__/ds_transform.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bd5970214fd7ed34d98060227ab02ff127f896a Binary files /dev/null and b/versatile_diffusion/lib/data_factory/common/__pycache__/ds_transform.cpython-38.pyc differ diff --git a/versatile_diffusion/lib/data_factory/common/ds_base.py b/versatile_diffusion/lib/data_factory/common/ds_base.py new file mode 100644 index 0000000000000000000000000000000000000000..cce61a1a906373bf6e10a4a5abe8336c670dd78c --- /dev/null +++ b/versatile_diffusion/lib/data_factory/common/ds_base.py @@ -0,0 +1,280 @@ +import os +import os.path as osp +import numpy as np +import numpy.random as npr +import torch +import torch.distributed as dist +import torchvision +import copy +import itertools + +from ... import sync +from ...cfg_holder import cfg_unique_holder as cfguh +from ...log_service import print_log + +import torch.distributed as dist +from multiprocessing import shared_memory + +# import multiprocessing +# if hasattr(multiprocessing, "shared_memory"): +# from multiprocessing import shared_memory +# else: +# # workaround for single gpu inference on colab +# shared_memory = None + +import pickle +import hashlib +import random + +class ds_base(torch.utils.data.Dataset): + def __init__(self, + cfg, + loader = None, + estimator = None, + transforms = None, + formatter = None): + + self.cfg = cfg + self.load_info = None + self.init_load_info() + self.loader = loader + self.transforms = transforms + self.formatter = formatter + + if self.load_info is not None: + load_info_order_by = getattr(self.cfg, 'load_info_order_by', 'default') + if load_info_order_by == 'default': + self.load_info = sorted(self.load_info, key=lambda x:x['unique_id']) + else: + try: + load_info_order_by, reverse = load_info_order_by.split('|') + reverse = reverse == 'reverse' + except: + reverse = False + self.load_info = sorted( + self.load_info, key=lambda x:x[load_info_order_by], reverse=reverse) + + load_info_add_idx = getattr(self.cfg, 'load_info_add_idx', True) + if (self.load_info is not None) and load_info_add_idx: + for idx, info in enumerate(self.load_info): + info['idx'] = idx + + if estimator is not None: + self.load_info = estimator(self.load_info) + + self.try_sample = getattr(self.cfg, 'try_sample', None) + if self.try_sample is not None: + try: + start, end = self.try_sample + except: + start, end = 0, self.try_sample + self.load_info = self.load_info[start:end] + + self.repeat = getattr(self.cfg, 'repeat', 1) + + pick = getattr(self.cfg, 'pick', None) + if pick is not None: + self.load_info = [i for i in self.load_info if i['filename'] in pick] + + ######### + # cache # + ######### + + self.cache_sm = getattr(self.cfg, 'cache_sm', False) + self.cache_cnt = 0 + if self.cache_sm: + self.cache_pct = getattr(self.cfg, 'cache_pct', 0) + cache_unique_id = sync.nodewise_sync().random_sync_id() + self.cache_unique_id = hashlib.sha256(pickle.dumps(cache_unique_id)).hexdigest() + self.__cache__(self.cache_pct) + + ####### + # log # + ####### + + if self.load_info is not None: + console_info = '{}: '.format(self.__class__.__name__) + console_info += 'total {} unique images, '.format(len(self.load_info)) + console_info += 'total {} unique sample. Cached {}. Repeat {} times.'.format( + len(self.load_info), self.cache_cnt, self.repeat) + else: + console_info = '{}: load_info not ready.'.format(self.__class__.__name__) + print_log(console_info) + + def init_load_info(self): + # implement by sub class + pass + + def __len__(self): + return len(self.load_info)*self.repeat + + def __cache__(self, pct): + if pct == 0: + self.cache_cnt = 0 + return + self.cache_cnt = int(len(self.load_info)*pct) + if not self.cache_sm: + for i in range(self.cache_cnt): + self.load_info[i] = self.loader(self.load_info[i]) + return + + for i in range(self.cache_cnt): + shm_name = str(self.load_info[i]['unique_id']) + '_' + self.cache_unique_id + if i % self.local_world_size == self.local_rank: + data = pickle.dumps(self.loader(self.load_info[i])) + datan = len(data) + # self.print_smname_to_file(shm_name) + shm = shared_memory.SharedMemory( + name=shm_name, create=True, size=datan) + shm.buf[0:datan] = data[0:datan] + shm.close() + self.load_info[i] = shm_name + else: + self.load_info[i] = shm_name + dist.barrier() + + def __getitem__(self, idx): + idx = idx%len(self.load_info) + # element = copy.deepcopy(self.load_info[idx]) + + # 0730 try shared memory + element = copy.deepcopy(self.load_info[idx]) + if isinstance(element, str): + shm = shared_memory.SharedMemory(name=element) + element = pickle.loads(shm.buf) + shm.close() + else: + element = copy.deepcopy(element) + element['load_info_ptr'] = self.load_info + + if idx >= self.cache_cnt: + element = self.loader(element) + if self.transforms is not None: + element = self.transforms(element) + if self.formatter is not None: + return self.formatter(element) + else: + return element + + # 0730 try shared memory + def __del__(self): + # Clean the shared memory + for infoi in self.load_info: + if isinstance(infoi, str) and (self.local_rank==0): + shm = shared_memory.SharedMemory(name=infoi) + shm.close() + shm.unlink() + + def print_smname_to_file(self, smname): + try: + log_file = cfguh().cfg.train.log_file + except: + try: + log_file = cfguh().cfg.eval.log_file + except: + raise ValueError + # a trick to use the log_file path + sm_file = log_file.replace('.log', '.smname') + with open(sm_file, 'a') as f: + f.write(smname + '\n') + +def singleton(class_): + instances = {} + def getinstance(*args, **kwargs): + if class_ not in instances: + instances[class_] = class_(*args, **kwargs) + return instances[class_] + return getinstance + +from .ds_loader import get_loader +from .ds_transform import get_transform +from .ds_estimator import get_estimator +from .ds_formatter import get_formatter + +@singleton +class get_dataset(object): + def __init__(self): + self.dataset = {} + + def register(self, ds): + self.dataset[ds.__name__] = ds + + def __call__(self, cfg): + if cfg is None: + return None + t = cfg.type + if t is None: + return None + elif t in ['laion2b', 'laion2b_dummy', + 'laion2b_webdataset', + 'laion2b_webdataset_sdofficial', ]: + from .. import ds_laion2b + elif t in ['coyo', 'coyo_dummy', + 'coyo_webdataset', ]: + from .. import ds_coyo_webdataset + elif t in ['laionart', 'laionart_dummy', + 'laionart_webdataset', ]: + from .. import ds_laionart + elif t in ['celeba']: + from .. import ds_celeba + elif t in ['div2k']: + from .. import ds_div2k + elif t in ['pafc']: + from .. import ds_pafc + elif t in ['coco_caption']: + from .. import ds_coco + else: + raise ValueError + + loader = get_loader() (cfg.get('loader' , None)) + transform = get_transform()(cfg.get('transform', None)) + estimator = get_estimator()(cfg.get('estimator', None)) + formatter = get_formatter()(cfg.get('formatter', None)) + + return self.dataset[t]( + cfg, loader, estimator, + transform, formatter) + +def register(): + def wrapper(class_): + get_dataset().register(class_) + return class_ + return wrapper + +# some other helpers + +class collate(object): + """ + Modified from torch.utils.data._utils.collate + It handle list different from the default. + List collate just by append each other. + """ + def __init__(self): + self.default_collate = \ + torch.utils.data._utils.collate.default_collate + + def __call__(self, batch): + """ + Args: + batch: [data, data] -or- [(data1, data2, ...), (data1, data2, ...)] + This function will not be used as induction function + """ + elem = batch[0] + if not (elem, (tuple, list)): + return self.default_collate(batch) + + rv = [] + # transposed + for i in zip(*batch): + if isinstance(i[0], list): + if len(i[0]) != 1: + raise ValueError + try: + i = [[self.default_collate(ii).squeeze(0)] for ii in i] + except: + pass + rvi = list(itertools.chain.from_iterable(i)) + rv.append(rvi) # list concat + else: + rv.append(self.default_collate(i)) + return rv diff --git a/versatile_diffusion/lib/data_factory/common/ds_estimator.py b/versatile_diffusion/lib/data_factory/common/ds_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..c291546565b8cbed5b5a40923f2c0af1446a0896 --- /dev/null +++ b/versatile_diffusion/lib/data_factory/common/ds_estimator.py @@ -0,0 +1,85 @@ +import os.path as osp +import numpy as np +import numpy.random as npr +import PIL +import cv2 + +import torch +import torchvision +import xml.etree.ElementTree as ET +import json +import copy +import math + +def singleton(class_): + instances = {} + def getinstance(*args, **kwargs): + if class_ not in instances: + instances[class_] = class_(*args, **kwargs) + return instances[class_] + return getinstance + +@singleton +class get_estimator(object): + def __init__(self): + self.estimator = {} + + def register(self, estimf): + self.estimator[estimf.__name__] = estimf + + def __call__(self, cfg): + if cfg is None: + return None + t = cfg.type + return self.estimator[t](**cfg.args) + +def register(): + def wrapper(class_): + get_estimator().register(class_) + return class_ + return wrapper + +@register() +class PickFileEstimator(object): + """ + This is an estimator that filter load_info + using the provided filelist + """ + def __init__(self, + filelist = None, + repeat_n = 1): + """ + Args: + filelist: a list of string gives the name of images + we would like to visualize, evaluate or train. + repeat_n: int, times these images will be repeated + """ + self.filelist = filelist + self.repeat_n = repeat_n + + def __call__(self, load_info): + load_info_new = [] + for info in load_info: + if os.path.basename(info['image_path']).split('.')[0] in self.filelist: + load_info_new.append(info) + return load_info_new * self.repeat_n + +@register() +class PickIndexEstimator(object): + """ + This is an estimator that filter load_info + using the provided indices + """ + def __init__(self, + indexlist = None, + **kwargs): + """ + Args: + indexlist: [] of int. + the indices to be filtered out. + """ + self.indexlist = indexlist + + def __call__(self, load_info): + load_info_new = [load_info[i] for i in self.indexlist] + return load_info_new diff --git a/versatile_diffusion/lib/data_factory/common/ds_formatter.py b/versatile_diffusion/lib/data_factory/common/ds_formatter.py new file mode 100644 index 0000000000000000000000000000000000000000..4e5dd64e825d648a7cd9aaec7d3f81cb0b311ec6 --- /dev/null +++ b/versatile_diffusion/lib/data_factory/common/ds_formatter.py @@ -0,0 +1,39 @@ +import os +import os.path as osp +import numpy as np +import numpy.random as npr +import torch +import cv2 +import scipy.ndimage +from PIL import Image +import copy +import gc +import itertools + +def singleton(class_): + instances = {} + def getinstance(*args, **kwargs): + if class_ not in instances: + instances[class_] = class_(*args, **kwargs) + return instances[class_] + return getinstance + +@singleton +class get_formatter(object): + def __init__(self): + self.formatter = {} + + def register(self, formatf): + self.formatter[formatf.__name__] = formatf + + def __call__(self, cfg): + if cfg is None: + return None + t = cfg.type + return self.formatter[t](**cfg.args) + +def register(): + def wrapper(class_): + get_formatter().register(class_) + return class_ + return wrapper diff --git a/versatile_diffusion/lib/data_factory/common/ds_loader.py b/versatile_diffusion/lib/data_factory/common/ds_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7d96341d0d315f3e28d58853de1e81fda99df8 --- /dev/null +++ b/versatile_diffusion/lib/data_factory/common/ds_loader.py @@ -0,0 +1,97 @@ +import os.path as osp +import numpy as np +import numpy.random as npr +import PIL +import cv2 + +import torch +import torchvision +import xml.etree.ElementTree as ET +import json +import copy + +from ...cfg_holder import cfg_unique_holder as cfguh + +def singleton(class_): + instances = {} + def getinstance(*args, **kwargs): + if class_ not in instances: + instances[class_] = class_(*args, **kwargs) + return instances[class_] + return getinstance + +@singleton +class get_loader(object): + def __init__(self): + self.loader = {} + + def register(self, loadf): + self.loader[loadf.__name__] = loadf + + def __call__(self, cfg): + if cfg is None: + return None + if isinstance(cfg, list): + loader = [] + for ci in cfg: + t = ci.type + loader.append(self.loader[t](**ci.args)) + return compose(loader) + t = cfg.type + return self.loader[t](**cfg.args) + +class compose(object): + def __init__(self, loaders): + self.loaders = loaders + + def __call__(self, element): + for l in self.loaders: + element = l(element) + return element + + def __getitem__(self, idx): + return self.loaders[idx] + +def register(): + def wrapper(class_): + get_loader().register(class_) + return class_ + return wrapper + +def pre_loader_checkings(ltype): + lpath = ltype+'_path' + # cache feature added on 20201021 + lcache = ltype+'_cache' + def wrapper(func): + def inner(self, element): + if lcache in element: + # cache feature added on 20201021 + data = element[lcache] + else: + if ltype in element: + raise ValueError + if lpath not in element: + raise ValueError + + if element[lpath] is None: + data = None + else: + data = func(self, element[lpath], element) + element[ltype] = data + + if ltype == 'image': + if isinstance(data, np.ndarray): + imsize = data.shape[-2:] + elif isinstance(data, PIL.Image.Image): + imsize = data.size[::-1] + elif isinstance(data, torch.Tensor): + imsize = [data.size(-2), data.size(-1)] + elif data is None: + imsize = None + else: + raise ValueError + element['imsize'] = imsize + element['imsize_current'] = copy.deepcopy(imsize) + return element + return inner + return wrapper diff --git a/versatile_diffusion/lib/data_factory/common/ds_sampler.py b/versatile_diffusion/lib/data_factory/common/ds_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..871c6e1b09a23058adb30aa3ca65d389b44cb9cb --- /dev/null +++ b/versatile_diffusion/lib/data_factory/common/ds_sampler.py @@ -0,0 +1,273 @@ +from tokenize import group +import torch +import numpy as np +import numpy.random as npr +import torch.distributed as dist +import math + +from ...log_service import print_log +from ... import sync + +def singleton(class_): + instances = {} + def getinstance(*args, **kwargs): + if class_ not in instances: + instances[class_] = class_(*args, **kwargs) + return instances[class_] + return getinstance + +@singleton +class get_sampler(object): + def __init__(self): + self.sampler = {} + + def register(self, sampler): + self.sampler[sampler.__name__] = sampler + + def __call__(self, dataset, cfg): + if cfg == 'default_train': + return GlobalDistributedSampler(dataset, shuffle=True, extend=False) + elif cfg == 'default_eval': + return GlobalDistributedSampler(dataset, shuffle=False, extend=True) + else: + t = cfg.type + return self.sampler[t](dataset=dataset, **cfg.args) + +def register(): + def wrapper(class_): + get_sampler().register(class_) + return class_ + return wrapper + +###################### +# DistributedSampler # +###################### + +@register() +class GlobalDistributedSampler(torch.utils.data.Sampler): + """ + This is a distributed sampler that sync accross gpus and nodes. + """ + def __init__(self, + dataset, + shuffle=True, + extend=False,): + """ + Arguments: + dataset: Dataset used for sampling. + shuffle: If true, sampler will shuffle the indices + extend: If true, sampler will extend the indices that can be even distributed by ranks + otherwise sampler will truncate the indices to make it even. + """ + self.ddp = sync.is_ddp() + self.rank = sync.get_rank('global') + self.world_size = sync.get_world_size('global') + self.dataset = dataset + self.shuffle = shuffle + self.extend = extend + + num_samples = len(dataset) // self.world_size + if extend and (len(dataset)%self.world_size != 0): + num_samples+=1 + self.num_samples = num_samples + self.total_size = num_samples * self.world_size + + def __iter__(self): + indices = self.get_sync_order() + if self.extend: + # extend using the front indices + indices = indices+indices[0:self.total_size-len(indices)] + else: + # truncate + indices = indices[0:self.total_size] + # subsample + indices = indices[self.rank : len(indices) : self.world_size] + return iter(indices) + + def __len__(self): + return self.num_samples + + def get_sync_order(self): + if self.shuffle: + indices = torch.randperm(len(self.dataset)).to(self.rank) + if self.ddp: + dist.broadcast(indices, src=0) + indices = indices.to('cpu').tolist() + else: + indices = list(range(len(self.dataset))) + print_log('Sampler : {}'.format(str(indices[0:5])) ) + return indices + +@register() +class LocalDistributedSampler(GlobalDistributedSampler): + """ + This is a distributed sampler that sync across gpus within the nodes. + But not sync across nodes. + """ + def __init__(self, + dataset, + shuffle=True, + extend=False,): + super().__init__(dataset, shuffle, extend) + self.rank = sync.get_rank('local') + self.world_size = sync.get_world_size('local') + + def get_sync_order(self): + if self.shuffle: + if self.rank == 0: + indices = list(npr.permutation(len(self.dataset))) + sync.nodewise_sync().broadcast_r0(indices) + else: + indices = sync.nodewise_sync().broadcast_r0(None) + else: + indices = list(range(len(self.dataset))) + print_log('Sampler : {}'.format(str(indices[0:5])) ) + return indices + +############################ +# random sample with group # +############################ +# Deprecated + +@register() +class GroupSampler(torch.utils.data.Sampler): + """ + This is a new DistributedSampler that sample all index according to group. + i.e. + if group_size=3, num_replicas=2, train mode: + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 + ==> (group) [0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10] + ==> (distribute) process0: [3, 4, 5], (leftover [6, 7, 8, 9, 10]) + process1: [0, 1, 2] + ==> (group leftover) process0: [3, 4, 5], (leftover [6, 7], [8, 9], 10) + process1: [0, 1, 2] + ==> (distribute) process0: [3, 4, 5], [6, 7] (remove 10) + process1: [0, 1, 2], [8, 9] + + it will avoid_batchsize=1: + 0, 1, 2, 3, 4, 5, 6, 7, 8, + ==> (group) [0, 1, 2], [3, 4, 5], [6, 7, 8] + ==> (distribute) process0: [3, 4, 5], (leftover [6, 7, 8]) + process1: [0, 1, 2] + ==> (group leftover) process0: [3, 4, 5], (leftover [6], [7], [8]) + process1: [0, 1, 2] + ==> (distribute) process0: [3, 4, 5], (remove 6, 7, 8) (because distribute make batchsize 1) + process1: [0, 1, 2] + + if group_size=3, num_replicas=2, eval mode: + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 + ==> (extend) 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10 + ==> (group) [0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 10] + ==> (distribute) process0: [0, 1, 2], [6, 7, 8], + process1: [3, 4, 5], [9, 10, 10] + """ + + def __init__(self, + dataset, + group_size, + num_replicas=None, + rank=None, + mode='train',): + if num_replicas is None: + if not dist.is_available(): + raise ValueError + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise ValueError + rank = dist.get_rank() + + self.dataset = dataset + self.len_dataset = len(dataset) + self.group_size = group_size + self.num_replicas = num_replicas + self.rank = rank + self.mode = mode + len_dataset = self.len_dataset + + if (len_dataset % num_replicas != 0) and (mode == 'train'): + # drop the non_aligned + aligned_indices = np.arange(len_dataset)[:-(len_dataset % num_replicas)] + aligned_len_dataset = aligned_indices.shape[0] + elif (len_dataset % num_replicas != 0) and (mode == 'eval'): + extend = np.array([len_dataset-1 for _ in range(num_replicas - len_dataset % num_replicas)]) + aligned_indices = np.concatenate([range(len_dataset), extend]) + aligned_len_dataset = aligned_indices.shape[0] + else: + aligned_indices = np.arange(len_dataset) + aligned_len_dataset = len_dataset + + num_even_distributed_groups = aligned_len_dataset // (group_size * num_replicas) + num_even = num_even_distributed_groups * group_size * num_replicas + + self.regular_groups = aligned_indices[0:num_even].reshape(-1, group_size) + self.leftover_groups = aligned_indices[num_even:].reshape(num_replicas, -1) + + if self.leftover_groups.size == 0: + self.leftover_groups = None + elif (self.leftover_groups.shape[-1]==1) and (mode == 'train'): + # avoid bs=1 + self.leftover_groups = None + + # a urly way to modify dataset.load_info according to the grouping + for groupi in self.regular_groups: + for idx in groupi: + idx_lowerbd = groupi[0] + idx_upperbd = groupi[-1] + idx_reference = (idx_lowerbd+idx_upperbd)//2 + dataset.load_info[idx]['ref_size'] = dataset.load_info[idx_reference]['image_size'] + if self.leftover_groups is not None: + for groupi in self.leftover_groups: + for idx in groupi: + idx_lowerbd = groupi[0] + idx_upperbd = groupi[-1] + idx_reference = (idx_lowerbd+idx_upperbd)//2 + dataset.load_info[idx]['ref_size'] = dataset.load_info[idx_reference]['image_size'] + + def concat(self, nparrays, axis=0): + # a helper for save concaternation + nparrays = [i for i in nparrays if i.size > 0] + return np.concatenate(nparrays, axis=axis) + + def __iter__(self): + indices = self.get_sync_order() + return iter(indices) + + def __len__(self): + return self.num_samples + + def get_sync_order(self): + # g = torch.Generator() + # g.manual_seed(self.epoch) + + mode = self.mode + rank = self.rank + num_replicas = self.num_replicas + group_size = self.group_size + num_groups = len(self.regular_groups) + + if mode == 'train': + g_indices = torch.randperm(num_groups).to(rank) + dist.broadcast(g_indices, src=0) + g_indices = g_indices.to('cpu').tolist() + num_groups_per_rank = num_groups // num_replicas + groups = self.regular_groups[g_indices][num_groups_per_rank*rank : num_groups_per_rank*(rank+1)] + indices = groups.flatten() + + if self.leftover_groups is not None: + leftg_indices = torch.randperm(len(self.leftover_groups)).to(rank) + dist.broadcast(leftg_indices, src=0) + leftg_indices = leftg_indices.to('cpu').tolist() + last = self.leftover_groups[leftg_indices][rank] + indices = np.concatenate([indices, last], axis=0) + elif mode == 'eval': + groups = self.regular_groups.reshape(-1, num_replicas, group_size)[:, rank, :] + indices = groups.flatten() + if self.leftover_groups is not None: + last = self.leftover_groups[rank] + indices = np.concatenate([indices, last], axis=0) + else: + raise ValueError + + print_log('Sampler RANK {} : {}'.format(rank, str(indices[0:group_size+1]))) + return indices diff --git a/versatile_diffusion/lib/data_factory/common/ds_transform.py b/versatile_diffusion/lib/data_factory/common/ds_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..55b76e29ad6aedd902faa529a5ab3d81cbf7d2b1 --- /dev/null +++ b/versatile_diffusion/lib/data_factory/common/ds_transform.py @@ -0,0 +1,178 @@ +import os.path as osp +import numpy as np +import numpy.random as npr +import PIL +import cv2 + +import torch +import torchvision +import xml.etree.ElementTree as ET +import json +import copy +import math + +def singleton(class_): + instances = {} + def getinstance(*args, **kwargs): + if class_ not in instances: + instances[class_] = class_(*args, **kwargs) + return instances[class_] + return getinstance + +@singleton +class get_transform(object): + def __init__(self): + self.transform = {} + + def register(self, transf): + self.transform[transf.__name__] = transf + + def __call__(self, cfg): + if cfg is None: + return None + if isinstance(cfg, list): + loader = [] + for ci in cfg: + t = ci.type + loader.append(self.transform[t](**ci.args)) + return compose(loader) + t = cfg.type + return self.transform[t](**cfg.args) + +def register(): + def wrapper(class_): + get_transform().register(class_) + return class_ + return wrapper + +def have(must=[], may=[]): + """ + The nextgen decorator that have two list of + input tells what category the transform + will operate on. + Args: + must: [] of str, + the names of the items that must be included + inside the element. + If element[name] exist: do the transform + If element[name] is None: raise Exception. + If element[name] not exist: raise Exception. + may: [] of str, + the names of the items that may be contained + inside the element for transform. + If element[name] exist: do the transform + If element[name] is None: ignore it. + If element[name] not exist: ignore it. + """ + def route(self, item, e, d): + """ + Route the element to a proper function + for calculation. + Args: + self: object, + the transform functor. + item: str, + the item name of the data. + e: {}, + the element + d: nparray, tensor or PIL.Image, + the data to transform. + """ + if isinstance(d, np.ndarray): + dtype = 'nparray' + elif isinstance(d, torch.Tensor): + dtype = 'tensor' + elif isinstance(d, PIL.Image.Image): + dtype = 'pilimage' + else: + raise ValueError + + # find function by order + f = None + for attrname in [ + 'exec_{}_{}'.format(item, dtype), + 'exec_{}'.format(item), + 'exec_{}'.format(dtype), + 'exec']: + f = getattr(self, attrname, None) + if f is not None: + break + d, e = f(d, e) + e[item] = d + return e + + def wrapper(func): + def inner(self, e): + e['imsize_previous'] = e['imsize_current'] + imsize_tag_cnt = 0 + imsize_tag = 'imsize_before_' + self.__class__.__name__ + while True: + if imsize_tag_cnt != 0: + tag = imsize_tag + str(imsize_tag_cnt) + else: + tag = imsize_tag + if not tag in e: + e[tag] = e['imsize_current'] + break + imsize_tag_cnt += 1 + + e = func(self, e) + # must transform list + for item in must: + try: + d = e[item] + except: + raise ValueError + if d is None: + raise ValueError + e = route(self, item, e, d) + # may transform list + for item in may: + try: + d = e[item] + except: + d = None + if d is not None: + e = route(self, item, e, d) + return e + return inner + return wrapper + +class compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, element): + for t in self.transforms: + element = t(element) + return element + +class TBase(object): + def __init__(self): + pass + + def exec(self, data, element): + raise ValueError + + def rand(self, + uid, + tag, + rand_f, + *args, + **kwargs): + """ + Args: + uid: string element['unique_id'] + tag: string tells the tag uses when tracking the random number. + Or the tag to restore the tracked random number. + rand_f: the random function use to generate random number. + **kwargs: the argument for the given random function. + """ + # if rnduh().hdata is not None: + # return rnduh().get_history(uid, self.__class__.__name__, tag) + # if rnduh().record_path is None: + # return rand_f(*args, **kwargs) + # the special mode to create the random file. + d = rand_f(*args, **kwargs) + # rnduh().record(uid, self.__class__.__name__, tag, d) + return d diff --git a/versatile_diffusion/lib/data_factory/ds_laion2b_webdataset.py b/versatile_diffusion/lib/data_factory/ds_laion2b_webdataset.py new file mode 100644 index 0000000000000000000000000000000000000000..da689eba418068d3d9f389c607921b73f0b78ac0 --- /dev/null +++ b/versatile_diffusion/lib/data_factory/ds_laion2b_webdataset.py @@ -0,0 +1,221 @@ +import os +import os.path as osp +import numpy as np +import numpy.random as npr +import torch +import torch.distributed as dist +import torchvision.transforms as tvtrans +import PIL.Image +PIL.Image.MAX_IMAGE_PIXELS = None +import math +import json +import copy +import pickle +from multiprocessing import shared_memory +import time +from .common import * +from ..log_service import print_log + +from lib import visual_service as vis +from .. import sync + +import webdataset as wds + +################################################### +# this is a special ds that use webdataset mainly # +################################################### + +@regdataset() +class laion2b_dummy(ds_base): + def init_load_info(self): + self.load_info = [] + +@regdataset() +class laion2b_webdataset(ds_base): + def init_load_info(self): + self.load_info = [] + + def make_loader(self, batch_size, num_workers, train=True): + cfg = self.cfg + self.root_dir = cfg.root_dir + + interpolation_mode = tvtrans.InterpolationMode.BICUBIC + if train: + trans = [ + tvtrans.Resize(cfg.scale, interpolation=interpolation_mode), + tvtrans.RandomCrop(cfg.scale), + tvtrans.ToTensor(),] + else: + trans = [ + tvtrans.Resize(cfg.scale, interpolation=interpolation_mode), + tvtrans.CenterCrop(cfg.scale), + tvtrans.ToTensor(),] + + trans = tvtrans.Compose(trans) + + trans_dict = {'jpg': trans} + postprocess = customized_postprocess + + shuffle = cfg.get('shuffle', 10000) + shardshuffle = shuffle > 0 + node_world_size = sync.get_world_size('node') + nodesplitter = wds.shardlists.split_by_node \ + if node_world_size==1 else wds.shardlists.single_node_only + + tars = [osp.join(self.root_dir, 'data', i) for i in os.listdir(osp.join(self.root_dir, 'data')) + if osp.splitext(i)[1]=='.tar'] + tars = sorted(tars) + + dset = wds.WebDataset( + tars, + nodesplitter=nodesplitter, + shardshuffle=shardshuffle, + handler=wds.warn_and_continue).repeat().shuffle(shuffle) + + print_log(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.') + self.min_size = cfg.get('min_size', None) + self.max_pwatermark = cfg.get('max_pwatermark', None) + dset = (dset + .select(self.filter_keys) + .decode('pil', handler=wds.warn_and_continue) + .select(self.filter_size) + .map_dict(**trans_dict, handler=wds.warn_and_continue)) + + if postprocess is not None: + dset = dset.map(postprocess) + + dset.batched(batch_size, partial=False) + + loader = wds.WebLoader( + dset, + batch_size=None, + shuffle=False, + num_workers=num_workers, ) + return loader + + def filter_size(self, x): + try: + valid = True + if self.min_size is not None and self.min_size > 1: + try: + valid = valid and x['json']['original_width'] >= self.min_size and \ + x['json']['original_height'] >= self.min_size + except Exception: + valid = False + if self.max_pwatermark is not None and self.max_pwatermark < 1.0: + try: + valid = valid and x['json']['pwatermark'] <= self.max_pwatermark + except Exception: + valid = False + return valid + except Exception: + return False + + def filter_keys(self, x): + try: + return ("jpg" in x) and ("txt" in x) + except Exception: + return False + + def train_dataloader(self): + return self.make_loader(self.train) + + def val_dataloader(self): + return self.make_loader(self.validation, train=False) + + def test_dataloader(self): + return self.make_loader(self.test, train=False) + +def customized_postprocess(element): + return element['jpg']*2-1, element['txt'], element['__key__'] + +def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True): + keys = set.intersection(*[set(sample.keys()) for sample in samples]) + batched = {key: [] for key in keys} + + for s in samples: + [batched[key].append(s[key]) for key in batched] + + result = {} + for key in batched: + if isinstance(batched[key][0], (int, float)): + if combine_scalars: + result[key] = np.array(list(batched[key])) + elif isinstance(batched[key][0], torch.Tensor): + if combine_tensors: + result[key] = torch.stack(list(batched[key])) + elif isinstance(batched[key][0], np.ndarray): + if combine_tensors: + result[key] = np.array(list(batched[key])) + else: + result[key] = list(batched[key]) + return result + +################### +# for sd official # +################### + +def customized_postprocess_sdofficial(element): + return { + 'jpg': element['jpg']*2-1, + 'txt': element['txt'], } + +@regdataset() +class laion2b_webdataset_sdofficial(laion2b_webdataset): + def make_loader(self, batch_size, num_workers, train=True): + cfg = self.cfg + self.root_dir = cfg.root_dir + + interpolation_mode = tvtrans.InterpolationMode.BICUBIC + if train: + trans = [ + tvtrans.Resize(cfg.scale, interpolation=interpolation_mode), + tvtrans.RandomCrop(cfg.scale), + tvtrans.ToTensor(),] + else: + trans = [ + tvtrans.Resize(cfg.scale, interpolation=interpolation_mode), + tvtrans.CenterCrop(cfg.scale), + tvtrans.ToTensor(),] + + trans = tvtrans.Compose(trans) + + trans_dict = {'jpg': trans} + postprocess = customized_postprocess_sdofficial + + shuffle = 10000 + shardshuffle = shuffle > 0 + node_world_size = 1 + nodesplitter = wds.shardlists.split_by_node \ + if node_world_size==1 else wds.shardlists.single_node_only + + tars = [osp.join(self.root_dir, 'data', i) for i in os.listdir(osp.join(self.root_dir, 'data')) + if osp.splitext(i)[1]=='.tar'] + tars = sorted(tars) + + dset = wds.WebDataset( + tars, + nodesplitter=nodesplitter, + shardshuffle=shardshuffle, + handler=wds.warn_and_continue).repeat().shuffle(shuffle) + + print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.') + self.min_size = cfg.get('min_size', None) + self.max_pwatermark = cfg.get('max_pwatermark', None) + dset = (dset + .select(self.filter_keys) + .decode('pil', handler=wds.warn_and_continue) + .select(self.filter_size) + .map_dict(**trans_dict, handler=wds.warn_and_continue)) + + if postprocess is not None: + dset = dset.map(postprocess) + + dset.batched(batch_size, partial=False, collation_fn=dict_collation_fn) + + loader = wds.WebLoader( + dset, + batch_size=None, + shuffle=False, + num_workers=num_workers, ) + return loader diff --git a/versatile_diffusion/lib/evaluator/__init__.py b/versatile_diffusion/lib/evaluator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..947850b4c6b49834ac187f9021f87f5579499827 --- /dev/null +++ b/versatile_diffusion/lib/evaluator/__init__.py @@ -0,0 +1 @@ +from .eva_base import get_evaluator \ No newline at end of file diff --git a/versatile_diffusion/lib/evaluator/__pycache__/__init__.cpython-310.pyc b/versatile_diffusion/lib/evaluator/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36f6d82fac3cedf00114085c1a100ace52f01a5d Binary files /dev/null and b/versatile_diffusion/lib/evaluator/__pycache__/__init__.cpython-310.pyc differ diff --git a/versatile_diffusion/lib/evaluator/__pycache__/__init__.cpython-38.pyc b/versatile_diffusion/lib/evaluator/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79e70dbf8390923a3c0367cfc098255fde9ce52b Binary files /dev/null and b/versatile_diffusion/lib/evaluator/__pycache__/__init__.cpython-38.pyc differ diff --git a/versatile_diffusion/lib/evaluator/__pycache__/eva_base.cpython-310.pyc b/versatile_diffusion/lib/evaluator/__pycache__/eva_base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20ce80bec1d55a6bc637561a0b57093bf854661f Binary files /dev/null and b/versatile_diffusion/lib/evaluator/__pycache__/eva_base.cpython-310.pyc differ diff --git a/versatile_diffusion/lib/evaluator/__pycache__/eva_base.cpython-38.pyc b/versatile_diffusion/lib/evaluator/__pycache__/eva_base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1efc3948b339c09df8638c8a6da9b979bfcb14f0 Binary files /dev/null and b/versatile_diffusion/lib/evaluator/__pycache__/eva_base.cpython-38.pyc differ diff --git a/versatile_diffusion/lib/evaluator/eva_base.py b/versatile_diffusion/lib/evaluator/eva_base.py new file mode 100644 index 0000000000000000000000000000000000000000..978685b2082a6eaab978da578ff14e4255e51a5c --- /dev/null +++ b/versatile_diffusion/lib/evaluator/eva_base.py @@ -0,0 +1,293 @@ +import torch +import torch.distributed as dist + +import os +import os.path as osp +import numpy as np +import cv2 +import copy +import json + +from ..log_service import print_log + +def singleton(class_): + instances = {} + def getinstance(*args, **kwargs): + if class_ not in instances: + instances[class_] = class_(*args, **kwargs) + return instances[class_] + return getinstance + +@singleton +class get_evaluator(object): + def __init__(self): + self.evaluator = {} + + def register(self, evaf, name): + self.evaluator[name] = evaf + + def __call__(self, pipeline_cfg=None): + if pipeline_cfg is None: + from . import eva_null + return self.evaluator['null']() + + if not isinstance(pipeline_cfg, list): + t = pipeline_cfg.type + if t == 'miou': + from . import eva_miou + if t == 'psnr': + from . import eva_psnr + if t == 'ssim': + from . import eva_ssim + if t == 'lpips': + from . import eva_lpips + if t == 'fid': + from . import eva_fid + return self.evaluator[t](**pipeline_cfg.args) + + evaluator = [] + for ci in pipeline_cfg: + t = ci.type + if t == 'miou': + from . import eva_miou + if t == 'psnr': + from . import eva_psnr + if t == 'ssim': + from . import eva_ssim + if t == 'lpips': + from . import eva_lpips + if t == 'fid': + from . import eva_fid + evaluator.append( + self.evaluator[t](**ci.args)) + if len(evaluator) == 0: + return None + else: + return compose(evaluator) + +def register(name): + def wrapper(class_): + get_evaluator().register(class_, name) + return class_ + return wrapper + +class base_evaluator(object): + def __init__(self, + **args): + ''' + Args: + sample_n, int, + the total number of sample. used in + distributed sync + ''' + if not dist.is_available(): + raise ValueError + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() + self.sample_n = None + self.final = {} + + def sync(self, data): + """ + Args: + data: any, + the data needs to be broadcasted + """ + if data is None: + return None + + if isinstance(data, tuple): + data = list(data) + + if isinstance(data, list): + data_list = [] + for datai in data: + data_list.append(self.sync(datai)) + data = [[*i] for i in zip(*data_list)] + return data + + data = [ + self.sync_(data, ranki) + for ranki in range(self.world_size) + ] + return data + + def sync_(self, data, rank): + + t = type(data) + is_broadcast = rank == self.rank + + if t is np.ndarray: + dtrans = data + dt = data.dtype + if dt in [ + int, + np.bool, + np.uint8, + np.int8, + np.int16, + np.int32, + np.int64,]: + dtt = torch.int64 + elif dt in [ + float, + np.float16, + np.float32, + np.float64,]: + dtt = torch.float64 + + elif t is str: + dtrans = np.array( + [ord(c) for c in data], + dtype = np.int64 + ) + dt = np.int64 + dtt = torch.int64 + else: + raise ValueError + + if is_broadcast: + n = len(dtrans.shape) + n = torch.tensor(n).long() + + n = n.to(self.rank) + dist.broadcast(n, src=rank) + + n = list(dtrans.shape) + n = torch.tensor(n).long() + n = n.to(self.rank) + dist.broadcast(n, src=rank) + + n = torch.tensor(dtrans, dtype=dtt) + n = n.to(self.rank) + dist.broadcast(n, src=rank) + return data + + n = torch.tensor(0).long() + n = n.to(self.rank) + dist.broadcast(n, src=rank) + n = n.item() + + n = torch.zeros(n).long() + n = n.to(self.rank) + dist.broadcast(n, src=rank) + n = list(n.to('cpu').numpy()) + + n = torch.zeros(n, dtype=dtt) + n = n.to(self.rank) + dist.broadcast(n, src=rank) + n = n.to('cpu').numpy().astype(dt) + + if t is np.ndarray: + return n + elif t is str: + n = ''.join([chr(c) for c in n]) + return n + + def zipzap_arrange(self, data): + ''' + Order the data so it range like this: + input [[0, 2, 4, 6], [1, 3, 5, 7]] -> output [0, 1, 2, 3, 4, 5, ...] + ''' + if isinstance(data[0], list): + data_new = [] + maxlen = max([len(i) for i in data]) + totlen = sum([len(i) for i in data]) + cnt = 0 + for idx in range(maxlen): + for datai in data: + data_new += [datai[idx]] + cnt += 1 + if cnt >= totlen: + break + return data_new + + elif isinstance(data[0], np.ndarray): + maxlen = max([i.shape[0] for i in data]) + totlen = sum([i.shape[0] for i in data]) + datai_shape = data[0].shape[1:] + data = [ + np.concatenate(datai, np.zeros(maxlen-datai.shape[0], *datai_shape), axis=0) + if datai.shape[0] < maxlen else datai + for datai in data + ] # even the array + data = np.stack(data, axis=1).reshape(-1, *datai_shape) + data = data[:totlen] + return data + + else: + raise NotImplementedError + + def add_batch(self, **args): + raise NotImplementedError + + def set_sample_n(self, sample_n): + self.sample_n = sample_n + + def compute(self): + raise NotImplementedError + + # Function needed in training to judge which + # evaluated number is better + def isbetter(self, old, new): + return new>old + + def one_line_summary(self): + print_log('Evaluator display') + + def save(self, path): + if not osp.exists(path): + os.makedirs(path) + ofile = osp.join(path, 'result.json') + with open(ofile, 'w') as f: + json.dump(self.final, f, indent=4) + + def clear_data(self): + raise NotImplementedError + +class compose(object): + def __init__(self, pipeline): + self.pipeline = pipeline + self.sample_n = None + self.final = {} + + def add_batch(self, *args, **kwargs): + for pi in self.pipeline: + pi.add_batch(*args, **kwargs) + + def set_sample_n(self, sample_n): + self.sample_n = sample_n + for pi in self.pipeline: + pi.set_sample_n(sample_n) + + def compute(self): + rv = {} + for pi in self.pipeline: + rv[pi.symbol] = pi.compute() + self.final[pi.symbol] = pi.final + return rv + + def isbetter(self, old, new): + check = 0 + for pi in self.pipeline: + if pi.isbetter(old, new): + check+=1 + if check/len(self.pipeline)>0.5: + return True + else: + return False + + def one_line_summary(self): + for pi in self.pipeline: + pi.one_line_summary() + + def save(self, path): + if not osp.exists(path): + os.makedirs(path) + ofile = osp.join(path, 'result.json') + with open(ofile, 'w') as f: + json.dump(self.final, f, indent=4) + + def clear_data(self): + for pi in self.pipeline: + pi.clear_data() diff --git a/versatile_diffusion/lib/evaluator/eva_null.py b/versatile_diffusion/lib/evaluator/eva_null.py new file mode 100644 index 0000000000000000000000000000000000000000..b018ee69abff41dacb424461de5137beb612262b --- /dev/null +++ b/versatile_diffusion/lib/evaluator/eva_null.py @@ -0,0 +1,26 @@ +import torch +import numpy as np +import lpips + +from .. import nputils +from ..log_service import print_log + +from .eva_base import base_evaluator, register + +@register('null') +class null_evaluator(base_evaluator): + def __init__(self, **dummy): + super().__init__() + + def add_batch(self, + **dummy): + pass + + def compute(self): + return None + + def one_line_summary(self): + print_log('Evaluator null') + + def clear_data(self): + pass \ No newline at end of file diff --git a/versatile_diffusion/lib/experiments/__init__.py b/versatile_diffusion/lib/experiments/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/versatile_diffusion/lib/experiments/__pycache__/__init__.cpython-310.pyc b/versatile_diffusion/lib/experiments/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3757737ae9fb4c6edaf1cbc9daf2870f07babbcb Binary files /dev/null and b/versatile_diffusion/lib/experiments/__pycache__/__init__.cpython-310.pyc differ diff --git a/versatile_diffusion/lib/experiments/__pycache__/__init__.cpython-38.pyc b/versatile_diffusion/lib/experiments/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cff7ecc4ded669389e55b4c635576439b4dd7c8 Binary files /dev/null and b/versatile_diffusion/lib/experiments/__pycache__/__init__.cpython-38.pyc differ diff --git a/versatile_diffusion/lib/experiments/__pycache__/sd_default.cpython-310.pyc b/versatile_diffusion/lib/experiments/__pycache__/sd_default.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99f048ad83788225788679a4a5f3762a1dcae284 Binary files /dev/null and b/versatile_diffusion/lib/experiments/__pycache__/sd_default.cpython-310.pyc differ diff --git a/versatile_diffusion/lib/experiments/__pycache__/sd_default.cpython-38.pyc b/versatile_diffusion/lib/experiments/__pycache__/sd_default.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f69749c39ddfb3b2702a09c1366d54fcabcba60 Binary files /dev/null and b/versatile_diffusion/lib/experiments/__pycache__/sd_default.cpython-38.pyc differ diff --git a/versatile_diffusion/lib/experiments/sd_default.py b/versatile_diffusion/lib/experiments/sd_default.py new file mode 100644 index 0000000000000000000000000000000000000000..3a3e0b1c737f7cd6d4cbf585a92a62c1bebe8a73 --- /dev/null +++ b/versatile_diffusion/lib/experiments/sd_default.py @@ -0,0 +1,441 @@ +import torch +import torch.distributed as dist +from torchvision import transforms as tvtrans +import os +import os.path as osp +import time +import timeit +import copy +import json +import pickle +import PIL.Image +import numpy as np +from datetime import datetime +from easydict import EasyDict as edict +from collections import OrderedDict + +from lib.cfg_holder import cfg_unique_holder as cfguh +from lib.data_factory import get_dataset, get_sampler, collate +from lib.model_zoo import \ + get_model, get_optimizer, get_scheduler +from lib.log_service import print_log + +from ..utils import train as train_base +from ..utils import eval as eval_base +from ..utils import train_stage as tsbase +from ..utils import eval_stage as esbase +from .. import sync + +############### +# some helper # +############### + +def atomic_save(cfg, net, opt, step, path): + if isinstance(net, (torch.nn.DataParallel, + torch.nn.parallel.DistributedDataParallel)): + netm = net.module + else: + netm = net + sd = netm.state_dict() + slimmed_sd = [(ki, vi) for ki, vi in sd.items() + if ki.find('first_stage_model')!=0 and ki.find('cond_stage_model')!=0] + + checkpoint = { + "config" : cfg, + "state_dict" : OrderedDict(slimmed_sd), + "step" : step} + if opt is not None: + checkpoint['optimizer_states'] = opt.state_dict() + import io + import fsspec + bytesbuffer = io.BytesIO() + torch.save(checkpoint, bytesbuffer) + with fsspec.open(path, "wb") as f: + f.write(bytesbuffer.getvalue()) + +def load_state_dict(net, cfg): + pretrained_pth_full = cfg.get('pretrained_pth_full' , None) + pretrained_ckpt_full = cfg.get('pretrained_ckpt_full', None) + pretrained_pth = cfg.get('pretrained_pth' , None) + pretrained_ckpt = cfg.get('pretrained_ckpt' , None) + pretrained_pth_dm = cfg.get('pretrained_pth_dm' , None) + pretrained_pth_ema = cfg.get('pretrained_pth_ema' , None) + strict_sd = cfg.get('strict_sd', False) + errmsg = "Overlapped model state_dict! This is undesired behavior!" + + if pretrained_pth_full is not None or pretrained_ckpt_full is not None: + assert (pretrained_pth is None) and \ + (pretrained_ckpt is None) and \ + (pretrained_pth_dm is None) and \ + (pretrained_pth_ema is None), errmsg + if pretrained_pth_full is not None: + target_file = pretrained_pth_full + sd = torch.load(target_file, map_location='cpu') + assert pretrained_ckpt is None, errmsg + else: + target_file = pretrained_ckpt_full + sd = torch.load(target_file, map_location='cpu')['state_dict'] + print_log('Load full model from [{}] strict [{}].'.format( + target_file, strict_sd)) + net.load_state_dict(sd, strict=strict_sd) + + if pretrained_pth is not None or pretrained_ckpt is not None: + assert (pretrained_ckpt_full is None) and \ + (pretrained_pth_full is None) and \ + (pretrained_pth_dm is None) and \ + (pretrained_pth_ema is None), errmsg + if pretrained_pth is not None: + target_file = pretrained_pth + sd = torch.load(target_file, map_location='cpu') + assert pretrained_ckpt is None, errmsg + else: + target_file = pretrained_ckpt + sd = torch.load(target_file, map_location='cpu')['state_dict'] + print_log('Load model from [{}] strict [{}].'.format( + target_file, strict_sd)) + sd_extra = [(ki, vi) for ki, vi in net.state_dict().items() \ + if ki.find('first_stage_model')==0 or ki.find('cond_stage_model')==0] + sd.update(OrderedDict(sd_extra)) + net.load_state_dict(sd, strict=strict_sd) + + if pretrained_pth_dm is not None: + assert (pretrained_ckpt_full is None) and \ + (pretrained_pth_full is None) and \ + (pretrained_pth is None) and \ + (pretrained_ckpt is None), errmsg + print_log('Load diffusion model from [{}] strict [{}].'.format( + pretrained_pth_dm, strict_sd)) + sd = torch.load(pretrained_pth_dm, map_location='cpu') + net.model.diffusion_model.load_state_dict(sd, strict=strict_sd) + + if pretrained_pth_ema is not None: + assert (pretrained_ckpt_full is None) and \ + (pretrained_pth_full is None) and \ + (pretrained_pth is None) and \ + (pretrained_ckpt is None), errmsg + print_log('Load unet ema model from [{}] strict [{}].'.format( + pretrained_pth_ema, strict_sd)) + sd = torch.load(pretrained_pth_ema, map_location='cpu') + net.model_ema.load_state_dict(sd, strict=strict_sd) + +def auto_merge_imlist(imlist, max=64): + imlist = imlist[0:max] + h, w = imlist[0].shape[0:2] + num_images = len(imlist) + num_row = int(np.sqrt(num_images)) + num_col = num_images//num_row + 1 if num_images%num_row!=0 else num_images//num_row + canvas = np.zeros([num_row*h, num_col*w, 3], dtype=np.uint8) + for idx, im in enumerate(imlist): + hi = (idx // num_col) * h + wi = (idx % num_col) * w + canvas[hi:hi+h, wi:wi+w, :] = im + return canvas + +def latent2im(net, latent): + single_input = len(latent.shape) == 3 + if single_input: + latent = latent[None] + im = net.decode_image(latent.to(net.device)) + im = torch.clamp((im+1.0)/2.0, min=0.0, max=1.0) + im = [tvtrans.ToPILImage()(i) for i in im] + if single_input: + im = im[0] + return im + +def im2latent(net, im): + single_input = not isinstance(im, list) + if single_input: + im = [im] + im = torch.stack([tvtrans.ToTensor()(i) for i in im], dim=0) + im = (im*2-1).to(net.device) + z = net.encode_image(im) + if single_input: + z = z[0] + return z + +class color_adjust(object): + def __init__(self, ref_from, ref_to): + x0, m0, std0 = self.get_data_and_stat(ref_from) + x1, m1, std1 = self.get_data_and_stat(ref_to) + self.ref_from_stat = (m0, std0) + self.ref_to_stat = (m1, std1) + self.ref_from = self.preprocess(x0).reshape(-1, 3) + self.ref_to = x1.reshape(-1, 3) + + def get_data_and_stat(self, x): + if isinstance(x, str): + x = np.array(PIL.Image.open(x)) + elif isinstance(x, PIL.Image.Image): + x = np.array(x) + elif isinstance(x, torch.Tensor): + x = torch.clamp(x, min=0.0, max=1.0) + x = np.array(tvtrans.ToPILImage()(x)) + elif isinstance(x, np.ndarray): + pass + else: + raise ValueError + x = x.astype(float) + m = np.reshape(x, (-1, 3)).mean(0) + s = np.reshape(x, (-1, 3)).std(0) + return x, m, s + + def preprocess(self, x): + m0, s0 = self.ref_from_stat + m1, s1 = self.ref_to_stat + y = ((x-m0)/s0)*s1 + m1 + return y + + def __call__(self, xin, keep=0, simple=False): + xin, _, _ = self.get_data_and_stat(xin) + x = self.preprocess(xin) + if simple: + y = (x*(1-keep) + xin*keep) + y = np.clip(y, 0, 255).astype(np.uint8) + return y + + h, w = x.shape[:2] + x = x.reshape(-1, 3) + y = [] + for chi in range(3): + yi = self.pdf_transfer_1d(self.ref_from[:, chi], self.ref_to[:, chi], x[:, chi]) + y.append(yi) + + y = np.stack(y, axis=1) + y = y.reshape(h, w, 3) + y = (y.astype(float)*(1-keep) + xin.astype(float)*keep) + y = np.clip(y, 0, 255).astype(np.uint8) + return y + + def pdf_transfer_1d(self, arr_fo, arr_to, arr_in, n=600): + arr = np.concatenate((arr_fo, arr_to)) + min_v = arr.min() - 1e-6 + max_v = arr.max() + 1e-6 + min_vto = arr_to.min() - 1e-6 + max_vto = arr_to.max() + 1e-6 + xs = np.array( + [min_v + (max_v - min_v) * i / n for i in range(n + 1)]) + hist_fo, _ = np.histogram(arr_fo, xs) + hist_to, _ = np.histogram(arr_to, xs) + xs = xs[:-1] + # compute probability distribution + cum_fo = np.cumsum(hist_fo) + cum_to = np.cumsum(hist_to) + d_fo = cum_fo / cum_fo[-1] + d_to = cum_to / cum_to[-1] + # transfer + t_d = np.interp(d_fo, d_to, xs) + t_d[d_fo <= d_to[ 0]] = min_vto + t_d[d_fo >= d_to[-1]] = max_vto + arr_out = np.interp(arr_in, xs, t_d) + return arr_out + +######## +# main # +######## + +class eval(eval_base): + def prepare_model(self): + cfg = cfguh().cfg + net = get_model()(cfg.model) + if cfg.env.cuda: + net.to(self.local_rank) + load_state_dict(net, cfg.eval) #<--- added + net = torch.nn.parallel.DistributedDataParallel( + net, device_ids=[self.local_rank], + find_unused_parameters=True) + net.eval() + return {'net' : net,} + +class eval_stage(esbase): + """ + This is eval stage that can check comprehensive results + """ + def __init__(self): + from ..model_zoo.ddim import DDIMSampler + self.sampler = DDIMSampler + + def get_net(self, paras): + return paras['net'] + + def get_image_path(self): + if 'train' in cfguh().cfg: + log_dir = cfguh().cfg.train.log_dir + else: + log_dir = cfguh().cfg.eval.log_dir + return os.path.join(log_dir, "udemo") + + @torch.no_grad() + def sample(self, net, sampler, prompt, output_dim, scale, n_samples, ddim_steps, ddim_eta): + h, w = output_dim + uc = None + if scale != 1.0: + uc = net.get_learned_conditioning(n_samples * [""]) + c = net.get_learned_conditioning(n_samples * [prompt]) + shape = [4, h//8, w//8] + rv = sampler.sample( + S=ddim_steps, + conditioning=c, + batch_size=n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=scale, + unconditional_conditioning=uc, + eta=ddim_eta) + return rv + + def save_images(self, pil_list, name, path, suffix=''): + canvas = auto_merge_imlist([np.array(i) for i in pil_list]) + image_name = '{}{}.png'.format(name, suffix) + PIL.Image.fromarray(canvas).save(osp.join(path, image_name)) + + def __call__(self, **paras): + cfg = cfguh().cfg + cfgv = cfg.eval + + net = paras['net'] + eval_cnt = paras.get('eval_cnt', None) + fix_seed = cfgv.get('fix_seed', False) + + LRANK = sync.get_rank('local') + LWSIZE = sync.get_world_size('local') + + image_path = self.get_image_path() + self.create_dir(image_path) + eval_cnt = paras.get('eval_cnt', None) + suffix='' if eval_cnt is None else '_itern'+str(eval_cnt) + + if isinstance(net, (torch.nn.DataParallel, + torch.nn.parallel.DistributedDataParallel)): + netm = net.module + else: + netm = net + + with_ema = getattr(netm, 'model_ema', None) is not None + sampler = self.sampler(netm) + setattr(netm, 'device', LRANK) # Trick + + replicate = cfgv.get('replicate', 1) + conditioning = cfgv.conditioning * replicate + conditioning_local = conditioning[LRANK : len(conditioning) : LWSIZE] + seed_increment = [i for i in range(len(conditioning))][LRANK : len(conditioning) : LWSIZE] + + for prompti, seedi in zip(conditioning_local, seed_increment): + if prompti == 'SKIP': + continue + draw_filename = prompti.strip().replace(' ', '-') + if fix_seed: + np.random.seed(cfg.env.rnd_seed + seedi) + torch.manual_seed(cfg.env.rnd_seed + seedi + 100) + suffixi = suffix + "_seed{}".format(cfg.env.rnd_seed + seedi + 100) + else: + suffixi = suffix + + if with_ema: + with netm.ema_scope(): + x, _ = self.sample(netm, sampler, prompti, **cfgv.sample) + else: + x, _ = self.sample(netm, sampler, prompti, **cfgv.sample) + + demo_image = latent2im(netm, x) + self.save_images(demo_image, draw_filename, image_path, suffix=suffixi) + + if eval_cnt is not None: + print_log('Demo printed for {}'.format(eval_cnt)) + return {} + +################## +# eval variation # +################## + +class eval_stage_variation(eval_stage): + @torch.no_grad() + def sample(self, net, sampler, visual_hint, output_dim, scale, n_samples, ddim_steps, ddim_eta): + h, w = output_dim + vh = tvtrans.ToTensor()(PIL.Image.open(visual_hint))[None].to(net.device) + c = net.get_learned_conditioning(vh) + c = c.repeat(n_samples, 1, 1) + uc = None + if scale != 1.0: + dummy = torch.zeros_like(vh) + uc = net.get_learned_conditioning(dummy) + uc = uc.repeat(n_samples, 1, 1) + + shape = [4, h//8, w//8] + rv = sampler.sample( + S=ddim_steps, + conditioning=c, + batch_size=n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=scale, + unconditional_conditioning=uc, + eta=ddim_eta) + return rv + + def __call__(self, **paras): + cfg = cfguh().cfg + cfgv = cfg.eval + + net = paras['net'] + eval_cnt = paras.get('eval_cnt', None) + fix_seed = cfgv.get('fix_seed', False) + + LRANK = sync.get_rank('local') + LWSIZE = sync.get_world_size('local') + + image_path = self.get_image_path() + self.create_dir(image_path) + eval_cnt = paras.get('eval_cnt', None) + suffix='' if eval_cnt is None else '_'+str(eval_cnt) + + if isinstance(net, (torch.nn.DataParallel, + torch.nn.parallel.DistributedDataParallel)): + netm = net.module + else: + netm = net + + with_ema = getattr(netm, 'model_ema', None) is not None + sampler = self.sampler(netm) + setattr(netm, 'device', LRANK) # Trick + + color_adj = cfguh().cfg.eval.get('color_adj', False) + color_adj_keep_ratio = cfguh().cfg.eval.get('color_adj_keep_ratio', 0.5) + color_adj_simple = cfguh().cfg.eval.get('color_adj_simple', True) + + replicate = cfgv.get('replicate', 1) + conditioning = cfgv.conditioning * replicate + conditioning_local = conditioning[LRANK : len(conditioning) : LWSIZE] + seed_increment = [i for i in range(len(conditioning))][LRANK : len(conditioning) : LWSIZE] + + for ci, seedi in zip(conditioning_local, seed_increment): + if ci == 'SKIP': + continue + + draw_filename = osp.splitext(osp.basename(ci))[0] + + if fix_seed: + np.random.seed(cfg.env.rnd_seed + seedi) + torch.manual_seed(cfg.env.rnd_seed + seedi + 100) + suffixi = suffix + "_seed{}".format(cfg.env.rnd_seed + seedi + 100) + else: + suffixi = suffix + + if with_ema: + with netm.ema_scope(): + x, _ = self.sample(netm, sampler, ci, **cfgv.sample) + else: + x, _ = self.sample(netm, sampler, ci, **cfgv.sample) + + demo_image = latent2im(netm, x) + if color_adj: + x_adj = [] + for demoi in demo_image: + color_adj_f = color_adjust(ref_from=demoi, ref_to=ci) + xi_adj = color_adj_f(demoi, keep=color_adj_keep_ratio, simple=color_adj_simple) + x_adj.append(xi_adj) + demo_image = x_adj + self.save_images(demo_image, draw_filename, image_path, suffix=suffixi) + + if eval_cnt is not None: + print_log('Demo printed for {}'.format(eval_cnt)) + return {} diff --git a/versatile_diffusion/lib/experiments/vd_default.py b/versatile_diffusion/lib/experiments/vd_default.py new file mode 100644 index 0000000000000000000000000000000000000000..265091538d2d294d027a13d7d1b9736c152514e2 --- /dev/null +++ b/versatile_diffusion/lib/experiments/vd_default.py @@ -0,0 +1,549 @@ +import torch +import torch.distributed as dist +from torchvision import transforms as tvtrans +import os +import os.path as osp +import time +import timeit +import copy +import json +import pickle +import PIL.Image +import numpy as np +from datetime import datetime +from easydict import EasyDict as edict +from collections import OrderedDict + +from lib.cfg_holder import cfg_unique_holder as cfguh +from lib.data_factory import get_dataset, get_sampler, collate +from lib.model_zoo import \ + get_model, get_optimizer, get_scheduler +from lib.log_service import print_log + +from ..utils import train as train_base +from ..utils import eval as eval_base +from ..utils import train_stage as tsbase +from ..utils import eval_stage as esbase +from .. import sync + +from .sd_default import auto_merge_imlist, latent2im, color_adjust +from .sd_default import eval as eval_base +from .sd_default import eval_stage as eval_stage_base + +############### +# some helper # +############### + +def atomic_save(cfg, net, opt, step, path): + if isinstance(net, (torch.nn.DataParallel, + torch.nn.parallel.DistributedDataParallel)): + netm = net.module + else: + netm = net + sd = netm.state_dict() + slimmed_sd = [(ki, vi) for ki, vi in sd.items() + if ki.find('autokl')!=0 and ki.find('optimus')!=0 and ki.find('clip')!=0] + + checkpoint = { + "config" : cfg, + "state_dict" : OrderedDict(slimmed_sd), + "step" : step} + if opt is not None: + checkpoint['optimizer_states'] = opt.state_dict() + import io + import fsspec + bytesbuffer = io.BytesIO() + torch.save(checkpoint, bytesbuffer) + with fsspec.open(path, "wb") as f: + f.write(bytesbuffer.getvalue()) + +def load_state_dict(net, cfg): + pretrained_pth_full = cfg.get('pretrained_pth_full' , None) + pretrained_ckpt_full = cfg.get('pretrained_ckpt_full', None) + pretrained_pth = cfg.get('pretrained_pth' , None) + pretrained_ckpt = cfg.get('pretrained_ckpt' , None) + pretrained_pth_dm = cfg.get('pretrained_pth_dm' , None) + pretrained_pth_ema = cfg.get('pretrained_pth_ema' , None) + strict_sd = cfg.get('strict_sd', False) + errmsg = "Overlapped model state_dict! This is undesired behavior!" + + if pretrained_pth_full is not None or pretrained_ckpt_full is not None: + assert (pretrained_pth is None) and \ + (pretrained_ckpt is None) and \ + (pretrained_pth_dm is None) and \ + (pretrained_pth_ema is None), errmsg + if pretrained_pth_full is not None: + target_file = pretrained_pth_full + sd = torch.load(target_file, map_location='cpu') + assert pretrained_ckpt is None, errmsg + else: + target_file = pretrained_ckpt_full + sd = torch.load(target_file, map_location='cpu')['state_dict'] + print_log('Load full model from [{}] strict [{}].'.format( + target_file, strict_sd)) + net.load_state_dict(sd, strict=strict_sd) + + if pretrained_pth is not None or pretrained_ckpt is not None: + assert (pretrained_ckpt_full is None) and \ + (pretrained_pth_full is None) and \ + (pretrained_pth_dm is None) and \ + (pretrained_pth_ema is None), errmsg + if pretrained_pth is not None: + target_file = pretrained_pth + sd = torch.load(target_file, map_location='cpu') + assert pretrained_ckpt is None, errmsg + else: + target_file = pretrained_ckpt + sd = torch.load(target_file, map_location='cpu')['state_dict'] + print_log('Load model from [{}] strict [{}].'.format( + target_file, strict_sd)) + sd_extra = [(ki, vi) for ki, vi in net.state_dict().items() \ + if ki.find('autokl')==0 or ki.find('optimus')==0 or ki.find('clip')==0] + sd.update(OrderedDict(sd_extra)) + net.load_state_dict(sd, strict=strict_sd) + + if pretrained_pth_dm is not None: + assert (pretrained_ckpt_full is None) and \ + (pretrained_pth_full is None) and \ + (pretrained_pth is None) and \ + (pretrained_ckpt is None), errmsg + print_log('Load diffusion model from [{}] strict [{}].'.format( + pretrained_pth_dm, strict_sd)) + sd = torch.load(pretrained_pth_dm, map_location='cpu') + net.model.diffusion_model.load_state_dict(sd, strict=strict_sd) + + if pretrained_pth_ema is not None: + assert (pretrained_ckpt_full is None) and \ + (pretrained_pth_full is None) and \ + (pretrained_pth is None) and \ + (pretrained_ckpt is None), errmsg + print_log('Load unet ema model from [{}] strict [{}].'.format( + pretrained_pth_ema, strict_sd)) + sd = torch.load(pretrained_pth_ema, map_location='cpu') + net.model_ema.load_state_dict(sd, strict=strict_sd) + +################### +# official stages # +################### + +class eval(eval_base): + pass + +class eval_stage(eval_stage_base): + """ + Evaluation of both prompt and vision + """ + def __init__(self): + from ..model_zoo.ddim_vd import DDIMSampler_VD + self.sampler = DDIMSampler_VD + + @torch.no_grad() + def sample( + self, net, sampler, context, otype, ctype, image_output_dim, text_latent_dim, + scale, n_samples, ddim_steps, ddim_eta): + if ctype == 'prompt': + c = net.clip_encode_text(n_samples * [context]) + uc = None + if scale != 1.0: + uc = net.clip_encode_text(n_samples * [""]) + elif ctype == 'vision': + context = context[None].repeat(n_samples, 1, 1, 1) + c = net.clip_encode_vision(context) + uc = None + if scale != 1.0: + dummy = torch.zeros_like(context) + uc = net.clip_encode_vision(dummy) + + if otype == 'image': + h, w = image_output_dim + shape = [n_samples, 4, h//8, w//8] + rv = sampler.sample( + steps=ddim_steps, + shape=shape, + conditioning=c, + unconditional_guidance_scale=scale, + unconditional_conditioning=uc, + xtype=otype, ctype=ctype, + eta=ddim_eta, + verbose=False,) + elif otype == 'text': + n = text_latent_dim + shape = [n_samples, n] + rv = sampler.sample( + steps=ddim_steps, + shape=shape, + conditioning=c, + unconditional_guidance_scale=scale, + unconditional_conditioning=uc, + xtype=otype, ctype=ctype, + eta=ddim_eta, + verbose=False,) + + return rv + + def decode_and_save( + self, netm, z, xtype, ctype, path, name, suffix, + color_adj=False, color_adj_to=None): + if xtype == 'image': + x = netm.autokl_decode(z) + name = 't2i_'+name if ctype == 'prompt' else 'v2i_'+name + if color_adj and (ctype=='vision'): + keep_ratio = cfguh().cfg.eval.get('color_adj_keep_ratio', 0.5) + simple = cfguh().cfg.eval.get('color_adj_simple', True) + x_adj = [] + for xi in x: + color_adj_f = color_adjust(ref_from=(xi+1)/2, ref_to=color_adj_to) + xi_adj = color_adj_f((xi+1)/2, keep=keep_ratio, simple=simple) + x_adj.append(xi_adj) + x = x_adj + self.save_images(x, name, path, suffix=suffix) + elif xtype == 'text': + prompt_temperature = cfguh().cfg.eval.get('prompt_temperature', 1.0) + x = netm.optimus_decode(z, temperature=prompt_temperature) + name = 't2t_'+name if ctype == 'prompt' else 'v2t_'+name + prompt_merge_same_adj_word = cfguh().cfg.eval.get('prompt_merge_same_adj_word', False) + if prompt_merge_same_adj_word: + xnew = [] + for xi in x: + xi_split = xi.split() + xinew = [] + for idxi, wi in enumerate(xi_split): + if idxi!=0 and wi==xi_split[idxi-1]: + continue + xinew.append(wi) + xnew.append(' '.join(xinew)) + x = xnew + self.save_text(x, name, path, suffix=suffix) + + def save_images(self, x, name, path, suffix=''): + if isinstance(x, torch.Tensor): + single_input = len(x.shape) == 3 + if single_input: + x = x[None] + x = torch.clamp((x+1.0)/2.0, min=0.0, max=1.0) + x = [tvtrans.ToPILImage()(xi) for xi in x] + xlist = [np.array(xi) for xi in x] + elif isinstance(x, list): + xlist = x + canvas = auto_merge_imlist(xlist) + image_name = '{}{}.png'.format(name, suffix) + PIL.Image.fromarray(canvas).save(osp.join(path, image_name)) + + def save_text(self, x, name, path, suffix=''): + file_name = '{}{}.txt'.format(name, suffix) + with open(osp.join(path, file_name) ,'w') as f: + for xi in x: + f.write(xi+'\n') + + def __call__(self, **paras): + cfg = cfguh().cfg + cfgv = cfg.eval + + net = self.get_net(paras) + eval_cnt = paras.get('eval_cnt', None) + fix_seed = cfgv.get('fix_seed', False) + + LRANK = sync.get_rank('local') + LWSIZE = sync.get_world_size('local') + + output_path = self.get_image_path() + self.create_dir(output_path) + eval_cnt = paras.get('eval_cnt', None) + suffix='' if eval_cnt is None else '_'+str(eval_cnt) + + if isinstance(net, (torch.nn.DataParallel, + torch.nn.parallel.DistributedDataParallel)): + netm = net.module + else: + netm = net + + with_ema = getattr(netm, 'model_ema', None) is not None + sampler = self.sampler(netm) + setattr(netm, 'device', LRANK) # Trick + + color_adj = cfguh().cfg.eval.get('color_adj', False) + + replicate = cfgv.get('replicate', 1) + conditioning = cfgv.conditioning * replicate + conditioning_local = conditioning[LRANK : len(conditioning) : LWSIZE] + seed_increment = [i for i in range(len(conditioning))][LRANK : len(conditioning) : LWSIZE] + + for conditioningi, seedi in zip(conditioning_local, seed_increment): + if conditioningi == 'SKIP': + continue + + ci, otypei = conditioningi + + if osp.isfile(ci): + # is vision + output_name = osp.splitext(osp.basename(ci))[0] + ci = tvtrans.ToTensor()(PIL.Image.open(ci)) + ci = ci*2 - 1 + ctypei = 'vision' + else: + # is prompt + output_name = ci.strip().replace(' ', '-') + ctypei = 'prompt' + + if fix_seed: + np.random.seed(cfg.env.rnd_seed + seedi) + torch.manual_seed(cfg.env.rnd_seed + seedi + 100) + suffixi = suffix + "_seed{}".format(cfg.env.rnd_seed + seedi + 100) + else: + suffixi = suffix + + if with_ema: + with netm.ema_scope(): + z, _ = self.sample(netm, sampler, ci, otypei, ctypei, **cfgv.sample) + else: + z, _ = self.sample(netm, sampler, ci, otypei, ctypei, **cfgv.sample) + + self.decode_and_save( + netm, z, otypei, ctypei, output_path, output_name, suffixi, + color_adj=color_adj, color_adj_to=conditioningi[0],) + + if eval_cnt is not None: + print_log('Demo printed for {}'.format(eval_cnt)) + return {} + +################ +# basic stages # +################ + +class eval_stage_basic(eval_stage_base): + @torch.no_grad() + def sample(self, net, sampler, visual_hint, output_dim, scale, n_samples, ddim_steps, ddim_eta): + h, w = output_dim + vh = PIL.Image.open(visual_hint) + c = net.clip_encode_vision(n_samples * [vh]) + uc = None + if scale != 1.0: + dummy = torch.zeros_like(tvtrans.ToTensor()(vh)) + uc = net.clip_encode_vision(n_samples * [dummy]) + + shape = [4, h//8, w//8] + rv = sampler.sample( + S=ddim_steps, + conditioning=c, + batch_size=n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=scale, + unconditional_conditioning=uc, + eta=ddim_eta) + return rv + + def __call__(self, **paras): + cfg = cfguh().cfg + cfgv = cfg.eval + + net = paras['net'] + eval_cnt = paras.get('eval_cnt', None) + fix_seed = cfgv.get('fix_seed', False) + + LRANK = sync.get_rank('local') + LWSIZE = sync.get_world_size('local') + + image_path = self.get_image_path() + self.create_dir(image_path) + eval_cnt = paras.get('eval_cnt', None) + suffix='' if eval_cnt is None else '_'+str(eval_cnt) + + if isinstance(net, (torch.nn.DataParallel, + torch.nn.parallel.DistributedDataParallel)): + netm = net.module + else: + netm = net + + with_ema = getattr(netm, 'model_ema', None) is not None + sampler = self.sampler(netm) + setattr(netm, 'device', LRANK) # Trick + + color_adj = cfguh().cfg.eval.get('color_adj', False) + color_adj_keep_ratio = cfguh().cfg.eval.get('color_adj_keep_ratio', 0.5) + color_adj_simple = cfguh().cfg.eval.get('color_adj_simple', True) + + replicate = cfgv.get('replicate', 1) + conditioning = cfgv.conditioning * replicate + conditioning_local = conditioning[LRANK : len(conditioning) : LWSIZE] + seed_increment = [i for i in range(len(conditioning))][LRANK : len(conditioning) : LWSIZE] + + for ci, seedi in zip(conditioning_local, seed_increment): + if ci == 'SKIP': + continue + draw_filename = osp.splitext(osp.basename(ci))[0] + if fix_seed: + np.random.seed(cfg.env.rnd_seed + seedi) + torch.manual_seed(cfg.env.rnd_seed + seedi + 100) + suffixi = suffix + "_seed{}".format(cfg.env.rnd_seed + seedi + 100) + else: + suffixi = suffix + + if with_ema: + with netm.ema_scope(): + x, _ = self.sample(netm, sampler, ci, **cfgv.sample) + else: + x, _ = self.sample(netm, sampler, ci, **cfgv.sample) + + demo_image = latent2im(netm, x) + if color_adj: + x_adj = [] + for demoi in demo_image: + color_adj_f = color_adjust(ref_from=demoi, ref_to=ci) + xi_adj = color_adj_f(demoi, keep=color_adj_keep_ratio, simple=color_adj_simple) + x_adj.append(xi_adj) + demo_image = x_adj + self.save_images(demo_image, draw_filename, image_path, suffix=suffixi) + + if eval_cnt is not None: + print_log('Demo printed for {}'.format(eval_cnt)) + return {} + +####################### +# dual context stages # +####################### + +class eval_stage_dc(eval_stage_base): + def __init__(self): + from ..model_zoo.ddim_dualcontext import DDIMSampler_DualContext + self.sampler = DDIMSampler_DualContext + + @torch.no_grad() + def sample( + self, net, sampler, conditioning, output_dim, + scale, n_samples, ddim_steps, ddim_eta): + ctype, cvalue =conditioning + if ctype == 'prompt': + return self.sample_text( + net, sampler, cvalue, output_dim, + scale, n_samples, ddim_steps, ddim_eta) + elif ctype == 'vision': + return self.sample_vision( + net, sampler, cvalue, output_dim, + scale, n_samples, ddim_steps, ddim_eta) + else: + raise ValueError + + @torch.no_grad() + def sample_text( + self, net, sampler, prompt, output_dim, + scale, n_samples, ddim_steps, ddim_eta): + h, w = output_dim + uc = None + if scale != 1.0: + uc = net.clip_encode_text(n_samples * [""]) + c = net.clip_encode_text(n_samples * [prompt]) + shape = [n_samples, 4, h//8, w//8] + rv = sampler.sample_text( + steps=ddim_steps, + shape=shape, + conditioning=c, + unconditional_guidance_scale=scale, + unconditional_conditioning=uc, + eta=ddim_eta, + verbose=False,) + return rv + + @torch.no_grad() + def sample_vision( + self, net, sampler, visual_hint, output_dim, + scale, n_samples, ddim_steps, ddim_eta): + h, w = output_dim + if len(visual_hint.shape) == 3: + visual_hint=visual_hint[None].repeat(n_samples, 1, 1, 1) + else: + raise ValueError + + c = net.clip_encode_vision(visual_hint) + uc = None + if scale != 1.0: + visual_hint_blank = torch.zeros_like(visual_hint) + uc = net.clip_encode_vision(visual_hint_blank) + + shape = [n_samples, 4, h//8, w//8] + rv = sampler.sample_vision( + steps=ddim_steps, + shape=shape, + conditioning=c, + unconditional_guidance_scale=scale, + unconditional_conditioning=uc, + eta=ddim_eta, + verbose=False,) + return rv + + def __call__(self, **paras): + cfg = cfguh().cfg + cfgv = cfg.eval + + net = self.get_net(paras) + eval_cnt = paras.get('eval_cnt', None) + fix_seed = cfgv.get('fix_seed', False) + + LRANK = sync.get_rank('local') + LWSIZE = sync.get_world_size('local') + + image_path = self.get_image_path() + self.create_dir(image_path) + suffix='' if eval_cnt is None else '_'+str(eval_cnt) + + if isinstance(net, (torch.nn.DataParallel, + torch.nn.parallel.DistributedDataParallel)): + netm = net.module + else: + netm = net + + with_ema = getattr(netm, 'model_ema', None) is not None + sampler = self.sampler(netm) + setattr(netm, 'device', LRANK) # Trick + + color_adj = cfguh().cfg.eval.get('color_adj', False) + color_adj_keep_ratio = cfguh().cfg.eval.get('color_adj_keep_ratio', 0.5) + color_adj_simple = cfguh().cfg.eval.get('color_adj_simple', True) + + replicate = cfgv.get('replicate', 1) + conditioning = cfgv.conditioning * replicate + conditioning_local = conditioning[LRANK : len(conditioning) : LWSIZE] + seed_increment = [i for i in range(len(conditioning))][LRANK : len(conditioning) : LWSIZE] + + for ci, seedi in zip(conditioning_local, seed_increment): + if ci == 'SKIP': + continue + + if osp.isfile(ci): + # is vision + draw_filename = 'v2i_' + osp.splitext(osp.basename(ci))[0] + ci = tvtrans.ToTensor()(PIL.Image.open(ci)) + ci = ci*2 - 1 + ci = ('vision', ci) + else: + # is prompt + draw_filename = 't2i_' + ci.strip().replace(' ', '-') + ci = ('prompt', ci) + + if fix_seed: + np.random.seed(cfg.env.rnd_seed + seedi) + torch.manual_seed(cfg.env.rnd_seed + seedi + 100) + suffixi = suffix + "_seed{}".format(cfg.env.rnd_seed + seedi + 100) + else: + suffixi = suffix + + if with_ema: + with netm.ema_scope(): + x, _ = self.sample(netm, sampler, ci, **cfgv.sample) + else: + x, _ = self.sample(netm, sampler, ci, **cfgv.sample) + + demo_image = latent2im(netm, x) + if color_adj and ci[0] == 'vision': + x_adj = [] + for demoi in demo_image: + color_adj_f = color_adjust(ref_from=demoi, ref_to=ci[1]) + xi_adj = color_adj_f(demoi, keep=color_adj_keep_ratio, simple=color_adj_simple) + x_adj.append(xi_adj) + demo_image = x_adj + self.save_images(demo_image, draw_filename, image_path, suffix=suffixi) + + if eval_cnt is not None: + print_log('Demo printed for {}'.format(eval_cnt)) + return {} + diff --git a/versatile_diffusion/lib/model_zoo/__init__.py b/versatile_diffusion/lib/model_zoo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1c0a57a9e5d66ee79319d7390dedf650ffb05caf --- /dev/null +++ b/versatile_diffusion/lib/model_zoo/__init__.py @@ -0,0 +1,4 @@ +from .common.get_model import get_model +from .common.get_optimizer import get_optimizer +from .common.get_scheduler import get_scheduler +from .common.utils import get_unit diff --git a/versatile_diffusion/lib/model_zoo/__pycache__/__init__.cpython-310.pyc b/versatile_diffusion/lib/model_zoo/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac1101609dea2c0c21deede6b0b7718fd35eaf7f Binary files /dev/null and b/versatile_diffusion/lib/model_zoo/__pycache__/__init__.cpython-310.pyc differ diff --git a/versatile_diffusion/lib/model_zoo/__pycache__/__init__.cpython-38.pyc b/versatile_diffusion/lib/model_zoo/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a0ffead0adf7e2b1f7bb625645eeb474c644fbb Binary files /dev/null and b/versatile_diffusion/lib/model_zoo/__pycache__/__init__.cpython-38.pyc differ diff --git a/versatile_diffusion/lib/model_zoo/__pycache__/attention.cpython-310.pyc b/versatile_diffusion/lib/model_zoo/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9286c7eec6bfc42585396739017d617c2f0fa323 Binary files /dev/null and b/versatile_diffusion/lib/model_zoo/__pycache__/attention.cpython-310.pyc differ diff --git a/versatile_diffusion/lib/model_zoo/__pycache__/attention.cpython-38.pyc b/versatile_diffusion/lib/model_zoo/__pycache__/attention.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..783d666b1c72bbb17a8c1abf129106ebdfd6b505 Binary files /dev/null and b/versatile_diffusion/lib/model_zoo/__pycache__/attention.cpython-38.pyc differ diff --git a/versatile_diffusion/lib/model_zoo/__pycache__/autoencoder.cpython-310.pyc b/versatile_diffusion/lib/model_zoo/__pycache__/autoencoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89cef1f47e9cda018fc4482bf66bd8adec7d111a Binary files /dev/null and b/versatile_diffusion/lib/model_zoo/__pycache__/autoencoder.cpython-310.pyc differ diff --git a/versatile_diffusion/lib/model_zoo/__pycache__/autoencoder.cpython-38.pyc b/versatile_diffusion/lib/model_zoo/__pycache__/autoencoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffecf60247d523feb181426add0c2c6fd2bf3e17 Binary files /dev/null and b/versatile_diffusion/lib/model_zoo/__pycache__/autoencoder.cpython-38.pyc differ diff --git a/versatile_diffusion/lib/model_zoo/__pycache__/clip.cpython-310.pyc b/versatile_diffusion/lib/model_zoo/__pycache__/clip.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61327291dc5e50b4bd0af27ccb5469aec9483845 Binary files /dev/null and b/versatile_diffusion/lib/model_zoo/__pycache__/clip.cpython-310.pyc differ diff --git a/versatile_diffusion/lib/model_zoo/__pycache__/clip.cpython-38.pyc b/versatile_diffusion/lib/model_zoo/__pycache__/clip.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb1addf8464fac7c13c06a9c97fa5ea64bfaf967 Binary files /dev/null and b/versatile_diffusion/lib/model_zoo/__pycache__/clip.cpython-38.pyc differ diff --git a/versatile_diffusion/lib/model_zoo/__pycache__/vd.cpython-310.pyc b/versatile_diffusion/lib/model_zoo/__pycache__/vd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b97d727129ca2ec13f00448a52c5b7c344cb8c47 Binary files /dev/null and b/versatile_diffusion/lib/model_zoo/__pycache__/vd.cpython-310.pyc differ diff --git a/versatile_diffusion/lib/model_zoo/__pycache__/vd.cpython-38.pyc b/versatile_diffusion/lib/model_zoo/__pycache__/vd.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0484ecd33413e741d9a4693e9d1a7bb776004de Binary files /dev/null and b/versatile_diffusion/lib/model_zoo/__pycache__/vd.cpython-38.pyc differ diff --git a/versatile_diffusion/lib/model_zoo/attention.py b/versatile_diffusion/lib/model_zoo/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..ac1ad1a1ce83ea9d4dbb4d597b39b2f4ae6dbe1f --- /dev/null +++ b/versatile_diffusion/lib/model_zoo/attention.py @@ -0,0 +1,435 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + +from .diffusion_utils import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, + disable_self_attn=False): + super().__init__() + self.disable_self_attn = disable_self_attn + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, + disable_self_attn=disable_self_attn) + for d in range(depth)] + ) + + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + x = self.proj_out(x) + return x + x_in + + +########################## +# transformer no context # +########################## + +class BasicTransformerBlockNoContext(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., gated_ff=True, checkpoint=True): + super().__init__() + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, + dropout=dropout, context_dim=None) + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, + dropout=dropout, context_dim=None) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint) + + def _forward(self, x): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x)) + x + x = self.ff(self.norm3(x)) + x + return x + +class SpatialTransformerNoContext(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0.,): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlockNoContext(inner_dim, n_heads, d_head, dropout=dropout) + for d in range(depth)] + ) + + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + + def forward(self, x): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + for block in self.transformer_blocks: + x = block(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + x = self.proj_out(x) + return x + x_in + + +####################################### +# Spatial Transformer with Two Branch # +####################################### + +class DualSpatialTransformer(nn.Module): + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + + # First crossattn + self.norm_0 = Normalize(in_channels) + self.proj_in_0 = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + self.transformer_blocks_0 = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, + disable_self_attn=disable_self_attn) + for d in range(depth)] + ) + self.proj_out_0 = zero_module(nn.Conv2d( + inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) + + # Second crossattn + self.norm_1 = Normalize(in_channels) + self.proj_in_1 = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + self.transformer_blocks_1 = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, + disable_self_attn=disable_self_attn) + for d in range(depth)] + ) + self.proj_out_1 = zero_module(nn.Conv2d( + inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) + + def forward(self, x, context=None, which=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + if which==0: + norm, proj_in, blocks, proj_out = \ + self.norm_0, self.proj_in_0, self.transformer_blocks_0, self.proj_out_0 + elif which==1: + norm, proj_in, blocks, proj_out = \ + self.norm_1, self.proj_in_1, self.transformer_blocks_1, self.proj_out_1 + else: + # assert False, 'DualSpatialTransformer forward with a invalid which branch!' + # import numpy.random as npr + # rwhich = 0 if npr.rand() < which else 1 + # context = context[rwhich] + # if rwhich==0: + # norm, proj_in, blocks, proj_out = \ + # self.norm_0, self.proj_in_0, self.transformer_blocks_0, self.proj_out_0 + # elif rwhich==1: + # norm, proj_in, blocks, proj_out = \ + # self.norm_1, self.proj_in_1, self.transformer_blocks_1, self.proj_out_1 + + # import numpy.random as npr + # rwhich = 0 if npr.rand() < 0.33 else 1 + # if rwhich==0: + # context = context[rwhich] + # norm, proj_in, blocks, proj_out = \ + # self.norm_0, self.proj_in_0, self.transformer_blocks_0, self.proj_out_0 + # else: + + norm, proj_in, blocks, proj_out = \ + self.norm_0, self.proj_in_0, self.transformer_blocks_0, self.proj_out_0 + x0 = norm(x) + x0 = proj_in(x0) + x0 = rearrange(x0, 'b c h w -> b (h w) c').contiguous() + for block in blocks: + x0 = block(x0, context=context[0]) + x0 = rearrange(x0, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + x0 = proj_out(x0) + + norm, proj_in, blocks, proj_out = \ + self.norm_1, self.proj_in_1, self.transformer_blocks_1, self.proj_out_1 + x1 = norm(x) + x1 = proj_in(x1) + x1 = rearrange(x1, 'b c h w -> b (h w) c').contiguous() + for block in blocks: + x1 = block(x1, context=context[1]) + x1 = rearrange(x1, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + x1 = proj_out(x1) + return x0*which + x1*(1-which) + x_in + + x = norm(x) + x = proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + for block in blocks: + x = block(x, context=context) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + x = proj_out(x) + return x + x_in diff --git a/versatile_diffusion/lib/model_zoo/autoencoder.py b/versatile_diffusion/lib/model_zoo/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..69359d18bb6c21013dcc512752a1748cab1f493a --- /dev/null +++ b/versatile_diffusion/lib/model_zoo/autoencoder.py @@ -0,0 +1,425 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from contextlib import contextmanager +from lib.model_zoo.common.get_model import get_model, register + +# from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer + +from .diffusion_modules import Encoder, Decoder +from .distributions import DiagonalGaussianDistribution + + +# class VQModel(nn.Module): +# def __init__(self, +# ddconfig, +# lossconfig, +# n_embed, +# embed_dim, +# ckpt_path=None, +# ignore_keys=[], +# image_key="image", +# colorize_nlabels=None, +# monitor=None, +# batch_resize_range=None, +# scheduler_config=None, +# lr_g_factor=1.0, +# remap=None, +# sane_index_shape=False, # tell vector quantizer to return indices as bhw +# use_ema=False +# ): +# super().__init__() +# self.embed_dim = embed_dim +# self.n_embed = n_embed +# self.image_key = image_key +# self.encoder = Encoder(**ddconfig) +# self.decoder = Decoder(**ddconfig) +# self.loss = instantiate_from_config(lossconfig) +# self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, +# remap=remap, +# sane_index_shape=sane_index_shape) +# self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) +# self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) +# if colorize_nlabels is not None: +# assert type(colorize_nlabels)==int +# self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) +# if monitor is not None: +# self.monitor = monitor +# self.batch_resize_range = batch_resize_range +# if self.batch_resize_range is not None: +# print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") + +# self.use_ema = use_ema +# if self.use_ema: +# self.model_ema = LitEma(self) +# print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + +# if ckpt_path is not None: +# self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) +# self.scheduler_config = scheduler_config +# self.lr_g_factor = lr_g_factor + +# @contextmanager +# def ema_scope(self, context=None): +# if self.use_ema: +# self.model_ema.store(self.parameters()) +# self.model_ema.copy_to(self) +# if context is not None: +# print(f"{context}: Switched to EMA weights") +# try: +# yield None +# finally: +# if self.use_ema: +# self.model_ema.restore(self.parameters()) +# if context is not None: +# print(f"{context}: Restored training weights") + +# def init_from_ckpt(self, path, ignore_keys=list()): +# sd = torch.load(path, map_location="cpu")["state_dict"] +# keys = list(sd.keys()) +# for k in keys: +# for ik in ignore_keys: +# if k.startswith(ik): +# print("Deleting key {} from state_dict.".format(k)) +# del sd[k] +# missing, unexpected = self.load_state_dict(sd, strict=False) +# print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") +# if len(missing) > 0: +# print(f"Missing Keys: {missing}") +# print(f"Unexpected Keys: {unexpected}") + +# def on_train_batch_end(self, *args, **kwargs): +# if self.use_ema: +# self.model_ema(self) + +# def encode(self, x): +# h = self.encoder(x) +# h = self.quant_conv(h) +# quant, emb_loss, info = self.quantize(h) +# return quant, emb_loss, info + +# def encode_to_prequant(self, x): +# h = self.encoder(x) +# h = self.quant_conv(h) +# return h + +# def decode(self, quant): +# quant = self.post_quant_conv(quant) +# dec = self.decoder(quant) +# return dec + +# def decode_code(self, code_b): +# quant_b = self.quantize.embed_code(code_b) +# dec = self.decode(quant_b) +# return dec + +# def forward(self, input, return_pred_indices=False): +# quant, diff, (_,_,ind) = self.encode(input) +# dec = self.decode(quant) +# if return_pred_indices: +# return dec, diff, ind +# return dec, diff + +# def get_input(self, batch, k): +# x = batch[k] +# if len(x.shape) == 3: +# x = x[..., None] +# x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() +# if self.batch_resize_range is not None: +# lower_size = self.batch_resize_range[0] +# upper_size = self.batch_resize_range[1] +# if self.global_step <= 4: +# # do the first few batches with max size to avoid later oom +# new_resize = upper_size +# else: +# new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) +# if new_resize != x.shape[2]: +# x = F.interpolate(x, size=new_resize, mode="bicubic") +# x = x.detach() +# return x + +# def training_step(self, batch, batch_idx, optimizer_idx): +# # https://github.com/pytorch/pytorch/issues/37142 +# # try not to fool the heuristics +# x = self.get_input(batch, self.image_key) +# xrec, qloss, ind = self(x, return_pred_indices=True) + +# if optimizer_idx == 0: +# # autoencode +# aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, +# last_layer=self.get_last_layer(), split="train", +# predicted_indices=ind) + +# self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) +# return aeloss + +# if optimizer_idx == 1: +# # discriminator +# discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, +# last_layer=self.get_last_layer(), split="train") +# self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) +# return discloss + +# def validation_step(self, batch, batch_idx): +# log_dict = self._validation_step(batch, batch_idx) +# with self.ema_scope(): +# log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") +# return log_dict + +# def _validation_step(self, batch, batch_idx, suffix=""): +# x = self.get_input(batch, self.image_key) +# xrec, qloss, ind = self(x, return_pred_indices=True) +# aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, +# self.global_step, +# last_layer=self.get_last_layer(), +# split="val"+suffix, +# predicted_indices=ind +# ) + +# discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, +# self.global_step, +# last_layer=self.get_last_layer(), +# split="val"+suffix, +# predicted_indices=ind +# ) +# rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] +# self.log(f"val{suffix}/rec_loss", rec_loss, +# prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) +# self.log(f"val{suffix}/aeloss", aeloss, +# prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) +# if version.parse(pl.__version__) >= version.parse('1.4.0'): +# del log_dict_ae[f"val{suffix}/rec_loss"] +# self.log_dict(log_dict_ae) +# self.log_dict(log_dict_disc) +# return self.log_dict + +# def configure_optimizers(self): +# lr_d = self.learning_rate +# lr_g = self.lr_g_factor*self.learning_rate +# print("lr_d", lr_d) +# print("lr_g", lr_g) +# opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ +# list(self.decoder.parameters())+ +# list(self.quantize.parameters())+ +# list(self.quant_conv.parameters())+ +# list(self.post_quant_conv.parameters()), +# lr=lr_g, betas=(0.5, 0.9)) +# opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), +# lr=lr_d, betas=(0.5, 0.9)) + +# if self.scheduler_config is not None: +# scheduler = instantiate_from_config(self.scheduler_config) + +# print("Setting up LambdaLR scheduler...") +# scheduler = [ +# { +# 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), +# 'interval': 'step', +# 'frequency': 1 +# }, +# { +# 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), +# 'interval': 'step', +# 'frequency': 1 +# }, +# ] +# return [opt_ae, opt_disc], scheduler +# return [opt_ae, opt_disc], [] + +# def get_last_layer(self): +# return self.decoder.conv_out.weight + +# def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): +# log = dict() +# x = self.get_input(batch, self.image_key) +# x = x.to(self.device) +# if only_inputs: +# log["inputs"] = x +# return log +# xrec, _ = self(x) +# if x.shape[1] > 3: +# # colorize with random projection +# assert xrec.shape[1] > 3 +# x = self.to_rgb(x) +# xrec = self.to_rgb(xrec) +# log["inputs"] = x +# log["reconstructions"] = xrec +# if plot_ema: +# with self.ema_scope(): +# xrec_ema, _ = self(x) +# if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) +# log["reconstructions_ema"] = xrec_ema +# return log + +# def to_rgb(self, x): +# assert self.image_key == "segmentation" +# if not hasattr(self, "colorize"): +# self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) +# x = F.conv2d(x, weight=self.colorize) +# x = 2.*(x-x.min())/(x.max()-x.min()) - 1. +# return x + +# class VQModelInterface(VQModel): +# def __init__(self, embed_dim, *args, **kwargs): +# super().__init__(embed_dim=embed_dim, *args, **kwargs) +# self.embed_dim = embed_dim + +# def encode(self, x): +# h = self.encoder(x) +# h = self.quant_conv(h) +# return h + +# def decode(self, h, force_not_quantize=False): +# # also go through quantization layer +# if not force_not_quantize: +# quant, emb_loss, info = self.quantize(h) +# else: +# quant = h +# quant = self.post_quant_conv(quant) +# dec = self.decoder(quant) +# return dec + +@register('autoencoderkl') +class AutoencoderKL(nn.Module): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None,): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + + self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + +class IdentityFirstStage(nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/versatile_diffusion/lib/model_zoo/bert.py b/versatile_diffusion/lib/model_zoo/bert.py new file mode 100644 index 0000000000000000000000000000000000000000..7ee4a4725e28976732bc34e06c382756522cd09e --- /dev/null +++ b/versatile_diffusion/lib/model_zoo/bert.py @@ -0,0 +1,142 @@ +import torch +import torch.nn as nn +from functools import partial + +# from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class'): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + + def forward(self, batch, key=None): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + c = self.embedding(c) + return c + + +class TransformerEmbedder(AbstractEncoder): + """Some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77): + super().__init__() + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer)) + + def forward(self, tokens): + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, x): + return self(x) + + +class BERTTokenizer(AbstractEncoder): + """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" + def __init__(self, device="cuda", vq_interface=True, max_length=77): + super().__init__() + from transformers import BertTokenizerFast # TODO: add to reuquirements + self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") + self.vq_interface = vq_interface + self.max_length = max_length + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"] + return tokens + + @torch.no_grad() + def encode(self, text): + tokens = self(text) + if not self.vq_interface: + return tokens + return None, None, [None, None, tokens] + + def decode(self, text): + return text + + +class BERTEmbedder(AbstractEncoder): + """Uses the BERT tokenizr model and add some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, + ckpt_path=None, ignore_keys=[], device="cuda", use_tokenizer=True, embedding_dropout=0.0): + super().__init__() + self.use_tknz_fn = use_tokenizer + if self.use_tknz_fn: + self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer), + emb_dropout=embedding_dropout) + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu") + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + + def forward(self, text): + if self.use_tknz_fn: + tokens = self.tknz_fn(text) + else: + tokens = text + device = self.transformer.token_emb.weight.device # a trick to get device + tokens = tokens.to(device) + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, text): + # output of length 77 + return self(text) + + +class SpatialRescaler(nn.Module): + def __init__(self, + n_stages=1, + method='bilinear', + multiplier=0.5, + in_channels=3, + out_channels=None, + bias=False): + super().__init__() + self.n_stages = n_stages + assert self.n_stages >= 0 + assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] + self.multiplier = multiplier + self.interpolator = partial(torch.nn.functional.interpolate, mode=method) + self.remap_output = out_channels is not None + if self.remap_output: + print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') + self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) + + def forward(self,x): + for stage in range(self.n_stages): + x = self.interpolator(x, scale_factor=self.multiplier) + + + if self.remap_output: + x = self.channel_mapper(x) + return x + + def encode(self, x): + return self(x) diff --git a/versatile_diffusion/lib/model_zoo/clip.py b/versatile_diffusion/lib/model_zoo/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..dfe06019bef8e5fd2f44aa30db602731ddbf7d6a --- /dev/null +++ b/versatile_diffusion/lib/model_zoo/clip.py @@ -0,0 +1,179 @@ +import torch +import torch.nn as nn +import numpy as np +from functools import partial +from lib.model_zoo.common.get_model import register + +version = '0' +symbol = 'clip' + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + +from transformers import CLIPTokenizer, CLIPTextModel + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + +@register('clip_text_frozen', version) +class FrozenCLIPTextEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32 + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length # TODO: typical value? + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + #self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + +from transformers import CLIPProcessor, CLIPModel + +@register('clip_frozen', version) +class FrozenCLIP(AbstractEncoder): + def __init__(self, + version="openai/clip-vit-large-patch14", + max_length=77, + encode_type='encode_text', + fp16=False, ): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.processor = CLIPProcessor.from_pretrained(version) + self.model = CLIPModel.from_pretrained(version) + self.max_length = max_length # TODO: typical value? + self.encode_type = encode_type + self.fp16 = fp16 + self.freeze() + + def get_device(self): + # A trick to get device + return self.model.text_projection.weight.device + + def freeze(self): + self.model = self.model.eval() + self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def encode_text_pooled(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.get_device()) + outputs = self.model.get_text_features(input_ids=tokens) + return outputs + + def encode_vision_pooled(self, images): + inputs = self.processor(images=images, return_tensors="pt") + pixels = inputs['pixel_values'].half() if self.fp16 else inputs['pixel_values'] + pixels = pixels.to(self.get_device()) + return self.model.get_image_features(pixel_values=pixels) + + def encode_text_noproj(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.get_device()) + outputs = self.model.text_model(input_ids=tokens) + return outputs.last_hidden_state + + def encode_vision_noproj(self, images): + inputs = self.processor(images=images, return_tensors="pt") + pixels = inputs['pixel_values'].half() if self.fp16 else inputs['pixel_values'] + pixels = pixels.to(self.get_device()) + outputs = self.model.vision_model(pixel_values=pixels) + return outputs.last_hidden_state + + def encode_text(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.get_device()) + #tokens = tokens.half() if self.fp16 else tokens ## Furkan added + outputs = self.model.text_model(input_ids=tokens) + z = self.model.text_projection(outputs.last_hidden_state) + z_pooled = self.model.text_projection(outputs.pooler_output) + z = z / torch.norm(z_pooled.unsqueeze(1), dim=-1, keepdim=True) + return z + + def encode_vision(self, images): + z = self.encode_vision_noproj(images) + z = self.model.vision_model.post_layernorm(z) + z = self.model.visual_projection(z) + z_pooled = z[:, 0:1] + # z_pooled_normed = z_pooled / z_pooled.norm(dim=-1, keepdim=True) + z = z / torch.norm(z_pooled, dim=-1, keepdim=True) + return z + + def encode(self, *args, **kwargs): + return getattr(self, self.encode_type)(*args, **kwargs) + +############################# +# copyed from justin's code # +############################# + +@register('clip_vision_frozen_justin', version) +class FrozenCLIPVisionEmbedder_Justin(AbstractEncoder): + """ + Uses the CLIP image encoder. + """ + def __init__( + self, + model='ViT-L/14', + jit=False, + device='cuda' if torch.cuda.is_available() else 'cpu', + antialias=False, + ): + super().__init__() + from . import clip_justin + self.model, _ = clip_justin.load(name=model, device=device, jit=jit) + self.device = device + self.antialias = antialias + + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + + # I didn't call this originally, but seems like it was frozen anyway + self.freeze() + + def freeze(self): + self.transformer = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def preprocess(self, x): + import kornia + # Expects inputs in the range -1, 1 + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic',align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x): + # x is assumed to be in range [-1,1] + return self.model.encode_image(self.preprocess(x)).float() + + def encode(self, im): + return self(im).unsqueeze(1) diff --git a/versatile_diffusion/lib/model_zoo/clip_justin/__init__.py b/versatile_diffusion/lib/model_zoo/clip_justin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae6bdd43be8089411574b2a17fd08be2d089dcfd --- /dev/null +++ b/versatile_diffusion/lib/model_zoo/clip_justin/__init__.py @@ -0,0 +1 @@ +from .clip import load diff --git a/versatile_diffusion/lib/model_zoo/clip_justin/__pycache__/__init__.cpython-38.pyc b/versatile_diffusion/lib/model_zoo/clip_justin/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b2c33faa06f089965ad1b4d658ddb9982716be7 Binary files /dev/null and b/versatile_diffusion/lib/model_zoo/clip_justin/__pycache__/__init__.cpython-38.pyc differ diff --git a/versatile_diffusion/lib/model_zoo/clip_justin/__pycache__/clip.cpython-38.pyc b/versatile_diffusion/lib/model_zoo/clip_justin/__pycache__/clip.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b3a2277c6c6fbcd4011a82367d89bb281a76464 Binary files /dev/null and b/versatile_diffusion/lib/model_zoo/clip_justin/__pycache__/clip.cpython-38.pyc differ diff --git a/versatile_diffusion/lib/model_zoo/clip_justin/clip.py b/versatile_diffusion/lib/model_zoo/clip_justin/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..afbae930f45bf12beac35eec50a17c426066dc8d --- /dev/null +++ b/versatile_diffusion/lib/model_zoo/clip_justin/clip.py @@ -0,0 +1,237 @@ +import hashlib +import os +import urllib +import warnings +from typing import Any, Union, List +from pkg_resources import packaging + +import torch +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from tqdm import tqdm + +from .model import build_model +# from .simple_tokenizer import SimpleTokenizer as _Tokenizer + +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): + warnings.warn("PyTorch version 1.7.1 or higher is recommended") + + +__all__ = ["available_models", "load", "tokenize"] +# _tokenizer = _Tokenizer() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", +} + + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _convert_image_to_rgb(image): + return image.convert("RGB") + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + download_root: str + path to download the model files; by default, it uses "~/.cache/clip" + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + with open(model_path, 'rb') as opened_file: + try: + # loading JIT archive + model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(opened_file, map_location="cpu") + + if not jit: + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + + +# def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: +# """ +# Returns the tokenized representation of given input string(s) + +# Parameters +# ---------- +# texts : Union[str, List[str]] +# An input string or a list of input strings to tokenize + +# context_length : int +# The context length to use; all CLIP models use 77 as the context length + +# truncate: bool +# Whether to truncate the text in case its encoding is longer than the context length + +# Returns +# ------- +# A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. +# We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. +# """ +# if isinstance(texts, str): +# texts = [texts] + +# sot_token = _tokenizer.encoder["<|startoftext|>"] +# eot_token = _tokenizer.encoder["<|endoftext|>"] +# all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] +# if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): +# result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) +# else: +# result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + +# for i, tokens in enumerate(all_tokens): +# if len(tokens) > context_length: +# if truncate: +# tokens = tokens[:context_length] +# tokens[-1] = eot_token +# else: +# raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") +# result[i, :len(tokens)] = torch.tensor(tokens) + +# return result diff --git a/versatile_diffusion/lib/model_zoo/clip_justin/simple_tokenizer.py b/versatile_diffusion/lib/model_zoo/clip_justin/simple_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0a66286b7d5019c6e221932a813768038f839c91 --- /dev/null +++ b/versatile_diffusion/lib/model_zoo/clip_justin/simple_tokenizer.py @@ -0,0 +1,132 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text diff --git a/versatile_diffusion/lib/model_zoo/common/get_model.py b/versatile_diffusion/lib/model_zoo/common/get_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b114a179b82cb29a5433064adc192b8adb6460fc --- /dev/null +++ b/versatile_diffusion/lib/model_zoo/common/get_model.py @@ -0,0 +1,120 @@ +from email.policy import strict +import torch +import torchvision.models +import os.path as osp +import copy +from ...log_service import print_log +from .utils import \ + get_total_param, get_total_param_sum, \ + get_unit + +# def load_state_dict(net, model_path): +# if isinstance(net, dict): +# for ni, neti in net.items(): +# paras = torch.load(model_path[ni], map_location=torch.device('cpu')) +# new_paras = neti.state_dict() +# new_paras.update(paras) +# neti.load_state_dict(new_paras) +# else: +# paras = torch.load(model_path, map_location=torch.device('cpu')) +# new_paras = net.state_dict() +# new_paras.update(paras) +# net.load_state_dict(new_paras) +# return + +# def save_state_dict(net, path): +# if isinstance(net, (torch.nn.DataParallel, +# torch.nn.parallel.DistributedDataParallel)): +# torch.save(net.module.state_dict(), path) +# else: +# torch.save(net.state_dict(), path) + +def singleton(class_): + instances = {} + def getinstance(*args, **kwargs): + if class_ not in instances: + instances[class_] = class_(*args, **kwargs) + return instances[class_] + return getinstance + +def preprocess_model_args(args): + # If args has layer_units, get the corresponding + # units. + # If args get backbone, get the backbone model. + args = copy.deepcopy(args) + if 'layer_units' in args: + layer_units = [ + get_unit()(i) for i in args.layer_units + ] + args.layer_units = layer_units + if 'backbone' in args: + args.backbone = get_model()(args.backbone) + return args + +@singleton +class get_model(object): + def __init__(self): + self.model = {} + self.version = {} + + def register(self, model, name, version='x'): + self.model[name] = model + self.version[name] = version + + def __call__(self, cfg, verbose=True): + """ + Construct model based on the config. + """ + t = cfg.type + + # the register is in each file + if t.find('ldm')==0: + from .. import ldm + elif t=='autoencoderkl': + from .. import autoencoder + elif t.find('clip')==0: + from .. import clip + elif t.find('sd')==0: + from .. import sd + elif t.find('vd')==0: + from .. import vd + elif t.find('openai_unet')==0: + from .. import openaimodel + elif t.find('optimus')==0: + from .. import optimus + + args = preprocess_model_args(cfg.args) + net = self.model[t](**args) + + map_location = cfg.get('map_location', 'cpu') + strict_sd = cfg.get('strict_sd', True) + if 'ckpt' in cfg: + checkpoint = torch.load(cfg.ckpt, map_location=map_location) + net.load_state_dict(checkpoint['state_dict'], strict=strict_sd) + if verbose: + print_log('Load ckpt from {}'.format(cfg.ckpt)) + elif 'pth' in cfg: + sd = torch.load(cfg.pth, map_location=map_location) + net.load_state_dict(sd, strict=strict_sd) + if verbose: + print_log('Load pth from {}'.format(cfg.pth)) + + # display param_num & param_sum + if verbose: + print_log( + 'Load {} with total {} parameters,' + '{:.3f} parameter sum.'.format( + t, + get_total_param(net), + get_total_param_sum(net) )) + + return net + + def get_version(self, name): + return self.version[name] + +def register(name, version='x'): + def wrapper(class_): + get_model().register(class_, name, version) + return class_ + return wrapper diff --git a/versatile_diffusion/lib/model_zoo/common/get_optimizer.py b/versatile_diffusion/lib/model_zoo/common/get_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2820ce6734fe0929963e5ba92c8fd4c4fd6ddd --- /dev/null +++ b/versatile_diffusion/lib/model_zoo/common/get_optimizer.py @@ -0,0 +1,47 @@ +import torch +import torch.optim as optim +import numpy as np +import itertools + +def singleton(class_): + instances = {} + def getinstance(*args, **kwargs): + if class_ not in instances: + instances[class_] = class_(*args, **kwargs) + return instances[class_] + return getinstance + +class get_optimizer(object): + def __init__(self): + self.optimizer = {} + self.register(optim.SGD, 'sgd') + self.register(optim.Adam, 'adam') + self.register(optim.AdamW, 'adamw') + + def register(self, optim, name): + self.optimizer[name] = optim + + def __call__(self, net, cfg): + if cfg is None: + return None + t = cfg.type + if isinstance(net, (torch.nn.DataParallel, + torch.nn.parallel.DistributedDataParallel)): + netm = net.module + else: + netm = net + pg = getattr(netm, 'parameter_group', None) + + if pg is not None: + params = [] + for group_name, module_or_para in pg.items(): + if not isinstance(module_or_para, list): + module_or_para = [module_or_para] + + grouped_params = [mi.parameters() if isinstance(mi, torch.nn.Module) else [mi] for mi in module_or_para] + grouped_params = itertools.chain(*grouped_params) + pg_dict = {'params':grouped_params, 'name':group_name} + params.append(pg_dict) + else: + params = net.parameters() + return self.optimizer[t](params, lr=0, **cfg.args) diff --git a/versatile_diffusion/lib/model_zoo/common/get_scheduler.py b/versatile_diffusion/lib/model_zoo/common/get_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..bd7c86e89dd9fcd092836546555b14cb68c7771d --- /dev/null +++ b/versatile_diffusion/lib/model_zoo/common/get_scheduler.py @@ -0,0 +1,262 @@ +import torch +import torch.optim as optim +import numpy as np +import copy +from ... import sync +from ...cfg_holder import cfg_unique_holder as cfguh + +def singleton(class_): + instances = {} + def getinstance(*args, **kwargs): + if class_ not in instances: + instances[class_] = class_(*args, **kwargs) + return instances[class_] + return getinstance + +@singleton +class get_scheduler(object): + def __init__(self): + self.lr_scheduler = {} + + def register(self, lrsf, name): + self.lr_scheduler[name] = lrsf + + def __call__(self, cfg): + if cfg is None: + return None + if isinstance(cfg, list): + schedulers = [] + for ci in cfg: + t = ci.type + schedulers.append( + self.lr_scheduler[t](**ci.args)) + if len(schedulers) == 0: + raise ValueError + else: + return compose_scheduler(schedulers) + t = cfg.type + return self.lr_scheduler[t](**cfg.args) + + +def register(name): + def wrapper(class_): + get_scheduler().register(class_, name) + return class_ + return wrapper + +class template_scheduler(object): + def __init__(self, step): + self.step = step + + def __getitem__(self, idx): + raise ValueError + + def set_lr(self, optim, new_lr, pg_lrscale=None): + """ + Set Each parameter_groups in optim with new_lr + New_lr can be find according to the idx. + pg_lrscale tells how to scale each pg. + """ + # new_lr = self.__getitem__(idx) + pg_lrscale = copy.deepcopy(pg_lrscale) + for pg in optim.param_groups: + if pg_lrscale is None: + pg['lr'] = new_lr + else: + pg['lr'] = new_lr * pg_lrscale.pop(pg['name']) + assert (pg_lrscale is None) or (len(pg_lrscale)==0), \ + "pg_lrscale doesn't match pg" + +@register('constant') +class constant_scheduler(template_scheduler): + def __init__(self, lr, step): + super().__init__(step) + self.lr = lr + + def __getitem__(self, idx): + if idx >= self.step: + raise ValueError + return self.lr + +@register('poly') +class poly_scheduler(template_scheduler): + def __init__(self, start_lr, end_lr, power, step): + super().__init__(step) + self.start_lr = start_lr + self.end_lr = end_lr + self.power = power + + def __getitem__(self, idx): + if idx >= self.step: + raise ValueError + a, b = self.start_lr, self.end_lr + p, n = self.power, self.step + return b + (a-b)*((1-idx/n)**p) + +@register('linear') +class linear_scheduler(template_scheduler): + def __init__(self, start_lr, end_lr, step): + super().__init__(step) + self.start_lr = start_lr + self.end_lr = end_lr + + def __getitem__(self, idx): + if idx >= self.step: + raise ValueError + a, b, n = self.start_lr, self.end_lr, self.step + return b + (a-b)*(1-idx/n) + +@register('multistage') +class constant_scheduler(template_scheduler): + def __init__(self, start_lr, milestones, gamma, step): + super().__init__(step) + self.start_lr = start_lr + m = [0] + milestones + [step] + lr_iter = start_lr + self.lr = [] + for ms, me in zip(m[0:-1], m[1:]): + for _ in range(ms, me): + self.lr.append(lr_iter) + lr_iter *= gamma + + def __getitem__(self, idx): + if idx >= self.step: + raise ValueError + return self.lr[idx] + +class compose_scheduler(template_scheduler): + def __init__(self, schedulers): + self.schedulers = schedulers + self.step = [si.step for si in schedulers] + self.step_milestone = [] + acc = 0 + for i in self.step: + acc += i + self.step_milestone.append(acc) + self.step = sum(self.step) + + def __getitem__(self, idx): + if idx >= self.step: + raise ValueError + ms = self.step_milestone + for idx, (mi, mj) in enumerate(zip(ms[:-1], ms[1:])): + if mi <= idx < mj: + return self.schedulers[idx-mi] + raise ValueError + +#################### +# lambda schedular # +#################### + +class LambdaWarmUpCosineScheduler(template_scheduler): + """ + note: use with a base_lr of 1.0 + """ + def __init__(self, + base_lr, + warm_up_steps, + lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): + cfgt = cfguh().cfg.train + bs = cfgt.batch_size + if 'gradacc_every' not in cfgt: + print('Warning, gradacc_every is not found in xml, use 1 as default.') + acc = cfgt.get('gradacc_every', 1) + self.lr_multi = base_lr * bs * acc + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0. + self.verbosity_interval = verbosity_interval + + def schedule(self, n): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( + 1 + np.cos(t * np.pi)) + self.last_lr = lr + return lr + + def __getitem__(self, idx): + return self.schedule(idx) * self.lr_multi + +class LambdaWarmUpCosineScheduler2(template_scheduler): + """ + supports repeated iterations, configurable via lists + note: use with a base_lr of 1.0. + """ + def __init__(self, + base_lr, + warm_up_steps, + f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): + cfgt = cfguh().cfg.train + # bs = cfgt.batch_size + # if 'gradacc_every' not in cfgt: + # print('Warning, gradacc_every is not found in xml, use 1 as default.') + # acc = cfgt.get('gradacc_every', 1) + # self.lr_multi = base_lr * bs * acc + self.lr_multi = base_lr + assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0. + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( + 1 + np.cos(t * np.pi)) + self.last_f = f + return f + + def __getitem__(self, idx): + return self.schedule(idx) * self.lr_multi + +@register('stable_diffusion_linear') +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): + def schedule(self, n): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + self.last_f = f + return f \ No newline at end of file diff --git a/versatile_diffusion/lib/model_zoo/ddim.py b/versatile_diffusion/lib/model_zoo/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..b610e943046c5c1f81341ed18b5c3a94c62b9613 --- /dev/null +++ b/versatile_diffusion/lib/model_zoo/ddim.py @@ -0,0 +1,216 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, + cond, shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, x0=None, + img_callback=None, log_every_t=100, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1., + unconditional_conditioning=None,): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 diff --git a/versatile_diffusion/lib/model_zoo/ddim_dualcontext.py b/versatile_diffusion/lib/model_zoo/ddim_dualcontext.py new file mode 100644 index 0000000000000000000000000000000000000000..cebb8967716735e19f1e87a3eed669af5a0988cc --- /dev/null +++ b/versatile_diffusion/lib/model_zoo/ddim_dualcontext.py @@ -0,0 +1,144 @@ +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + +from .ddim import DDIMSampler + +class DDIMSampler_DualContext(DDIMSampler): + @torch.no_grad() + def sample_text(self, *args, **kwargs): + self.cond_type = 'prompt' + return self.sample(*args, **kwargs) + + @torch.no_grad() + def sample_vision(self, *args, **kwargs): + self.cond_type = 'vision' + return self.sample(*args, **kwargs) + + @torch.no_grad() + def sample_mixed(self, *args, **kwargs): + self.cond_type = kwargs.pop('cond_mixed_p') + return self.sample(*args, **kwargs) + + @torch.no_grad() + def sample(self, + steps, + shape, + xt=None, + conditioning=None, + eta=0., + temperature=1., + noise_dropout=0., + verbose=True, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None,): + + self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose) + # sampling + print(f'Data shape for DDIM sampling is {shape}, eta {eta}') + + samples, intermediates = self.ddim_sampling( + conditioning, + shape, + xt=xt, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning,) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, + conditioning, + shape, + xt=None, + ddim_use_original_steps=False, + timesteps=None, + log_every_t=100, + temperature=1., + noise_dropout=0., + unconditional_guidance_scale=1., + unconditional_conditioning=None,): + device = self.model.betas.device + bs = shape[0] + if xt is None: + img = torch.randn(shape, device=device) + else: + img = xt + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((bs,), step, device=device, dtype=torch.long) + + outs = self.p_sample_ddim(img, conditioning, ts, index=index, use_original_steps=ddim_use_original_steps, + temperature=temperature, + noise_dropout=noise_dropout, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + img, pred_x0 = outs + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, conditioning, t, index, repeat_noise=False, use_original_steps=False, + temperature=1., noise_dropout=0., + unconditional_guidance_scale=1., unconditional_conditioning=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, conditioning, cond_type=self.cond_type) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + # c_in = torch.cat([unconditional_conditioning, conditioning]) + + # Added for vd-dc dual guidance + if isinstance(unconditional_conditioning, list): + c_in = [torch.cat([ui, ci]) for ui, ci in zip(unconditional_conditioning, conditioning)] + else: + c_in = torch.cat([unconditional_conditioning, conditioning]) + + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, cond_type=self.cond_type).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 diff --git a/versatile_diffusion/lib/model_zoo/ddim_dualmodel.py b/versatile_diffusion/lib/model_zoo/ddim_dualmodel.py new file mode 100644 index 0000000000000000000000000000000000000000..6641ce3d900c54fb103328ce44d168dd996e9ad1 --- /dev/null +++ b/versatile_diffusion/lib/model_zoo/ddim_dualmodel.py @@ -0,0 +1,244 @@ +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + +from .ddim import DDIMSampler + +class DDIMSampler_DualModel(DDIMSampler): + def __init__(self, model_t2i, model_v2i, schedule="linear", **kwargs): + self.model = model_t2i + self.model_t2i = model_t2i + self.model_v2i = model_v2i + self.device = self.model_t2i.device + self.ddpm_num_timesteps = model_t2i.num_timesteps + self.schedule = schedule + + @torch.no_grad() + def sample_text(self, *args, **kwargs): + self.cond_type = 'prompt' + self.p_sample_model_type = 't2i' + return self.sample(*args, **kwargs) + + @torch.no_grad() + def sample_vision(self, *args, **kwargs): + self.cond_type = 'vision' + self.p_sample_model_type = 'v2i' + return self.sample(*args, **kwargs) + + @torch.no_grad() + def sample(self, + steps, + shape, + xt=None, + conditioning=None, + eta=0., + temperature=1., + noise_dropout=0., + verbose=True, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None,): + + self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose) + # sampling + print(f'Data shape for DDIM sampling is {shape}, eta {eta}') + + samples, intermediates = self.ddim_sampling( + conditioning, + shape, + xt=xt, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning,) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, + conditioning, + shape, + xt=None, + ddim_use_original_steps=False, + timesteps=None, + log_every_t=100, + temperature=1., + noise_dropout=0., + unconditional_guidance_scale=1., + unconditional_conditioning=None,): + device = self.model.betas.device + bs = shape[0] + if xt is None: + img = torch.randn(shape, device=device) + else: + img = xt + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((bs,), step, device=device, dtype=torch.long) + + outs = self.p_sample_ddim(img, conditioning, ts, index=index, use_original_steps=ddim_use_original_steps, + temperature=temperature, + noise_dropout=noise_dropout, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + img, pred_x0 = outs + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, conditioning, t, index, repeat_noise=False, use_original_steps=False, + temperature=1., noise_dropout=0., + unconditional_guidance_scale=1., unconditional_conditioning=None): + b, *_, device = *x.shape, x.device + + if self.p_sample_model_type == 't2i': + apply_model = self.model_t2i.apply_model + elif self.p_sample_model_type == 'v2i': + apply_model = self.model_v2i.apply_model + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = apply_model(x, t, conditioning) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, conditioning]) + e_t_uncond, e_t = apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def sample_mixed(self, + steps, + steps_t2i, + steps_v2i, + shape, + xt=None, + c_prompt=None, + c_vision=None, + eta=0., + temperature=1., + noise_dropout=0., + verbose=True, + log_every_t=100, + uc_scale=1., + uc_prompt=None, + uc_vision=None,): + + print(f'DDIM mixed sampling with shape {shape}, eta {eta}') + print(f'steps_t2i {steps_t2i}') + print(f'steps_v2i {steps_v2i}') + + self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose) + self.ddim_timesteps_t2i = self.ddim_timesteps[steps_t2i] + self.ddim_timesteps_v2i = self.ddim_timesteps[steps_v2i] + + samples, intermediates = self.ddim_sampling_mixed( + c_prompt, + c_vision, + shape, + xt=xt, + noise_dropout=noise_dropout, + temperature=temperature, + log_every_t=log_every_t, + uc_scale=uc_scale, + uc_prompt=uc_prompt, + uc_vision=uc_vision, ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling_mixed(self, + c_prompt, + c_vision, + shape, + xt=None, + log_every_t=100, + temperature=1., + noise_dropout=0., + uc_scale=1., + uc_prompt=None, + uc_vision=None, ): + device = self.device + bs = shape[0] + if xt is None: + img = torch.randn(shape, device=device) + else: + img = xt + + timesteps = self.ddim_timesteps + intermediates = {'x_inter': [], 'pred_x0': []} + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(iterator): + if step in self.ddim_timesteps_t2i: + self.p_sample_model_type = 't2i' + conditioning = c_prompt + unconditional_conditioning = uc_prompt + elif step in self.ddim_timesteps_v2i: + self.p_sample_model_type = 'v2i' + conditioning = c_vision + unconditional_conditioning = uc_vision + else: + raise ValueError # shouldn't reached + + index = total_steps - i - 1 + ts = torch.full((bs,), step, device=device, dtype=torch.long) + outs = self.p_sample_ddim( + img, conditioning, ts, + index=index, + temperature=temperature, + noise_dropout=noise_dropout, + unconditional_guidance_scale=uc_scale, + unconditional_conditioning=unconditional_conditioning) + img, pred_x0 = outs + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + diff --git a/versatile_diffusion/lib/model_zoo/ddim_vd.py b/versatile_diffusion/lib/model_zoo/ddim_vd.py new file mode 100644 index 0000000000000000000000000000000000000000..764cb9ba5f78361b33e76da439b8c2211c5df3f1 --- /dev/null +++ b/versatile_diffusion/lib/model_zoo/ddim_vd.py @@ -0,0 +1,419 @@ +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + +from .ddim import DDIMSampler + +class DDIMSampler_VD(DDIMSampler): + @torch.no_grad() + def sample(self, + steps, + shape, + xt=None, + conditioning=None, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + xtype='image', + ctype='prompt', + eta=0., + temperature=1., + noise_dropout=0., + verbose=True, + log_every_t=100,): + + self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose) + print(f'Data shape for DDIM sampling is {shape}, eta {eta}') + samples, intermediates = self.ddim_sampling( + shape, + xt=xt, + conditioning=conditioning, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + xtype=xtype, + ctype=ctype, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + log_every_t=log_every_t,) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, + shape, + xt=None, + conditioning=None, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + xtype='image', + ctype='prompt', + ddim_use_original_steps=False, + timesteps=None, + noise_dropout=0., + temperature=1., + log_every_t=100,): + + device = self.model.model.diffusion_model.device + bs = shape[0] + if xt is None: + xt = torch.randn(shape, device=device, dtype=conditioning.dtype) + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'pred_xt': [], 'pred_x0': []} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + # print(f"Running DDIM Sampling with {total_steps} timesteps") + + pred_xt = xt + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((bs,), step, device=device, dtype=torch.long) + + outs = self.p_sample_ddim( + pred_xt, conditioning, ts, index, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + xtype=xtype, + ctype=ctype, + use_original_steps=ddim_use_original_steps, + noise_dropout=noise_dropout, + temperature=temperature,) + pred_xt, pred_x0 = outs + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['pred_xt'].append(pred_xt) + intermediates['pred_x0'].append(pred_x0) + + return pred_xt, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, conditioning, t, index, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + xtype='image', + ctype='prompt', + repeat_noise=False, + use_original_steps=False, + noise_dropout=0., + temperature=1.,): + + b, *_, device = *x.shape, self.model.model.diffusion_model.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, conditioning, xtype=xtype, ctype=ctype) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, conditioning]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, xtype=xtype, ctype=ctype).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + + if xtype == 'image': + extended_shape = (b, 1, 1, 1) + elif xtype == 'text': + extended_shape = (b, 1) + + a_t = torch.full(extended_shape, alphas[index], device=device, dtype=x.dtype) + a_prev = torch.full(extended_shape, alphas_prev[index], device=device, dtype=x.dtype) + sigma_t = torch.full(extended_shape, sigmas[index], device=device, dtype=x.dtype) + sqrt_one_minus_at = torch.full(extended_shape, sqrt_one_minus_alphas[index], device=device, dtype=x.dtype) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def sample_dc(self, + steps, + shape, + xt=None, + first_conditioning=None, + second_conditioning=None, + unconditional_guidance_scale=1., + xtype='image', + first_ctype='prompt', + second_ctype='prompt', + eta=0., + temperature=1., + mixed_ratio=0.5, + noise_dropout=0., + verbose=True, + log_every_t=100,): + + self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose) + print(f'Data shape for DDIM sampling is {shape}, eta {eta}') + samples, intermediates = self.ddim_sampling_dc( + shape, + xt=xt, + first_conditioning=first_conditioning, + second_conditioning=second_conditioning, + unconditional_guidance_scale=unconditional_guidance_scale, + xtype=xtype, + first_ctype=first_ctype, + second_ctype=second_ctype, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + log_every_t=log_every_t, + mixed_ratio=mixed_ratio, ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling_dc(self, + shape, + xt=None, + first_conditioning=None, + second_conditioning=None, + unconditional_guidance_scale=1., + xtype='image', + first_ctype='prompt', + second_ctype='prompt', + ddim_use_original_steps=False, + timesteps=None, + noise_dropout=0., + temperature=1., + mixed_ratio=0.5, + log_every_t=100,): + + device = self.model.model.diffusion_model.device + bs = shape[0] + if xt is None: + xt = torch.randn(shape, device=device, dtype=first_conditioning[1].dtype) + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'pred_xt': [], 'pred_x0': []} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + # print(f"Running DDIM Sampling with {total_steps} timesteps") + + pred_xt = xt + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((bs,), step, device=device, dtype=torch.long) + + outs = self.p_sample_ddim_dc( + pred_xt, + first_conditioning, + second_conditioning, + ts, index, + unconditional_guidance_scale=unconditional_guidance_scale, + xtype=xtype, + first_ctype=first_ctype, + second_ctype=second_ctype, + use_original_steps=ddim_use_original_steps, + noise_dropout=noise_dropout, + temperature=temperature, + mixed_ratio=mixed_ratio,) + pred_xt, pred_x0 = outs + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['pred_xt'].append(pred_xt) + intermediates['pred_x0'].append(pred_x0) + + return pred_xt, intermediates + + @torch.no_grad() + def p_sample_ddim_dc(self, x, + first_conditioning, + second_conditioning, + t, index, + unconditional_guidance_scale=1., + xtype='image', + first_ctype='prompt', + second_ctype='prompt', + repeat_noise=False, + use_original_steps=False, + noise_dropout=0., + temperature=1., + mixed_ratio=0.5,): + + b, *_, device = *x.shape, self.model.model.diffusion_model.device + + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + first_c = torch.cat(first_conditioning) + second_c = torch.cat(second_conditioning) + + e_t_uncond, e_t = self.model.apply_model_dc( + x_in, t_in, first_c, second_c, xtype=xtype, first_ctype=first_ctype, second_ctype=second_ctype, mixed_ratio=mixed_ratio).chunk(2) + + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + + if xtype == 'image': + extended_shape = (b, 1, 1, 1) + elif xtype == 'text': + extended_shape = (b, 1) + + a_t = torch.full(extended_shape, alphas[index], device=device, dtype=x.dtype) + a_prev = torch.full(extended_shape, alphas_prev[index], device=device, dtype=x.dtype) + sigma_t = torch.full(extended_shape, sigmas[index], device=device, dtype=x.dtype) + sqrt_one_minus_at = torch.full(extended_shape, sqrt_one_minus_alphas[index], device=device, dtype=x.dtype) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, + unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None): + num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0] + + assert t_enc <= num_reference_steps + num_steps = t_enc + + if use_original_steps: + alphas_next = self.alphas_cumprod[:num_steps] + alphas = self.alphas_cumprod_prev[:num_steps] + else: + alphas_next = self.ddim_alphas[:num_steps] + alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) + + alphas_next = alphas_next.to(x0.device) + alphas = alphas.to(x0.device) + x_next = x0 + intermediates = [] + inter_steps = [] + for i in tqdm(range(num_steps), desc='Encoding Image'): + t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long) + if unconditional_guidance_scale == 1.: + noise_pred = self.model.apply_model(x_next, t, c) + else: + assert unconditional_conditioning is not None + e_t_uncond, noise_pred = torch.chunk( + self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), + torch.cat((unconditional_conditioning, c))), 2) + noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) + + xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next + weighted_noise_pred = alphas_next[i].sqrt() * ( + (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred + x_next = xt_weighted + weighted_noise_pred + if return_intermediates and i % ( + num_steps // return_intermediates) == 0 and i < num_steps - 1: + intermediates.append(x_next) + inter_steps.append(i) + elif return_intermediates and i >= num_steps - 2: + intermediates.append(x_next) + inter_steps.append(i) + if callback: callback(i) + + out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} + if return_intermediates: + out.update({'intermediates': intermediates}) + return x_next, out + + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(t.device) + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(t.device) + + if noise is None: + noise = torch.randn_like(x0) + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + + @torch.no_grad() + def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, xtype='image', ctype='vision', + use_original_steps=False, callback=None): + + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) + x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, xtype=xtype, ctype=ctype, use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + if callback: callback(i) + return x_dec + + @torch.no_grad() + def decode_dc(self, x_latent, first_conditioning, second_conditioning, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, xtype='image', first_ctype='vision', second_ctype='prompt', + use_original_steps=False, mixed_ratio=0.5, callback=None): + + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) + x_dec, _ = self.p_sample_ddim_dc( + x_dec, + first_conditioning, + second_conditioning, + ts, index, + unconditional_guidance_scale=unconditional_guidance_scale, + xtype=xtype, + first_ctype=first_ctype, + second_ctype=second_ctype, + use_original_steps=use_original_steps, + noise_dropout=0, + temperature=1, + mixed_ratio=mixed_ratio,) + if callback: callback(i) + return x_dec + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) \ No newline at end of file diff --git a/versatile_diffusion/lib/model_zoo/ddim_vd_old.py b/versatile_diffusion/lib/model_zoo/ddim_vd_old.py new file mode 100644 index 0000000000000000000000000000000000000000..0ebd6de45a547c984bd99597925e5696cd2d523b --- /dev/null +++ b/versatile_diffusion/lib/model_zoo/ddim_vd_old.py @@ -0,0 +1,293 @@ +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + +from .ddim import DDIMSampler + +class DDIMSampler_VD(DDIMSampler): + @torch.no_grad() + def sample(self, + steps, + shape, + xt=None, + conditioning=None, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + xtype='image', + ctype='prompt', + eta=0., + temperature=1., + noise_dropout=0., + verbose=True, + log_every_t=100,): + + self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose) + print(f'Data shape for DDIM sampling is {shape}, eta {eta}') + samples, intermediates = self.ddim_sampling( + shape, + xt=xt, + conditioning=conditioning, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + xtype=xtype, + ctype=ctype, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + log_every_t=log_every_t,) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, + shape, + xt=None, + conditioning=None, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + xtype='image', + ctype='prompt', + ddim_use_original_steps=False, + timesteps=None, + noise_dropout=0., + temperature=1., + log_every_t=100,): + + device = 1 + bs = shape[0] + if xt is None: + xt = torch.randn(shape, device=device) + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'pred_xt': [], 'pred_x0': []} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + # print(f"Running DDIM Sampling with {total_steps} timesteps") + + pred_xt = xt + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((bs,), step, device=device, dtype=torch.long) + + outs = self.p_sample_ddim( + pred_xt, conditioning, ts, index, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + xtype=xtype, + ctype=ctype, + use_original_steps=ddim_use_original_steps, + noise_dropout=noise_dropout, + temperature=temperature,) + pred_xt, pred_x0 = outs + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['pred_xt'].append(pred_xt) + intermediates['pred_x0'].append(pred_x0) + + return pred_xt, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, conditioning, t, index, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + xtype='image', + ctype='prompt', + repeat_noise=False, + use_original_steps=False, + noise_dropout=0., + temperature=1.,): + + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, conditioning, xtype=xtype, ctype=ctype) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, conditioning]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, xtype=xtype, ctype=ctype).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + + if xtype == 'image': + extended_shape = (b, 1, 1, 1) + elif xtype == 'text': + extended_shape = (b, 1) + + a_t = torch.full(extended_shape, alphas[index], device=device) + a_prev = torch.full(extended_shape, alphas_prev[index], device=device) + sigma_t = torch.full(extended_shape, sigmas[index], device=device) + sqrt_one_minus_at = torch.full(extended_shape, sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + +class DDIMSampler_VD_DualContext(DDIMSampler_VD): + @torch.no_grad() + def sample_dc(self, + steps, + shape, + xt=None, + first_conditioning=None, + second_conditioning=None, + unconditional_guidance_scale=1., + xtype='image', + first_ctype='prompt', + second_ctype='prompt', + eta=0., + temperature=1., + mixed_ratio=0.5, + noise_dropout=0., + verbose=True, + log_every_t=100,): + + self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose) + print(f'Data shape for DDIM sampling is {shape}, eta {eta}') + samples, intermediates = self.ddim_sampling_dc( + shape, + xt=xt, + first_conditioning=first_conditioning, + second_conditioning=second_conditioning, + unconditional_guidance_scale=unconditional_guidance_scale, + xtype=xtype, + first_ctype=first_ctype, + second_ctype=second_ctype, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + log_every_t=log_every_t, + mixed_ratio=mixed_ratio, ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling_dc(self, + shape, + xt=None, + first_conditioning=None, + second_conditioning=None, + unconditional_guidance_scale=1., + xtype='image', + first_ctype='prompt', + second_ctype='prompt', + ddim_use_original_steps=False, + timesteps=None, + noise_dropout=0., + temperature=1., + mixed_ratio=0.5, + log_every_t=100,): + + device = self.model.device + bs = shape[0] + if xt is None: + xt = torch.randn(shape, device=device) + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'pred_xt': [], 'pred_x0': []} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + # print(f"Running DDIM Sampling with {total_steps} timesteps") + + pred_xt = xt + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((bs,), step, device=device, dtype=torch.long) + + outs = self.p_sample_ddim_dc( + pred_xt, + first_conditioning, + second_conditioning, + ts, index, + unconditional_guidance_scale=unconditional_guidance_scale, + xtype=xtype, + first_ctype=first_ctype, + second_ctype=second_ctype, + use_original_steps=ddim_use_original_steps, + noise_dropout=noise_dropout, + temperature=temperature, + mixed_ratio=mixed_ratio,) + pred_xt, pred_x0 = outs + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['pred_xt'].append(pred_xt) + intermediates['pred_x0'].append(pred_x0) + + return pred_xt, intermediates + + @torch.no_grad() + def p_sample_ddim_dc(self, x, + first_conditioning, + second_conditioning, + t, index, + unconditional_guidance_scale=1., + xtype='image', + first_ctype='prompt', + second_ctype='prompt', + repeat_noise=False, + use_original_steps=False, + noise_dropout=0., + temperature=1., + mixed_ratio=0.5,): + + b, *_, device = *x.shape, x.device + + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + first_c = torch.cat(first_conditioning) + second_c = torch.cat(second_conditioning) + + e_t_uncond, e_t = self.model.apply_model_dc( + x_in, t_in, first_c, second_c, xtype=xtype, first_ctype=first_ctype, second_ctype=second_ctype, mixed_ratio=mixed_ratio).chunk(2) + + # e_t_uncond, e_t = self.model.apply_model(x_in, t_in, first_c, xtype='image', ctype='vision').chunk(2) + + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + + if xtype == 'image': + extended_shape = (b, 1, 1, 1) + elif xtype == 'text': + extended_shape = (b, 1) + + a_t = torch.full(extended_shape, alphas[index], device=device) + a_prev = torch.full(extended_shape, alphas_prev[index], device=device) + sigma_t = torch.full(extended_shape, sigmas[index], device=device) + sqrt_one_minus_at = torch.full(extended_shape, sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 diff --git a/versatile_diffusion/lib/model_zoo/diffusion_modules.py b/versatile_diffusion/lib/model_zoo/diffusion_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..146c5b241feb8b6f46946b29534f6212fab1ad85 --- /dev/null +++ b/versatile_diffusion/lib/model_zoo/diffusion_modules.py @@ -0,0 +1,835 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +# from .diffusion_utils import instantiate_from_config +from .attention import LinearAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, resolution=resolution, + attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels*ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, resolution=resolution, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size//in_size))+1 + factor_up = 1.+ (out_size % in_size) + print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, + attn_resolutions=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor==1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x + +class FirstStagePostProcessor(nn.Module): + + def __init__(self, ch_mult:list, in_channels, + pretrained_model:nn.Module=None, + reshape=False, + n_channels=None, + dropout=0., + pretrained_config=None): + super().__init__() + if pretrained_config is None: + assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels,num_groups=in_channels//2) + self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3, + stride=1,padding=1) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout)) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + + @torch.no_grad() + def encode_with_pretrained(self,x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self,x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model,self.downsampler): + z = submodel(z,temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z,'b c h w -> b (h w) c') + return z + diff --git a/versatile_diffusion/lib/model_zoo/diffusion_utils.py b/versatile_diffusion/lib/model_zoo/diffusion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b28b42dc6d2933d4a6159e973f70dc721f19701d --- /dev/null +++ b/versatile_diffusion/lib/model_zoo/diffusion_utils.py @@ -0,0 +1,250 @@ +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + # return super().forward(x.float()).type(x.dtype) + return super().forward(x) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + +def noise_like(x, repeat=False): + noise = torch.randn_like(x) + if repeat: + bs = x.shape[0] + noise = noise[0:1].repeat(bs, *((1,) * (len(x.shape) - 1))) + return noise + +########################## +# inherit from ldm.utils # +########################## + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + return total_params diff --git a/versatile_diffusion/lib/model_zoo/distributions.py b/versatile_diffusion/lib/model_zoo/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..f2b8ef901130efc171aa69742ca0244d94d3f2e9 --- /dev/null +++ b/versatile_diffusion/lib/model_zoo/distributions.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/versatile_diffusion/lib/model_zoo/ema.py b/versatile_diffusion/lib/model_zoo/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..e5d61e90eadb4701c7c38d9ed63e4fca7afb78d9 --- /dev/null +++ b/versatile_diffusion/lib/model_zoo/ema.py @@ -0,0 +1,75 @@ +import torch +from torch import nn + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_updates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_updates + else torch.tensor(-1,dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + #remove as '.'-character is not allowed in buffers + s_name = name.replace('.','') + self.m_name2s_name.update({name:s_name}) + self.register_buffer(s_name,p.clone().detach().data) + + self.collected_params = [] + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/versatile_diffusion/lib/model_zoo/openaimodel.py b/versatile_diffusion/lib/model_zoo/openaimodel.py new file mode 100644 index 0000000000000000000000000000000000000000..b8657527fc70f5dbc69982f4a430ed6acaed1e92 --- /dev/null +++ b/versatile_diffusion/lib/model_zoo/openaimodel.py @@ -0,0 +1,2569 @@ +from abc import abstractmethod +from functools import partial +import math +from typing import Iterable + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from .diffusion_utils import \ + checkpoint, conv_nd, linear, avg_pool_nd, \ + zero_module, normalization, timestep_embedding + +from .attention import SpatialTransformer + +from lib.model_zoo.common.get_model import get_model, register + +version = '0' +symbol = 'openai' + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +@register('openai_unet', version) +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + #self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") # todo: convert to warning + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if disable_self_attentions is not None: + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if num_attention_blocks is None or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if disable_self_attentions is not None: + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if num_attention_blocks is None or i < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) + + +####################### +# Unet with self-attn # +####################### + +from .attention import SpatialTransformerNoContext + +@register('openai_unet_nocontext', version) +class UNetModelNoContext(nn.Module): + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + num_attention_blocks=None, ): + + super().__init__() + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + #self.num_res_blocks = num_res_blocks + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") # todo: convert to warning + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformerNoContext( + ch, num_heads, dim_head, depth=transformer_depth + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformerNoContext( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + + if not exists(num_attention_blocks) or i < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformerNoContext( + ch, num_heads, dim_head, depth=transformer_depth, + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def forward(self, x, timesteps): + assert self.num_classes is None, \ + "not supported" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + hs.append(h) + h = self.middle_block(h, emb) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + +@register('openai_unet_nocontext_noatt', version) +class UNetModelNoContextNoAtt(nn.Module): + def __init__( + self, + in_channels, + model_channels, + out_channels, + num_res_blocks, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + use_scale_shift_norm=False, + resblock_updown=False, + n_embed=None,): + + super().__init__() + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + #self.num_res_blocks = num_res_blocks + + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def forward(self, x, timesteps): + assert self.num_classes is None, \ + "not supported" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + hs.append(h) + h = self.middle_block(h, emb) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + +@register('openai_unet_nocontext_noatt_decoderonly', version) +class UNetModelNoContextNoAttDecoderOnly(nn.Module): + def __init__( + self, + in_channels, + out_channels, + model_channels, + num_res_blocks, + dropout=0, + channel_mult=(4, 2, 1), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + use_scale_shift_norm=False, + resblock_updown=False, + n_embed=None,): + + super().__init__() + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + #self.num_res_blocks = num_res_blocks + + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self._feature_size = model_channels + + ch = model_channels * self.channel_mult[0] + self.output_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, ch, 3, padding=1) + ) + ] + ) + + for level, mult in enumerate(channel_mult): + for i in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if level != len(channel_mult)-1 and (i == self.num_res_blocks[level]-1): + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + self.output_blocks.append(TimestepEmbedSequential(*layers)) + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def forward(self, x, timesteps): + assert self.num_classes is None, \ + "not supported" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + h = x.type(self.dtype) + for module in self.output_blocks: + h = module(h, emb) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + +######################### +# Double Attention Unet # +######################### + +from .attention import DualSpatialTransformer + +class TimestepEmbedSequentialExtended(nn.Sequential, TimestepBlock): + def forward(self, x, emb, context=None, which_attn=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + elif isinstance(layer, DualSpatialTransformer): + x = layer(x, context, which=which_attn) + else: + x = layer(x) + return x + +@register('openai_unet_dual_context', version) +class UNetModelDualContext(nn.Module): + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + #self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") # todo: convert to warning + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequentialExtended( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if disable_self_attentions is not None: + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if num_attention_blocks is None or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else DualSpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa + ) + ) + self.input_blocks.append(TimestepEmbedSequentialExtended(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequentialExtended( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequentialExtended( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else DualSpatialTransformer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if disable_self_attentions is not None: + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if num_attention_blocks is None or i < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else DualSpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequentialExtended(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def forward(self, x, timesteps=None, context=None, y=None, which_attn=None, **kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + t_emb = t_emb.to(context.dtype) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context, which_attn=which_attn) + hs.append(h) + h = self.middle_block(h, emb, context, which_attn=which_attn) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context, which_attn=which_attn) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + +########### +# VD Unet # +########### + +from functools import partial + +@register('openai_unet_2d', version) +class UNetModel2D(nn.Module): + def __init__(self, + input_channels, + model_channels, + output_channels, + context_dim=768, + num_noattn_blocks=(2, 2, 2, 2), + channel_mult=(1, 2, 4, 8), + with_attn=[True, True, True, False], + num_heads=8, + use_checkpoint=True, ): + + super().__init__() + + ResBlockPreset = partial( + ResBlock, dropout=0, dims=2, use_checkpoint=use_checkpoint, + use_scale_shift_norm=False) + + self.input_channels = input_channels + self.model_channels = model_channels + self.num_noattn_blocks = num_noattn_blocks + self.channel_mult = channel_mult + self.num_heads = num_heads + + ################## + # Time embedding # + ################## + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim),) + + ################ + # input_blocks # + ################ + current_channel = model_channels + input_blocks = [ + TimestepEmbedSequential( + nn.Conv2d(input_channels, model_channels, 3, padding=1, bias=True))] + input_block_channels = [current_channel] + + for level_idx, mult in enumerate(channel_mult): + for _ in range(self.num_noattn_blocks[level_idx]): + layers = [ + ResBlockPreset( + current_channel, time_embed_dim, + out_channels = mult * model_channels,)] + + current_channel = mult * model_channels + dim_head = current_channel // num_heads + if with_attn[level_idx]: + layers += [ + SpatialTransformer( + current_channel, num_heads, dim_head, + depth=1, context_dim=context_dim, )] + + input_blocks += [TimestepEmbedSequential(*layers)] + input_block_channels.append(current_channel) + + if level_idx != len(channel_mult) - 1: + input_blocks += [ + TimestepEmbedSequential( + Downsample( + current_channel, use_conv=True, + dims=2, out_channels=current_channel,))] + input_block_channels.append(current_channel) + + self.input_blocks = nn.ModuleList(input_blocks) + + ################# + # middle_blocks # + ################# + middle_block = [ + ResBlockPreset( + current_channel, time_embed_dim,), + SpatialTransformer( + current_channel, num_heads, dim_head, + depth=1, context_dim=context_dim, ), + ResBlockPreset( + current_channel, time_embed_dim,),] + self.middle_block = TimestepEmbedSequential(*middle_block) + + ################# + # output_blocks # + ################# + output_blocks = [] + for level_idx, mult in list(enumerate(channel_mult))[::-1]: + for block_idx in range(self.num_noattn_blocks[level_idx] + 1): + extra_channel = input_block_channels.pop() + layers = [ + ResBlockPreset( + current_channel + extra_channel, + time_embed_dim, + out_channels = model_channels * mult,) ] + + current_channel = model_channels * mult + dim_head = current_channel // num_heads + if with_attn[level_idx]: + layers += [ + SpatialTransformer( + current_channel, num_heads, dim_head, + depth=1, context_dim=context_dim,)] + + if level_idx!=0 and block_idx==self.num_noattn_blocks[level_idx]: + layers += [ + Upsample( + current_channel, use_conv=True, + dims=2, out_channels=current_channel)] + + output_blocks += [TimestepEmbedSequential(*layers)] + + self.output_blocks = nn.ModuleList(output_blocks) + + self.out = nn.Sequential( + normalization(current_channel), + nn.SiLU(), + zero_module(nn.Conv2d(model_channels, output_channels, 3, padding=1)),) + + def forward(self, x, timesteps=None, context=None): + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + h = x + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + return self.out(h) + +class FCBlock(TimestepBlock): + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_checkpoint=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_checkpoint = use_checkpoint + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + nn.Conv2d(channels, self.out_channels, 1, padding=0),) + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear(emb_channels, self.out_channels,),) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module(nn.Conv2d(self.out_channels, self.out_channels, 1, padding=0)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = nn.Conv2d(channels, self.out_channels, 1, padding=0) + + def forward(self, x, emb): + if len(x.shape) == 2: + x = x[:, :, None, None] + elif len(x.shape) == 4: + pass + else: + raise ValueError + y = checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint) + if len(x.shape) == 2: + return y[:, :, 0, 0] + elif len(x.shape) == 4: + return y + + def _forward(self, x, emb): + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + +@register('openai_unet_0d', version) +class UNetModel0D(nn.Module): + def __init__(self, + input_channels, + model_channels, + output_channels, + context_dim=768, + num_noattn_blocks=(2, 2, 2, 2), + channel_mult=(1, 2, 4, 8), + with_attn=[True, True, True, False], + num_heads=8, + use_checkpoint=True, ): + + super().__init__() + + FCBlockPreset = partial(FCBlock, dropout=0, use_checkpoint=use_checkpoint) + + self.input_channels = input_channels + self.model_channels = model_channels + self.num_noattn_blocks = num_noattn_blocks + self.channel_mult = channel_mult + self.num_heads = num_heads + + ################## + # Time embedding # + ################## + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim),) + + ################ + # input_blocks # + ################ + current_channel = model_channels + input_blocks = [ + TimestepEmbedSequential( + nn.Conv2d(input_channels, model_channels, 1, padding=0, bias=True))] + input_block_channels = [current_channel] + + for level_idx, mult in enumerate(channel_mult): + for _ in range(self.num_noattn_blocks[level_idx]): + layers = [ + FCBlockPreset( + current_channel, time_embed_dim, + out_channels = mult * model_channels,)] + + current_channel = mult * model_channels + dim_head = current_channel // num_heads + if with_attn[level_idx]: + layers += [ + SpatialTransformer( + current_channel, num_heads, dim_head, + depth=1, context_dim=context_dim, )] + + input_blocks += [TimestepEmbedSequential(*layers)] + input_block_channels.append(current_channel) + + if level_idx != len(channel_mult) - 1: + input_blocks += [ + TimestepEmbedSequential( + Downsample( + current_channel, use_conv=True, + dims=2, out_channels=current_channel,))] + input_block_channels.append(current_channel) + + self.input_blocks = nn.ModuleList(input_blocks) + + ################# + # middle_blocks # + ################# + middle_block = [ + FCBlockPreset( + current_channel, time_embed_dim,), + SpatialTransformer( + current_channel, num_heads, dim_head, + depth=1, context_dim=context_dim, ), + FCBlockPreset( + current_channel, time_embed_dim,),] + self.middle_block = TimestepEmbedSequential(*middle_block) + + ################# + # output_blocks # + ################# + output_blocks = [] + for level_idx, mult in list(enumerate(channel_mult))[::-1]: + for block_idx in range(self.num_noattn_blocks[level_idx] + 1): + extra_channel = input_block_channels.pop() + layers = [ + FCBlockPreset( + current_channel + extra_channel, + time_embed_dim, + out_channels = model_channels * mult,) ] + + current_channel = model_channels * mult + dim_head = current_channel // num_heads + if with_attn[level_idx]: + layers += [ + SpatialTransformer( + current_channel, num_heads, dim_head, + depth=1, context_dim=context_dim,)] + + if level_idx!=0 and block_idx==self.num_noattn_blocks[level_idx]: + layers += [ + nn.Conv2d(current_channel, current_channel, 1, padding=0)] + + output_blocks += [TimestepEmbedSequential(*layers)] + + self.output_blocks = nn.ModuleList(output_blocks) + + self.out = nn.Sequential( + normalization(current_channel), + nn.SiLU(), + zero_module(nn.Conv2d(model_channels, output_channels, 1, padding=0)),) + + def forward(self, x, timesteps=None, context=None): + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + h = x + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + return self.out(h) + +class Linear_MultiDim(nn.Linear): + def __init__(self, in_features, out_features, *args, **kwargs): + + in_features = [in_features] if isinstance(in_features, int) else list(in_features) + out_features = [out_features] if isinstance(out_features, int) else list(out_features) + self.in_features_multidim = in_features + self.out_features_multidim = out_features + super().__init__( + np.array(in_features).prod(), + np.array(out_features).prod(), + *args, **kwargs) + + def forward(self, x): + shape = x.shape + n = len(self.in_features_multidim) + x = x.view(*shape[0:-n], self.in_features) + y = super().forward(x) + y = y.view(*shape[0:-n], *self.out_features_multidim) + return y + +class FCBlock_MultiDim(FCBlock): + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_checkpoint=False,): + channels = [channels] if isinstance(channels, int) else list(channels) + channels_all = np.array(channels).prod() + self.channels_multidim = channels + + if out_channels is not None: + out_channels = [out_channels] if isinstance(out_channels, int) else list(out_channels) + out_channels_all = np.array(out_channels).prod() + self.out_channels_multidim = out_channels + else: + out_channels_all = channels_all + self.out_channels_multidim = self.channels_multidim + + self.channels = channels + super().__init__( + channels = channels_all, + emb_channels = emb_channels, + dropout = dropout, + out_channels = out_channels_all, + use_checkpoint = use_checkpoint,) + + def forward(self, x, emb): + shape = x.shape + n = len(self.channels_multidim) + x = x.view(*shape[0:-n], self.channels, 1, 1) + x = x.view(-1, self.channels, 1, 1) + y = checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint) + y = y.view(*shape[0:-n], -1) + y = y.view(*shape[0:-n], *self.out_channels_multidim) + return y + +@register('openai_unet_0dmd', version) +class UNetModel0D_MultiDim(nn.Module): + def __init__(self, + input_channels, + model_channels, + output_channels, + context_dim=768, + num_noattn_blocks=(2, 2, 2, 2), + channel_mult=(1, 2, 4, 8), + second_dim=(4, 4, 4, 4), + with_attn=[True, True, True, False], + num_heads=8, + use_checkpoint=True, ): + + super().__init__() + + FCBlockPreset = partial(FCBlock_MultiDim, dropout=0, use_checkpoint=use_checkpoint) + + self.input_channels = input_channels + self.model_channels = model_channels + self.num_noattn_blocks = num_noattn_blocks + self.channel_mult = channel_mult + self.second_dim = second_dim + self.num_heads = num_heads + + ################## + # Time embedding # + ################## + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim),) + + ################ + # input_blocks # + ################ + sdim = second_dim[0] + current_channel = [model_channels, sdim, 1] + input_blocks = [ + TimestepEmbedSequential( + Linear_MultiDim([input_channels, 1, 1], current_channel, bias=True))] + input_block_channels = [current_channel] + + for level_idx, (mult, sdim) in enumerate(zip(channel_mult, second_dim)): + for _ in range(self.num_noattn_blocks[level_idx]): + layers = [ + FCBlockPreset( + current_channel, + time_embed_dim, + out_channels = [mult*model_channels, sdim, 1],)] + + current_channel = [mult*model_channels, sdim, 1] + dim_head = current_channel[0] // num_heads + if with_attn[level_idx]: + layers += [ + SpatialTransformer( + current_channel[0], num_heads, dim_head, + depth=1, context_dim=context_dim, )] + + input_blocks += [TimestepEmbedSequential(*layers)] + input_block_channels.append(current_channel) + + if level_idx != len(channel_mult) - 1: + input_blocks += [ + TimestepEmbedSequential( + Linear_MultiDim(current_channel, current_channel, bias=True, ))] + input_block_channels.append(current_channel) + + self.input_blocks = nn.ModuleList(input_blocks) + + ################# + # middle_blocks # + ################# + middle_block = [ + FCBlockPreset( + current_channel, time_embed_dim, ), + SpatialTransformer( + current_channel[0], num_heads, dim_head, + depth=1, context_dim=context_dim, ), + FCBlockPreset( + current_channel, time_embed_dim, ),] + self.middle_block = TimestepEmbedSequential(*middle_block) + + ################# + # output_blocks # + ################# + output_blocks = [] + for level_idx, (mult, sdim) in list(enumerate(zip(channel_mult, second_dim)))[::-1]: + for block_idx in range(self.num_noattn_blocks[level_idx] + 1): + extra_channel = input_block_channels.pop() + layers = [ + FCBlockPreset( + [current_channel[0] + extra_channel[0]] + current_channel[1:], + time_embed_dim, + out_channels = [mult*model_channels, sdim, 1], )] + + current_channel = [mult*model_channels, sdim, 1] + dim_head = current_channel[0] // num_heads + if with_attn[level_idx]: + layers += [ + SpatialTransformer( + current_channel[0], num_heads, dim_head, + depth=1, context_dim=context_dim,)] + + if level_idx!=0 and block_idx==self.num_noattn_blocks[level_idx]: + layers += [ + Linear_MultiDim(current_channel, current_channel, bias=True, )] + + output_blocks += [TimestepEmbedSequential(*layers)] + + self.output_blocks = nn.ModuleList(output_blocks) + + self.out = nn.Sequential( + normalization(current_channel[0]), + nn.SiLU(), + zero_module(Linear_MultiDim(current_channel, [output_channels, 1, 1], bias=True, )),) + + def forward(self, x, timesteps=None, context=None): + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + h = x + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + return self.out(h) + +@register('openai_unet_vd', version) +class UNetModelVD(nn.Module): + def __init__(self, + unet_image_cfg, + unet_text_cfg, ): + + super().__init__() + self.unet_image = get_model()(unet_image_cfg) + self.unet_text = get_model()(unet_text_cfg) + self.time_embed = self.unet_image.time_embed + del self.unet_image.time_embed + del self.unet_text.time_embed + + self.model_channels = self.unet_image.model_channels + + def forward(self, x, timesteps, context, xtype='image', ctype='prompt'): + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + x=x.to(self.device).half()#.float()#.half() + emb = self.time_embed(t_emb.to(self.device).half())#.float()#.half()) + + if xtype == 'text': + x = x[:, :, None, None] + + h = x + for i_module, t_module in zip(self.unet_image.input_blocks, self.unet_text.input_blocks): + h = self.mixed_run(i_module, t_module, h, emb, context, xtype, ctype) + hs.append(h) + h = self.mixed_run( + self.unet_image.middle_block, self.unet_text.middle_block, + h, emb, context, xtype, ctype) + for i_module, t_module in zip(self.unet_image.output_blocks, self.unet_text.output_blocks): + h = th.cat([h, hs.pop()], dim=1) + h = self.mixed_run(i_module, t_module, h, emb, context, xtype, ctype) + if xtype == 'image': + return self.unet_image.out(h) + elif xtype == 'text': + return self.unet_text.out(h).squeeze(-1).squeeze(-1) + + def mixed_run(self, inet, tnet, x, emb, context, xtype, ctype): + + h = x + for ilayer, tlayer in zip(inet, tnet): + if isinstance(ilayer, TimestepBlock) and xtype=='image': + h = ilayer(h, emb) + elif isinstance(tlayer, TimestepBlock) and xtype=='text': + h = tlayer(h, emb) + elif isinstance(ilayer, SpatialTransformer) and ctype=='vision': + h = ilayer(h, context) + elif isinstance(ilayer, SpatialTransformer) and ctype=='prompt': + h = tlayer(h, context) + elif xtype=='image': + h = ilayer(h) + elif xtype == 'text': + h = tlayer(h) + else: + raise ValueError + return h + + def forward_dc(self, x, timesteps, c0, c1, xtype, c0_type, c1_type, mixed_ratio): + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + x=x.to(self.device).half() + emb = self.time_embed(t_emb.to(self.device).half()) + + if xtype == 'text': + x = x[:, :, None, None] + h = x + for i_module, t_module in zip(self.unet_image.input_blocks, self.unet_text.input_blocks): + h = self.mixed_run_dc(i_module, t_module, h, emb, c0, c1, xtype, c0_type, c1_type, mixed_ratio) + hs.append(h) + h = self.mixed_run_dc( + self.unet_image.middle_block, self.unet_text.middle_block, + h, emb, c0, c1, xtype, c0_type, c1_type, mixed_ratio) + for i_module, t_module in zip(self.unet_image.output_blocks, self.unet_text.output_blocks): + h = th.cat([h, hs.pop()], dim=1) + h = self.mixed_run_dc(i_module, t_module, h, emb, c0, c1, xtype, c0_type, c1_type, mixed_ratio) + if xtype == 'image': + return self.unet_image.out(h) + elif xtype == 'text': + return self.unet_text.out(h).squeeze(-1).squeeze(-1) + + def mixed_run_dc(self, inet, tnet, x, emb, c0, c1, xtype, c0_type, c1_type, mixed_ratio): + h = x + for ilayer, tlayer in zip(inet, tnet): + if isinstance(ilayer, TimestepBlock) and xtype=='image': + h = ilayer(h, emb) + elif isinstance(tlayer, TimestepBlock) and xtype=='text': + h = tlayer(h, emb) + elif isinstance(ilayer, SpatialTransformer): + h0 = ilayer(h, c0)-h if c0_type=='vision' else tlayer(h, c0)-h + h1 = ilayer(h, c1)-h if c1_type=='vision' else tlayer(h, c1)-h + h = h0*mixed_ratio + h1*(1-mixed_ratio) + h + # h = ilayer(h, c0) + elif xtype=='image': + h = ilayer(h) + elif xtype == 'text': + h = tlayer(h) + else: + raise ValueError + return h diff --git a/versatile_diffusion/lib/model_zoo/optimus.py b/versatile_diffusion/lib/model_zoo/optimus.py new file mode 100644 index 0000000000000000000000000000000000000000..58c6d2ce4e922098de2d1bcef361dac6ce5f8f1f --- /dev/null +++ b/versatile_diffusion/lib/model_zoo/optimus.py @@ -0,0 +1,719 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np +import numpy.random as npr +import copy + +from lib.model_zoo.common.get_model import get_model, register +from lib.model_zoo.common import utils + +from .optimus_models.tokenization_gpt2 import GPT2Tokenizer + +version = '0' +symbol = 'optimus' + +@register('optimus_vae', version) +class optimus_vae(nn.Module): + """VAE with normal prior""" + def __init__(self, encoder, decoder, tokenizer_encoder, tokenizer_decoder, args): # + super().__init__() + self.encoder = encoder if isinstance(encoder, nn.Module) else get_model()(encoder) + self.decoder = decoder if isinstance(decoder, nn.Module) else get_model()(decoder) + self.tokenizer_encoder = tokenizer_encoder \ + if isinstance(tokenizer_encoder, nn.Module) \ + else get_model()(tokenizer_encoder, verbose=False) + self.tokenizer_decoder = tokenizer_decoder \ + if isinstance(tokenizer_decoder, nn.Module) \ + else get_model()(tokenizer_decoder, verbose=False) + + gpt2_special_tokens_dict = {'pad_token': '', 'bos_token': '', 'eos_token': ''} + if isinstance(self.tokenizer_encoder, GPT2Tokenizer): + self.tokenizer_encoder.add_special_tokens(gpt2_special_tokens_dict) + if isinstance(self.tokenizer_decoder, GPT2Tokenizer): + self.tokenizer_decoder.add_special_tokens(gpt2_special_tokens_dict) + + self.args = args + self.nz = args.latent_size + + self.eos_token_id = self.tokenizer_decoder.convert_tokens_to_ids( + [self.tokenizer_decoder.eos_token])[0] + self.pad_token_id = self.tokenizer_decoder.convert_tokens_to_ids( + [self.tokenizer_decoder.pad_token])[0] + + # connector: from Bert hidden units to the latent space + # self.linear = nn.Linear(args.nz, 2 * args.nz, bias=False) + + # Standard Normal prior + loc = torch.zeros(self.nz) + scale = torch.ones(self.nz) + self.prior = torch.distributions.normal.Normal(loc, scale) + + def connect(self, bert_fea, nsamples=1): + """ + Returns: Tensor1, Tensor2 + Tensor1: the tensor latent z with shape [batch, nsamples, nz] + Tensor2: the tenor of KL for each x with shape [batch] + """ + + # (batch_size, nz) + + mean, logvar = self.encoder.linear(bert_fea).chunk(2, -1) + # pdb.set_trace() + # mean, logvar = mean.squeeze(0), logvar.squeeze(0) + + # (batch, nsamples, nz) + z = self.reparameterize(mean, logvar, nsamples) + KL = 0.5 * (mean.pow(2) + logvar.exp() - logvar - 1).sum(dim=1) + + return z, KL + + def connect_deterministic(self, bert_fea, nsamples=1): + """ + Returns: Tensor1, Tensor2 + Tensor1: the tensor latent z with shape [batch, nsamples, nz] + Tensor2: the tenor of KL for each x with shape [batch] + """ + + # (batch_size, nz) + + mean, logvar = self.encoder.linear(bert_fea).chunk(2, -1) + # pdb.set_trace() + # mean, logvar = mean.squeeze(0), logvar.squeeze(0) + + logvar.fill_(.0) + # (batch, nsamples, nz) + z = self.reparameterize(mean, logvar, nsamples) + KL = 0.5 * (mean.pow(2) + logvar.exp() - logvar - 1).sum(dim=1) + + return z, KL + + def reparameterize(self, mu, logvar, nsamples=1): + """sample from posterior Gaussian family + Args: + mu: Tensor + Mean of gaussian distribution with shape (batch, nz) + logvar: Tensor + logvar of gaussian distibution with shape (batch, nz) + Returns: Tensor + Sampled z with shape (batch, nsamples, nz) + """ + batch_size, nz = mu.size() + std = logvar.mul(0.5).exp() + + mu_expd = mu.unsqueeze(1).expand(batch_size, nsamples, nz) + std_expd = std.unsqueeze(1).expand(batch_size, nsamples, nz) + + eps = torch.zeros_like(std_expd).normal_() + + return mu_expd + torch.mul(eps, std_expd) + + def forward(self, inputs, labels): + + # pdb.set_trace() + + attention_mask=(inputs > 0).float() + # logger.info(inputs) + # logger.info(attention_mask) + # logger.info(labels) + reconstrution_mask=(labels != 50257).float() # 50257 is the padding token for GPT2 + sent_length = torch.sum(reconstrution_mask, dim=1) + + + outputs = self.encoder(inputs, attention_mask) + pooled_hidden_fea = outputs[1] # model outputs are always tuple in pytorch-transformers (see doc) + + if self.args.fb_mode==0: + # Connect hidden feature to the latent space + latent_z, loss_kl = self.connect(pooled_hidden_fea) + latent_z = latent_z.squeeze(1) + + + # Decoding + outputs = self.decoder(input_ids=labels, past=latent_z, labels=labels, label_ignore=self.pad_token_id) + loss_rec = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc) + + elif self.args.fb_mode==1: + # Connect hidden feature to the latent space + mu, logvar = self.encoder.linear(pooled_hidden_fea).chunk(2, -1) + latent_z = self.reparameterize(mu, logvar, nsamples=1) + latent_z = latent_z.squeeze(1) + loss_kl = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1) + kl_mask = (loss_kl > self.args.dim_target_kl).float() + loss_kl = (kl_mask * loss_kl).sum(dim=1) + + # pdb.set_trace() + # past = self.decoder.linear(latent_z) + # Decoding + outputs = self.decoder(input_ids=labels, past=latent_z, labels=labels, label_ignore=self.pad_token_id) + loss_rec = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc) + + elif self.args.fb_mode==2: + # Connect hidden feature to the latent space + latent_z, loss_kl = self.connect_deterministic(pooled_hidden_fea) + latent_z = latent_z.squeeze(1) + + # past = self.decoder.linear(latent_z) + # Decoding + outputs = self.decoder(input_ids=labels, past=latent_z, labels=labels, label_ignore=self.pad_token_id) + loss_rec = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc) + + + # pdb.set_trace() + if self.args.length_weighted_loss: + loss = loss_rec / sent_length + self.args.beta * loss_kl + else: + loss = loss_rec + self.args.beta * loss_kl + + + return loss_rec, loss_kl, loss + + def encoder_sample(self, bert_fea, nsamples): + """sampling from the encoder + Returns: Tensor1 + Tensor1: the tensor latent z with shape [batch, nsamples, nz] + """ + + # (batch_size, nz) + + mu, logvar = self.encoder.linear(bert_fea).chunk(2, -1) + mu, logvar = mu.squeeze(0), logvar.squeeze(0) + + # (batch, nsamples, nz) + z = self.reparameterize(mu, logvar, nsamples) + + return z, (mu, logvar) + + def encode_stats(self, x): + """ + Returns: Tensor1, Tensor2 + Tensor1: the mean of latent z with shape [batch, nz] + Tensor2: the logvar of latent z with shape [batch, nz] + """ + + return self.encoder.encode_stats(x) + + def decode(self, z, strategy, K=10): + """generate samples from z given strategy + Args: + z: [batch, nsamples, nz] + strategy: "beam" or "greedy" or "sample" + K: the beam width parameter + Returns: List1 + List1: a list of decoded word sequence + """ + + if strategy == "beam": + return self.decoder.beam_search_decode(z, K) + elif strategy == "greedy": + return self.decoder.greedy_decode(z) + elif strategy == "sample": + return self.decoder.sample_decode(z) + else: + raise ValueError("the decoding strategy is not supported") + + def reconstruct(self, x, decoding_strategy="greedy", K=5): + """reconstruct from input x + Args: + x: (batch, *) + decoding_strategy: "beam" or "greedy" or "sample" + K: the beam width parameter + Returns: List1 + List1: a list of decoded word sequence + """ + z = self.sample_from_inference(x).squeeze(1) + + return self.decode(z, decoding_strategy, K) + + def log_probability(self, x, z): + """Cross Entropy in the language case + Args: + x: (batch_size, seq_len) + z: (batch_size, n_sample, nz) + Returns: + log_p: (batch_size, n_sample). + log_p(x|z) across different x and z + """ + outputs = self.decoder(input_ids=x, past=z, labels=x, label_ignore=self.pad_token_id) + loss_rec = outputs[0] + return -loss_rec + + def loss_iw(self, x0, x1, nsamples=50, ns=1): + """ + Args: + x: if the data is constant-length, x is the data tensor with + shape (batch, *). Otherwise x is a tuple that contains + the data tensor and length list + Returns: Tensor1, Tensor2, Tensor3 + Tensor1: total loss [batch] + Tensor2: reconstruction loss shape [batch] + Tensor3: KL loss shape [batch] + """ + + # encoding into bert features + bert_fea = self.encoder(x0)[1] + + # (batch_size, nz) + + mu, logvar = self.encoder.linear(bert_fea).chunk(2, -1) + + + ################## + # compute KL + ################## + # pdb.set_trace() + KL = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1).sum(dim=1) + + # mu, logvar = mu.squeeze(0), logvar.squeeze(0) + ll_tmp, rc_tmp = [], [] + for _ in range(int(nsamples / ns)): + + # (batch, nsamples, nz) + z = self.reparameterize(mu, logvar, ns) + # past = self.decoder.linear(z) + past = z + + # [batch, nsamples] + log_prior = self.eval_prior_dist(z) + log_gen = self.eval_cond_ll(x1, past) + log_infer = self.eval_inference_dist(z, (mu, logvar)) + + # pdb.set_trace() + log_gen = log_gen.unsqueeze(0).contiguous().view(z.shape[0],-1) + + + # pdb.set_trace() + rc_tmp.append(log_gen) + ll_tmp.append(log_gen + log_prior - log_infer) + + + + log_prob_iw = log_sum_exp(torch.cat(ll_tmp, dim=-1), dim=-1) - math.log(nsamples) + log_gen_iw = torch.mean(torch.cat(rc_tmp, dim=-1), dim=-1) + + return log_prob_iw, log_gen_iw , KL + + def nll_iw(self, x0, x1, nsamples, ns=1): + """compute the importance weighting estimate of the log-likelihood + Args: + x0, x1: two different tokenization results of x, where x is the data tensor with shape (batch, *). + nsamples: Int + the number of samples required to estimate marginal data likelihood + Returns: Tensor1 + Tensor1: the estimate of log p(x), shape [batch] + """ + + # compute iw every ns samples to address the memory issue + # nsamples = 500, ns = 100 + # nsamples = 500, ns = 10 + + # TODO: note that x is forwarded twice in self.encoder.sample(x, ns) and self.eval_inference_dist(x, z, param) + #. this problem is to be solved in order to speed up + + tmp = [] + for _ in range(int(nsamples / ns)): + # [batch, ns, nz] + + # Chunyuan: + # encoding into bert features + pooled_hidden_fea = self.encoder(x0)[1] + + # param is the parameters required to evaluate q(z|x) + z, param = self.encoder_sample(pooled_hidden_fea, ns) + + # [batch, ns] + log_comp_ll = self.eval_complete_ll(x1, z) + log_infer_ll = self.eval_inference_dist(z, param) + + tmp.append(log_comp_ll - log_infer_ll) + + ll_iw = log_sum_exp(torch.cat(tmp, dim=-1), dim=-1) - math.log(nsamples) + + return ll_iw + + def KL(self, x): + _, KL = self.encode(x, 1) + + return KL + + def eval_prior_dist(self, zrange): + """perform grid search to calculate the true posterior + Args: + zrange: tensor + different z points that will be evaluated, with + shape (k^2, nz), where k=(zmax - zmin)/space + """ + + # (k^2) + return self.prior.log_prob(zrange).sum(dim=-1) + + def eval_complete_ll(self, x, z): + """compute log p(z,x) + Args: + x: Tensor + input with shape [batch, seq_len] + z: Tensor + evaluation points with shape [batch, nsamples, nz] + Returns: Tensor1 + Tensor1: log p(z,x) Tensor with shape [batch, nsamples] + """ + + # [batch, nsamples] + log_prior = self.eval_prior_dist(z) + log_gen = self.eval_cond_ll(x, z) + + return log_prior + log_gen + + def eval_cond_ll(self, x, z): + """compute log p(x|z) + """ + x_shape = list(x.size()) + z_shape = list(z.size()) + if len(z_shape) == 3: + x = x.unsqueeze(1).repeat(1, z_shape[1], 1).contiguous().view(x_shape[0]*z_shape[1], x_shape[-1]) + z = z.contiguous().view(x_shape[0]*z_shape[1], z_shape[-1]) + + return self.log_probability(x, z) + + def eval_log_model_posterior(self, x, grid_z): + """perform grid search to calculate the true posterior + this function computes p(z|x) + Args: + grid_z: tensor + different z points that will be evaluated, with + shape (k^2, nz), where k=(zmax - zmin)/pace + Returns: Tensor + Tensor: the log posterior distribution log p(z|x) with + shape [batch_size, K^2] + """ + try: + batch_size = x.size(0) + except: + batch_size = x[0].size(0) + + # (batch_size, k^2, nz) + grid_z = grid_z.unsqueeze(0).expand(batch_size, *grid_z.size()).contiguous() + + # (batch_size, k^2) + log_comp = self.eval_complete_ll(x, grid_z) + + # normalize to posterior + log_posterior = log_comp - log_sum_exp(log_comp, dim=1, keepdim=True) + + return log_posterior + + def sample_from_inference(self, x, nsamples=1): + """perform sampling from inference net + Returns: Tensor + Tensor: samples from infernece nets with + shape (batch_size, nsamples, nz) + """ + z, _ = self.encoder.sample(x, nsamples) + + return z + + def sample_from_posterior(self, x, nsamples): + """perform MH sampling from model posterior + Returns: Tensor + Tensor: samples from model posterior with + shape (batch_size, nsamples, nz) + """ + + # use the samples from inference net as initial points + # for MCMC sampling. [batch_size, nsamples, nz] + cur = self.encoder.sample_from_inference(x, 1) + cur_ll = self.eval_complete_ll(x, cur) + total_iter = self.args.mh_burn_in + nsamples * self.args.mh_thin + samples = [] + for iter_ in range(total_iter): + next = torch.normal(mean=cur, + std=cur.new_full(size=cur.size(), fill_value=self.args.mh_std)) + # [batch_size, 1] + next_ll = self.eval_complete_ll(x, next) + ratio = next_ll - cur_ll + + accept_prob = torch.min(ratio.exp(), ratio.new_ones(ratio.size())) + + uniform_t = accept_prob.new_empty(accept_prob.size()).uniform_() + + # [batch_size, 1] + mask = (uniform_t < accept_prob).float() + mask_ = mask.unsqueeze(2) + + cur = mask_ * next + (1 - mask_) * cur + cur_ll = mask * next_ll + (1 - mask) * cur_ll + + if iter_ >= self.args.mh_burn_in and (iter_ - self.args.mh_burn_in) % self.args.mh_thin == 0: + samples.append(cur.unsqueeze(1)) + + return torch.cat(samples, dim=1) + + def calc_model_posterior_mean(self, x, grid_z): + """compute the mean value of model posterior, i.e. E_{z ~ p(z|x)}[z] + Args: + grid_z: different z points that will be evaluated, with + shape (k^2, nz), where k=(zmax - zmin)/pace + x: [batch, *] + Returns: Tensor1 + Tensor1: the mean value tensor with shape [batch, nz] + """ + + # [batch, K^2] + log_posterior = self.eval_log_model_posterior(x, grid_z) + posterior = log_posterior.exp() + + # [batch, nz] + return torch.mul(posterior.unsqueeze(2), grid_z.unsqueeze(0)).sum(1) + + def calc_infer_mean(self, x): + """ + Returns: Tensor1 + Tensor1: the mean of inference distribution, with shape [batch, nz] + """ + + mean, logvar = self.encoder.forward(x) + + return mean + + def eval_inference_dist(self, z, param): + """this function computes log q(z | x) + Args: + z: tensor + different z points that will be evaluated, with + shape [batch, nsamples, nz] + Returns: Tensor1 + Tensor1: log q(z|x) with shape [batch, nsamples] + """ + + nz = z.size(2) + mu, logvar = param + + # (batch_size, 1, nz) + mu, logvar = mu.unsqueeze(1), logvar.unsqueeze(1) + var = logvar.exp() + + # (batch_size, nsamples, nz) + dev = z - mu + + # (batch_size, nsamples) + log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \ + 0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1)) + + return log_density + + def calc_mi(self, test_data_batch, args): + # calc_mi_v3 + import math + from modules.utils import log_sum_exp + + mi = 0 + num_examples = 0 + + mu_batch_list, logvar_batch_list = [], [] + neg_entropy = 0. + for batch_data in test_data_batch: + + x0, _, _ = batch_data + x0 = x0.to(args.device) + + # encoding into bert features + bert_fea = self.encoder(x0)[1] + + (batch_size, nz) + mu, logvar = self.encoder.linear(bert_fea).chunk(2, -1) + + x_batch, nz = mu.size() + + #print(x_batch, end=' ') + + num_examples += x_batch + + # E_{q(z|x)}log(q(z|x)) = -0.5*nz*log(2*\pi) - 0.5*(1+logvar).sum(-1) + + neg_entropy += (-0.5 * nz * math.log(2 * math.pi)- 0.5 * (1 + logvar).sum(-1)).sum().item() + mu_batch_list += [mu.cpu()] + logvar_batch_list += [logvar.cpu()] + + pdb.set_trace() + + neg_entropy = neg_entropy / num_examples + ##print() + + num_examples = 0 + log_qz = 0. + for i in range(len(mu_batch_list)): + ############### + # get z_samples + ############### + mu, logvar = mu_batch_list[i].cuda(), logvar_batch_list[i].cuda() + + # [z_batch, 1, nz] + + z_samples = self.reparameterize(mu, logvar, 1) + + z_samples = z_samples.view(-1, 1, nz) + num_examples += z_samples.size(0) + + ############### + # compute density + ############### + # [1, x_batch, nz] + #mu, logvar = mu_batch_list[i].cuda(), logvar_batch_list[i].cuda() + #indices = list(np.random.choice(np.arange(len(mu_batch_list)), 10)) + [i] + indices = np.arange(len(mu_batch_list)) + mu = torch.cat([mu_batch_list[_] for _ in indices], dim=0).cuda() + logvar = torch.cat([logvar_batch_list[_] for _ in indices], dim=0).cuda() + x_batch, nz = mu.size() + + mu, logvar = mu.unsqueeze(0), logvar.unsqueeze(0) + var = logvar.exp() + + # (z_batch, x_batch, nz) + dev = z_samples - mu + + # (z_batch, x_batch) + log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \ + 0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1)) + + # log q(z): aggregate posterior + # [z_batch] + log_qz += (log_sum_exp(log_density, dim=1) - math.log(x_batch)).sum(-1) + + log_qz /= num_examples + mi = neg_entropy - log_qz + + return mi + + def calc_au(self, eval_dataloader, args, delta=0.01): + """compute the number of active units + """ + cnt = 0 + for batch_data in eval_dataloader: + + x0, _, _ = batch_data + x0 = x0.to(args.device) + + # encoding into bert features + bert_fea = self.encoder(x0)[1] + + # (batch_size, nz) + mean, logvar = self.encoder.linear(bert_fea).chunk(2, -1) + + if cnt == 0: + means_sum = mean.sum(dim=0, keepdim=True) + else: + means_sum = means_sum + mean.sum(dim=0, keepdim=True) + cnt += mean.size(0) + + # (1, nz) + mean_mean = means_sum / cnt + + cnt = 0 + for batch_data in eval_dataloader: + + x0, _, _ = batch_data + x0 = x0.to(args.device) + + # encoding into bert features + bert_fea = self.encoder(x0)[1] + + # (batch_size, nz) + mean, _ = self.encoder.linear(bert_fea).chunk(2, -1) + + if cnt == 0: + var_sum = ((mean - mean_mean) ** 2).sum(dim=0) + else: + var_sum = var_sum + ((mean - mean_mean) ** 2).sum(dim=0) + cnt += mean.size(0) + + # (nz) + au_var = var_sum / (cnt - 1) + + return (au_var >= delta).sum().item(), au_var + +from .optimus_models.optimus_bert import BertForLatentConnector_XX + +@register('optimus_bert_connector', version) +class optimus_bert_connector(BertForLatentConnector_XX): + pass + +from .optimus_models.tokenization_bert import BertTokenizer + +@register('optimus_bert_tokenizer', version) +class optimus_bert_tokenizer(BertTokenizer): + pass + +from .optimus_models.optimus_gpt2 import GPT2ForLatentConnector_XX + +@register('optimus_gpt2_connector', version) +class optimus_gpt2_connector(GPT2ForLatentConnector_XX): + pass + +from .optimus_models.tokenization_gpt2 import GPT2Tokenizer + +@register('optimus_gpt2_tokenizer', version) +class optimus_gpt2_tokenizer(GPT2Tokenizer): + pass + +############################## +# some helpers for inference # +############################## + +def sample_single_sequence_conditional( + model, + context, + past=None, + temperature=1, + top_k=0, + top_p=0.0, + eos_token=50829, + max_length=30, ): + + past = past.unsqueeze(0) + generated = context.unsqueeze(0) + with torch.no_grad(): + while True: + # for _ in trange(length): + inputs = {'input_ids': generated, 'past': past} + outputs = model(**inputs) + next_token_logits = outputs[0][0, -1, :] / temperature + filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) + next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) + generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1) + if next_token[0].item() == eos_token: + break + if generated.shape[1] >= max_length: + generated[0, -1] = eos_token + break + return generated.squeeze(0) + +def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): + """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (vocabulary size) + top_k > 0: keep only top k tokens with highest probability (top-k filtering). + top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear + top_k = min(top_k, logits.size(-1)) # Safety check + if top_k > 0: + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p > 0.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + indices_to_remove = sorted_indices[sorted_indices_to_remove] + logits[indices_to_remove] = filter_value + return logits \ No newline at end of file diff --git a/versatile_diffusion/lib/model_zoo/sd.py b/versatile_diffusion/lib/model_zoo/sd.py new file mode 100644 index 0000000000000000000000000000000000000000..096d531036b97ca5540dbd166db654af09776f8d --- /dev/null +++ b/versatile_diffusion/lib/model_zoo/sd.py @@ -0,0 +1,706 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import numpy.random as npr +import copy +from functools import partial +from contextlib import contextmanager +from lib.model_zoo.common.get_model import get_model, register +from lib.log_service import print_log + +version = '0' +symbol = 'sd' + +from .diffusion_utils import \ + count_params, extract_into_tensor, make_beta_schedule +from .distributions import normal_kl, DiagonalGaussianDistribution +from .ema import LitEma + +def highlight_print(info): + print_log('') + print_log(''.join(['#']*(len(info)+4))) + print_log('# '+info+' #') + print_log(''.join(['#']*(len(info)+4))) + print_log('') + +class DDPM(nn.Module): + def __init__(self, + unet_config, + timesteps=1000, + use_ema=True, + + beta_schedule="linear", + beta_linear_start=1e-4, + beta_linear_end=2e-2, + loss_type="l2", + + clip_denoised=True, + cosine_s=8e-3, + given_betas=None, + + l_simple_weight=1., + original_elbo_weight=0., + + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + parameterization="eps", + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0, ): + + super().__init__() + assert parameterization in ["eps", "x0"], \ + 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + highlight_print("Running in {} mode".format(self.parameterization)) + + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.use_positional_encodings = use_positional_encodings + + from collections import OrderedDict + self.model = nn.Sequential(OrderedDict([('diffusion_model', get_model()(unet_config))])) + # TODO: Remove this ugly trick to match SD with deprecated version, after no bug with the module. + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + print_log(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.v_posterior = v_posterior + self.l_simple_weight = l_simple_weight + self.original_elbo_weight = original_elbo_weight + + self.register_schedule( + given_betas=given_betas, + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=beta_linear_start, + linear_end=beta_linear_end, + cosine_s=cosine_s) + + self.loss_type = loss_type + self.learn_logvar = learn_logvar + self.logvar = torch.full( + fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + def register_schedule(self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3): + if given_betas is not None: + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, \ + 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print_log(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print_log(f"{context}: Restored training weights") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + value1 = extract_into_tensor( + self.sqrt_recip_alphas_cumprod, t, x_t.shape) + value2 = extract_into_tensor( + self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + return value1*x_t -value2*noise + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = torch.randn_like(x_start) if noise is None else noise + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + else: + raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + +@register('sd_t2i', version) +class SD_T2I(DDPM): + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_trainable=False, + scale_factor=1.0, + scale_by_std=False, + *args, + **kwargs): + self.num_timesteps_cond = num_timesteps_cond \ + if num_timesteps_cond is not None else 1 + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + + super().__init__(*args, **kwargs) + + self.first_stage_model = get_model()(first_stage_config) + self.cond_stage_model = get_model()(cond_stage_config) + + self.concat_mode = 'crossattn' + self.cond_stage_trainable = cond_stage_trainable + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.device = 'cpu' + + def to(self, device): + self.device = device + super().to(device) + + @torch.no_grad() + def on_train_batch_start(self, x): + # only for very first batch + if self.scale_by_std: + assert self.scale_factor == 1., \ + 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + highlight_print("setting self.scale_factor to {}".format(self.scale_factor)) + + def register_schedule(self, + given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @torch.no_grad() + def encode_image(self, im): + encoder_posterior = self.first_stage_model.encode(im) + z = self.get_first_stage_encoding(encoder_posterior).detach() + return z + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z + + @torch.no_grad() + def decode_image(self, z, predict_cids=False, force_not_quantize=False): + z = 1. / self.scale_factor * z + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_text(self, text): + return self.get_learned_conditioning(text) + + def get_learned_conditioning(self, c): + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + return c + + def forward(self, x, c, noise=None): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=x.device).long() + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c) + return self.p_losses(x, c, t, noise) + + def apply_model(self, x_noisy, t, cond): + return self.model.diffusion_model(x_noisy, t, cond) + + def p_losses(self, x_start, cond, t, noise=None): + noise = torch.randn_like(x_start) if noise is None else noise + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict['loss_simple'] = loss_simple.mean() + + logvar_t = self.logvar[t].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + + if self.learn_logvar: + loss_dict['loss_gamma'] = loss.mean() + loss_dict['logvar' ] = self.logvar.data.mean() + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict['loss_vlb'] = loss_vlb + + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({'Loss': loss}) + + return loss, loss_dict + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, + return_x0=False, score_corrector=None, corrector_kwargs=None): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, + return_codebook_ids=False, quantize_denoised=False, return_x0=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, + img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., + score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, + log_every_t=None): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', + total=timesteps) if verbose else reversed( + range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, return_x0=True, + temperature=temperature[i], noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: callback(i) + if img_callback: img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop(self, cond, shape, return_intermediates=False, + x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, start_T=None, + log_every_t=None): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( + range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, + verbose=True, timesteps=None, quantize_denoised=False, + mask=None, x0=None, shape=None,**kwargs): + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, x_T=x_T, + verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, + mask=mask, x0=x0) + +@register('sd_variation', version) +class SD_Variation(SD_T2I): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def is_part_of_trans(name): + if name.find('.1.norm')!=-1: + return True + if name.find('.1.proj_in')!=-1: + return True + if name.find('.1.transformer_blocks')!=-1: + return True + if name.find('.1.proj_out')!=-1: + return True + return False + + self.parameter_group = { + 'transformers' : [v for n, v in self.model.named_parameters() if is_part_of_trans(n)], + 'other' :[v for n, v in self.model.named_parameters() if not is_part_of_trans(n)], + } + + self.encode_image = None + self.encode_text = None + self._predict_eps_from_xstart = None + self._prior_bpd = None + self.p_mean_variance = None + self.p_sample = None + self.progressive_denoising = None + self.p_sample_loop = None + self.sample = None + + @torch.no_grad() + def encode_input(self, im): + encoder_posterior = self.first_stage_model.encode(im) + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError("Encoder_posterior of type '{}' not yet implemented".format(type(encoder_posterior))) + return z * self.scale_factor + + @torch.no_grad() + def decode_latent(self, z): + z = 1. / self.scale_factor * z + return self.first_stage_model.decode(z) + + @torch.no_grad() + def clip_encode_vision(self, vision): + if isinstance(vision, list): + if not isinstance(vision[0], torch.Tensor): + import torchvision.transforms as tvtrans + vision = [tvtrans.ToTensor()(i) for i in vision] + vh = torch.stack(vision) + elif isinstance(vision, torch.Tensor): + vh = vision.unsqueeze(0) if (vision.shape==3) else vision + assert len(vh.shape) == 4 + else: + raise ValueError + vh = vh.to(self.device) + return self.encode_conditioning(vh) + + def encode_conditioning(self, c): + return self.cond_stage_model.encode(c) + + def forward(self, x, c, noise=None): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=x.device).long() + if self.cond_stage_trainable: + c = self.encode_conditioning(c) + return self.p_losses(x, c, t, noise) diff --git a/versatile_diffusion/lib/model_zoo/vd.py b/versatile_diffusion/lib/model_zoo/vd.py new file mode 100644 index 0000000000000000000000000000000000000000..00d29523d0dd341f132d7b079ae0a948554bfe21 --- /dev/null +++ b/versatile_diffusion/lib/model_zoo/vd.py @@ -0,0 +1,442 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import numpy.random as npr +import copy +from functools import partial +from contextlib import contextmanager +from lib.model_zoo.common.get_model import get_model, register +from lib.log_service import print_log + +version = '0' +symbol = 'vd' + +from .diffusion_utils import \ + count_params, extract_into_tensor, make_beta_schedule +from .distributions import normal_kl, DiagonalGaussianDistribution + +from .autoencoder import AutoencoderKL +from .ema import LitEma + +from .sd import highlight_print, DDPM, SD_T2I + +@register('vd_basic', version) +class VD_Basic(SD_T2I): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def is_part_of_crossattn(name): + if name.find('.1.norm')!=-1: + return True + if name.find('.1.proj_in')!=-1: + return True + if name.find('.1.transformer_blocks')!=-1: + return True + if name.find('.1.proj_out')!=-1: + return True + return False + + self.parameter_group = { + 'context' :[v for n, v in self.model.named_parameters() if is_part_of_crossattn(n)], + 'data' :[v for n, v in self.model.named_parameters() if not is_part_of_crossattn(n)], + } + + self.encode_image = None + self.encode_text = None + self._predict_eps_from_xstart = None + self._prior_bpd = None + self.p_mean_variance = None + self.p_sample = None + self.progressive_denoising = None + self.p_sample_loop = None + self.sample = None + + @torch.no_grad() + def encode_input(self, im): + encoder_posterior = self.first_stage_model.encode(im) + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError("Encoder_posterior of type '{}' not yet implemented".format(type(encoder_posterior))) + return z * self.scale_factor + + @torch.no_grad() + def decode_latent(self, z): + z = 1. / self.scale_factor * z + return self.first_stage_model.decode(z) + + @torch.no_grad() + def clip_encode_vision(self, vision, encode_type='encode_vision'): + clip_encode_type = self.cond_stage_model.encode_type + self.cond_stage_model.encode_type = encode_type + if isinstance(vision, torch.Tensor): + vision = ((vision+1)/2).to('cpu').numpy() + vision = np.transpose(vision, (0, 2, 3, 1)) + vision = [vi for vi in vision] + + embedding = self.encode_conditioning(vision) + self.cond_stage_model.encode_type = clip_encode_type + return embedding + + def encode_conditioning(self, c): + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + return c + + # legacy + def get_learned_conditioning(self, c): + return self.encode_conditioning(c) + + def forward(self, x, c, noise=None): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=x.device).long() + if self.cond_stage_trainable: + c = self.encode_conditioning(c) + return self.p_losses(x, c, t, noise) + +@register('vd_dc', version) +class VD_DualContext(SD_T2I): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def is_part_of_trans(name): + if name.find('.1.norm')!=-1: + return True + if name.find('.1.proj_in')!=-1: + return True + if name.find('.1.transformer_blocks')!=-1: + return True + if name.find('.1.proj_out')!=-1: + return True + return False + + self.parameter_group = { + 'transformers' : [v for n, v in self.model.named_parameters() if is_part_of_trans(n)], + 'other' :[v for n, v in self.model.named_parameters() if not is_part_of_trans(n)], + } + + def apply_model(self, x_noisy, t, cond, cond_type): + if cond_type in ['prompt', 'text']: + which_attn = 0 + elif cond_type in ['vision', 'visual', 'image']: + which_attn = 1 + elif isinstance(cond_type, float): + assert 0 < cond_type < 1, \ + 'A special cond_type that will doing a random mix between two input condition, '\ + 'rand() < cond_type is text, else visual' + which_attn = cond_type + else: + assert False + return self.model.diffusion_model(x_noisy, t, cond, which_attn=which_attn) + + def p_losses(self, x_start, cond, t, noise=None, cond_type=None): + noise = torch.randn_like(x_start) if noise is None else noise + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond, cond_type=cond_type) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict['loss_simple'] = loss_simple.mean() + + logvar_t = self.logvar[t].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + + if self.learn_logvar: + loss_dict['loss_gamma'] = loss.mean() + loss_dict['logvar' ] = self.logvar.data.mean() + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict['loss_vlb'] = loss_vlb + + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({'Loss': loss}) + + return loss, loss_dict + + @torch.no_grad() + def clip_encode_text(self, text): + clip_encode_type = self.cond_stage_model.encode_type + self.cond_stage_model.encode_type = 'encode_text' + embedding = self.get_learned_conditioning(text) + self.cond_stage_model.encode_type = clip_encode_type + return embedding + + @torch.no_grad() + def clip_encode_vision(self, vision, encode_type='encode_vision'): + clip_encode_type = self.cond_stage_model.encode_type + self.cond_stage_model.encode_type = encode_type + if isinstance(vision, torch.Tensor): + vision = ((vision+1)/2).to('cpu').numpy() + vision = np.transpose(vision, (0, 2, 3, 1)) + vision = [vi for vi in vision] + embedding = self.get_learned_conditioning(vision) + self.cond_stage_model.encode_type = clip_encode_type + return embedding + + def get_learned_conditioning(self, c): + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + return c + + def forward(self, x, c, noise=None, cond_type=None): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=x.device).long() + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c) + return self.p_losses(x, c, t, noise, cond_type=cond_type) + +@register('vd', version) +class VD(DDPM): + def __init__(self, + autokl_cfg, + optimus_cfg, + clip_cfg, + scale_factor=1.0, + scale_by_std=False, + *args, + **kwargs): + self.scale_by_std = scale_by_std + super().__init__(*args, **kwargs) + + self.autokl = get_model()(autokl_cfg) + self.optimus = get_model()(optimus_cfg) + self.clip = get_model()(clip_cfg) + + self.concat_mode = 'crossattn' + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.device = 'cpu' + self.parameter_group = self.create_parameter_group() + + def create_parameter_group(self): + def is_part_of_unet_image(name): + if name.find('.unet_image.')!=-1: + return True + return False + def is_part_of_unet_text(name): + if name.find('.unet_text.')!=-1: + return True + return False + def is_part_of_trans(name): + if name.find('.1.norm')!=-1: + return True + if name.find('.1.proj_in')!=-1: + return True + if name.find('.1.transformer_blocks')!=-1: + return True + if name.find('.1.proj_out')!=-1: + return True + return False + parameter_group = { + 'image_trans' : [], + 'image_rest' : [], + 'text_trans' : [], + 'text_rest' : [], + 'rest' : [],} + for pname, para in self.model.named_parameters(): + if is_part_of_unet_image(pname): + if is_part_of_trans(pname): + parameter_group['image_trans'].append(para) + else: + parameter_group['image_rest'].append(para) + elif is_part_of_unet_text(pname): + if is_part_of_trans(pname): + parameter_group['text_trans'].append(para) + else: + parameter_group['text_rest'].append(para) + else: + parameter_group['rest'].append(para) + + return parameter_group + + def to(self, device): + self.device = device + super().to(device) + + @torch.no_grad() + def on_train_batch_start(self, x): + # only for very first batch + if self.scale_by_std: + assert self.scale_factor == 1., \ + 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + highlight_print("setting self.scale_factor to {}".format(self.scale_factor)) + + @torch.no_grad() + def autokl_encode(self, image): + encoder_posterior = self.autokl.encode(image) + z = encoder_posterior.sample() + return self.scale_factor * z + + @torch.no_grad() + def autokl_decode(self, z): + z = 1. / self.scale_factor * z + return self.autokl.decode(z) + + def mask_tokens(inputs, tokenizer, args): + labels = inputs.clone() + # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) + + masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).to(torch.uint8) + labels[masked_indices==1] = -1 # We only compute loss on masked tokens + + # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) + indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).to(torch.uint8) & masked_indices + inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) + + # 10% of the time, we replace masked input tokens with random word + indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).to(torch.uint8) & masked_indices & ~indices_replaced + indices_random = indices_random + random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long) + inputs[indices_random] = random_words[indices_random] + + # The rest of the time (10% of the time) we keep the masked input tokens unchanged + return inputs, labels + + @torch.no_grad() + def optimus_encode(self, text): + tokenizer = self.optimus.tokenizer_encoder + token = [tokenizer.tokenize(sentence.lower()) for sentence in text] + token_id = [] + for tokeni in token: + token_sentence = [tokenizer._convert_token_to_id(i) for i in tokeni] + token_sentence = tokenizer.add_special_tokens_single_sentence(token_sentence) + token_id.append(torch.LongTensor(token_sentence)) + token_id = torch._C._nn.pad_sequence(token_id, batch_first=True, padding_value=0.0) + token_id = token_id.to(self.device) + z = self.optimus.encoder(token_id, attention_mask=(token_id > 0).float())[1] + z_mu, z_logvar = self.optimus.encoder.linear(z).chunk(2, -1) + # z_sampled = self.optimus.reparameterize(z_mu, z_logvar, 1) + return z_mu.squeeze(1) + + @torch.no_grad() + def optimus_decode(self, z, temperature=1.0): + bos_token = self.optimus.tokenizer_decoder.encode('') + eos_token = self.optimus.tokenizer_decoder.encode('') + context_tokens = torch.LongTensor(bos_token).to(z.device) + + from .optimus import sample_single_sequence_conditional + sentenses = [] + for zi in z: + out = sample_single_sequence_conditional( + model=self.optimus.decoder, + context=context_tokens, + past=zi, temperature=temperature, + top_k=0, top_p=1.0, + max_length=30, + eos_token = eos_token[0],) + text = self.optimus.tokenizer_decoder.decode(out.tolist(), clean_up_tokenization_spaces=True) + text = text.split()[1:-1] + text = ' '.join(text) + sentenses.append(text) + return sentenses + + @torch.no_grad() + def clip_encode_text(self, text, encode_type='encode_text'): + swap_type = self.clip.encode_type + self.clip.encode_type = encode_type + embedding = self.clip.encode(text) + self.clip.encode_type = swap_type + return embedding + + @torch.no_grad() + def clip_encode_vision(self, vision, encode_type='encode_vision'): + swap_type = self.clip.encode_type + self.clip.encode_type = encode_type + if isinstance(vision, torch.Tensor): + vision = ((vision+1)/2).to('cpu').numpy() + vision = np.transpose(vision, (0, 2, 3, 1)) + vision = [vi for vi in vision] + embedding = self.clip.encode(vision) + self.clip.encode_type = swap_type + return embedding + + def forward(self, x, c, noise=None, xtype='image', ctype='prompt'): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=x.device).long() + return self.p_losses(x, c, t, noise, xtype, ctype) + + def apply_model(self, x_noisy, t, cond, xtype='image', ctype='prompt'): + return self.model.diffusion_model(x_noisy, t, cond, xtype, ctype) + + def get_image_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + return loss + + def get_text_loss(self, pred, target): + if self.loss_type == 'l1': + loss = (target - pred).abs() + elif self.loss_type == 'l2': + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + return loss + + def p_losses(self, x_start, cond, t, noise=None, xtype='image', ctype='prompt'): + noise = torch.randn_like(x_start) if noise is None else noise + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond, xtype, ctype) + + loss_dict = {} + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + else: + raise NotImplementedError() + + if xtype == 'image': + loss_simple = self.get_image_loss(model_output, target, mean=False).mean([1, 2, 3]) + elif xtype == 'text': + loss_simple = self.get_text_loss(model_output, target).mean([1]) + + logvar_t = self.logvar[t].to(self.device) + if logvar_t.sum().item() != 0: + assert False, "Default SD training has logvar fixed at 0" + if self.learn_logvar: + assert False, "Default SD training don't learn logvar" + if self.l_simple_weight != 1: + assert False, "Default SD training always set l_simple_weight==1" + + loss = loss_simple.mean() + loss_dict['loss_simple'] = loss_simple.mean().item() + loss_dict['Loss'] = loss.item() + return loss, loss_dict + + def apply_model_dc(self, x_noisy, t, first_c, second_c, xtype='image', first_ctype='vision', second_ctype='prompt', mixed_ratio=0.5): + return self.model.diffusion_model.forward_dc(x_noisy, t, first_c, second_c, xtype, first_ctype, second_ctype, mixed_ratio) \ No newline at end of file