# ========================================================= # ZERO GPU PATCHED VERSION FOR HUGGING FACE SPACES # Based on original Lance app.py # ========================================================= from __future__ import annotations import concurrent.futures import gc import json import os os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" os.environ["CUDA_VISIBLE_DEVICES"] = "0" import random import threading import time import traceback from collections import deque from copy import deepcopy from datetime import datetime from pathlib import Path from typing import Optional import gradio as gr import spaces import torch from huggingface_hub import login from safetensors.torch import load_file from transformers import set_seed from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( Qwen2_5_VLVisionConfig, ) from common.utils.logging import get_logger from common.utils.misc import AutoEncoderParams, tuple_mul from config.config_factory import ( DataArguments, InferenceArguments, ModelArguments, ) from data.data_utils import add_special_tokens from data.dataset_base import DataConfig, simple_custom_collate from data.datasets_custom import ValidationDataset from inference_lance import ( PROMPT_JSON_FILENAME, apply_inference_defaults, clean_memory, init_from_model_path_if_needed, save_prompt_results, validate_on_fixed_batch, ) from modeling.lance import Lance, LanceConfig, Qwen2ForCausalLM from modeling.qwen2 import Qwen2Tokenizer from modeling.qwen2.modeling_qwen2 import Qwen2Config from modeling.vae.wan.model import WanVideoVAE from modeling.vit.qwen2_5_vl_vit import ( Qwen2_5_VisionTransformerPretrainedModel, ) # ========================================================= # HF TOKEN # ========================================================= HF_TOKEN = os.getenv("HF_TOKEN") if HF_TOKEN: login(token=HF_TOKEN) # ========================================================= # PERFORMANCE SETTINGS # ========================================================= torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # ========================================================= # PATHS # ========================================================= # ========================================================= # PATHS # ========================================================= REPO_ROOT = Path(__file__).resolve().parent GRADIO_TMP_ROOT = REPO_ROOT / "tmps" / "gradio_t2v_v2t" TMP_INPUT_DIR = GRADIO_TMP_ROOT / "inputs" RESULTS_ROOT = GRADIO_TMP_ROOT / "results" GLOBAL_RECORDS_FILE = GRADIO_TMP_ROOT / "generation_records.jsonl" RUN_RECORD_FILENAME = "generation_record.json" # ========================================================= # MODEL DOWNLOAD # ========================================================= from huggingface_hub import snapshot_download MODEL_REPO = "bytedance-research/Lance" MODEL_CACHE_DIR = REPO_ROOT / "downloads" snapshot_download( repo_id=MODEL_REPO, local_dir=str(MODEL_CACHE_DIR), local_dir_use_symlinks=False, token=HF_TOKEN, resume_download=True, ) DEFAULT_MODEL_PATH = str( MODEL_CACHE_DIR / "Lance_3B_Video" ) print("DEFAULT_MODEL_PATH =", DEFAULT_MODEL_PATH) print("FILES =", os.listdir(DEFAULT_MODEL_PATH)) DEFAULT_VIT_TYPE = "qwen_2_5_vl_original" DEFAULT_TASK = "t2v" DEFAULT_TIMESTEPS = 30 DEFAULT_TIMESTEP_SHIFT = 3.5 DEFAULT_CFG_TEXT_SCALE = 4.0 DEFAULT_RESOLUTION = "video_480p" DEFAULT_BASIC_SEED = -1 DEFAULT_HEIGHT = 480 DEFAULT_WIDTH = 848 DEFAULT_NUM_FRAMES = 50 DEFAULT_QUEUE_SIZE = 4 USE_KVCACHE = True TEXT_TEMPLATE = True RECORD_WRITE_LOCK = threading.Lock() TASK_T2V = "t2v" TASK_V2T = "v2t" TASK_X2T_VIDEO = "x2t_video" TASK_CHOICES = [TASK_T2V, TASK_V2T] VIDEO_RESOLUTION_CHOICES = [ "video_192p", "video_360p", "video_480p", ] V2T_SYSTEM_PROMPT = ( "Watch the video carefully and answer the question." ) # ========================================================= # HELPERS # ========================================================= def ensure_dirs(): TMP_INPUT_DIR.mkdir(parents=True, exist_ok=True) RESULTS_ROOT.mkdir(parents=True, exist_ok=True) def normalize_seed(seed: int): if seed == -1: return random.randint(0, 2**31 - 1) return seed def normalize_task(task: str): task = (task or DEFAULT_TASK).strip().lower() if task == TASK_V2T: return TASK_X2T_VIDEO return task # ========================================================= # PIPELINE # ========================================================= class LanceT2VV2TPipeline: def __init__(self): self.initialized = False self.logger = get_logger("lance_zerogpu") self.model = None self.vae_model = None self.vae_config = None self.tokenizer = None self.new_token_ids = None self.image_token_id = None self.base_model_args = None self.base_data_args = None self.base_inference_args = None self.lock = threading.Lock() def initialize(self): with self.lock: if self.initialized: return ensure_dirs() if not torch.cuda.is_available(): raise RuntimeError("CUDA unavailable") model_args = ModelArguments( model_path=str(DEFAULT_MODEL_PATH), vit_type=DEFAULT_VIT_TYPE, llm_qk_norm=True, llm_qk_norm_und=True, llm_qk_norm_gen=True, tie_word_embeddings=False, max_num_frames=121, max_latent_size=64, latent_patch_size=[1,1,1], ) data_args = DataArguments() inference_args = InferenceArguments( validation_num_timesteps=DEFAULT_TIMESTEPS, validation_timestep_shift=DEFAULT_TIMESTEP_SHIFT, copy_init_moe=True, visual_und=True, visual_gen=True, vae_model_type="wan", apply_qwen_2_5_vl_pos_emb=True, apply_chat_template=False, cfg_type=0, validation_data_seed=42, video_height=DEFAULT_HEIGHT, video_width=DEFAULT_WIDTH, num_frames=DEFAULT_NUM_FRAMES, task=DEFAULT_TASK, save_path_gen=str(RESULTS_ROOT), resolution=DEFAULT_RESOLUTION, text_template=TEXT_TEMPLATE, use_KVcache=USE_KVCACHE, ) apply_inference_defaults( model_args, data_args, inference_args, ) set_seed(inference_args.global_seed) llm_config = Qwen2Config.from_json_file( str(Path(model_args.model_path) / "llm_config.json") ) language_model = Qwen2ForCausalLM(llm_config) vit_config = Qwen2_5_VLVisionConfig.from_pretrained( model_args.vit_path ) vit_model = Qwen2_5_VisionTransformerPretrainedModel( vit_config ) vit_weights = load_file( str(Path(model_args.vit_path) / "vit.safetensors") ) vit_model.load_state_dict(vit_weights, strict=True) clean_memory(vit_weights) vae_model = WanVideoVAE() vae_config = deepcopy(vae_model.vae_config) config = LanceConfig( visual_gen=True, visual_und=True, llm_config=llm_config, vit_config=vit_config, vae_config=vae_config, latent_patch_size=model_args.latent_patch_size, max_num_frames=model_args.max_num_frames, max_latent_size=model_args.max_latent_size, vit_max_num_patch_per_side=model_args.vit_max_num_patch_per_side, connector_act=model_args.connector_act, interpolate_pos=model_args.interpolate_pos, timestep_shift=inference_args.timestep_shift, ) model = Lance( language_model=language_model, vit_model=vit_model, vit_type=model_args.vit_type, config=config, training_args=inference_args, ) model = model.half().to("cuda") tokenizer = Qwen2Tokenizer.from_pretrained( model_args.model_path ) tokenizer, new_token_ids, num_new_tokens = ( add_special_tokens(tokenizer) ) init_from_model_path_if_needed( model, model_args, ) image_token_id = model.language_model.config.video_token_id model.eval() self.model = model self.vae_model = vae_model self.vae_config = vae_config self.tokenizer = tokenizer self.new_token_ids = new_token_ids self.image_token_id = image_token_id self.base_model_args = model_args self.base_data_args = data_args self.base_inference_args = inference_args self.initialized = True print("Lance initialized successfully") def generate( self, task, prompt, input_video, question, height, width, num_frames, seed, resolution, validation_num_timesteps, validation_timestep_shift, cfg_text_scale, ): self.initialize() task = normalize_task(task) actual_seed = normalize_seed(int(seed)) try: save_dir = RESULTS_ROOT / str(time.time()) save_dir.mkdir(parents=True, exist_ok=True) inference_args = deepcopy( self.base_inference_args ) inference_args.video_height = int(height) inference_args.video_width = int(width) inference_args.num_frames = int(num_frames) inference_args.validation_num_timesteps = ( validation_num_timesteps ) inference_args.validation_timestep_shift = ( validation_timestep_shift ) inference_args.task = task prompt_file = TMP_INPUT_DIR / "prompt.json" if task == TASK_T2V: payload = { "000000.mp4": prompt } else: payload = { "000000": { "interleave_array": [ input_video, [ V2T_SYSTEM_PROMPT, question, "" ] ], "element_dtype_array": [ "video", "text" ], "istarget_in_interleave": [ 0, 1 ], } } with open(prompt_file, "w") as f: json.dump(payload, f) dataset_config = DataConfig.from_yaml( str(prompt_file) ) val_dataset = ValidationDataset( jsonl_path=str(prompt_file), tokenizer=self.tokenizer, data_args=self.base_data_args, model_args=self.base_model_args, training_args=inference_args, new_token_ids=self.new_token_ids, dataset_config=dataset_config, local_rank=0, world_size=1, ) val_data_cpu = simple_custom_collate( [val_dataset[0]] ) validate_on_fixed_batch( fsdp_model=self.model, vae_model=self.vae_model, tokenizer=self.tokenizer, val_data_cpu=val_data_cpu, training_args=inference_args, model_args=self.base_model_args, inference_args=inference_args, new_token_ids=self.new_token_ids, image_token_id=self.image_token_id, device="cuda", save_source_video=False, save_path_gen=str(save_dir), save_path_gt="", ) clean_memory() gc.collect() torch.cuda.empty_cache() videos = list(save_dir.glob("*.mp4")) if task == TASK_T2V: if len(videos) == 0: return None, "", "No video generated", "" return ( str(videos[0]), "", "Generation complete", "Success", ) return ( None, "Understanding complete", "Success", "", ) except Exception: err = traceback.format_exc() print(err) return None, "", "Generation failed", err # ========================================================= # GLOBAL PIPELINE # ========================================================= PIPELINE = LanceT2VV2TPipeline() # ========================================================= # SPACES GPU FUNCTION # ========================================================= @spaces.GPU(duration=300) def run_task( task, prompt, input_video, question, height, width, num_frames, seed, resolution, validation_num_timesteps, validation_timestep_shift, cfg_text_scale, ): return PIPELINE.generate( task=task, prompt=prompt, input_video=input_video, question=question, height=height, width=width, num_frames=num_frames, seed=seed, resolution=resolution, validation_num_timesteps=validation_num_timesteps, validation_timestep_shift=validation_timestep_shift, cfg_text_scale=cfg_text_scale, ) # ========================================================= # UI # ========================================================= with gr.Blocks(title="Lance ZeroGPU") as demo: gr.Markdown("# Lance T2V/V2T ZeroGPU") with gr.Row(): with gr.Column(): task = gr.Dropdown( label="Task", choices=TASK_CHOICES, value=DEFAULT_TASK, ) prompt = gr.Textbox( label="Prompt", lines=6, ) input_video = gr.Video( label="Input Video", ) question = gr.Textbox( label="Question", lines=3, ) height = gr.Slider( minimum=192, maximum=1024, step=16, value=DEFAULT_HEIGHT, label="Height", ) width = gr.Slider( minimum=192, maximum=1024, step=16, value=DEFAULT_WIDTH, label="Width", ) num_frames = gr.Slider( minimum=1, maximum=121, step=1, value=DEFAULT_NUM_FRAMES, label="Frames", ) seed = gr.Number( label="Seed", value=-1, precision=0, ) resolution = gr.Dropdown( label="Resolution", choices=VIDEO_RESOLUTION_CHOICES, value=DEFAULT_RESOLUTION, ) validation_num_timesteps = gr.Slider( minimum=1, maximum=100, step=1, value=DEFAULT_TIMESTEPS, label="Timesteps", ) validation_timestep_shift = gr.Number( label="Timestep Shift", value=DEFAULT_TIMESTEP_SHIFT, ) cfg_text_scale = gr.Number( label="CFG Text Scale", value=DEFAULT_CFG_TEXT_SCALE, ) run_button = gr.Button( "Run", variant="primary", ) with gr.Column(): output_video = gr.Video( label="Generated Video" ) output_text = gr.Textbox( label="Text Output", lines=8, ) status = gr.Markdown() logs = gr.Textbox( label="Logs", lines=20, ) run_button.click( fn=run_task, inputs=[ task, prompt, input_video, question, height, width, num_frames, seed, resolution, validation_num_timesteps, validation_timestep_shift, cfg_text_scale, ], outputs=[ output_video, output_text, status, logs, ], ) # ========================================================= # LAUNCH # ========================================================= demo.queue( max_size=4, default_concurrency_limit=1, ).launch( server_name="0.0.0.0", server_port=7860, )