| """CC12M image-text data preparation. |
| |
| Dataset: https://huggingface.co/datasets/opendiffusionai/cc12m-4mp-realistic |
| Storage: ~250 GB for images |
| License: MIT (dataset), images vary |
| """ |
| import torch, io |
| from dataclasses import dataclass |
| from PIL import Image |
| from torchvision import transforms |
| import sys, os |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
| from arbitor.config import SPECIAL_VOCAB |
|
|
|
|
| @dataclass |
| class CC12MConfig: |
| batch_size: int = 4 |
| shuffle_buffer: int = 5000 |
| split: str = "train" |
| image_size: int = 224 |
| max_samples: int = 0 |
|
|
|
|
| class CC12MStream: |
| def __init__(self, cfg: CC12MConfig): |
| self.cfg = cfg |
| self._ds = None |
| self._bos = SPECIAL_VOCAB['BOS'] |
| self._eos = SPECIAL_VOCAB['EOS'] |
| self._pad = SPECIAL_VOCAB['PAD'] |
| self._image_tok = SPECIAL_VOCAB['IMAGE'] |
| self._transform = transforms.Compose([ |
| transforms.Resize((cfg.image_size, cfg.image_size)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
| def _lazy_init(self): |
| if self._ds is not None: |
| return |
| from datasets import load_dataset |
| ds = load_dataset("opendiffusionai/cc12m-4mp-realistic", |
| split=self.cfg.split, streaming=True) |
| if self.cfg.max_samples: |
| ds = ds.take(self.cfg.max_samples) |
| self._ds = ds.shuffle(buffer_size=self.cfg.shuffle_buffer, seed=42) |
|
|
| def _decode_image(self, img_bytes): |
| try: |
| pil = Image.open(io.BytesIO(img_bytes)).convert("RGB") |
| return self._transform(pil) |
| except Exception: |
| return None |
|
|
| def batches(self): |
| self._lazy_init() |
| buf = [] |
| for example in self._ds: |
| img_tensor = self._decode_image(example["image"]["bytes"]) |
| if img_tensor is None: |
| continue |
| caption = example.get("text", "") |
| raw = caption.encode("utf-8") |
| text_tokens = [self._bos] + list(raw) + [self._eos] |
| text_t = torch.tensor(text_tokens, dtype=torch.long) |
| buf.append((img_tensor, text_t)) |
| if len(buf) >= self.cfg.batch_size: |
| batch = buf[:self.cfg.batch_size] |
| buf = buf[self.cfg.batch_size:] |
| imgs = torch.stack([b[0] for b in batch]) |
| max_len = max(b[1].numel() for b in batch) |
| texts = torch.stack([ |
| torch.cat([b[1], b[1].new_full((max_len - b[1].numel(),), self._pad)]) |
| for b in batch |
| ]) |
| yield imgs, texts |
|
|
| def num_samples(self) -> int: |
| return 12_000_000 |
|
|