| """WebVid-10M video data preparation with pig-vae latent encoding. |
| |
| Dataset: https://huggingface.co/datasets/TempoFunk/webvid-10M |
| Storage: ~5 GB (CLIP features) or ~TB (raw video) |
| License: Research only |
| |
| Note: For 450GB constraint, we download CLIP features or pre-encoded |
| latents rather than raw video files. |
| """ |
| import torch |
| from dataclasses import dataclass |
| import sys, os |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
| from arbitor.config import SPECIAL_VOCAB |
|
|
|
|
| @dataclass |
| class WebVidConfig: |
| batch_size: int = 2 |
| shuffle_buffer: int = 2000 |
| split: str = "train" |
| max_samples: int = 0 |
| n_frames: int = 4 |
| latent_h: int = 32 |
| latent_w: int = 32 |
|
|
|
|
| class WebVidStream: |
| def __init__(self, cfg: WebVidConfig): |
| self.cfg = cfg |
| self._ds = None |
| self._vae = None |
| self._bos = SPECIAL_VOCAB['BOS'] |
| self._eos = SPECIAL_VOCAB['EOS'] |
| self._pad = SPECIAL_VOCAB['PAD'] |
| self._video_tok = SPECIAL_VOCAB['VIDEO'] |
|
|
| def _lazy_init(self): |
| if self._ds is not None: |
| return |
| from datasets import load_dataset |
| ds = load_dataset("TempoFunk/webvid-10M", |
| split=self.cfg.split, streaming=True) |
| if self.cfg.max_samples: |
| ds = ds.take(self.cfg.max_samples) |
| self._ds = ds.shuffle(buffer_size=self.cfg.shuffle_buffer, seed=42) |
|
|
| def _get_vae(self): |
| if self._vae is None: |
| from arbitor.encoders.pig_vae import load_vae |
| self._vae = load_vae(device="cuda") |
| self._vae.eval() |
| for p in self._vae.parameters(): |
| p.requires_grad = False |
| return self._vae |
|
|
| def batches(self): |
| self._lazy_init() |
| buf = [] |
| for example in self._ds: |
| caption = example.get("text") or example.get("caption", "") |
| raw = caption.encode("utf-8") |
| text_tokens = [self._bos] + list(raw) + [self._eos] |
| text_t = torch.tensor(text_tokens, dtype=torch.long) |
|
|
| video = example.get("video") |
| if video is not None: |
| frames = torch.tensor(video["array"], dtype=torch.float32) |
| if frames.dim() == 4: |
| frames = frames.permute(3, 0, 1, 2) |
| if frames.shape[1] > self.cfg.n_frames: |
| idx = torch.linspace(0, frames.shape[1] - 1, self.cfg.n_frames).long() |
| frames = frames[:, idx] |
| frames = frames.unsqueeze(0).cuda() |
| vae = self._get_vae() |
| with torch.no_grad(): |
| latents = vae.encode(frames).cpu() |
| else: |
| latents = torch.randn(1, 16, self.cfg.n_frames, self.cfg.latent_h, self.cfg.latent_w) |
|
|
| buf.append((text_t, latents.squeeze(0))) |
| if len(buf) >= self.cfg.batch_size: |
| batch = buf[:self.cfg.batch_size] |
| buf = buf[self.cfg.batch_size:] |
| texts = [b[0] for b in batch] |
| max_len = max(t.numel() for t in texts) |
| text_batch = torch.stack([ |
| torch.cat([t, t.new_full((max_len - t.numel(),), self._pad)]) |
| for t in texts |
| ]) |
| latent_batch = torch.stack([b[1] for b in batch]) |
| yield text_batch, latent_batch |
|
|
| def num_samples(self) -> int: |
| return 10_000_000 |
|
|