| from functools import partial |
| from io import BytesIO |
| import pathlib |
| from typing import Any |
|
|
| import datasets |
| from loguru import logger |
| from PIL import Image |
| import requests |
| from tqdm.auto import tqdm |
|
|
| from src import config |
|
|
|
|
| def _save_resized_image(example: dict[str, Any], size: tuple[int, int], path: pathlib.Path): |
| |
| image_url = example["url"] |
| image_path = path / image_url.rsplit("/", 1)[-1] |
| if image_path.exists(): |
| return |
|
|
| response = requests.get(image_url) |
| image = Image.open(BytesIO(response.content)) |
| |
| image_resized = image.resize(size) |
| image_resized.save(image_path) |
|
|
|
|
| def _get_images(dataset: datasets.Dataset, path: pathlib.Path): |
| save_resized_image = partial(_save_resized_image, path=path, size=(256, 256)) |
| dataset.map(save_resized_image, num_proc=128) |
|
|
|
|
| def _check_corrupt_images(image_file: pathlib.Path): |
| try: |
| with Image.open(image_file) as img: |
| img.verify() |
| except (IOError, SyntaxError) as e: |
| logger.error(f"Corrupt image: {image_file}") |
|
|
|
|
| if __name__ == "__main__": |
| hyper_parameters = config.TrainerConfig() |
|
|
| dataset = datasets.load_dataset( |
| hyper_parameters._data_config.dataset, |
| split="train", |
| ) |
|
|
| config.IMAGE_DOWNLOAD_PATH.mkdir(parents=True, exist_ok=True) |
| _get_images(dataset, config.IMAGE_DOWNLOAD_PATH) |
|
|
| for image in tqdm(config.IMAGE_DOWNLOAD_PATH.iterdir()): |
| _check_corrupt_images(image) |
|
|