| # ========================================================= |
| # ZERO GPU PATCHED VERSION FOR HUGGING FACE SPACES |
| # Based on original Lance app.py |
| # ========================================================= |
|
|
| from __future__ import annotations |
|
|
| import concurrent.futures |
| import gc |
| import json |
| import os |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" |
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
| import random |
| import threading |
| import time |
| import traceback |
|
|
| from collections import deque |
| from copy import deepcopy |
| from datetime import datetime |
| from pathlib import Path |
| from typing import Optional |
|
|
| import gradio as gr |
| import spaces |
| import torch |
|
|
| from huggingface_hub import login |
| from safetensors.torch import load_file |
| from transformers import set_seed |
| from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( |
| Qwen2_5_VLVisionConfig, |
| ) |
|
|
| from common.utils.logging import get_logger |
| from common.utils.misc import AutoEncoderParams, tuple_mul |
|
|
| from config.config_factory import ( |
| DataArguments, |
| InferenceArguments, |
| ModelArguments, |
| ) |
|
|
| from data.data_utils import add_special_tokens |
| from data.dataset_base import DataConfig, simple_custom_collate |
| from data.datasets_custom import ValidationDataset |
|
|
| from inference_lance import ( |
| PROMPT_JSON_FILENAME, |
| apply_inference_defaults, |
| clean_memory, |
| init_from_model_path_if_needed, |
| save_prompt_results, |
| validate_on_fixed_batch, |
| ) |
|
|
| from modeling.lance import Lance, LanceConfig, Qwen2ForCausalLM |
| from modeling.qwen2 import Qwen2Tokenizer |
| from modeling.qwen2.modeling_qwen2 import Qwen2Config |
| from modeling.vae.wan.model import WanVideoVAE |
| from modeling.vit.qwen2_5_vl_vit import ( |
| Qwen2_5_VisionTransformerPretrainedModel, |
| ) |
|
|
| # ========================================================= |
| # HF TOKEN |
| # ========================================================= |
|
|
| HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
| if HF_TOKEN: |
| login(token=HF_TOKEN) |
|
|
| # ========================================================= |
| # PERFORMANCE SETTINGS |
| # ========================================================= |
|
|
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| # ========================================================= |
| # PATHS |
| # ========================================================= |
|
|
| # ========================================================= |
| # PATHS |
| # ========================================================= |
|
|
| REPO_ROOT = Path(__file__).resolve().parent |
|
|
| GRADIO_TMP_ROOT = REPO_ROOT / "tmps" / "gradio_t2v_v2t" |
|
|
| TMP_INPUT_DIR = GRADIO_TMP_ROOT / "inputs" |
|
|
| RESULTS_ROOT = GRADIO_TMP_ROOT / "results" |
|
|
| GLOBAL_RECORDS_FILE = GRADIO_TMP_ROOT / "generation_records.jsonl" |
|
|
| RUN_RECORD_FILENAME = "generation_record.json" |
|
|
| # ========================================================= |
| # MODEL DOWNLOAD |
| # ========================================================= |
|
|
| from huggingface_hub import snapshot_download |
|
|
| MODEL_REPO = "bytedance-research/Lance" |
|
|
| MODEL_CACHE_DIR = REPO_ROOT / "downloads" |
|
|
| snapshot_download( |
| repo_id=MODEL_REPO, |
| local_dir=str(MODEL_CACHE_DIR), |
| local_dir_use_symlinks=False, |
| token=HF_TOKEN, |
| resume_download=True, |
| ) |
|
|
| DEFAULT_MODEL_PATH = str( |
| MODEL_CACHE_DIR / "Lance_3B_Video" |
| ) |
| print("DEFAULT_MODEL_PATH =", DEFAULT_MODEL_PATH) |
| print("FILES =", os.listdir(DEFAULT_MODEL_PATH)) |
| DEFAULT_VIT_TYPE = "qwen_2_5_vl_original" |
|
|
| DEFAULT_TASK = "t2v" |
|
|
| DEFAULT_TIMESTEPS = 30 |
|
|
| DEFAULT_TIMESTEP_SHIFT = 3.5 |
|
|
| DEFAULT_CFG_TEXT_SCALE = 4.0 |
|
|
| DEFAULT_RESOLUTION = "video_480p" |
|
|
| DEFAULT_BASIC_SEED = -1 |
|
|
| DEFAULT_HEIGHT = 480 |
|
|
| DEFAULT_WIDTH = 848 |
|
|
| DEFAULT_NUM_FRAMES = 50 |
|
|
| DEFAULT_QUEUE_SIZE = 4 |
|
|
| USE_KVCACHE = True |
|
|
| TEXT_TEMPLATE = True |
|
|
| RECORD_WRITE_LOCK = threading.Lock() |
|
|
| TASK_T2V = "t2v" |
|
|
| TASK_V2T = "v2t" |
|
|
| TASK_X2T_VIDEO = "x2t_video" |
|
|
| TASK_CHOICES = [TASK_T2V, TASK_V2T] |
|
|
| VIDEO_RESOLUTION_CHOICES = [ |
| "video_192p", |
| "video_360p", |
| "video_480p", |
| ] |
|
|
| V2T_SYSTEM_PROMPT = ( |
| "Watch the video carefully and answer the question." |
| ) |
|
|
| # ========================================================= |
| # HELPERS |
| # ========================================================= |
|
|
| def ensure_dirs(): |
| TMP_INPUT_DIR.mkdir(parents=True, exist_ok=True) |
| RESULTS_ROOT.mkdir(parents=True, exist_ok=True) |
|
|
| def normalize_seed(seed: int): |
| if seed == -1: |
| return random.randint(0, 2**31 - 1) |
| return seed |
|
|
| def normalize_task(task: str): |
|
|
| task = (task or DEFAULT_TASK).strip().lower() |
|
|
| if task == TASK_V2T: |
| return TASK_X2T_VIDEO |
|
|
| return task |
|
|
| # ========================================================= |
| # PIPELINE |
| # ========================================================= |
| |
| class LanceT2VV2TPipeline: |
| |
| def __init__(self): |
| |
| self.initialized = False |
| |
| self.logger = get_logger("lance_zerogpu") |
| |
| self.model = None |
| self.vae_model = None |
| self.vae_config = None |
| self.tokenizer = None |
| self.new_token_ids = None |
| self.image_token_id = None |
| |
| self.base_model_args = None |
| self.base_data_args = None |
| self.base_inference_args = None |
| |
| self.lock = threading.Lock() |
| |
| def initialize(self): |
| |
| with self.lock: |
| |
| if self.initialized: |
| return |
| |
| ensure_dirs() |
| |
| if not torch.cuda.is_available(): |
| raise RuntimeError("CUDA unavailable") |
| |
| model_args = ModelArguments( |
| model_path=str(DEFAULT_MODEL_PATH), |
| vit_type=DEFAULT_VIT_TYPE, |
| llm_qk_norm=True, |
| llm_qk_norm_und=True, |
| llm_qk_norm_gen=True, |
| tie_word_embeddings=False, |
| max_num_frames=121, |
| max_latent_size=64, |
| latent_patch_size=[1,1,1], |
| ) |
| |
| data_args = DataArguments() |
| |
| inference_args = InferenceArguments( |
| validation_num_timesteps=DEFAULT_TIMESTEPS, |
| validation_timestep_shift=DEFAULT_TIMESTEP_SHIFT, |
| copy_init_moe=True, |
| visual_und=True, |
| visual_gen=True, |
| vae_model_type="wan", |
| apply_qwen_2_5_vl_pos_emb=True, |
| apply_chat_template=False, |
| cfg_type=0, |
| validation_data_seed=42, |
| video_height=DEFAULT_HEIGHT, |
| video_width=DEFAULT_WIDTH, |
| num_frames=DEFAULT_NUM_FRAMES, |
| task=DEFAULT_TASK, |
| save_path_gen=str(RESULTS_ROOT), |
| resolution=DEFAULT_RESOLUTION, |
| text_template=TEXT_TEMPLATE, |
| use_KVcache=USE_KVCACHE, |
| ) |
| |
| apply_inference_defaults( |
| model_args, |
| data_args, |
| inference_args, |
| ) |
| |
| set_seed(inference_args.global_seed) |
| |
| llm_config = Qwen2Config.from_json_file( |
| str(Path(model_args.model_path) / "llm_config.json") |
| ) |
| |
| language_model = Qwen2ForCausalLM(llm_config) |
| |
| vit_config = Qwen2_5_VLVisionConfig.from_pretrained( |
| model_args.vit_path |
| ) |
| |
| vit_model = Qwen2_5_VisionTransformerPretrainedModel( |
| vit_config |
| ) |
| |
| vit_weights = load_file( |
| str(Path(model_args.vit_path) / "vit.safetensors") |
| ) |
| |
| vit_model.load_state_dict(vit_weights, strict=True) |
| |
| clean_memory(vit_weights) |
| |
| vae_model = WanVideoVAE() |
| |
| vae_config = deepcopy(vae_model.vae_config) |
| |
| config = LanceConfig( |
| visual_gen=True, |
| visual_und=True, |
| llm_config=llm_config, |
| vit_config=vit_config, |
| vae_config=vae_config, |
| latent_patch_size=model_args.latent_patch_size, |
| max_num_frames=model_args.max_num_frames, |
| max_latent_size=model_args.max_latent_size, |
| vit_max_num_patch_per_side=model_args.vit_max_num_patch_per_side, |
| connector_act=model_args.connector_act, |
| interpolate_pos=model_args.interpolate_pos, |
| timestep_shift=inference_args.timestep_shift, |
| ) |
| |
| model = Lance( |
| language_model=language_model, |
| vit_model=vit_model, |
| vit_type=model_args.vit_type, |
| config=config, |
| training_args=inference_args, |
| ) |
| |
| model = model.half().to("cuda") |
| |
| tokenizer = Qwen2Tokenizer.from_pretrained( |
| model_args.model_path |
| ) |
| |
| tokenizer, new_token_ids, num_new_tokens = ( |
| add_special_tokens(tokenizer) |
| ) |
| |
| init_from_model_path_if_needed( |
| model, |
| model_args, |
| ) |
| |
| image_token_id = model.language_model.config.video_token_id |
| |
| model.eval() |
| |
| self.model = model |
| self.vae_model = vae_model |
| self.vae_config = vae_config |
| self.tokenizer = tokenizer |
| self.new_token_ids = new_token_ids |
| self.image_token_id = image_token_id |
| |
| self.base_model_args = model_args |
| self.base_data_args = data_args |
| self.base_inference_args = inference_args |
| |
| self.initialized = True |
| |
| print("Lance initialized successfully") |
| |
| def generate( |
| self, |
| task, |
| prompt, |
| input_video, |
| question, |
| height, |
| width, |
| num_frames, |
| seed, |
| resolution, |
| validation_num_timesteps, |
| validation_timestep_shift, |
| cfg_text_scale, |
| ): |
| |
| self.initialize() |
| |
| task = normalize_task(task) |
| |
| actual_seed = normalize_seed(int(seed)) |
| |
| try: |
| |
| save_dir = RESULTS_ROOT / str(time.time()) |
| |
| save_dir.mkdir(parents=True, exist_ok=True) |
| |
| inference_args = deepcopy( |
| self.base_inference_args |
| ) |
| |
| inference_args.video_height = int(height) |
| inference_args.video_width = int(width) |
| inference_args.num_frames = int(num_frames) |
| inference_args.validation_num_timesteps = ( |
| validation_num_timesteps |
| ) |
| |
| inference_args.validation_timestep_shift = ( |
| validation_timestep_shift |
| ) |
| |
| inference_args.task = task |
| |
| prompt_file = TMP_INPUT_DIR / "prompt.json" |
| |
| if task == TASK_T2V: |
| |
| payload = { |
| "000000.mp4": prompt |
| } |
| |
| else: |
| |
| payload = { |
| "000000": { |
| "interleave_array": [ |
| input_video, |
| [ |
| V2T_SYSTEM_PROMPT, |
| question, |
| "" |
| ] |
| ], |
| "element_dtype_array": [ |
| "video", |
| "text" |
| ], |
| "istarget_in_interleave": [ |
| 0, |
| 1 |
| ], |
| } |
| } |
| |
| with open(prompt_file, "w") as f: |
| json.dump(payload, f) |
| |
| dataset_config = DataConfig.from_yaml( |
| str(prompt_file) |
| ) |
| |
| val_dataset = ValidationDataset( |
| jsonl_path=str(prompt_file), |
| tokenizer=self.tokenizer, |
| data_args=self.base_data_args, |
| model_args=self.base_model_args, |
| training_args=inference_args, |
| new_token_ids=self.new_token_ids, |
| dataset_config=dataset_config, |
| local_rank=0, |
| world_size=1, |
| ) |
| |
| val_data_cpu = simple_custom_collate( |
| [val_dataset[0]] |
| ) |
| |
| validate_on_fixed_batch( |
| fsdp_model=self.model, |
| vae_model=self.vae_model, |
| tokenizer=self.tokenizer, |
| val_data_cpu=val_data_cpu, |
| training_args=inference_args, |
| model_args=self.base_model_args, |
| inference_args=inference_args, |
| new_token_ids=self.new_token_ids, |
| image_token_id=self.image_token_id, |
| device="cuda", |
| save_source_video=False, |
| save_path_gen=str(save_dir), |
| save_path_gt="", |
| ) |
| |
| clean_memory() |
| |
| gc.collect() |
| |
| torch.cuda.empty_cache() |
| |
| videos = list(save_dir.glob("*.mp4")) |
| |
| if task == TASK_T2V: |
| |
| if len(videos) == 0: |
| return None, "", "No video generated", "" |
| |
| return ( |
| str(videos[0]), |
| "", |
| "Generation complete", |
| "Success", |
| ) |
| |
| return ( |
| None, |
| "Understanding complete", |
| "Success", |
| "", |
| ) |
| |
| except Exception: |
| |
| err = traceback.format_exc() |
| |
| print(err) |
| |
| return None, "", "Generation failed", err |
| |
| # ========================================================= |
| # GLOBAL PIPELINE |
| # ========================================================= |
| |
| PIPELINE = LanceT2VV2TPipeline() |
| |
| # ========================================================= |
| # SPACES GPU FUNCTION |
| # ========================================================= |
| |
| @spaces.GPU(duration=300) |
| def run_task( |
| task, |
| prompt, |
| input_video, |
| question, |
| height, |
| width, |
| num_frames, |
| seed, |
| resolution, |
| validation_num_timesteps, |
| validation_timestep_shift, |
| cfg_text_scale, |
| ): |
| |
| return PIPELINE.generate( |
| task=task, |
| prompt=prompt, |
| input_video=input_video, |
| question=question, |
| height=height, |
| width=width, |
| num_frames=num_frames, |
| seed=seed, |
| resolution=resolution, |
| validation_num_timesteps=validation_num_timesteps, |
| validation_timestep_shift=validation_timestep_shift, |
| cfg_text_scale=cfg_text_scale, |
| ) |
| |
| # ========================================================= |
| # UI |
| # ========================================================= |
| |
| with gr.Blocks(title="Lance ZeroGPU") as demo: |
| |
| gr.Markdown("# Lance T2V/V2T ZeroGPU") |
| |
| with gr.Row(): |
| |
| with gr.Column(): |
| |
| task = gr.Dropdown( |
| label="Task", |
| choices=TASK_CHOICES, |
| value=DEFAULT_TASK, |
| ) |
| |
| prompt = gr.Textbox( |
| label="Prompt", |
| lines=6, |
| ) |
| |
| input_video = gr.Video( |
| label="Input Video", |
| ) |
| |
| question = gr.Textbox( |
| label="Question", |
| lines=3, |
| ) |
| |
| height = gr.Slider( |
| minimum=192, |
| maximum=1024, |
| step=16, |
| value=DEFAULT_HEIGHT, |
| label="Height", |
| ) |
| |
| width = gr.Slider( |
| minimum=192, |
| maximum=1024, |
| step=16, |
| value=DEFAULT_WIDTH, |
| label="Width", |
| ) |
| |
| num_frames = gr.Slider( |
| minimum=1, |
| maximum=121, |
| step=1, |
| value=DEFAULT_NUM_FRAMES, |
| label="Frames", |
| ) |
| |
| seed = gr.Number( |
| label="Seed", |
| value=-1, |
| precision=0, |
| ) |
| |
| resolution = gr.Dropdown( |
| label="Resolution", |
| choices=VIDEO_RESOLUTION_CHOICES, |
| value=DEFAULT_RESOLUTION, |
| ) |
| |
| validation_num_timesteps = gr.Slider( |
| minimum=1, |
| maximum=100, |
| step=1, |
| value=DEFAULT_TIMESTEPS, |
| label="Timesteps", |
| ) |
| |
| validation_timestep_shift = gr.Number( |
| label="Timestep Shift", |
| value=DEFAULT_TIMESTEP_SHIFT, |
| ) |
| |
| cfg_text_scale = gr.Number( |
| label="CFG Text Scale", |
| value=DEFAULT_CFG_TEXT_SCALE, |
| ) |
| |
| run_button = gr.Button( |
| "Run", |
| variant="primary", |
| ) |
| |
| with gr.Column(): |
| |
| output_video = gr.Video( |
| label="Generated Video" |
| ) |
| |
| output_text = gr.Textbox( |
| label="Text Output", |
| lines=8, |
| ) |
| |
| status = gr.Markdown() |
| |
| logs = gr.Textbox( |
| label="Logs", |
| lines=20, |
| ) |
| |
| run_button.click( |
| fn=run_task, |
| inputs=[ |
| task, |
| prompt, |
| input_video, |
| question, |
| height, |
| width, |
| num_frames, |
| seed, |
| resolution, |
| validation_num_timesteps, |
| validation_timestep_shift, |
| cfg_text_scale, |
| ], |
| outputs=[ |
| output_video, |
| output_text, |
| status, |
| logs, |
| ], |
| ) |
| |
| # ========================================================= |
| # LAUNCH |
| # ========================================================= |
| |
| demo.queue( |
| max_size=4, |
| default_concurrency_limit=1, |
| ).launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| ) |