File size: 2,763 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | """CC12M image-text data preparation.
Dataset: https://huggingface.co/datasets/opendiffusionai/cc12m-4mp-realistic
Storage: ~250 GB for images
License: MIT (dataset), images vary
"""
import torch, io
from dataclasses import dataclass
from PIL import Image
from torchvision import transforms
import sys, os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from arbitor.config import SPECIAL_VOCAB
@dataclass
class CC12MConfig:
batch_size: int = 4
shuffle_buffer: int = 5000
split: str = "train"
image_size: int = 224
max_samples: int = 0
class CC12MStream:
def __init__(self, cfg: CC12MConfig):
self.cfg = cfg
self._ds = None
self._bos = SPECIAL_VOCAB['BOS']
self._eos = SPECIAL_VOCAB['EOS']
self._pad = SPECIAL_VOCAB['PAD']
self._image_tok = SPECIAL_VOCAB['IMAGE']
self._transform = transforms.Compose([
transforms.Resize((cfg.image_size, cfg.image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def _lazy_init(self):
if self._ds is not None:
return
from datasets import load_dataset
ds = load_dataset("opendiffusionai/cc12m-4mp-realistic",
split=self.cfg.split, streaming=True)
if self.cfg.max_samples:
ds = ds.take(self.cfg.max_samples)
self._ds = ds.shuffle(buffer_size=self.cfg.shuffle_buffer, seed=42)
def _decode_image(self, img_bytes):
try:
pil = Image.open(io.BytesIO(img_bytes)).convert("RGB")
return self._transform(pil)
except Exception:
return None
def batches(self):
self._lazy_init()
buf = []
for example in self._ds:
img_tensor = self._decode_image(example["image"]["bytes"])
if img_tensor is None:
continue
caption = example.get("text", "")
raw = caption.encode("utf-8")
text_tokens = [self._bos] + list(raw) + [self._eos]
text_t = torch.tensor(text_tokens, dtype=torch.long)
buf.append((img_tensor, text_t))
if len(buf) >= self.cfg.batch_size:
batch = buf[:self.cfg.batch_size]
buf = buf[self.cfg.batch_size:]
imgs = torch.stack([b[0] for b in batch])
max_len = max(b[1].numel() for b in batch)
texts = torch.stack([
torch.cat([b[1], b[1].new_full((max_len - b[1].numel(),), self._pad)])
for b in batch
])
yield imgs, texts
def num_samples(self) -> int:
return 12_000_000
|