| import os |
| import os.path as osp |
| import numpy as np |
| import numpy.random as npr |
| import torch |
| import torch.distributed as dist |
| import torchvision.transforms as tvtrans |
| import PIL.Image |
| PIL.Image.MAX_IMAGE_PIXELS = None |
| import math |
| import json |
| import copy |
| import pickle |
| from multiprocessing import shared_memory |
| import time |
| from .common import * |
| from ..log_service import print_log |
|
|
| from lib import visual_service as vis |
| from .. import sync |
|
|
| import webdataset as wds |
|
|
| |
| |
| |
|
|
| @regdataset() |
| class laion2b_dummy(ds_base): |
| def init_load_info(self): |
| self.load_info = [] |
|
|
| @regdataset() |
| class laion2b_webdataset(ds_base): |
| def init_load_info(self): |
| self.load_info = [] |
|
|
| def make_loader(self, batch_size, num_workers, train=True): |
| cfg = self.cfg |
| self.root_dir = cfg.root_dir |
|
|
| interpolation_mode = tvtrans.InterpolationMode.BICUBIC |
| if train: |
| trans = [ |
| tvtrans.Resize(cfg.scale, interpolation=interpolation_mode), |
| tvtrans.RandomCrop(cfg.scale), |
| tvtrans.ToTensor(),] |
| else: |
| trans = [ |
| tvtrans.Resize(cfg.scale, interpolation=interpolation_mode), |
| tvtrans.CenterCrop(cfg.scale), |
| tvtrans.ToTensor(),] |
|
|
| trans = tvtrans.Compose(trans) |
|
|
| trans_dict = {'jpg': trans} |
| postprocess = customized_postprocess |
|
|
| shuffle = cfg.get('shuffle', 10000) |
| shardshuffle = shuffle > 0 |
| node_world_size = sync.get_world_size('node') |
| nodesplitter = wds.shardlists.split_by_node \ |
| if node_world_size==1 else wds.shardlists.single_node_only |
|
|
| tars = [osp.join(self.root_dir, 'data', i) for i in os.listdir(osp.join(self.root_dir, 'data')) |
| if osp.splitext(i)[1]=='.tar'] |
| tars = sorted(tars) |
|
|
| dset = wds.WebDataset( |
| tars, |
| nodesplitter=nodesplitter, |
| shardshuffle=shardshuffle, |
| handler=wds.warn_and_continue).repeat().shuffle(shuffle) |
|
|
| print_log(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.') |
| self.min_size = cfg.get('min_size', None) |
| self.max_pwatermark = cfg.get('max_pwatermark', None) |
| dset = (dset |
| .select(self.filter_keys) |
| .decode('pil', handler=wds.warn_and_continue) |
| .select(self.filter_size) |
| .map_dict(**trans_dict, handler=wds.warn_and_continue)) |
|
|
| if postprocess is not None: |
| dset = dset.map(postprocess) |
| |
| dset.batched(batch_size, partial=False) |
| |
| loader = wds.WebLoader( |
| dset, |
| batch_size=None, |
| shuffle=False, |
| num_workers=num_workers, ) |
| return loader |
|
|
| def filter_size(self, x): |
| try: |
| valid = True |
| if self.min_size is not None and self.min_size > 1: |
| try: |
| valid = valid and x['json']['original_width'] >= self.min_size and \ |
| x['json']['original_height'] >= self.min_size |
| except Exception: |
| valid = False |
| if self.max_pwatermark is not None and self.max_pwatermark < 1.0: |
| try: |
| valid = valid and x['json']['pwatermark'] <= self.max_pwatermark |
| except Exception: |
| valid = False |
| return valid |
| except Exception: |
| return False |
|
|
| def filter_keys(self, x): |
| try: |
| return ("jpg" in x) and ("txt" in x) |
| except Exception: |
| return False |
|
|
| def train_dataloader(self): |
| return self.make_loader(self.train) |
|
|
| def val_dataloader(self): |
| return self.make_loader(self.validation, train=False) |
|
|
| def test_dataloader(self): |
| return self.make_loader(self.test, train=False) |
|
|
| def customized_postprocess(element): |
| return element['jpg']*2-1, element['txt'], element['__key__'] |
|
|
| def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True): |
| keys = set.intersection(*[set(sample.keys()) for sample in samples]) |
| batched = {key: [] for key in keys} |
|
|
| for s in samples: |
| [batched[key].append(s[key]) for key in batched] |
|
|
| result = {} |
| for key in batched: |
| if isinstance(batched[key][0], (int, float)): |
| if combine_scalars: |
| result[key] = np.array(list(batched[key])) |
| elif isinstance(batched[key][0], torch.Tensor): |
| if combine_tensors: |
| result[key] = torch.stack(list(batched[key])) |
| elif isinstance(batched[key][0], np.ndarray): |
| if combine_tensors: |
| result[key] = np.array(list(batched[key])) |
| else: |
| result[key] = list(batched[key]) |
| return result |
|
|
| |
| |
| |
|
|
| def customized_postprocess_sdofficial(element): |
| return { |
| 'jpg': element['jpg']*2-1, |
| 'txt': element['txt'], } |
|
|
| @regdataset() |
| class laion2b_webdataset_sdofficial(laion2b_webdataset): |
| def make_loader(self, batch_size, num_workers, train=True): |
| cfg = self.cfg |
| self.root_dir = cfg.root_dir |
|
|
| interpolation_mode = tvtrans.InterpolationMode.BICUBIC |
| if train: |
| trans = [ |
| tvtrans.Resize(cfg.scale, interpolation=interpolation_mode), |
| tvtrans.RandomCrop(cfg.scale), |
| tvtrans.ToTensor(),] |
| else: |
| trans = [ |
| tvtrans.Resize(cfg.scale, interpolation=interpolation_mode), |
| tvtrans.CenterCrop(cfg.scale), |
| tvtrans.ToTensor(),] |
|
|
| trans = tvtrans.Compose(trans) |
|
|
| trans_dict = {'jpg': trans} |
| postprocess = customized_postprocess_sdofficial |
|
|
| shuffle = 10000 |
| shardshuffle = shuffle > 0 |
| node_world_size = 1 |
| nodesplitter = wds.shardlists.split_by_node \ |
| if node_world_size==1 else wds.shardlists.single_node_only |
|
|
| tars = [osp.join(self.root_dir, 'data', i) for i in os.listdir(osp.join(self.root_dir, 'data')) |
| if osp.splitext(i)[1]=='.tar'] |
| tars = sorted(tars) |
|
|
| dset = wds.WebDataset( |
| tars, |
| nodesplitter=nodesplitter, |
| shardshuffle=shardshuffle, |
| handler=wds.warn_and_continue).repeat().shuffle(shuffle) |
|
|
| print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.') |
| self.min_size = cfg.get('min_size', None) |
| self.max_pwatermark = cfg.get('max_pwatermark', None) |
| dset = (dset |
| .select(self.filter_keys) |
| .decode('pil', handler=wds.warn_and_continue) |
| .select(self.filter_size) |
| .map_dict(**trans_dict, handler=wds.warn_and_continue)) |
|
|
| if postprocess is not None: |
| dset = dset.map(postprocess) |
| |
| dset.batched(batch_size, partial=False, collation_fn=dict_collation_fn) |
| |
| loader = wds.WebLoader( |
| dset, |
| batch_size=None, |
| shuffle=False, |
| num_workers=num_workers, ) |
| return loader |
|
|