| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import argparse |
| import glob |
| import json |
| import os |
| import random |
|
|
| import torch |
| import torchvision |
| from einops import rearrange |
| from huggingface_hub import snapshot_download |
| from nemo.collections.diffusion.models.model import DiT7BConfig |
| from tqdm import tqdm |
| from transformers import T5EncoderModel, T5TokenizerFast |
|
|
| from cosmos1.utils import log |
|
|
|
|
| def get_parser(): |
| parser = argparse.ArgumentParser(description="Process some configurations.") |
| parser.add_argument("--tokenizer_dir", type=str, default="", help="Path to the VAE model") |
| parser.add_argument( |
| "--dataset_path", type=str, default="video_dataset", help="Path to the dataset (a folder of videos)" |
| ) |
| parser.add_argument("--output_path", type=str, default="video_dataset_cached", help="Path to the output directory") |
| parser.add_argument("--prompt", type=str, default="a video of sks.", help="Prompt for the video") |
| parser.add_argument("--num_chunks", type=int, default=5, help="Number of random chunks to sample per video") |
| parser.add_argument("--height", type=int, default=704, help="Height to resize video") |
| parser.add_argument("--width", type=int, default=1280, help="Width to resize video") |
| return parser |
|
|
|
|
| def init_t5(): |
| """Initialize and return the T5 tokenizer and text encoder.""" |
| tokenizer = T5TokenizerFast.from_pretrained("google-t5/t5-11b") |
| text_encoder = T5EncoderModel.from_pretrained("google-t5/t5-11b") |
| text_encoder.to("cuda") |
| text_encoder.eval() |
| return tokenizer, text_encoder |
|
|
|
|
| def init_video_tokenizer(tokenizer_dir: str): |
| """Initialize and return the Cosmos Video tokenizer.""" |
| dit_config = DiT7BConfig(vae_path=tokenizer_dir) |
| vae = dit_config.configure_vae() |
| return vae |
|
|
|
|
| @torch.no_grad() |
| def encode_for_batch(tokenizer, encoder, prompts: list[str], max_length=512): |
| """ |
| Encode a batch of text prompts to a batch of T5 embeddings. |
| Parameters: |
| tokenizer: T5 embedding tokenizer. |
| encoder: T5 embedding text encoder. |
| prompts: A batch of text prompts. |
| max_length: Sequence length of text embedding (defaults to 512). |
| """ |
|
|
| batch_encoding = tokenizer.batch_encode_plus( |
| prompts, |
| return_tensors="pt", |
| truncation=True, |
| padding="max_length", |
| max_length=max_length, |
| return_length=True, |
| return_offsets_mapping=False, |
| ) |
|
|
| |
| input_ids = batch_encoding.input_ids.cuda() |
| attn_mask = batch_encoding.attention_mask.cuda() |
|
|
| outputs = encoder(input_ids=input_ids, attention_mask=attn_mask) |
| encoded_text = outputs.last_hidden_state |
|
|
| lengths = attn_mask.sum(dim=1).cpu() |
| for batch_id in range(encoded_text.shape[0]): |
| encoded_text[batch_id][lengths[batch_id] :] = 0 |
|
|
| return encoded_text |
|
|
|
|
| def main(args): |
| |
| os.makedirs(args.output_path, exist_ok=True) |
|
|
| |
| tokenizer, text_encoder = init_t5() |
|
|
| |
| if args.tokenizer_dir == "": |
| args.tokenizer_dir = snapshot_download("nvidia/Cosmos-1.0-Tokenizer-CV8x8x8") |
| vae = init_video_tokenizer(args.tokenizer_dir) |
|
|
| |
| t5_embeding_max_length = 512 |
| chunk_duration = vae.video_vae.pixel_chunk_duration |
| cnt = 0 |
|
|
| |
| files = glob.glob(os.path.join(args.dataset_path, "*.mp4")) |
| if not files: |
| raise ValueError(f"Dataset path {args.dataset_path} does not contain any .mp4 files.") |
|
|
| |
| with torch.no_grad(): |
| for video_path in tqdm(glob.glob(os.path.join(args.dataset_path, "*.mp4"))): |
| |
| video, _, meta = torchvision.io.read_video(video_path) |
| T, H, W, C = video.shape |
|
|
| |
| if T < chunk_duration: |
| log.info(f"Video {video_path} is shorter than {chunk_duration} frames. Skipped.") |
| continue |
|
|
| |
| for _ in range(args.num_chunks): |
| start_idx = random.randint(0, T - chunk_duration) |
| chunk = video[start_idx : start_idx + chunk_duration] |
|
|
| |
| chunk = rearrange(chunk, "t h w c -> t c h w") |
|
|
| |
| chunk = torchvision.transforms.functional.resize(chunk, [args.height, args.width]) |
|
|
| |
| chunk = rearrange(chunk, "(b t) c h w -> b c t h w", b=1) |
|
|
| |
| chunk = chunk.to(device="cuda", dtype=torch.bfloat16, non_blocking=True) / 127.5 - 1.0 |
|
|
| |
| latent = vae.encode(chunk).cpu() |
|
|
| |
| out = encode_for_batch(tokenizer, text_encoder, [args.prompt])[0] |
| encoded_text = torch.tensor(out, dtype=torch.bfloat16) |
|
|
| |
| L, C_ = encoded_text.shape |
| t5_embed = torch.zeros(1, t5_embeding_max_length, C_, dtype=torch.bfloat16) |
| t5_embed[0, :L] = encoded_text |
|
|
| |
| torch.save(latent[0], os.path.join(args.output_path, f"{cnt}.video_latent.pth")) |
| torch.save(t5_embed[0], os.path.join(args.output_path, f"{cnt}.t5_text_embeddings.pth")) |
|
|
| |
| torch.save( |
| torch.ones(512, dtype=torch.bfloat16), os.path.join(args.output_path, f"{cnt}.t5_text_mask.pth") |
| ) |
|
|
| |
| info = { |
| "height": H, |
| "width": W, |
| "fps": meta["video_fps"], |
| "num_frames": chunk_duration, |
| "video_path": os.path.basename(video_path), |
| "start_frame": start_idx, |
| } |
| with open(os.path.join(args.output_path, f"{cnt}.info.json"), "w") as json_file: |
| json.dump(info, json_file) |
|
|
| cnt += 1 |
|
|
|
|
| if __name__ == "__main__": |
| parser = get_parser() |
| args = parser.parse_args() |
| main(args) |
|
|