| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import argparse |
| import json |
| import math |
| import os |
| from pathlib import Path |
| from typing import List |
|
|
| import numpy as np |
| import torch |
| import torchvision |
| from PIL import Image |
|
|
| from .ar_config_inference_inference_config import SamplingConfig |
| from .log import log |
|
|
| _IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", "webp"] |
| _VIDEO_EXTENSIONS = [".mp4"] |
| _SUPPORTED_CONTEXT_LEN = [1, 9] |
| NUM_TOTAL_FRAMES = 33 |
|
|
|
|
| def add_common_arguments(parser): |
| """Add common command line arguments. |
| |
| Args: |
| parser (ArgumentParser): Argument parser to add arguments to |
| """ |
| parser.add_argument( |
| "--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints" |
| ) |
| parser.add_argument( |
| "--video_save_name", |
| type=str, |
| default="output", |
| help="Output filename for generating a single video", |
| ) |
| parser.add_argument("--video_save_folder", type=str, default="outputs/", help="Output folder for saving videos") |
| parser.add_argument( |
| "--input_image_or_video_path", |
| type=str, |
| help="Input path for input image or video", |
| ) |
| parser.add_argument( |
| "--batch_input_path", |
| type=str, |
| help="Input folder containing all input images or videos", |
| ) |
| parser.add_argument( |
| "--num_input_frames", |
| type=int, |
| default=9, |
| help="Number of input frames for world generation", |
| choices=_SUPPORTED_CONTEXT_LEN, |
| ) |
| parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for sampling") |
| parser.add_argument("--top_p", type=float, default=0.8, help="Top-p value for sampling") |
| parser.add_argument("--seed", type=int, default=0, help="Random seed") |
| parser.add_argument("--disable_diffusion_decoder", action="store_true", help="Disable diffusion decoder") |
| parser.add_argument( |
| "--offload_guardrail_models", |
| action="store_true", |
| help="Offload guardrail models after inference", |
| ) |
| parser.add_argument( |
| "--offload_diffusion_decoder", |
| action="store_true", |
| help="Offload diffusion decoder after inference", |
| ) |
| parser.add_argument( |
| "--offload_ar_model", |
| action="store_true", |
| help="Offload AR model after inference", |
| ) |
| parser.add_argument( |
| "--offload_tokenizer", |
| action="store_true", |
| help="Offload discrete tokenizer model after inference", |
| ) |
|
|
|
|
| def validate_args(args: argparse.Namespace, inference_type: str): |
| """Validate command line arguments for base and video2world generation.""" |
| assert inference_type in [ |
| "base", |
| "video2world", |
| ], "Invalid inference_type, must be 'base' or 'video2world'" |
| if args.input_type in ["image", "text_and_image"] and args.num_input_frames != 1: |
| args.num_input_frames = 1 |
| log.info(f"Set num_input_frames to 1 for {args.input_type} input") |
|
|
| if args.num_input_frames == 1: |
| if "4B" in args.ar_model_dir: |
| log.warning( |
| "The failure rate for 4B model with image input is ~15%. 12B / 13B model have a smaller failure rate. Please be cautious and refer to README.md for more details." |
| ) |
| elif "5B" in args.ar_model_dir: |
| log.warning( |
| "The failure rate for 5B model with image input is ~7%. 12B / 13B model have a smaller failure rate. Please be cautious and refer to README.md for more details." |
| ) |
|
|
| |
| assert ( |
| args.input_image_or_video_path or args.batch_input_path |
| ), "--input_image_or_video_path or --batch_input_path must be provided." |
| if inference_type == "video2world" and (not args.batch_input_path): |
| assert args.prompt, "--prompt is required for single video generation." |
| args.data_resolution = [640, 1024] |
|
|
| |
| num_gpus = int(os.getenv("WORLD_SIZE", 1)) |
| assert num_gpus <= 1, "We support only single GPU inference for now" |
|
|
| |
| Path(args.video_save_folder).mkdir(parents=True, exist_ok=True) |
|
|
| sampling_config = SamplingConfig( |
| echo=True, |
| temperature=args.temperature, |
| top_p=args.top_p, |
| compile_sampling=True, |
| ) |
| return sampling_config |
|
|
|
|
| def resize_input(video: torch.Tensor, resolution: list[int]): |
| r""" |
| Function to perform aspect ratio preserving resizing and center cropping. |
| This is needed to make the video into target resolution. |
| Args: |
| video (torch.Tensor): Input video tensor |
| resolution (list[int]): Data resolution |
| Returns: |
| Cropped video |
| """ |
|
|
| orig_h, orig_w = video.shape[2], video.shape[3] |
| target_h, target_w = resolution |
|
|
| scaling_ratio = max((target_w / orig_w), (target_h / orig_h)) |
| resizing_shape = (int(math.ceil(scaling_ratio * orig_h)), int(math.ceil(scaling_ratio * orig_w))) |
| video_resized = torchvision.transforms.functional.resize(video, resizing_shape) |
| video_cropped = torchvision.transforms.functional.center_crop(video_resized, resolution) |
| return video_cropped |
|
|
|
|
| def load_image_from_list(flist, data_resolution: List[int]) -> dict: |
| """ |
| Function to load images from a list of image paths. |
| Args: |
| flist (List[str]): List of image paths |
| data_resolution (List[int]): Data resolution |
| Returns: |
| Dict containing input images |
| """ |
| all_videos = dict() |
| for img_path in flist: |
| ext = os.path.splitext(img_path)[1] |
| if ext in _IMAGE_EXTENSIONS: |
| |
| img = Image.open(img_path) |
|
|
| |
| img = torchvision.transforms.functional.to_tensor(img) |
| static_vid = img.unsqueeze(0).repeat(NUM_TOTAL_FRAMES, 1, 1, 1) |
| static_vid = static_vid * 2 - 1 |
|
|
| log.debug( |
| f"Resizing input image of shape ({static_vid.shape[2]}, {static_vid.shape[3]}) -> ({data_resolution[0]}, {data_resolution[1]})" |
| ) |
| static_vid = resize_input(static_vid, data_resolution) |
| fname = os.path.basename(img_path) |
| all_videos[fname] = static_vid.transpose(0, 1).unsqueeze(0) |
|
|
| return all_videos |
|
|
|
|
| def read_input_images(batch_input_path: str, data_resolution: List[int]) -> dict: |
| """ |
| Function to read input images from a JSONL file. |
| |
| Args: |
| batch_input_path (str): Path to JSONL file containing visual input paths |
| data_resolution (list[int]): Data resolution |
| |
| Returns: |
| Dict containing input images |
| """ |
| |
| flist = [] |
| with open(batch_input_path, "r") as f: |
| for line in f: |
| data = json.loads(line.strip()) |
| flist.append(data["visual_input"]) |
|
|
| return load_image_from_list(flist, data_resolution=data_resolution) |
|
|
|
|
| def read_input_image(input_path: str, data_resolution: List[int]) -> dict: |
| """ |
| Function to read input image. |
| Args: |
| input_path (str): Path to input image |
| data_resolution (List[int]): Data resolution |
| Returns: |
| Dict containing input image |
| """ |
| flist = [input_path] |
| return load_image_from_list(flist, data_resolution=data_resolution) |
|
|
|
|
| def read_input_videos(batch_input_path: str, data_resolution: List[int], num_input_frames: int) -> dict: |
| r""" |
| Function to read input videos. |
| Args: |
| batch_input_path (str): Path to JSONL file containing visual input paths |
| data_resolution (list[int]): Data resolution |
| Returns: |
| Dict containing input videos |
| """ |
| |
| flist = [] |
| with open(batch_input_path, "r") as f: |
| for line in f: |
| data = json.loads(line.strip()) |
| flist.append(data["visual_input"]) |
| return load_videos_from_list(flist, data_resolution=data_resolution, num_input_frames=num_input_frames) |
|
|
|
|
| def read_input_video(input_path: str, data_resolution: List[int], num_input_frames: int) -> dict: |
| """ |
| Function to read input video. |
| Args: |
| input_path (str): Path to input video |
| data_resolution (List[int]): Data resolution |
| num_input_frames (int): Number of frames in context |
| Returns: |
| Dict containing input video |
| """ |
| flist = [input_path] |
| return load_videos_from_list(flist, data_resolution=data_resolution, num_input_frames=num_input_frames) |
|
|
|
|
| def load_videos_from_list(flist: List[str], data_resolution: List[int], num_input_frames: int) -> dict: |
| """ |
| Function to load videos from a list of video paths. |
| Args: |
| flist (List[str]): List of video paths |
| data_resolution (List[int]): Data resolution |
| num_input_frames (int): Number of frames in context |
| Returns: |
| Dict containing input videos |
| """ |
| all_videos = dict() |
|
|
| for video_path in flist: |
| ext = os.path.splitext(video_path)[-1] |
| if ext in _VIDEO_EXTENSIONS: |
| video, _, _ = torchvision.io.read_video(video_path, pts_unit="sec") |
| video = video.float() / 255.0 |
| video = video * 2 - 1 |
|
|
| |
| nframes_in_video = video.shape[0] |
| if nframes_in_video < num_input_frames: |
| fname = os.path.basename(video_path) |
| log.warning( |
| f"Video {fname} has {nframes_in_video} frames, less than the requried {num_input_frames} frames. Skipping." |
| ) |
| continue |
|
|
| video = video[-num_input_frames:, :, :, :] |
|
|
| |
| video = torch.cat( |
| (video, video[-1, :, :, :].unsqueeze(0).repeat(NUM_TOTAL_FRAMES - num_input_frames, 1, 1, 1)), |
| dim=0, |
| ) |
|
|
| video = video.permute(0, 3, 1, 2) |
|
|
| log.debug( |
| f"Resizing input video of shape ({video.shape[2]}, {video.shape[3]}) -> ({data_resolution[0]}, {data_resolution[1]})" |
| ) |
| video = resize_input(video, data_resolution) |
|
|
| fname = os.path.basename(video_path) |
| all_videos[fname] = video.transpose(0, 1).unsqueeze(0) |
|
|
| return all_videos |
|
|
|
|
| def load_vision_input( |
| input_type: str, |
| batch_input_path: str, |
| input_image_or_video_path: str, |
| data_resolution: List[int], |
| num_input_frames: int, |
| ): |
| """ |
| Function to load vision input. |
| Note: We pad the frames of the input image/video to NUM_TOTAL_FRAMES here, and feed the padded video tensors to the video tokenizer to obtain tokens. The tokens will be truncated based on num_input_frames when feeding to the autoregressive model. |
| Args: |
| input_type (str): Type of input |
| batch_input_path (str): Folder containing input images or videos |
| input_image_or_video_path (str): Path to input image or video |
| data_resolution (List[int]): Data resolution |
| num_input_frames (int): Number of frames in context |
| Returns: |
| Dict containing input videos |
| """ |
| if batch_input_path: |
| log.info(f"Reading batch inputs from path: {batch_input_path}") |
| if input_type == "image" or input_type == "text_and_image": |
| input_videos = read_input_images(batch_input_path, data_resolution=data_resolution) |
| elif input_type == "video" or input_type == "text_and_video": |
| input_videos = read_input_videos( |
| batch_input_path, |
| data_resolution=data_resolution, |
| num_input_frames=num_input_frames, |
| ) |
| else: |
| raise ValueError(f"Invalid input type {input_type}") |
| else: |
| if input_type == "image" or input_type == "text_and_image": |
| input_videos = read_input_image(input_image_or_video_path, data_resolution=data_resolution) |
| elif input_type == "video" or input_type == "text_and_video": |
| input_videos = read_input_video( |
| input_image_or_video_path, |
| data_resolution=data_resolution, |
| num_input_frames=num_input_frames, |
| ) |
| else: |
| raise ValueError(f"Invalid input type {input_type}") |
| return input_videos |
|
|
|
|
| def prepare_video_batch_for_saving(video_batch: List[torch.Tensor]) -> List[np.ndarray]: |
| """ |
| Function to convert output tensors to numpy format for saving. |
| Args: |
| video_batch (List[torch.Tensor]): List of output tensors |
| Returns: |
| List of numpy arrays |
| """ |
| return [(video * 255).to(torch.uint8).permute(1, 2, 3, 0).cpu().numpy() for video in video_batch] |
|
|