ARBS / training /data /prepare_webvid.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""WebVid-10M video data preparation with pig-vae latent encoding.
Dataset: https://huggingface.co/datasets/TempoFunk/webvid-10M
Storage: ~5 GB (CLIP features) or ~TB (raw video)
License: Research only
Note: For 450GB constraint, we download CLIP features or pre-encoded
latents rather than raw video files.
"""
import torch
from dataclasses import dataclass
import sys, os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from arbitor.config import SPECIAL_VOCAB
@dataclass
class WebVidConfig:
batch_size: int = 2
shuffle_buffer: int = 2000
split: str = "train"
max_samples: int = 0
n_frames: int = 4
latent_h: int = 32
latent_w: int = 32
class WebVidStream:
def __init__(self, cfg: WebVidConfig):
self.cfg = cfg
self._ds = None
self._vae = None
self._bos = SPECIAL_VOCAB['BOS']
self._eos = SPECIAL_VOCAB['EOS']
self._pad = SPECIAL_VOCAB['PAD']
self._video_tok = SPECIAL_VOCAB['VIDEO']
def _lazy_init(self):
if self._ds is not None:
return
from datasets import load_dataset
ds = load_dataset("TempoFunk/webvid-10M",
split=self.cfg.split, streaming=True)
if self.cfg.max_samples:
ds = ds.take(self.cfg.max_samples)
self._ds = ds.shuffle(buffer_size=self.cfg.shuffle_buffer, seed=42)
def _get_vae(self):
if self._vae is None:
from arbitor.encoders.pig_vae import load_vae
self._vae = load_vae(device="cuda")
self._vae.eval()
for p in self._vae.parameters():
p.requires_grad = False
return self._vae
def batches(self):
self._lazy_init()
buf = []
for example in self._ds:
caption = example.get("text") or example.get("caption", "")
raw = caption.encode("utf-8")
text_tokens = [self._bos] + list(raw) + [self._eos]
text_t = torch.tensor(text_tokens, dtype=torch.long)
video = example.get("video")
if video is not None:
frames = torch.tensor(video["array"], dtype=torch.float32)
if frames.dim() == 4:
frames = frames.permute(3, 0, 1, 2)
if frames.shape[1] > self.cfg.n_frames:
idx = torch.linspace(0, frames.shape[1] - 1, self.cfg.n_frames).long()
frames = frames[:, idx]
frames = frames.unsqueeze(0).cuda()
vae = self._get_vae()
with torch.no_grad():
latents = vae.encode(frames).cpu()
else:
latents = torch.randn(1, 16, self.cfg.n_frames, self.cfg.latent_h, self.cfg.latent_w)
buf.append((text_t, latents.squeeze(0)))
if len(buf) >= self.cfg.batch_size:
batch = buf[:self.cfg.batch_size]
buf = buf[self.cfg.batch_size:]
texts = [b[0] for b in batch]
max_len = max(t.numel() for t in texts)
text_batch = torch.stack([
torch.cat([t, t.new_full((max_len - t.numel(),), self._pad)])
for t in texts
])
latent_batch = torch.stack([b[1] for b in batch])
yield text_batch, latent_batch
def num_samples(self) -> int:
return 10_000_000