| import argparse |
| import multiprocessing |
| import os |
| import shutil |
| from functools import partial |
| from io import BytesIO |
| from multiprocessing import Process, Queue |
| from os.path import exists, join |
| from pathlib import Path |
|
|
| import lmdb |
| from PIL import Image |
| from torch.utils.data import DataLoader, Dataset |
| from torchvision.datasets import LSUNClass |
| from torchvision.transforms import functional as trans_fn |
| from tqdm import tqdm |
|
|
|
|
| def resize_and_convert(img, size, resample, quality=100): |
| if size is not None: |
| img = trans_fn.resize(img, size, resample) |
| img = trans_fn.center_crop(img, size) |
|
|
| buffer = BytesIO() |
| img.save(buffer, format="webp", quality=quality) |
| val = buffer.getvalue() |
|
|
| return val |
|
|
|
|
| def resize_multiple(img, |
| sizes=(128, 256, 512, 1024), |
| resample=Image.LANCZOS, |
| quality=100): |
| imgs = [] |
|
|
| for size in sizes: |
| imgs.append(resize_and_convert(img, size, resample, quality)) |
|
|
| return imgs |
|
|
|
|
| def resize_worker(idx, img, sizes, resample): |
| img = img.convert("RGB") |
| out = resize_multiple(img, sizes=sizes, resample=resample) |
| return idx, out |
|
|
|
|
| class ConvertDataset(Dataset): |
| def __init__(self, data, size) -> None: |
| self.data = data |
| self.size = size |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, index): |
| img = self.data[index] |
| bytes = resize_and_convert(img, self.size, Image.LANCZOS, quality=100) |
| return bytes |
|
|
|
|
| class ImageFolder(Dataset): |
| def __init__(self, folder, ext='jpg'): |
| super().__init__() |
| paths = sorted([p for p in Path(f'{folder}').glob(f'*.{ext}')]) |
| self.paths = paths |
|
|
| def __len__(self): |
| return len(self.paths) |
|
|
| def __getitem__(self, index): |
| path = os.path.join(self.paths[index]) |
| img = Image.open(path) |
| return img |
|
|
|
|
| if __name__ == "__main__": |
| from tqdm import tqdm |
|
|
| out_path = 'datasets/celeba.lmdb' |
| in_path = 'datasets/celeba' |
| ext = 'jpg' |
| size = None |
|
|
| dataset = ImageFolder(in_path, ext) |
| print('len:', len(dataset)) |
| dataset = ConvertDataset(dataset, size) |
| loader = DataLoader(dataset, |
| batch_size=50, |
| num_workers=12, |
| collate_fn=lambda x: x, |
| shuffle=False) |
|
|
| target = os.path.expanduser(out_path) |
| if os.path.exists(target): |
| shutil.rmtree(target) |
|
|
| with lmdb.open(target, map_size=1024**4, readahead=False) as env: |
| with tqdm(total=len(dataset)) as progress: |
| i = 0 |
| for batch in loader: |
| with env.begin(write=True) as txn: |
| for img in batch: |
| key = f"{size}-{str(i).zfill(7)}".encode("utf-8") |
| |
| txn.put(key, img) |
| i += 1 |
| progress.update() |
| |
| |
| |
| |
|
|
| with env.begin(write=True) as txn: |
| txn.put("length".encode("utf-8"), str(i).encode("utf-8")) |
|
|