"""CC12M image-text data preparation. Dataset: https://huggingface.co/datasets/opendiffusionai/cc12m-4mp-realistic Storage: ~250 GB for images License: MIT (dataset), images vary """ import torch, io from dataclasses import dataclass from PIL import Image from torchvision import transforms import sys, os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) from arbitor.config import SPECIAL_VOCAB @dataclass class CC12MConfig: batch_size: int = 4 shuffle_buffer: int = 5000 split: str = "train" image_size: int = 224 max_samples: int = 0 class CC12MStream: def __init__(self, cfg: CC12MConfig): self.cfg = cfg self._ds = None self._bos = SPECIAL_VOCAB['BOS'] self._eos = SPECIAL_VOCAB['EOS'] self._pad = SPECIAL_VOCAB['PAD'] self._image_tok = SPECIAL_VOCAB['IMAGE'] self._transform = transforms.Compose([ transforms.Resize((cfg.image_size, cfg.image_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def _lazy_init(self): if self._ds is not None: return from datasets import load_dataset ds = load_dataset("opendiffusionai/cc12m-4mp-realistic", split=self.cfg.split, streaming=True) if self.cfg.max_samples: ds = ds.take(self.cfg.max_samples) self._ds = ds.shuffle(buffer_size=self.cfg.shuffle_buffer, seed=42) def _decode_image(self, img_bytes): try: pil = Image.open(io.BytesIO(img_bytes)).convert("RGB") return self._transform(pil) except Exception: return None def batches(self): self._lazy_init() buf = [] for example in self._ds: img_tensor = self._decode_image(example["image"]["bytes"]) if img_tensor is None: continue caption = example.get("text", "") raw = caption.encode("utf-8") text_tokens = [self._bos] + list(raw) + [self._eos] text_t = torch.tensor(text_tokens, dtype=torch.long) buf.append((img_tensor, text_t)) if len(buf) >= self.cfg.batch_size: batch = buf[:self.cfg.batch_size] buf = buf[self.cfg.batch_size:] imgs = torch.stack([b[0] for b in batch]) max_len = max(b[1].numel() for b in batch) texts = torch.stack([ torch.cat([b[1], b[1].new_full((max_len - b[1].numel(),), self._pad)]) for b in batch ]) yield imgs, texts def num_samples(self) -> int: return 12_000_000