"""WebVid-10M video data preparation with pig-vae latent encoding. Dataset: https://huggingface.co/datasets/TempoFunk/webvid-10M Storage: ~5 GB (CLIP features) or ~TB (raw video) License: Research only Note: For 450GB constraint, we download CLIP features or pre-encoded latents rather than raw video files. """ import torch from dataclasses import dataclass import sys, os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) from arbitor.config import SPECIAL_VOCAB @dataclass class WebVidConfig: batch_size: int = 2 shuffle_buffer: int = 2000 split: str = "train" max_samples: int = 0 n_frames: int = 4 latent_h: int = 32 latent_w: int = 32 class WebVidStream: def __init__(self, cfg: WebVidConfig): self.cfg = cfg self._ds = None self._vae = None self._bos = SPECIAL_VOCAB['BOS'] self._eos = SPECIAL_VOCAB['EOS'] self._pad = SPECIAL_VOCAB['PAD'] self._video_tok = SPECIAL_VOCAB['VIDEO'] def _lazy_init(self): if self._ds is not None: return from datasets import load_dataset ds = load_dataset("TempoFunk/webvid-10M", split=self.cfg.split, streaming=True) if self.cfg.max_samples: ds = ds.take(self.cfg.max_samples) self._ds = ds.shuffle(buffer_size=self.cfg.shuffle_buffer, seed=42) def _get_vae(self): if self._vae is None: from arbitor.encoders.pig_vae import load_vae self._vae = load_vae(device="cuda") self._vae.eval() for p in self._vae.parameters(): p.requires_grad = False return self._vae def batches(self): self._lazy_init() buf = [] for example in self._ds: caption = example.get("text") or example.get("caption", "") raw = caption.encode("utf-8") text_tokens = [self._bos] + list(raw) + [self._eos] text_t = torch.tensor(text_tokens, dtype=torch.long) video = example.get("video") if video is not None: frames = torch.tensor(video["array"], dtype=torch.float32) if frames.dim() == 4: frames = frames.permute(3, 0, 1, 2) if frames.shape[1] > self.cfg.n_frames: idx = torch.linspace(0, frames.shape[1] - 1, self.cfg.n_frames).long() frames = frames[:, idx] frames = frames.unsqueeze(0).cuda() vae = self._get_vae() with torch.no_grad(): latents = vae.encode(frames).cpu() else: latents = torch.randn(1, 16, self.cfg.n_frames, self.cfg.latent_h, self.cfg.latent_w) buf.append((text_t, latents.squeeze(0))) if len(buf) >= self.cfg.batch_size: batch = buf[:self.cfg.batch_size] buf = buf[self.cfg.batch_size:] texts = [b[0] for b in batch] max_len = max(t.numel() for t in texts) text_batch = torch.stack([ torch.cat([t, t.new_full((max_len - t.numel(),), self._pad)]) for t in texts ]) latent_batch = torch.stack([b[1] for b in batch]) yield text_batch, latent_batch def num_samples(self) -> int: return 10_000_000