File size: 3,432 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""WebVid-10M video data preparation with pig-vae latent encoding.

Dataset: https://huggingface.co/datasets/TempoFunk/webvid-10M
Storage: ~5 GB (CLIP features) or ~TB (raw video)
License: Research only

Note: For 450GB constraint, we download CLIP features or pre-encoded
latents rather than raw video files.
"""
import torch
from dataclasses import dataclass
import sys, os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from arbitor.config import SPECIAL_VOCAB


@dataclass
class WebVidConfig:
    batch_size: int = 2
    shuffle_buffer: int = 2000
    split: str = "train"
    max_samples: int = 0
    n_frames: int = 4
    latent_h: int = 32
    latent_w: int = 32


class WebVidStream:
    def __init__(self, cfg: WebVidConfig):
        self.cfg = cfg
        self._ds = None
        self._vae = None
        self._bos = SPECIAL_VOCAB['BOS']
        self._eos = SPECIAL_VOCAB['EOS']
        self._pad = SPECIAL_VOCAB['PAD']
        self._video_tok = SPECIAL_VOCAB['VIDEO']

    def _lazy_init(self):
        if self._ds is not None:
            return
        from datasets import load_dataset
        ds = load_dataset("TempoFunk/webvid-10M",
                          split=self.cfg.split, streaming=True)
        if self.cfg.max_samples:
            ds = ds.take(self.cfg.max_samples)
        self._ds = ds.shuffle(buffer_size=self.cfg.shuffle_buffer, seed=42)

    def _get_vae(self):
        if self._vae is None:
            from arbitor.encoders.pig_vae import load_vae
            self._vae = load_vae(device="cuda")
            self._vae.eval()
            for p in self._vae.parameters():
                p.requires_grad = False
        return self._vae

    def batches(self):
        self._lazy_init()
        buf = []
        for example in self._ds:
            caption = example.get("text") or example.get("caption", "")
            raw = caption.encode("utf-8")
            text_tokens = [self._bos] + list(raw) + [self._eos]
            text_t = torch.tensor(text_tokens, dtype=torch.long)

            video = example.get("video")
            if video is not None:
                frames = torch.tensor(video["array"], dtype=torch.float32)
                if frames.dim() == 4:
                    frames = frames.permute(3, 0, 1, 2)
                if frames.shape[1] > self.cfg.n_frames:
                    idx = torch.linspace(0, frames.shape[1] - 1, self.cfg.n_frames).long()
                    frames = frames[:, idx]
                frames = frames.unsqueeze(0).cuda()
                vae = self._get_vae()
                with torch.no_grad():
                    latents = vae.encode(frames).cpu()
            else:
                latents = torch.randn(1, 16, self.cfg.n_frames, self.cfg.latent_h, self.cfg.latent_w)

            buf.append((text_t, latents.squeeze(0)))
            if len(buf) >= self.cfg.batch_size:
                batch = buf[:self.cfg.batch_size]
                buf = buf[self.cfg.batch_size:]
                texts = [b[0] for b in batch]
                max_len = max(t.numel() for t in texts)
                text_batch = torch.stack([
                    torch.cat([t, t.new_full((max_len - t.numel(),), self._pad)])
                    for t in texts
                ])
                latent_batch = torch.stack([b[1] for b in batch])
                yield text_batch, latent_batch

    def num_samples(self) -> int:
        return 10_000_000