|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Tool for creating ZIP/PNG based datasets."""
|
|
|
| from collections.abc import Iterator
|
| from dataclasses import dataclass
|
| import functools
|
| import io
|
| import json
|
| import os
|
| import re
|
| import zipfile
|
| from pathlib import Path
|
| from typing import Callable, Optional, Tuple, Union
|
| import click
|
| import numpy as np
|
| import PIL.Image
|
| import torch
|
| from tqdm import tqdm
|
|
|
| from encoders import StabilityVAEEncoder
|
| from utils import load_encoders
|
| from torchvision.transforms import Normalize
|
| from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| CLIP_DEFAULT_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
| CLIP_DEFAULT_STD = (0.26862954, 0.26130258, 0.27577711)
|
|
|
| def preprocess_raw_image(x, enc_type):
|
| resolution = x.shape[-1]
|
| if 'clip' in enc_type:
|
| x = x / 255.
|
| x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
|
| x = Normalize(CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD)(x)
|
| elif 'mocov3' in enc_type or 'mae' in enc_type:
|
| x = x / 255.
|
| x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
| elif 'dinov2' in enc_type:
|
| x = x / 255.
|
| x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
| x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
|
| elif 'dinov1' in enc_type:
|
| x = x / 255.
|
| x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
| elif 'jepa' in enc_type:
|
| x = x / 255.
|
| x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
| x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
|
|
|
| return x
|
|
|
|
|
|
|
|
|
| @dataclass
|
| class ImageEntry:
|
| img: np.ndarray
|
| label: Optional[int]
|
|
|
|
|
|
|
|
|
|
|
| def parse_tuple(s: str) -> Tuple[int, int]:
|
| m = re.match(r'^(\d+)[x,](\d+)$', s)
|
| if m:
|
| return int(m.group(1)), int(m.group(2))
|
| raise click.ClickException(f'cannot parse tuple {s}')
|
|
|
|
|
|
|
| def maybe_min(a: int, b: Optional[int]) -> int:
|
| if b is not None:
|
| return min(a, b)
|
| return a
|
|
|
|
|
|
|
| def file_ext(name: Union[str, Path]) -> str:
|
| return str(name).split('.')[-1]
|
|
|
|
|
|
|
| def is_image_ext(fname: Union[str, Path]) -> bool:
|
| ext = file_ext(fname).lower()
|
| return f'.{ext}' in PIL.Image.EXTENSION
|
|
|
|
|
|
|
| def open_image_folder(source_dir, *, max_images: Optional[int]) -> tuple[int, Iterator[ImageEntry]]:
|
| input_images = []
|
| def _recurse_dirs(root: str):
|
| with os.scandir(root) as it:
|
| for e in it:
|
| if e.is_file():
|
| input_images.append(os.path.join(root, e.name))
|
| elif e.is_dir():
|
| _recurse_dirs(os.path.join(root, e.name))
|
| _recurse_dirs(source_dir)
|
| input_images = sorted([f for f in input_images if is_image_ext(f)])
|
|
|
| arch_fnames = {fname: os.path.relpath(fname, source_dir).replace('\\', '/') for fname in input_images}
|
| max_idx = maybe_min(len(input_images), max_images)
|
|
|
|
|
| labels = dict()
|
| meta_fname = os.path.join(source_dir, 'dataset.json')
|
| if os.path.isfile(meta_fname):
|
| with open(meta_fname, 'r') as file:
|
| data = json.load(file)['labels']
|
| if data is not None:
|
| labels = {x[0]: x[1] for x in data}
|
|
|
|
|
| if len(labels) == 0:
|
| toplevel_names = {arch_fname: arch_fname.split('/')[0] if '/' in arch_fname else '' for arch_fname in arch_fnames.values()}
|
| toplevel_indices = {toplevel_name: idx for idx, toplevel_name in enumerate(sorted(set(toplevel_names.values())))}
|
| if len(toplevel_indices) > 1:
|
| labels = {arch_fname: toplevel_indices[toplevel_name] for arch_fname, toplevel_name in toplevel_names.items()}
|
|
|
| def iterate_images():
|
| for idx, fname in enumerate(input_images):
|
| img = np.array(PIL.Image.open(fname).convert('RGB'))
|
| yield ImageEntry(img=img, label=labels.get(arch_fnames[fname]))
|
| if idx >= max_idx - 1:
|
| break
|
| return max_idx, iterate_images()
|
|
|
|
|
|
|
| def open_image_zip(source, *, max_images: Optional[int]) -> tuple[int, Iterator[ImageEntry]]:
|
| with zipfile.ZipFile(source, mode='r') as z:
|
| input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)]
|
| max_idx = maybe_min(len(input_images), max_images)
|
|
|
|
|
| labels = dict()
|
| if 'dataset.json' in z.namelist():
|
| with z.open('dataset.json', 'r') as file:
|
| data = json.load(file)['labels']
|
| if data is not None:
|
| labels = {x[0]: x[1] for x in data}
|
|
|
| def iterate_images():
|
| with zipfile.ZipFile(source, mode='r') as z:
|
| for idx, fname in enumerate(input_images):
|
| with z.open(fname, 'r') as file:
|
| img = np.array(PIL.Image.open(file).convert('RGB'))
|
| yield ImageEntry(img=img, label=labels.get(fname))
|
| if idx >= max_idx - 1:
|
| break
|
| return max_idx, iterate_images()
|
|
|
|
|
|
|
| def make_transform(
|
| transform: Optional[str],
|
| output_width: Optional[int],
|
| output_height: Optional[int]
|
| ) -> Callable[[np.ndarray], Optional[np.ndarray]]:
|
| def scale(width, height, img):
|
| w = img.shape[1]
|
| h = img.shape[0]
|
| if width == w and height == h:
|
| return img
|
| img = PIL.Image.fromarray(img, 'RGB')
|
| ww = width if width is not None else w
|
| hh = height if height is not None else h
|
| img = img.resize((ww, hh), PIL.Image.Resampling.LANCZOS)
|
| return np.array(img)
|
|
|
| def center_crop(width, height, img):
|
| crop = np.min(img.shape[:2])
|
| img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
|
| img = PIL.Image.fromarray(img, 'RGB')
|
| img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
|
| return np.array(img)
|
|
|
| def center_crop_wide(width, height, img):
|
| ch = int(np.round(width * img.shape[0] / img.shape[1]))
|
| if img.shape[1] < width or ch < height:
|
| return None
|
|
|
| img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]
|
| img = PIL.Image.fromarray(img, 'RGB')
|
| img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
|
| img = np.array(img)
|
|
|
| canvas = np.zeros([width, width, 3], dtype=np.uint8)
|
| canvas[(width - height) // 2 : (width + height) // 2, :] = img
|
| return canvas
|
|
|
| def center_crop_imagenet(image_size: int, arr: np.ndarray):
|
| """
|
| Center cropping implementation from ADM.
|
| https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
| """
|
| pil_image = PIL.Image.fromarray(arr)
|
| while min(*pil_image.size) >= 2 * image_size:
|
| new_size = tuple(x // 2 for x in pil_image.size)
|
| assert len(new_size) == 2
|
| pil_image = pil_image.resize(new_size, resample=PIL.Image.Resampling.BOX)
|
|
|
| scale = image_size / min(*pil_image.size)
|
| new_size = tuple(round(x * scale) for x in pil_image.size)
|
| assert len(new_size) == 2
|
| pil_image = pil_image.resize(new_size, resample=PIL.Image.Resampling.BICUBIC)
|
|
|
| arr = np.array(pil_image)
|
| crop_y = (arr.shape[0] - image_size) // 2
|
| crop_x = (arr.shape[1] - image_size) // 2
|
| return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]
|
|
|
| if transform is None:
|
| return functools.partial(scale, output_width, output_height)
|
| if transform == 'center-crop':
|
| if output_width is None or output_height is None:
|
| raise click.ClickException('must specify --resolution=WxH when using ' + transform + 'transform')
|
| return functools.partial(center_crop, output_width, output_height)
|
| if transform == 'center-crop-wide':
|
| if output_width is None or output_height is None:
|
| raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform')
|
| return functools.partial(center_crop_wide, output_width, output_height)
|
| if transform == 'center-crop-dhariwal':
|
| if output_width is None or output_height is None:
|
| raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform')
|
| if output_width != output_height:
|
| raise click.ClickException('width and height must match in --resolution=WxH when using ' + transform + ' transform')
|
| return functools.partial(center_crop_imagenet, output_width)
|
| assert False, 'unknown transform'
|
|
|
|
|
|
|
| def open_dataset(source, *, max_images: Optional[int]):
|
| if os.path.isdir(source):
|
| return open_image_folder(source, max_images=max_images)
|
| elif os.path.isfile(source):
|
| if file_ext(source) == 'zip':
|
| return open_image_zip(source, max_images=max_images)
|
| else:
|
| raise click.ClickException(f'Only zip archives are supported: {source}')
|
| else:
|
| raise click.ClickException(f'Missing input file or directory: {source}')
|
|
|
|
|
|
|
| def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]:
|
| dest_ext = file_ext(dest)
|
|
|
| if dest_ext == 'zip':
|
| if os.path.dirname(dest) != '':
|
| os.makedirs(os.path.dirname(dest), exist_ok=True)
|
| zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED)
|
| def zip_write_bytes(fname: str, data: Union[bytes, str]):
|
| zf.writestr(fname, data)
|
| return '', zip_write_bytes, zf.close
|
| else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| if os.path.isdir(dest) and len(os.listdir(dest)) != 0:
|
| raise click.ClickException('--dest folder must be empty')
|
| os.makedirs(dest, exist_ok=True)
|
|
|
| def folder_write_bytes(fname: str, data: Union[bytes, str]):
|
| os.makedirs(os.path.dirname(fname), exist_ok=True)
|
| with open(fname, 'wb') as fout:
|
| if isinstance(data, str):
|
| data = data.encode('utf8')
|
| fout.write(data)
|
| return dest, folder_write_bytes, lambda: None
|
|
|
|
|
|
|
| @click.group()
|
| def cmdline():
|
| '''Dataset processing tool for dataset image data conversion and VAE encode/decode preprocessing.'''
|
| if os.environ.get('WORLD_SIZE', '1') != '1':
|
| raise click.ClickException('Distributed execution is not supported.')
|
|
|
|
|
|
|
|
|
|
|
|
|
| @cmdline.command()
|
| @click.option('--source', help='Input directory or archive name', metavar='PATH', type=str, required=True)
|
| @click.option('--dest', help='Output directory or archive name', metavar='PATH', type=str, required=True)
|
| @click.option('--max-images', help='Maximum number of images to output', metavar='INT', type=int)
|
| @click.option('--enc-type', help='Maximum number of images to output', metavar='PATH', type=str, default='dinov2-vit-b')
|
| @click.option('--resolution', help='Maximum number of images to output', metavar='INT', type=int, default=256)
|
|
|
| def encode(
|
| source: str,
|
| dest: str,
|
| max_images: Optional[int],
|
| enc_type,
|
| resolution
|
| ):
|
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| encoder, encoder_type, architectures = load_encoders(enc_type, device, resolution)
|
| encoder, encoder_type, architectures = encoder[0], encoder_type[0], architectures[0]
|
| print("Encoder is over!!!")
|
|
|
| """Encode pixel data to VAE latents."""
|
| PIL.Image.init()
|
| if dest == '':
|
| raise click.ClickException('--dest output filename or directory must not be an empty string')
|
|
|
| num_files, input_iter = open_dataset(source, max_images=max_images)
|
| archive_root_dir, save_bytes, close_dest = open_dest(dest)
|
| print("Data is over!!!")
|
| labels = []
|
|
|
| temp_list1 = []
|
| temp_list2 = []
|
| for idx, image in tqdm(enumerate(input_iter), total=num_files):
|
| with torch.no_grad():
|
| img_tensor = torch.tensor(image.img).to('cuda').permute(2, 0, 1).unsqueeze(0)
|
| raw_image_ = preprocess_raw_image(img_tensor, encoder_type)
|
| z = encoder.forward_features(raw_image_)
|
| if 'dinov2' in encoder_type: z = z['x_norm_patchtokens']
|
| temp_list1.append(z)
|
| z = z.detach().cpu().numpy()
|
| temp_list2.append(z)
|
|
|
| idx_str = f'{idx:08d}'
|
| archive_fname = f'{idx_str[:5]}/img-feature-{idx_str}.npy'
|
|
|
| f = io.BytesIO()
|
| np.save(f, z)
|
| save_bytes(os.path.join(archive_root_dir, archive_fname), f.getvalue())
|
| labels.append([archive_fname, image.label] if image.label is not None else None)
|
|
|
|
|
| metadata = {'labels': labels if all(x is not None for x in labels) else None}
|
| save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
|
| close_dest()
|
|
|
| if __name__ == "__main__":
|
| cmdline()
|
|
|
|
|
|
|
|
|