File size: 2,763 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""CC12M image-text data preparation.

Dataset: https://huggingface.co/datasets/opendiffusionai/cc12m-4mp-realistic
Storage: ~250 GB for images
License: MIT (dataset), images vary
"""
import torch, io
from dataclasses import dataclass
from PIL import Image
from torchvision import transforms
import sys, os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from arbitor.config import SPECIAL_VOCAB


@dataclass
class CC12MConfig:
    batch_size: int = 4
    shuffle_buffer: int = 5000
    split: str = "train"
    image_size: int = 224
    max_samples: int = 0


class CC12MStream:
    def __init__(self, cfg: CC12MConfig):
        self.cfg = cfg
        self._ds = None
        self._bos = SPECIAL_VOCAB['BOS']
        self._eos = SPECIAL_VOCAB['EOS']
        self._pad = SPECIAL_VOCAB['PAD']
        self._image_tok = SPECIAL_VOCAB['IMAGE']
        self._transform = transforms.Compose([
            transforms.Resize((cfg.image_size, cfg.image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def _lazy_init(self):
        if self._ds is not None:
            return
        from datasets import load_dataset
        ds = load_dataset("opendiffusionai/cc12m-4mp-realistic",
                          split=self.cfg.split, streaming=True)
        if self.cfg.max_samples:
            ds = ds.take(self.cfg.max_samples)
        self._ds = ds.shuffle(buffer_size=self.cfg.shuffle_buffer, seed=42)

    def _decode_image(self, img_bytes):
        try:
            pil = Image.open(io.BytesIO(img_bytes)).convert("RGB")
            return self._transform(pil)
        except Exception:
            return None

    def batches(self):
        self._lazy_init()
        buf = []
        for example in self._ds:
            img_tensor = self._decode_image(example["image"]["bytes"])
            if img_tensor is None:
                continue
            caption = example.get("text", "")
            raw = caption.encode("utf-8")
            text_tokens = [self._bos] + list(raw) + [self._eos]
            text_t = torch.tensor(text_tokens, dtype=torch.long)
            buf.append((img_tensor, text_t))
            if len(buf) >= self.cfg.batch_size:
                batch = buf[:self.cfg.batch_size]
                buf = buf[self.cfg.batch_size:]
                imgs = torch.stack([b[0] for b in batch])
                max_len = max(b[1].numel() for b in batch)
                texts = torch.stack([
                    torch.cat([b[1], b[1].new_full((max_len - b[1].numel(),), self._pad)])
                    for b in batch
                ])
                yield imgs, texts

    def num_samples(self) -> int:
        return 12_000_000