# ========================================================= # ZERO GPU PATCHED + ALL TASKS ENABLED # Hugging Face Spaces Compatible # ========================================================= from __future__ import annotations import gc import json import os import random import threading import time import traceback from copy import deepcopy from pathlib import Path import gradio as gr import spaces import torch # ========================================================= # ENV # ========================================================= os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" os.environ["TRANSFORMERS_NO_FLASH_ATTN"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" # ========================================================= # LOGIN # ========================================================= from huggingface_hub import login, snapshot_download HF_TOKEN = os.getenv("HF_TOKEN") if HF_TOKEN: login(token=HF_TOKEN) # ========================================================= # IMPORTS # ========================================================= from safetensors.torch import load_file from transformers import set_seed from transformers.utils import is_flash_attn_2_available from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( Qwen2_5_VLVisionConfig, ) from common.utils.logging import get_logger from config.config_factory import ( DataArguments, InferenceArguments, ModelArguments, ) from data.data_utils import add_special_tokens from data.dataset_base import DataConfig, simple_custom_collate from data.datasets_custom import ValidationDataset from inference_lance import ( apply_inference_defaults, clean_memory, init_from_model_path_if_needed, validate_on_fixed_batch, ) from modeling.lance import ( Lance, LanceConfig, Qwen2ForCausalLM, ) from modeling.qwen2 import Qwen2Tokenizer from modeling.qwen2.modeling_qwen2 import Qwen2Config from modeling.vae.wan.model import WanVideoVAE from modeling.vit.qwen2_5_vl_vit import ( Qwen2_5_VisionTransformerPretrainedModel, ) print("FlashAttention2 available:", is_flash_attn_2_available()) # ========================================================= # PERFORMANCE # ========================================================= torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # ========================================================= # PATHS # ========================================================= REPO_ROOT = Path(__file__).resolve().parent TMP_ROOT = REPO_ROOT / "tmps" TMP_ROOT.mkdir(parents=True, exist_ok=True) TMP_INPUT_DIR = TMP_ROOT / "inputs" TMP_INPUT_DIR.mkdir(parents=True, exist_ok=True) RESULTS_ROOT = TMP_ROOT / "results" RESULTS_ROOT.mkdir(parents=True, exist_ok=True) # ========================================================= # MODEL DOWNLOAD # ========================================================= MODEL_REPO = "bytedance-research/Lance" MODEL_CACHE_DIR = REPO_ROOT / "downloads" snapshot_download( repo_id=MODEL_REPO, allow_patterns=[ "Lance_3B_Video/*", ], local_dir=str(MODEL_CACHE_DIR), local_dir_use_symlinks=False, token=HF_TOKEN, ) DEFAULT_MODEL_PATH = str( MODEL_CACHE_DIR / "Lance_3B_Video" ) print("DEFAULT_MODEL_PATH =", DEFAULT_MODEL_PATH) # ========================================================= # DEFAULTS # ========================================================= DEFAULT_TIMESTEPS = 30 DEFAULT_TIMESTEP_SHIFT = 3.5 DEFAULT_CFG_TEXT_SCALE = 4.0 DEFAULT_HEIGHT = 480 DEFAULT_WIDTH = 848 DEFAULT_NUM_FRAMES = 50 DEFAULT_RESOLUTION = "video_480p" DEFAULT_TASK = "t2v" DEFAULT_SEED = -1 DEFAULT_VIT_TYPE = "qwen_2_5_vl_original" # ========================================================= # TASKS # ========================================================= TASK_T2V = "t2v" TASK_T2I = "t2i" TASK_IMAGE_EDIT = "image_edit" TASK_VIDEO_EDIT = "video_edit" TASK_X2T_IMAGE = "x2t_image" TASK_X2T_VIDEO = "x2t_video" TASK_CHOICES = [ TASK_T2V, TASK_T2I, TASK_IMAGE_EDIT, TASK_VIDEO_EDIT, TASK_X2T_IMAGE, TASK_X2T_VIDEO, ] VIDEO_RESOLUTION_CHOICES = [ "video_192p", "video_360p", "video_480p", ] # ========================================================= # HELPERS # ========================================================= def normalize_seed(seed: int): if seed == -1: return random.randint(0, 2**31 - 1) return seed def normalize_task(task: str): task = (task or DEFAULT_TASK).strip().lower() valid = { TASK_T2V, TASK_T2I, TASK_IMAGE_EDIT, TASK_VIDEO_EDIT, TASK_X2T_IMAGE, TASK_X2T_VIDEO, } if task not in valid: return DEFAULT_TASK return task # ========================================================= # PIPELINE # ========================================================= class LancePipeline: def __init__(self): self.initialized = False self.lock = threading.Lock() self.logger = get_logger("lance") def initialize(self): with self.lock: if self.initialized: return if not torch.cuda.is_available(): raise RuntimeError("CUDA unavailable") model_args = ModelArguments( model_path=DEFAULT_MODEL_PATH, vit_type=DEFAULT_VIT_TYPE, llm_qk_norm=True, llm_qk_norm_und=True, llm_qk_norm_gen=True, tie_word_embeddings=False, max_num_frames=121, max_latent_size=64, latent_patch_size=[1,1,1], ) data_args = DataArguments() inference_args = InferenceArguments( validation_num_timesteps=DEFAULT_TIMESTEPS, validation_timestep_shift=DEFAULT_TIMESTEP_SHIFT, copy_init_moe=True, visual_und=True, visual_gen=True, vae_model_type="wan", apply_qwen_2_5_vl_pos_emb=True, apply_chat_template=False, cfg_type=0, validation_data_seed=42, video_height=DEFAULT_HEIGHT, video_width=DEFAULT_WIDTH, num_frames=DEFAULT_NUM_FRAMES, task=DEFAULT_TASK, save_path_gen=str(RESULTS_ROOT), resolution=DEFAULT_RESOLUTION, text_template=True, use_KVcache=True, ) apply_inference_defaults( model_args, data_args, inference_args, ) set_seed(42) llm_config = Qwen2Config.from_json_file( str(Path(model_args.model_path) / "llm_config.json") ) language_model = Qwen2ForCausalLM(llm_config) vit_config = Qwen2_5_VLVisionConfig.from_pretrained( model_args.vit_path ) vit_config._attn_implementation = "eager" vit_model = Qwen2_5_VisionTransformerPretrainedModel( vit_config ) vit_weights = load_file( str(Path(model_args.vit_path) / "vit.safetensors") ) vit_model.load_state_dict(vit_weights, strict=True) clean_memory(vit_weights) vae_model = WanVideoVAE() vae_config = deepcopy(vae_model.vae_config) config = LanceConfig( visual_gen=True, visual_und=True, llm_config=llm_config, vit_config=vit_config, vae_config=vae_config, latent_patch_size=model_args.latent_patch_size, max_num_frames=model_args.max_num_frames, max_latent_size=model_args.max_latent_size, vit_max_num_patch_per_side=model_args.vit_max_num_patch_per_side, connector_act=model_args.connector_act, interpolate_pos=model_args.interpolate_pos, timestep_shift=inference_args.timestep_shift, ) model = Lance( language_model=language_model, vit_model=vit_model, vit_type=model_args.vit_type, config=config, training_args=inference_args, ) model = model.to( device="cuda", dtype=torch.bfloat16, ) tokenizer = Qwen2Tokenizer.from_pretrained( model_args.model_path ) tokenizer, new_token_ids, _ = add_special_tokens( tokenizer ) init_from_model_path_if_needed( model, model_args, ) image_token_id = model.language_model.config.video_token_id model.eval() self.model = model self.vae_model = vae_model self.vae_config = vae_config self.tokenizer = tokenizer self.new_token_ids = new_token_ids self.image_token_id = image_token_id self.base_model_args = model_args self.base_data_args = data_args self.base_inference_args = inference_args self.initialized = True print("Lance initialized successfully") def generate( self, task, prompt, input_image, input_video, question, height, width, num_frames, seed, resolution, validation_num_timesteps, validation_timestep_shift, cfg_text_scale, ): task = normalize_task(task) actual_seed = normalize_seed(int(seed)) set_seed(actual_seed) save_dir = RESULTS_ROOT / str(time.time()) save_dir.mkdir(parents=True, exist_ok=True) inference_args = deepcopy( self.base_inference_args ) inference_args.video_height = int(height) inference_args.video_width = int(width) inference_args.num_frames = int(num_frames) inference_args.validation_num_timesteps = ( validation_num_timesteps ) inference_args.validation_timestep_shift = ( validation_timestep_shift ) inference_args.task = task prompt_file = TMP_INPUT_DIR / "prompt.json" # ===================================================== # PAYLOADS # ===================================================== if task == TASK_T2V: payload = { "000000.mp4": prompt } elif task == TASK_T2I: payload = { "000000.png": prompt } elif task == TASK_IMAGE_EDIT: payload = { "000000": { "interleave_array": [ input_image, [prompt, ""] ], "element_dtype_array": [ "image", "text" ], "istarget_in_interleave": [ 0, 1 ], } } elif task == TASK_VIDEO_EDIT: payload = { "000000": { "interleave_array": [ input_video, [prompt, ""] ], "element_dtype_array": [ "video", "text" ], "istarget_in_interleave": [ 0, 1 ], } } elif task == TASK_X2T_IMAGE: payload = { "000000": { "interleave_array": [ input_image, [ "Describe the image", question, "" ] ], "element_dtype_array": [ "image", "text" ], "istarget_in_interleave": [ 0, 1 ], } } elif task == TASK_X2T_VIDEO: payload = { "000000": { "interleave_array": [ input_video, [ "Describe the video", question, "" ] ], "element_dtype_array": [ "video", "text" ], "istarget_in_interleave": [ 0, 1 ], } } else: return ( None, None, "", "Invalid task", "", ) with open(prompt_file, "w") as f: json.dump(payload, f) dataset_config = DataConfig.from_yaml( str(prompt_file) ) val_dataset = ValidationDataset( jsonl_path=str(prompt_file), tokenizer=self.tokenizer, data_args=self.base_data_args, model_args=self.base_model_args, training_args=inference_args, new_token_ids=self.new_token_ids, dataset_config=dataset_config, local_rank=0, world_size=1, ) val_data_cpu = simple_custom_collate( [val_dataset[0]] ) validate_on_fixed_batch( fsdp_model=self.model, vae_model=self.vae_model, tokenizer=self.tokenizer, val_data_cpu=val_data_cpu, training_args=inference_args, model_args=self.base_model_args, inference_args=inference_args, new_token_ids=self.new_token_ids, image_token_id=self.image_token_id, device="cuda", save_source_video=False, save_path_gen=str(save_dir), save_path_gt="", ) clean_memory() gc.collect() torch.cuda.empty_cache() videos = list(save_dir.glob("*.mp4")) images = list(save_dir.glob("*.png")) if len(videos) > 0: return ( str(videos[0]), None, "", "Success", "", ) if len(images) > 0: return ( None, str(images[0]), "", "Success", "", ) if task in [TASK_X2T_IMAGE, TASK_X2T_VIDEO]: return ( None, None, "Understanding complete", "Success", "", ) return ( None, None, "", "No output generated", "", ) # ========================================================= # GLOBAL # ========================================================= PIPELINE = LancePipeline() # ========================================================= # GPU RUN # ========================================================= @spaces.GPU(duration=300) def run_task( task, prompt, input_image, input_video, question, height, width, num_frames, seed, resolution, validation_num_timesteps, validation_timestep_shift, cfg_text_scale, ): PIPELINE.initialize() return PIPELINE.generate( task=task, prompt=prompt, input_image=input_image, input_video=input_video, question=question, height=height, width=width, num_frames=num_frames, seed=seed, resolution=resolution, validation_num_timesteps=validation_num_timesteps, validation_timestep_shift=validation_timestep_shift, cfg_text_scale=cfg_text_scale, ) # ========================================================= # UI # ========================================================= with gr.Blocks(title="Lance ZeroGPU") as demo: gr.Markdown("# Lance Multi-Task ZeroGPU") with gr.Row(): with gr.Column(): task = gr.Dropdown( label="Task", choices=TASK_CHOICES, value=DEFAULT_TASK, ) prompt = gr.Textbox( label="Prompt", lines=5, ) input_image = gr.Image( label="Input Image", type="filepath", ) input_video = gr.Video( label="Input Video", ) question = gr.Textbox( label="Question", lines=3, ) height = gr.Slider( minimum=192, maximum=1024, step=16, value=DEFAULT_HEIGHT, label="Height", ) width = gr.Slider( minimum=192, maximum=1024, step=16, value=DEFAULT_WIDTH, label="Width", ) num_frames = gr.Slider( minimum=1, maximum=121, step=1, value=DEFAULT_NUM_FRAMES, label="Frames", ) seed = gr.Number( label="Seed", value=-1, precision=0, ) resolution = gr.Dropdown( label="Resolution", choices=VIDEO_RESOLUTION_CHOICES, value=DEFAULT_RESOLUTION, ) validation_num_timesteps = gr.Slider( minimum=1, maximum=100, step=1, value=DEFAULT_TIMESTEPS, label="Timesteps", ) validation_timestep_shift = gr.Number( label="Timestep Shift", value=DEFAULT_TIMESTEP_SHIFT, ) cfg_text_scale = gr.Number( label="CFG Text Scale", value=DEFAULT_CFG_TEXT_SCALE, ) run_button = gr.Button( "Run", variant="primary", ) with gr.Column(): output_video = gr.Video( label="Generated Video" ) output_image = gr.Image( label="Generated Image" ) output_text = gr.Textbox( label="Text Output", lines=8, ) status = gr.Markdown() logs = gr.Textbox( label="Logs", lines=15, ) run_button.click( fn=run_task, inputs=[ task, prompt, input_image, input_video, question, height, width, num_frames, seed, resolution, validation_num_timesteps, validation_timestep_shift, cfg_text_scale, ], outputs=[ output_video, output_image, output_text, status, logs, ], ) # ========================================================= # LAUNCH # ========================================================= demo.queue( max_size=4, default_concurrency_limit=1, ).launch( server_name="0.0.0.0", server_port=7860, )