import glob import os import pickle import random import re import time from functools import partial from os import path as osp from typing import List, Tuple, Union import json import itertools import concurrent.futures from multiprocessing import cpu_count import tqdm import numpy as np import torch import pandas as pd from PIL import Image as PImage from torch.nn import functional as F from torch.utils.data import Dataset from torchvision.transforms.functional import to_tensor from torch.utils.data import IterableDataset, DataLoader import torch.distributed as tdist from infinity.utils.dynamic_resolution import dynamic_resolution_h_w, get_h_div_w_template2indices, h_div_w_templates from infinity.utils.large_file_util import get_part_jsonls, split_large_txt_files from utils.degradation import ( random_mixed_kernels, random_add_gaussian_noise, random_add_jpg_compression ) from utils.image import center_crop_arr, augment, random_crop_arr import cv2 import math from typing import Sequence, Dict, Union import torchvision.transforms as transforms import pdb def center_crop_to_tensor_pm1(pil_image, mid_reso: int, final_reso: int): """ Center cropping implementation from ADM. https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 Then to_tensor and normalize to [-1, 1] """ while min(*pil_image.size) >= 2 * mid_reso: pil_image = pil_image.resize( tuple(x // 2 for x in pil_image.size), resample=PImage.BOX ) if mid_reso == final_reso == pil_image.size[0] == pil_image.size[1]: im = to_tensor(pil_image) else: # resize the shorter edge to mid_reso scale = mid_reso / min(*pil_image.size) pil_image = pil_image.resize( tuple(round(x * scale) for x in pil_image.size), resample=PImage.LANCZOS ) # crop the center out arr = np.array(pil_image) crop_y = (arr.shape[0] - final_reso) // 2 crop_x = (arr.shape[1] - final_reso) // 2 # return PImage.fromarray(arr[crop_y: crop_y + final_reso, crop_x: crop_x + final_reso]) im = to_tensor(arr[crop_y: crop_y + final_reso, crop_x: crop_x + final_reso]) return im.add(im).add_(-1) def transform(pil_img, tgt_h, tgt_w): width, height = pil_img.size if width / height <= tgt_w / tgt_h: resized_width = tgt_w resized_height = int(tgt_w / (width / height)) else: resized_height = tgt_h resized_width = int((width / height) * tgt_h) pil_img = pil_img.resize((resized_width, resized_height), resample=PImage.LANCZOS) # crop the center out arr = np.array(pil_img) crop_y = (arr.shape[0] - tgt_h) // 2 crop_x = (arr.shape[1] - tgt_w) // 2 im = to_tensor(arr[crop_y: crop_y + tgt_h, crop_x: crop_x + tgt_w]) # print(f'im size {im.shape}') return im.add(im).add_(-1) def process_short_text(short_text): if '--' in short_text: processed_text = short_text.split('--')[0] if processed_text: short_text = processed_text return short_text class T2IIterableDataset(IterableDataset): def __init__( self, meta_folder: str, max_caption_len=512, short_prob=0.2, load_vae_instead_of_image=False, buffersize: int = 10000, seed: int = 0, pn: str = '', online_t5: bool = True, batch_size: int = 2, num_replicas: int = 1, # 1, rank: int = 0, # 0 dataloader_workers: int = 2, dynamic_resolution_across_gpus: bool = True, enable_dynamic_length_prompt: bool = True, **kwargs, ): self.meta_folder = meta_folder self.pn = pn self.online_t5 = online_t5 self.buffer_size = buffersize self.num_replicas = num_replicas self.rank = rank self.worker_id = 0 self.global_worker_id = 0 self.dataloader_workers = max(1, dataloader_workers) self.max_caption_len = max_caption_len self.short_prob = short_prob self.load_vae_instead_of_image = load_vae_instead_of_image # set to false self.dynamic_resolution_across_gpus = dynamic_resolution_across_gpus self.enable_dynamic_length_prompt = enable_dynamic_length_prompt self.batch_size = batch_size print(f'self.dynamic_resolution_across_gpus: {self.dynamic_resolution_across_gpus}') print(f'self.enable_dynamic_length_prompt: {self.enable_dynamic_length_prompt}') print(f'self.buffer_size: {self.buffer_size}') self.shuffle = True self.global_workers = self.num_replicas * self.dataloader_workers self.h_div_w_template2generator, self.samples_div_gpus_workers_batchsize_2batches, total_samples = self.set_h_div_w_template2generator() self.split_meta_files() self.seed = seed self.epoch_worker_generator = None self.epoch_global_worker_generator = None self.set_epoch(0) print(f'num_replicas: {num_replicas}, rank: {rank}, dataloader_workers: {dataloader_workers}, seed:{seed}, samples_div_gpus_workers_batchsize_2batches: {self.samples_div_gpus_workers_batchsize_2batches}') def set_h_div_w_template2generator(self,): samples_div_gpus_workers_batchsize_2batches = 0 h_div_w_template2generator = {} total_samples = 0 for filepath in sorted(glob.glob(osp.join(self.meta_folder, '*.jsonl'))): filename = osp.basename(filepath) h_div_w_template, num_of_samples = osp.splitext(filename)[0].split('_') total_samples += int(num_of_samples) for filepath in sorted(glob.glob(osp.join(self.meta_folder, '*.jsonl'))): filename = osp.basename(filepath) h_div_w_template, num_of_samples = osp.splitext(filename)[0].split('_') num_of_samples = int(num_of_samples) if num_of_samples < self.global_workers: print(f'{filepath} has too few examples ({num_of_samples}, proportion: {num_of_samples/total_samples*100:.1f}%), < global workers ({self.global_workers})! Skip h_div_w_template: {h_div_w_template}') continue print(f'{filepath} has sufficient examples ({num_of_samples}), proportion: {num_of_samples/total_samples*100:.1f}%, > global workers ({self.global_workers})! Preserve h_div_w_template: {h_div_w_template}') num_of_batches = max(1, int((num_of_samples // self.global_workers // self.batch_size))) h_div_w_template2generator[h_div_w_template] = { 'filepath': filepath, 'num_of_samples': num_of_samples, 'num_of_batches': num_of_batches, } samples_div_gpus_workers_batchsize_2batches += num_of_batches return h_div_w_template2generator, samples_div_gpus_workers_batchsize_2batches, total_samples def split_meta_files(self, ): print('[data preprocess] split_meta_files') def split_and_sleep(generator_info): missing, chunk_id2save_files = get_part_jsonls(generator_info['filepath'], generator_info['num_of_samples'], parts=self.num_replicas) if missing: tdist.barrier() if self.rank == 0: split_large_txt_files(generator_info['filepath'], chunk_id2save_files) else: sleep_time = int(generator_info['num_of_samples'] / 30000000 * 10) print(f'[data preprocess] sleep {sleep_time} minutes awaiting rank0 split_meta_files...') time.sleep(sleep_time*60) tdist.barrier() generator_info['part_filepaths'] = sorted(list(chunk_id2save_files.values())) return generator_info with concurrent.futures.ThreadPoolExecutor(max_workers=cpu_count()) as executor: futures = {executor.submit(split_and_sleep, generator_info): h_div_w_template for h_div_w_template, generator_info in self.h_div_w_template2generator.items()} for future in concurrent.futures.as_completed(futures): h_div_w_template = futures[future] try: self.h_div_w_template2generator[h_div_w_template] = future.result() except Exception as exc: print(f'[data preprocess] h_div_w_template {h_div_w_template} generated an exception: {exc}') print('[data preprocess] split_meta_files done') def set_global_worker_id(self): worker_info = torch.utils.data.get_worker_info() if worker_info: worker_total_num = worker_info.num_workers worker_id = worker_info.id else: worker_id = 0 worker_total_num = 1 assert worker_total_num == self.dataloader_workers, print(worker_total_num, self.dataloader_workers) self.worker_id = worker_id self.global_worker_id = self.rank * self.dataloader_workers + worker_id # print(f'Set worker_id to {self.worker_id}, global_worker_id to {self.global_worker_id}') def set_epoch(self, epoch): self.epoch = epoch self.set_generator() def set_generator(self, ): self.epoch_worker_generator = np.random.default_rng(self.seed + self.epoch + self.worker_id) self.epoch_global_worker_generator = np.random.default_rng(self.seed + self.epoch + self.global_worker_id) def get_h_div_w_template_2_unlearned_batches(self,): h_div_w_template_2_unlearned_batches = {} total_unlearned_batches = 0 for h_div_w_template, generator_info in self.h_div_w_template2generator.items(): h_div_w_template_2_unlearned_batches[h_div_w_template] = generator_info['num_of_batches'] total_unlearned_batches += generator_info['num_of_batches'] self.total_unlearned_batches = total_unlearned_batches self.h_div_w_template_2_unlearned_batches = h_div_w_template_2_unlearned_batches assert self.total_unlearned_batches == self.samples_div_gpus_workers_batchsize_2batches def _next_h_div_w_template(self,): while True: self.get_h_div_w_template_2_unlearned_batches() while self.total_unlearned_batches > 0: if self.dynamic_resolution_across_gpus: i = self.epoch_global_worker_generator.integers(0, self.total_unlearned_batches) else: i = self.epoch_worker_generator.integers(0, self.total_unlearned_batches) self.total_unlearned_batches -= 1 for h_div_w_template, unlearned_batches in self.h_div_w_template_2_unlearned_batches.items(): if i < unlearned_batches: yield h_div_w_template self.h_div_w_template_2_unlearned_batches[h_div_w_template] -= 1 break else: i -= unlearned_batches def __iter__(self): self.set_global_worker_id() self.set_generator() for h_div_w_template, generator_info in self.h_div_w_template2generator.items(): proportion = generator_info['num_of_batches'] / self.samples_div_gpus_workers_batchsize_2batches h_div_w_buffer_size = int(self.buffer_size * proportion) h_div_w_buffer_size = min(max(1, h_div_w_buffer_size), generator_info['num_of_batches'] * self.batch_size) if 'mem_buffer' in generator_info: del generator_info['mem_buffer'] mem_buffer = [] for _ in range(h_div_w_buffer_size): mem_buffer.append(self.infinite_next(generator_info)) generator_info['mem_buffer'] = mem_buffer next_h_div_w_template_iter = self._next_h_div_w_template() # while True: for _ in range(self.samples_div_gpus_workers_batchsize_2batches): batch_data = [] h_div_w_template = next(next_h_div_w_template_iter) while len(batch_data) < self.batch_size: try: generator_info = self.h_div_w_template2generator[h_div_w_template] mem_buffer = generator_info['mem_buffer'] i = self.epoch_global_worker_generator.integers(0, len(mem_buffer)) data_item = mem_buffer[i] mem_buffer[i] = self.infinite_next(generator_info) ret, model_input = self.prepare_model_input(json.loads(data_item)) # data_item[0] is row number of panda dataframe if ret: c_, h_, w_ = model_input[1].shape[-3:] if c_ != 3 or np.abs(h_/w_-float(h_div_w_template)) > 0.01: print(f'Croupt data item: {data_item}') else: batch_data.append(model_input) del data_item except Exception as e: print(e) captions = [item[0] for item in batch_data] images = torch.stack([item[1] for item in batch_data]) yield (images, captions) del batch_data del images del captions def infinite_next(self, generator_info): try: if 'sub_iterator' not in generator_info: raise StopIteration return next(generator_info['sub_iterator']) except StopIteration as e: if 'record_iterator' in generator_info: generator_info['record_iterator'].close() if 'sub_iterator' in generator_info: del generator_info['sub_iterator'] part_filepath = generator_info['part_filepaths'][self.rank] generator_info['record_iterator'] = open(part_filepath, 'r') part_num_of_samples = int(osp.splitext(osp.basename(part_filepath))[0].split('_')[-1]) # print(f'part_filepath: {part_filepath}, rank: {self.rank}, worker_id:{self.worker_id}, part_num_of_samples: {part_num_of_samples}, dataloader_workers: {self.dataloader_workers}') generator_info['sub_iterator'] = itertools.islice(generator_info['record_iterator'], self.worker_id, part_num_of_samples, self.dataloader_workers) return next(generator_info['sub_iterator']) def __len__(self): return self.samples_div_gpus_workers_batchsize_2batches * self.dataloader_workers def total_samples(self): return self.samples_div_gpus_workers_batchsize_2batches * self.dataloader_workers * self.num_replicas * self.batch_size def get_text_input(self, long_text_input, short_text_input, long_text_type): random_value = self.epoch_global_worker_generator.random() if self.enable_dynamic_length_prompt and long_text_type != 'user_prompt': long_text_elems = [item for item in long_text_input.split('.') if item] if len(long_text_elems): first_sentence_words = [item for item in long_text_elems[0].split(' ') if item] else: first_sentence_words = 0 if len(first_sentence_words) >= 15: num_sentence4short_text = 1 else: num_sentence4short_text = 2 if not short_text_input: short_text_input = '.'.join(long_text_elems[:num_sentence4short_text]) if random_value < self.short_prob: return short_text_input if len(long_text_elems) <= num_sentence4short_text: return long_text_input select_sentence_num = self.epoch_global_worker_generator.integers(num_sentence4short_text+1, len(long_text_elems)+1) return '.'.join(long_text_elems[:select_sentence_num]) else: if short_text_input and random_value < self.short_prob: return short_text_input return long_text_input def prepare_model_input(self, data_item) -> Tuple: img_path, h_div_w = data_item['image_path'], data_item['h_div_w'] short_text_input, long_text_input = data_item['text'], data_item['long_caption'] long_text_type = data_item.get('long_caption_type', 'user_prompt') text_input = self.get_text_input(long_text_input, short_text_input, long_text_type) text_input = process_short_text(text_input) h_div_w_template = h_div_w_templates[np.argmin(np.abs(h_div_w - h_div_w_templates))] try: if self.load_vae_instead_of_image: img_B3HW = None vae_path = self.get_vae_path(img_path) with open(vae_path, 'rb') as f: gt_ms_idx_Bl = pickle.load(f) else: gt_ms_idx_Bl = None with open(img_path, 'rb') as f: img: PImage.Image = PImage.open(f) img = img.convert('RGB') tgt_h, tgt_w = dynamic_resolution_h_w[h_div_w_template][self.pn]['pixel'] img_B3HW = transform(img, tgt_h, tgt_w) if not self.online_t5: short_t5_path, long_t5_path = self.get_t5_path(img_path) if self.epoch_global_worker_generator.random() <= self.short_prob: t5_path = short_t5_path else: t5_path = long_t5_path t5_meta = np.load(t5_path) text_input = t5_meta['t5_feat'][:self.max_caption_len] # L x C except Exception as e: print(f'input error: {e}, skip to another index') return False, None if self.load_vae_instead_of_image: return True, (text_input, *gt_ms_idx_Bl) else: return True, (text_input, img_B3HW) @staticmethod def collate_function(batch, online_t5: bool = False) -> None: pass if __name__ == '__main__': # torchrun --nnodes=1 --nproc-per-node=2 --master_addr=$METIS_WORKER_0_HOST --master_port=$METIS_WORKER_0_PORT dataset/dataset_t2i_iterable.py tdist.init_process_group(backend='nccl') batch_size = 2 dataloader_workers = 12 dataset = T2IIterableDataset( args=None, meta_folder='data/train_splits/xxx_pretrain/jsonl_files_filter_duplicate_captions', data_load_reso=None, max_caption_len=512, short_prob=1.0, load_vae_instead_of_image=False, buffersize=100000, seed=0, online_t5=True, pn='0.06M', batch_size=batch_size, num_replicas=8, # tdist.get_world_size(), rank=tdist.get_rank(), # 0 dataloader_workers=dataloader_workers, ) dataloader = DataLoader(dataset, batch_size=None, num_workers=dataloader_workers) print(f'len(dataloader): {len(dataloader)}, len(dataset): {len(dataset)}, total_samples: {dataset.total_samples()}') t1 = time.time() h_div_w2samples = {} for ep in range(4): dataloader.dataset.set_epoch(ep) pbar = tqdm.tqdm(total=len(dataloader)) for i, data in enumerate(iter(dataloader)): pbar.update(1) t2 = time.time() h_div_w = data[0].shape[-2] / data[0].shape[-1] h_div_w = f'{h_div_w:.3f}' if h_div_w not in h_div_w2samples: h_div_w2samples[h_div_w] = 0 h_div_w2samples[h_div_w] += 1 if (i+1) % 100 == 0: total_samples = np.sum(list(h_div_w2samples.values())) print() for h_div_w, num in sorted(h_div_w2samples.items()): print(f'h_div_w: {h_div_w}, samples: {num}, proportion: {num/total_samples*100:.1f}%') print() t1 = time.time() class SRIterableDataset(IterableDataset): def __init__( self, meta_folder: str, max_caption_len=512, short_prob=0.2, load_vae_instead_of_image=False, buffersize: int = 10000, seed: int = 0, pn: str = '', online_t5: bool = True, batch_size: int = 2, num_replicas: int = 1, # 1, rank: int = 0, # 0 dataloader_workers: int = 2, dynamic_resolution_across_gpus: bool = True, enable_dynamic_length_prompt: bool = True, #my code crop_type='center', use_hflip=True, blur_kernel_size=41, kernel_list=['iso','aniso'], kernel_prob=[0.5,0.5], blur_sigma=[0.1,12], downsample_range=[1,12], noise_range=[0,15], jpeg_range=[30,100], raw_scale_schedule=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), **kwargs, ): self.meta_folder = meta_folder self.pn = pn self.online_t5 = online_t5 self.buffer_size = buffersize self.num_replicas = num_replicas self.rank = rank self.worker_id = 0 self.global_worker_id = 0 self.dataloader_workers = max(1, dataloader_workers) self.max_caption_len = max_caption_len self.short_prob = short_prob self.load_vae_instead_of_image = load_vae_instead_of_image # set to false self.dynamic_resolution_across_gpus = dynamic_resolution_across_gpus self.enable_dynamic_length_prompt = enable_dynamic_length_prompt self.batch_size = batch_size print(f'self.dynamic_resolution_across_gpus: {self.dynamic_resolution_across_gpus}') print(f'self.enable_dynamic_length_prompt: {self.enable_dynamic_length_prompt}') print(f'self.buffer_size: {self.buffer_size}') self.shuffle = True self.global_workers = self.num_replicas * self.dataloader_workers self.h_div_w_template2generator, self.samples_div_gpus_workers_batchsize_2batches, total_samples = self.set_h_div_w_template2generator() self.split_meta_files() self.seed = seed self.epoch_worker_generator = None self.epoch_global_worker_generator = None self.set_epoch(0) #my code self.crop_type = crop_type assert self.crop_type in ["none", "center", "random"] self.use_hflip = use_hflip # degradation configurations self.blur_kernel_size = blur_kernel_size self.kernel_list = kernel_list self.kernel_prob = kernel_prob self.blur_sigma = blur_sigma self.downsample_range = downsample_range self.noise_range = noise_range self.jpeg_range = jpeg_range self.raw_scale_schedule = raw_scale_schedule print(f'num_replicas: {num_replicas}, rank: {rank}, dataloader_workers: {dataloader_workers}, seed:{seed}, samples_div_gpus_workers_batchsize_2batches: {self.samples_div_gpus_workers_batchsize_2batches}') def set_h_div_w_template2generator(self,): samples_div_gpus_workers_batchsize_2batches = 0 h_div_w_template2generator = {} total_samples = 0 for filepath in sorted(glob.glob(osp.join(self.meta_folder, '*.jsonl'))): filename = osp.basename(filepath) h_div_w_template, num_of_samples = osp.splitext(filename)[0].split('_') total_samples += int(num_of_samples) for filepath in sorted(glob.glob(osp.join(self.meta_folder, '*.jsonl'))): filename = osp.basename(filepath) h_div_w_template, num_of_samples = osp.splitext(filename)[0].split('_') num_of_samples = int(num_of_samples) if num_of_samples < self.global_workers: print(f'{filepath} has too few examples ({num_of_samples}, proportion: {num_of_samples/total_samples*100:.1f}%), < global workers ({self.global_workers})! Skip h_div_w_template: {h_div_w_template}') continue print(f'{filepath} has sufficient examples ({num_of_samples}), proportion: {num_of_samples/total_samples*100:.1f}%, > global workers ({self.global_workers})! Preserve h_div_w_template: {h_div_w_template}') num_of_batches = max(1, int((num_of_samples // self.global_workers // self.batch_size))) h_div_w_template2generator[h_div_w_template] = { 'filepath': filepath, 'num_of_samples': num_of_samples, 'num_of_batches': num_of_batches, } samples_div_gpus_workers_batchsize_2batches += num_of_batches return h_div_w_template2generator, samples_div_gpus_workers_batchsize_2batches, total_samples def split_meta_files(self, ): print('[data preprocess] split_meta_files') def split_and_sleep(generator_info): missing, chunk_id2save_files = get_part_jsonls(generator_info['filepath'], generator_info['num_of_samples'], parts=self.num_replicas) if missing: tdist.barrier() if self.rank == 0: split_large_txt_files(generator_info['filepath'], chunk_id2save_files) else: sleep_time = int(generator_info['num_of_samples'] / 30000000 * 10) print(f'[data preprocess] sleep {sleep_time} minutes awaiting rank0 split_meta_files...') time.sleep(sleep_time*60) tdist.barrier() generator_info['part_filepaths'] = sorted(list(chunk_id2save_files.values())) return generator_info with concurrent.futures.ThreadPoolExecutor(max_workers=cpu_count()) as executor: futures = {executor.submit(split_and_sleep, generator_info): h_div_w_template for h_div_w_template, generator_info in self.h_div_w_template2generator.items()} for future in concurrent.futures.as_completed(futures): h_div_w_template = futures[future] try: self.h_div_w_template2generator[h_div_w_template] = future.result() except Exception as exc: print(f'[data preprocess] h_div_w_template {h_div_w_template} generated an exception: {exc}') print('[data preprocess] split_meta_files done') def set_global_worker_id(self): worker_info = torch.utils.data.get_worker_info() if worker_info: worker_total_num = worker_info.num_workers worker_id = worker_info.id else: worker_id = 0 worker_total_num = 1 assert worker_total_num == self.dataloader_workers, print(worker_total_num, self.dataloader_workers) self.worker_id = worker_id self.global_worker_id = self.rank * self.dataloader_workers + worker_id # print(f'Set worker_id to {self.worker_id}, global_worker_id to {self.global_worker_id}') def set_epoch(self, epoch): self.epoch = epoch self.set_generator() def set_generator(self, ): self.epoch_worker_generator = np.random.default_rng(self.seed + self.epoch + self.worker_id) self.epoch_global_worker_generator = np.random.default_rng(self.seed + self.epoch + self.global_worker_id) def get_h_div_w_template_2_unlearned_batches(self,): h_div_w_template_2_unlearned_batches = {} total_unlearned_batches = 0 for h_div_w_template, generator_info in self.h_div_w_template2generator.items(): h_div_w_template_2_unlearned_batches[h_div_w_template] = generator_info['num_of_batches'] total_unlearned_batches += generator_info['num_of_batches'] self.total_unlearned_batches = total_unlearned_batches self.h_div_w_template_2_unlearned_batches = h_div_w_template_2_unlearned_batches assert self.total_unlearned_batches == self.samples_div_gpus_workers_batchsize_2batches def _next_h_div_w_template(self,): while True: self.get_h_div_w_template_2_unlearned_batches() while self.total_unlearned_batches > 0: if self.dynamic_resolution_across_gpus: i = self.epoch_global_worker_generator.integers(0, self.total_unlearned_batches) else: i = self.epoch_worker_generator.integers(0, self.total_unlearned_batches) self.total_unlearned_batches -= 1 for h_div_w_template, unlearned_batches in self.h_div_w_template_2_unlearned_batches.items(): if i < unlearned_batches: yield h_div_w_template self.h_div_w_template_2_unlearned_batches[h_div_w_template] -= 1 break else: i -= unlearned_batches def __iter__(self): self.set_global_worker_id() self.set_generator() for h_div_w_template, generator_info in self.h_div_w_template2generator.items(): proportion = generator_info['num_of_batches'] / self.samples_div_gpus_workers_batchsize_2batches h_div_w_buffer_size = int(self.buffer_size * proportion) h_div_w_buffer_size = min(max(1, h_div_w_buffer_size), generator_info['num_of_batches'] * self.batch_size) if 'mem_buffer' in generator_info: del generator_info['mem_buffer'] mem_buffer = [] for _ in range(h_div_w_buffer_size): mem_buffer.append(self.infinite_next(generator_info)) generator_info['mem_buffer'] = mem_buffer next_h_div_w_template_iter = self._next_h_div_w_template() # while True: for _ in range(self.samples_div_gpus_workers_batchsize_2batches): batch_data = [] h_div_w_template = next(next_h_div_w_template_iter) while len(batch_data) < self.batch_size: try: generator_info = self.h_div_w_template2generator[h_div_w_template] mem_buffer = generator_info['mem_buffer'] i = self.epoch_global_worker_generator.integers(0, len(mem_buffer)) data_item = mem_buffer[i] mem_buffer[i] = self.infinite_next(generator_info) ret, model_input = self.prepare_model_input(json.loads(data_item)) # data_item[0] is row number of panda dataframe if ret: c_, h_, w_ = model_input[1].shape[-3:] if c_ != 3 or np.abs(h_/w_-float(h_div_w_template)) > 0.01: print(f'Croupt data item: {data_item}') else: batch_data.append(model_input) del data_item except Exception as e: print(e) captions = [item[0] for item in batch_data] images = torch.stack([item[1] for item in batch_data]) lq_images = torch.stack([item[2] for item in batch_data]) yield (images, captions, lq_images) del batch_data del images del captions del lq_images def infinite_next(self, generator_info): try: if 'sub_iterator' not in generator_info: raise StopIteration return next(generator_info['sub_iterator']) except StopIteration as e: if 'record_iterator' in generator_info: generator_info['record_iterator'].close() if 'sub_iterator' in generator_info: del generator_info['sub_iterator'] part_filepath = generator_info['part_filepaths'][self.rank] generator_info['record_iterator'] = open(part_filepath, 'r') part_num_of_samples = int(osp.splitext(osp.basename(part_filepath))[0].split('_')[-1]) # print(f'part_filepath: {part_filepath}, rank: {self.rank}, worker_id:{self.worker_id}, part_num_of_samples: {part_num_of_samples}, dataloader_workers: {self.dataloader_workers}') generator_info['sub_iterator'] = itertools.islice(generator_info['record_iterator'], self.worker_id, part_num_of_samples, self.dataloader_workers) return next(generator_info['sub_iterator']) def __len__(self): return self.samples_div_gpus_workers_batchsize_2batches * self.dataloader_workers def total_samples(self): return self.samples_div_gpus_workers_batchsize_2batches * self.dataloader_workers * self.num_replicas * self.batch_size def get_text_input(self, long_text_input, short_text_input, long_text_type): random_value = self.epoch_global_worker_generator.random() if self.enable_dynamic_length_prompt and long_text_type != 'user_prompt': long_text_elems = [item for item in long_text_input.split('.') if item] if len(long_text_elems): first_sentence_words = [item for item in long_text_elems[0].split(' ') if item] else: first_sentence_words = 0 if len(first_sentence_words) >= 15: num_sentence4short_text = 1 else: num_sentence4short_text = 2 if not short_text_input: short_text_input = '.'.join(long_text_elems[:num_sentence4short_text]) if random_value < self.short_prob: return short_text_input if len(long_text_elems) <= num_sentence4short_text: return long_text_input select_sentence_num = self.epoch_global_worker_generator.integers(num_sentence4short_text+1, len(long_text_elems)+1) return '.'.join(long_text_elems[:select_sentence_num]) else: if short_text_input and random_value < self.short_prob: return short_text_input return long_text_input def prepare_model_input(self, data_item) -> Tuple: img_path, h_div_w = data_item['image_path'], data_item['h_div_w'] short_text_input, long_text_input = data_item['text'], data_item['long_caption'] long_text_type = data_item.get('long_caption_type', 'user_prompt') text_input = self.get_text_input(long_text_input, short_text_input, long_text_type) text_input = process_short_text(text_input) h_div_w_template = h_div_w_templates[np.argmin(np.abs(h_div_w - h_div_w_templates))] try: if self.load_vae_instead_of_image: img_B3HW = None vae_path = self.get_vae_path(img_path) with open(vae_path, 'rb') as f: gt_ms_idx_Bl = pickle.load(f) else: gt_ms_idx_Bl = None with open(img_path, 'rb') as f: img: PImage.Image = PImage.open(f) img = img.convert('RGB') #my code #important tgt_h, tgt_w = dynamic_resolution_h_w[h_div_w_template][self.pn]['pixel'] #my code img_B3HW, img_lq = self.get_img_gt_and_img_lq(img, tgt_h, tgt_w) if not self.online_t5: short_t5_path, long_t5_path = self.get_t5_path(img_path) if self.epoch_global_worker_generator.random() <= self.short_prob: t5_path = short_t5_path else: t5_path = long_t5_path t5_meta = np.load(t5_path) text_input = t5_meta['t5_feat'][:self.max_caption_len] # L x C except Exception as e: print(f'input error: {e}, skip to another index') return False, None if self.load_vae_instead_of_image: return True, (text_input, *gt_ms_idx_Bl) else: return True, (text_input, img_B3HW, img_lq) def get_img_gt_and_img_lq(self,pil_img, tgt_h, tgt_w): width, height = pil_img.size if width / height <= tgt_w / tgt_h: resized_width = tgt_w resized_height = int(tgt_w / (width / height)) else: resized_height = tgt_h resized_width = int((width / height) * tgt_h) pil_img = pil_img.resize((resized_width, resized_height), resample=PImage.LANCZOS) # crop the center out arr = np.array(pil_img) crop_y = (arr.shape[0] - tgt_h) // 2 crop_x = (arr.shape[1] - tgt_w) // 2 #my code #im = to_tensor(arr[crop_y: crop_y + tgt_h, crop_x: crop_x + tgt_w]) im = arr[crop_y: crop_y + tgt_h, crop_x: crop_x + tgt_w] img_gt = (im[..., ::-1] / 255.0).astype(np.float32) # # random horizontal flip img_gt = augment(img_gt, hflip=self.use_hflip, rotation=False, return_status=False) h, w, _ = img_gt.shape # ------------------------ generate lq image ------------------------ # # blur kernel = random_mixed_kernels( self.kernel_list, self.kernel_prob, self.blur_kernel_size, self.blur_sigma, self.blur_sigma, [-math.pi, math.pi], noise_range=None ) img_lq = cv2.filter2D(img_gt, -1, kernel) # downsample scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1]) img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR) # noise if self.noise_range is not None: img_lq = random_add_gaussian_noise(img_lq, self.noise_range) # jpeg compression if self.jpeg_range is not None: img_lq = random_add_jpg_compression(img_lq, self.jpeg_range) # resize to original size img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR) # BGR to RGB, [-1, 1] target = (img_gt[..., ::-1] * 2 - 1).astype(np.float32) target = to_tensor(target) # BGR to RGB, [-1, 1] source = (img_lq[..., ::-1] * 2 - 1).astype(np.float32) source = to_tensor(source) return target,source @staticmethod def collate_function(batch, online_t5: bool = False) -> None: pass