| import datetime |
| import os |
| import os.path as osp |
| import random |
| import subprocess |
| from functools import partial |
| from typing import Optional |
| import time |
|
|
| import pytz |
|
|
| from infinity.dataset.dataset_t2i_iterable import T2IIterableDataset,SRIterableDataset |
| import pdb |
|
|
| try: |
| from grp import getgrgid |
| from pwd import getpwuid |
| except: |
| pass |
| import PIL.Image as PImage |
| from PIL import ImageFile |
| import numpy as np |
| from torchvision.transforms import transforms |
| from torchvision.transforms.functional import resize, to_tensor |
| import torch.distributed as tdist |
|
|
| from torchvision.transforms import InterpolationMode |
| bicubic = InterpolationMode.BICUBIC |
| lanczos = InterpolationMode.LANCZOS |
| PImage.MAX_IMAGE_PIXELS = (1024 * 1024 * 1024 // 4 // 3) * 5 |
| ImageFile.LOAD_TRUNCATED_IMAGES = False |
|
|
|
|
| def time_str(fmt='[%m-%d %H:%M:%S]'): |
| return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime(fmt) |
|
|
|
|
| def normalize_01_into_pm1(x): |
| return x.add(x).add_(-1) |
|
|
|
|
| def denormalize_pm1_into_01(x): |
| return x.add(1).mul_(0.5) |
|
|
|
|
| def center_crop_arr(pil_image, image_size): |
| """ |
| Center cropping implementation from ADM. |
| https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 |
| """ |
| while min(*pil_image.size) >= 2 * image_size: |
| pil_image = pil_image.resize( |
| tuple(x // 2 for x in pil_image.size), resample=PImage.BOX |
| ) |
| |
| scale = image_size / min(*pil_image.size) |
| pil_image = pil_image.resize( |
| tuple(round(x * scale) for x in pil_image.size), resample=PImage.LANCZOS |
| ) |
| |
| arr = np.array(pil_image) |
| crop_y = (arr.shape[0] - image_size) // 2 |
| crop_x = (arr.shape[1] - image_size) // 2 |
| return PImage.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) |
|
|
|
|
| class RandomResize: |
| def __init__(self, mid_reso, final_reso, interpolation): |
| ub = max(round((mid_reso + (mid_reso-final_reso) / 8) / 4) * 4, mid_reso) |
| self.reso_lb, self.reso_ub = final_reso, ub |
| self.interpolation = interpolation |
| |
| def __call__(self, img): |
| return resize(img, size=random.randint(self.reso_lb, self.reso_ub), interpolation=self.interpolation) |
| |
| def __repr__(self): |
| return f'RandomResize(reso=({self.reso_lb}, {self.reso_ub}), interpolation={self.interpolation})' |
|
|
|
|
| def load_save(reso=512): |
| import os |
| from PIL import Image as PImage |
| from torchvision.transforms import transforms, InterpolationMode |
| aug = transforms.Compose([ |
| transforms.Resize(512, interpolation=InterpolationMode.LANCZOS), |
| transforms.CenterCrop((512, 512)) |
| ]) |
| src_folder = r'C:\Users\16333\Pictures\imgs_to_visual_v2' |
| ls = [os.path.join(src_folder, x) for x in ('1.jpg', '2.jpg', '3.png', '4.png', '5.png')] |
| print(ls) |
| imgs = [] |
| for i, fname in enumerate(ls): |
| assert os.path.exists(fname) |
| with PImage.open(fname) as img: |
| img = img.convert('RGB') |
| img = aug(img) |
| imgs.append(img) |
| dst_d, dst_f = os.path.split(fname) |
| dst = os.path.join(dst_d, f'crop{dst_f.replace(".jpg", ".png")}') |
| img.save(dst) |
| |
| W, H = imgs[0].size |
| WW = W * len(imgs) |
| new_im = PImage.new('RGB', (WW, H)) |
| x_offset = 0 |
| for img in imgs: |
| new_im.paste(img, (x_offset, 0)) |
| x_offset += W |
| dst = os.path.join(src_folder, f'junfeng.png') |
| new_im.save(dst) |
|
|
|
|
| def print_aug(transform, label): |
| print(f'Transform {label} = ') |
| if hasattr(transform, 'transforms'): |
| for t in transform.transforms: |
| print(t) |
| else: |
| print(transform) |
| print('---------------------------\n') |
|
|
|
|
| def build_t2i_dataset( |
| args, |
| data_path: str, |
| data_load_reso: int, |
| max_caption_len: int, |
| short_prob=0.2, |
| load_vae_instead_of_image=False |
| ): |
| if args.use_streaming_dataset: |
| return T2IIterableDataset( |
| data_path, |
| max_caption_len=max_caption_len, |
| short_prob=short_prob, |
| load_vae_instead_of_image=load_vae_instead_of_image, |
| buffersize=args.iterable_data_buffersize, |
| pn=args.pn, |
| online_t5=args.online_t5, |
| batch_size=args.batch_size, |
| num_replicas=tdist.get_world_size(), |
| rank=tdist.get_rank(), |
| dataloader_workers=args.workers, |
| dynamic_resolution_across_gpus=args.dynamic_resolution_across_gpus, |
| enable_dynamic_length_prompt=args.enable_dynamic_length_prompt, |
| seed=args.seed if args.seed is not None else int(time.time()), |
| ) |
| else: |
| raise ValueError(f'args.use_streaming_dataset={args.use_streaming_dataset} unsupported') |
|
|
| def build_sr_dataset( |
| args, |
| data_path: str, |
| data_load_reso: int, |
| max_caption_len: int, |
| short_prob=0.2, |
| load_vae_instead_of_image=False |
| ): |
| if args.use_streaming_dataset: |
| return SRIterableDataset( |
| data_path, |
| max_caption_len=max_caption_len, |
| short_prob=short_prob, |
| load_vae_instead_of_image=load_vae_instead_of_image, |
| buffersize=args.iterable_data_buffersize, |
| pn=args.pn, |
| online_t5=args.online_t5, |
| batch_size=args.batch_size, |
| num_replicas=tdist.get_world_size(), |
| rank=tdist.get_rank(), |
| dataloader_workers=args.workers, |
| dynamic_resolution_across_gpus=args.dynamic_resolution_across_gpus, |
| enable_dynamic_length_prompt=args.enable_dynamic_length_prompt, |
| seed=args.seed if args.seed is not None else int(time.time()), |
| |
| crop_type='center', |
| use_hflip=True, |
| blur_kernel_size=41, |
| kernel_list=['iso','aniso'], |
| kernel_prob=[0.5,0.5], |
| blur_sigma=[0.1,12], |
| downsample_range=[1,12], |
| noise_range=[0,15], |
| jpeg_range=[30,100], |
| raw_scale_schedule=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), |
| ) |
| else: |
| raise ValueError(f'args.use_streaming_dataset={args.use_streaming_dataset} unsupported') |
|
|
|
|
| def pil_load(path: str, proposal_size): |
| with open(path, 'rb') as f: |
| img: PImage.Image = PImage.open(f) |
| w: int = img.width |
| h: int = img.height |
| sh: int = min(h, w) |
| if sh > proposal_size: |
| ratio: float = proposal_size / sh |
| w = round(ratio * w) |
| h = round(ratio * h) |
| img.draft('RGB', (w, h)) |
| img = img.convert('RGB') |
| return img |
|
|
|
|
| def rewrite(im: PImage, file: str, info: str): |
| kw = dict(quality=100) |
| if file.lower().endswith('.tif') or file.lower().endswith('.tiff'): |
| kw['compression'] = 'none' |
| elif file.lower().endswith('.webp'): |
| kw['lossless'] = True |
| |
| st = os.stat(file) |
| uname = getpwuid(st.st_uid).pw_name |
| gname = getgrgid(st.st_gid).gr_name |
| mode = oct(st.st_mode)[-3:] |
| |
| local_file = osp.basename(file) |
| im.save(local_file, **kw) |
| print(f'************* <REWRITE: {info}> ************* @ {file}') |
| subprocess.call(f'sudo mv {local_file} {file}; sudo chown {uname}:{gname} {file}; sudo chmod {mode} {file}', shell=True) |
|
|