| import argparse |
| import multiprocessing |
| import os |
| import shutil |
| from functools import partial |
| from io import BytesIO |
| from multiprocessing import Process, Queue |
| from os.path import exists, join |
|
|
| import lmdb |
| from PIL import Image |
| from torch.utils.data import DataLoader, Dataset |
| from torchvision.datasets import LSUNClass |
| from torchvision.transforms import functional as trans_fn |
| from tqdm import tqdm |
|
|
|
|
| def resize_and_convert(img, size, resample, quality=100): |
| img = trans_fn.resize(img, size, resample) |
| img = trans_fn.center_crop(img, size) |
| buffer = BytesIO() |
| img.save(buffer, format="webp", quality=quality) |
| val = buffer.getvalue() |
|
|
| return val |
|
|
|
|
| def resize_multiple(img, |
| sizes=(128, 256, 512, 1024), |
| resample=Image.LANCZOS, |
| quality=100): |
| imgs = [] |
|
|
| for size in sizes: |
| imgs.append(resize_and_convert(img, size, resample, quality)) |
|
|
| return imgs |
|
|
|
|
| def resize_worker(idx, img, sizes, resample): |
| img = img.convert("RGB") |
| out = resize_multiple(img, sizes=sizes, resample=resample) |
| return idx, out |
|
|
|
|
| class ConvertDataset(Dataset): |
| def __init__(self, data) -> None: |
| self.data = data |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, index): |
| img, _ = self.data[index] |
| bytes = resize_and_convert(img, 256, Image.LANCZOS, quality=90) |
| return bytes |
|
|
|
|
| if __name__ == "__main__": |
| """ |
| converting lsun' original lmdb to our lmdb, which is somehow more performant. |
| """ |
| from tqdm import tqdm |
|
|
| |
| src_path = 'datasets/horse_train_lmdb' |
| out_path = 'datasets/horse256.lmdb' |
|
|
| dataset = LSUNClass(root=os.path.expanduser(src_path)) |
| dataset = ConvertDataset(dataset) |
| loader = DataLoader(dataset, |
| batch_size=50, |
| num_workers=16, |
| collate_fn=lambda x: x) |
|
|
| target = os.path.expanduser(out_path) |
| if os.path.exists(target): |
| shutil.rmtree(target) |
|
|
| with lmdb.open(target, map_size=1024**4, readahead=False) as env: |
| with tqdm(total=len(dataset)) as progress: |
| i = 0 |
| for batch in loader: |
| with env.begin(write=True) as txn: |
| for img in batch: |
| key = f"{256}-{str(i).zfill(7)}".encode("utf-8") |
| |
| txn.put(key, img) |
| i += 1 |
| progress.update() |
| |
| |
| |
| |
|
|
| with env.begin(write=True) as txn: |
| txn.put("length".encode("utf-8"), str(i).encode("utf-8")) |
|
|