| |
| |
| |
| |
| |
|
|
| from __future__ import annotations |
|
|
| import gc |
| import json |
| import os |
| import random |
| import threading |
| import time |
| import traceback |
|
|
| from copy import deepcopy |
| from pathlib import Path |
|
|
| import gradio as gr |
| import spaces |
| import torch |
|
|
| |
| |
| |
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" |
| os.environ["TRANSFORMERS_NO_FLASH_ATTN"] = "1" |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" |
| os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" |
|
|
| |
| |
| |
|
|
| from huggingface_hub import ( |
| login, |
| snapshot_download, |
| hf_hub_download, |
| ) |
|
|
| HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
| if HF_TOKEN: |
| login(token=HF_TOKEN) |
|
|
| |
| |
| |
|
|
| from safetensors.torch import load_file |
|
|
| from transformers import ( |
| set_seed, |
| AutoConfig, |
| ) |
|
|
| from transformers.utils import is_flash_attn_2_available |
|
|
| from common.utils.logging import get_logger |
|
|
| from config.config_factory import ( |
| DataArguments, |
| InferenceArguments, |
| ModelArguments, |
| ) |
|
|
| from data.data_utils import add_special_tokens |
| from data.dataset_base import DataConfig, simple_custom_collate |
| from data.datasets_custom import ValidationDataset |
|
|
| from inference_lance import ( |
| apply_inference_defaults, |
| clean_memory, |
| init_from_model_path_if_needed, |
| validate_on_fixed_batch, |
| ) |
|
|
| from modeling.lance import ( |
| Lance, |
| LanceConfig, |
| Qwen2ForCausalLM, |
| ) |
|
|
| from modeling.qwen2 import Qwen2Tokenizer |
| from modeling.qwen2.modeling_qwen2 import Qwen2Config |
|
|
| from modeling.vae.wan.model import WanVideoVAE |
|
|
| from modeling.vit.qwen2_5_vl_vit import ( |
| Qwen2_5_VisionTransformerPretrainedModel, |
| ) |
|
|
| print("FlashAttention2 available:", is_flash_attn_2_available()) |
|
|
| |
| |
| |
|
|
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| |
| |
| |
|
|
| REPO_ROOT = Path(__file__).resolve().parent |
|
|
| TMP_ROOT = REPO_ROOT / "tmps" |
| TMP_ROOT.mkdir(parents=True, exist_ok=True) |
|
|
| TMP_INPUT_DIR = TMP_ROOT / "inputs" |
| TMP_INPUT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| RESULTS_ROOT = TMP_ROOT / "results" |
| RESULTS_ROOT.mkdir(parents=True, exist_ok=True) |
|
|
| |
| |
| |
|
|
| MODEL_REPO = "bytedance-research/Lance" |
|
|
| MODEL_CACHE_DIR = REPO_ROOT / "downloads" |
|
|
| snapshot_download( |
| repo_id=MODEL_REPO, |
| allow_patterns=[ |
| "Lance_3B_Video/*", |
| ], |
| local_dir=str(MODEL_CACHE_DIR), |
| local_dir_use_symlinks=False, |
| token=HF_TOKEN, |
| ) |
|
|
| DEFAULT_MODEL_PATH = str( |
| MODEL_CACHE_DIR / "Lance_3B_Video" |
| ) |
|
|
| print("DEFAULT_MODEL_PATH =", DEFAULT_MODEL_PATH) |
|
|
| |
| |
| |
|
|
| QWEN_VL_REPO = "Qwen/Qwen2.5-VL-7B-Instruct" |
|
|
| |
| |
| |
|
|
| DEFAULT_TIMESTEPS = 30 |
| DEFAULT_TIMESTEP_SHIFT = 3.5 |
| DEFAULT_CFG_TEXT_SCALE = 4.0 |
|
|
| DEFAULT_HEIGHT = 480 |
| DEFAULT_WIDTH = 848 |
| DEFAULT_NUM_FRAMES = 50 |
|
|
| DEFAULT_RESOLUTION = "video_480p" |
|
|
| DEFAULT_TASK = "t2v" |
|
|
| DEFAULT_SEED = -1 |
|
|
| DEFAULT_VIT_TYPE = "qwen_2_5_vl_original" |
|
|
| |
| |
| |
|
|
| TASK_T2V = "t2v" |
| TASK_T2I = "t2i" |
|
|
| TASK_IMAGE_EDIT = "image_edit" |
| TASK_VIDEO_EDIT = "video_edit" |
|
|
| TASK_X2T_IMAGE = "x2t_image" |
| TASK_X2T_VIDEO = "x2t_video" |
|
|
| TASK_CHOICES = [ |
| TASK_T2V, |
| TASK_T2I, |
| TASK_IMAGE_EDIT, |
| TASK_VIDEO_EDIT, |
| TASK_X2T_IMAGE, |
| TASK_X2T_VIDEO, |
| ] |
|
|
| VIDEO_RESOLUTION_CHOICES = [ |
| "video_192p", |
| "video_360p", |
| "video_480p", |
| ] |
|
|
| |
| |
| |
|
|
| def normalize_seed(seed: int): |
|
|
| if seed == -1: |
| return random.randint(0, 2**31 - 1) |
|
|
| return seed |
|
|
|
|
| def normalize_task(task: str): |
|
|
| task = (task or DEFAULT_TASK).strip().lower() |
|
|
| valid = { |
| TASK_T2V, |
| TASK_T2I, |
| TASK_IMAGE_EDIT, |
| TASK_VIDEO_EDIT, |
| TASK_X2T_IMAGE, |
| TASK_X2T_VIDEO, |
| } |
|
|
| if task not in valid: |
| return DEFAULT_TASK |
|
|
| return task |
|
|
| |
| |
| |
|
|
| class LancePipeline: |
|
|
| def __init__(self): |
|
|
| self.initialized = False |
|
|
| self.lock = threading.Lock() |
|
|
| self.logger = get_logger("lance") |
|
|
| def initialize(self): |
|
|
| with self.lock: |
|
|
| if self.initialized: |
| return |
|
|
| if not torch.cuda.is_available(): |
| raise RuntimeError("CUDA unavailable") |
|
|
| print("Initializing Lance...") |
|
|
| model_args = ModelArguments( |
| model_path=DEFAULT_MODEL_PATH, |
| vit_type=DEFAULT_VIT_TYPE, |
| llm_qk_norm=True, |
| llm_qk_norm_und=True, |
| llm_qk_norm_gen=True, |
| tie_word_embeddings=False, |
| max_num_frames=121, |
| max_latent_size=64, |
| latent_patch_size=[1,1,1], |
| ) |
|
|
| |
| |
| |
|
|
| model_args.vit_path = QWEN_VL_REPO |
|
|
| data_args = DataArguments() |
|
|
| inference_args = InferenceArguments( |
| validation_num_timesteps=DEFAULT_TIMESTEPS, |
| validation_timestep_shift=DEFAULT_TIMESTEP_SHIFT, |
| copy_init_moe=True, |
| visual_und=True, |
| visual_gen=True, |
| vae_model_type="wan", |
| apply_qwen_2_5_vl_pos_emb=True, |
| apply_chat_template=False, |
| cfg_type=0, |
| validation_data_seed=42, |
| video_height=DEFAULT_HEIGHT, |
| video_width=DEFAULT_WIDTH, |
| num_frames=DEFAULT_NUM_FRAMES, |
| task=DEFAULT_TASK, |
| save_path_gen=str(RESULTS_ROOT), |
| resolution=DEFAULT_RESOLUTION, |
| text_template=True, |
| use_KVcache=True, |
| ) |
|
|
| apply_inference_defaults( |
| model_args, |
| data_args, |
| inference_args, |
| ) |
|
|
| set_seed(42) |
|
|
| |
| |
| |
|
|
| llm_config = Qwen2Config.from_json_file( |
| str(Path(model_args.model_path) / "llm_config.json") |
| ) |
|
|
| language_model = Qwen2ForCausalLM(llm_config) |
|
|
| |
| |
| |
|
|
| print("Loading Qwen2.5-VL config...") |
|
|
| full_qwen_config = AutoConfig.from_pretrained( |
| QWEN_VL_REPO, |
| token=HF_TOKEN, |
| trust_remote_code=True, |
| ) |
|
|
| vit_config = full_qwen_config.vision_config |
|
|
| vit_config._attn_implementation = "eager" |
|
|
| print("Creating vision transformer...") |
|
|
| vit_model = Qwen2_5_VisionTransformerPretrainedModel( |
| vit_config |
| ) |
|
|
| |
| |
| |
|
|
| print("Downloading Qwen weights...") |
|
|
| vit_weights_path = hf_hub_download( |
| repo_id=QWEN_VL_REPO, |
| filename="model.safetensors", |
| token=HF_TOKEN, |
| ) |
|
|
| print("Loading VIT weights...") |
|
|
| vit_weights = load_file(vit_weights_path) |
|
|
| missing, unexpected = vit_model.load_state_dict( |
| vit_weights, |
| strict=False, |
| ) |
|
|
| print("Missing keys:", len(missing)) |
| print("Unexpected keys:", len(unexpected)) |
|
|
| clean_memory(vit_weights) |
|
|
| |
| |
| |
|
|
| vae_model = WanVideoVAE() |
|
|
| vae_config = deepcopy(vae_model.vae_config) |
|
|
| |
| |
| |
|
|
| config = LanceConfig( |
| visual_gen=True, |
| visual_und=True, |
| llm_config=llm_config, |
| vit_config=vit_config, |
| vae_config=vae_config, |
| latent_patch_size=model_args.latent_patch_size, |
| max_num_frames=model_args.max_num_frames, |
| max_latent_size=model_args.max_latent_size, |
| vit_max_num_patch_per_side=model_args.vit_max_num_patch_per_side, |
| connector_act=model_args.connector_act, |
| interpolate_pos=model_args.interpolate_pos, |
| timestep_shift=inference_args.timestep_shift, |
| ) |
|
|
| model = Lance( |
| language_model=language_model, |
| vit_model=vit_model, |
| vit_type=model_args.vit_type, |
| config=config, |
| training_args=inference_args, |
| ) |
|
|
| print("Moving model to CUDA...") |
|
|
| model = model.to( |
| device="cuda", |
| dtype=torch.bfloat16, |
| ) |
|
|
| tokenizer = Qwen2Tokenizer.from_pretrained( |
| model_args.model_path |
| ) |
|
|
| tokenizer, new_token_ids, _ = add_special_tokens( |
| tokenizer |
| ) |
|
|
| init_from_model_path_if_needed( |
| model, |
| model_args, |
| ) |
|
|
| image_token_id = model.language_model.config.video_token_id |
|
|
| model.eval() |
|
|
| self.model = model |
| self.vae_model = vae_model |
| self.vae_config = vae_config |
|
|
| self.tokenizer = tokenizer |
| self.new_token_ids = new_token_ids |
| self.image_token_id = image_token_id |
|
|
| self.base_model_args = model_args |
| self.base_data_args = data_args |
| self.base_inference_args = inference_args |
|
|
| self.initialized = True |
|
|
| print("Lance initialized successfully") |
|
|
| |
| |
| |
|
|
| def generate( |
| self, |
| task, |
| prompt, |
| input_image, |
| input_video, |
| question, |
| height, |
| width, |
| num_frames, |
| seed, |
| resolution, |
| validation_num_timesteps, |
| validation_timestep_shift, |
| cfg_text_scale, |
| ): |
|
|
| try: |
|
|
| task = normalize_task(task) |
|
|
| actual_seed = normalize_seed(int(seed)) |
|
|
| set_seed(actual_seed) |
|
|
| save_dir = RESULTS_ROOT / str(time.time()) |
| save_dir.mkdir(parents=True, exist_ok=True) |
|
|
| inference_args = deepcopy( |
| self.base_inference_args |
| ) |
|
|
| inference_args.video_height = int(height) |
| inference_args.video_width = int(width) |
| inference_args.num_frames = int(num_frames) |
|
|
| inference_args.validation_num_timesteps = ( |
| validation_num_timesteps |
| ) |
|
|
| inference_args.validation_timestep_shift = ( |
| validation_timestep_shift |
| ) |
|
|
| inference_args.task = task |
|
|
| prompt_file = TMP_INPUT_DIR / "prompt.json" |
|
|
| |
| |
| |
|
|
| if task == TASK_T2V: |
|
|
| payload = { |
| "000000.mp4": prompt |
| } |
|
|
| elif task == TASK_T2I: |
|
|
| payload = { |
| "000000.png": prompt |
| } |
|
|
| elif task == TASK_IMAGE_EDIT: |
|
|
| payload = { |
| "000000": { |
| "interleave_array": [ |
| input_image, |
| [prompt, ""] |
| ], |
| "element_dtype_array": [ |
| "image", |
| "text" |
| ], |
| "istarget_in_interleave": [ |
| 0, |
| 1 |
| ], |
| } |
| } |
|
|
| elif task == TASK_VIDEO_EDIT: |
|
|
| payload = { |
| "000000": { |
| "interleave_array": [ |
| input_video, |
| [prompt, ""] |
| ], |
| "element_dtype_array": [ |
| "video", |
| "text" |
| ], |
| "istarget_in_interleave": [ |
| 0, |
| 1 |
| ], |
| } |
| } |
|
|
| elif task == TASK_X2T_IMAGE: |
|
|
| payload = { |
| "000000": { |
| "interleave_array": [ |
| input_image, |
| [ |
| "Describe the image", |
| question, |
| "" |
| ] |
| ], |
| "element_dtype_array": [ |
| "image", |
| "text" |
| ], |
| "istarget_in_interleave": [ |
| 0, |
| 1 |
| ], |
| } |
| } |
|
|
| elif task == TASK_X2T_VIDEO: |
|
|
| payload = { |
| "000000": { |
| "interleave_array": [ |
| input_video, |
| [ |
| "Describe the video", |
| question, |
| "" |
| ] |
| ], |
| "element_dtype_array": [ |
| "video", |
| "text" |
| ], |
| "istarget_in_interleave": [ |
| 0, |
| 1 |
| ], |
| } |
| } |
|
|
| else: |
|
|
| return ( |
| None, |
| None, |
| "", |
| "Invalid task", |
| "", |
| ) |
|
|
| with open(prompt_file, "w") as f: |
| json.dump(payload, f) |
|
|
| dataset_config = DataConfig.from_yaml( |
| str(prompt_file) |
| ) |
|
|
| val_dataset = ValidationDataset( |
| jsonl_path=str(prompt_file), |
| tokenizer=self.tokenizer, |
| data_args=self.base_data_args, |
| model_args=self.base_model_args, |
| training_args=inference_args, |
| new_token_ids=self.new_token_ids, |
| dataset_config=dataset_config, |
| local_rank=0, |
| world_size=1, |
| ) |
|
|
| val_data_cpu = simple_custom_collate( |
| [val_dataset[0]] |
| ) |
|
|
| validate_on_fixed_batch( |
| fsdp_model=self.model, |
| vae_model=self.vae_model, |
| tokenizer=self.tokenizer, |
| val_data_cpu=val_data_cpu, |
| training_args=inference_args, |
| model_args=self.base_model_args, |
| inference_args=inference_args, |
| new_token_ids=self.new_token_ids, |
| image_token_id=self.image_token_id, |
| device="cuda", |
| save_source_video=False, |
| save_path_gen=str(save_dir), |
| save_path_gt="", |
| ) |
|
|
| clean_memory() |
|
|
| gc.collect() |
|
|
| torch.cuda.empty_cache() |
|
|
| videos = list(save_dir.glob("*.mp4")) |
| images = list(save_dir.glob("*.png")) |
|
|
| if len(videos) > 0: |
|
|
| return ( |
| str(videos[0]), |
| None, |
| "", |
| "Success", |
| "", |
| ) |
|
|
| if len(images) > 0: |
|
|
| return ( |
| None, |
| str(images[0]), |
| "", |
| "Success", |
| "", |
| ) |
|
|
| if task in [TASK_X2T_IMAGE, TASK_X2T_VIDEO]: |
|
|
| return ( |
| None, |
| None, |
| "Understanding complete", |
| "Success", |
| "", |
| ) |
|
|
| return ( |
| None, |
| None, |
| "", |
| "No output generated", |
| "", |
| ) |
|
|
| except Exception as e: |
|
|
| traceback.print_exc() |
|
|
| return ( |
| None, |
| None, |
| "", |
| f"ERROR: {str(e)}", |
| traceback.format_exc(), |
| ) |
|
|
| |
| |
| |
|
|
| PIPELINE = LancePipeline() |
|
|
| |
| |
| |
|
|
| @spaces.GPU(duration=300) |
| def run_task( |
| task, |
| prompt, |
| input_image, |
| input_video, |
| question, |
| height, |
| width, |
| num_frames, |
| seed, |
| resolution, |
| validation_num_timesteps, |
| validation_timestep_shift, |
| cfg_text_scale, |
| ): |
|
|
| PIPELINE.initialize() |
|
|
| return PIPELINE.generate( |
| task=task, |
| prompt=prompt, |
| input_image=input_image, |
| input_video=input_video, |
| question=question, |
| height=height, |
| width=width, |
| num_frames=num_frames, |
| seed=seed, |
| resolution=resolution, |
| validation_num_timesteps=validation_num_timesteps, |
| validation_timestep_shift=validation_timestep_shift, |
| cfg_text_scale=cfg_text_scale, |
| ) |
|
|
| |
| |
| |
|
|
| with gr.Blocks(title="Lance ZeroGPU") as demo: |
|
|
| gr.Markdown("# Lance Multi-Task ZeroGPU") |
|
|
| with gr.Row(): |
|
|
| with gr.Column(): |
|
|
| task = gr.Dropdown( |
| label="Task", |
| choices=TASK_CHOICES, |
| value=DEFAULT_TASK, |
| ) |
|
|
| prompt = gr.Textbox( |
| label="Prompt", |
| lines=5, |
| ) |
|
|
| input_image = gr.Image( |
| label="Input Image", |
| type="filepath", |
| ) |
|
|
| input_video = gr.Video( |
| label="Input Video", |
| ) |
|
|
| question = gr.Textbox( |
| label="Question", |
| lines=3, |
| ) |
|
|
| height = gr.Slider( |
| minimum=192, |
| maximum=1024, |
| step=16, |
| value=DEFAULT_HEIGHT, |
| label="Height", |
| ) |
|
|
| width = gr.Slider( |
| minimum=192, |
| maximum=1024, |
| step=16, |
| value=DEFAULT_WIDTH, |
| label="Width", |
| ) |
|
|
| num_frames = gr.Slider( |
| minimum=1, |
| maximum=121, |
| step=1, |
| value=DEFAULT_NUM_FRAMES, |
| label="Frames", |
| ) |
|
|
| seed = gr.Number( |
| label="Seed", |
| value=-1, |
| precision=0, |
| ) |
|
|
| resolution = gr.Dropdown( |
| label="Resolution", |
| choices=VIDEO_RESOLUTION_CHOICES, |
| value=DEFAULT_RESOLUTION, |
| ) |
|
|
| validation_num_timesteps = gr.Slider( |
| minimum=1, |
| maximum=100, |
| step=1, |
| value=DEFAULT_TIMESTEPS, |
| label="Timesteps", |
| ) |
|
|
| validation_timestep_shift = gr.Number( |
| label="Timestep Shift", |
| value=DEFAULT_TIMESTEP_SHIFT, |
| ) |
|
|
| cfg_text_scale = gr.Number( |
| label="CFG Text Scale", |
| value=DEFAULT_CFG_TEXT_SCALE, |
| ) |
|
|
| run_button = gr.Button( |
| "Run", |
| variant="primary", |
| ) |
|
|
| with gr.Column(): |
|
|
| output_video = gr.Video( |
| label="Generated Video" |
| ) |
|
|
| output_image = gr.Image( |
| label="Generated Image" |
| ) |
|
|
| output_text = gr.Textbox( |
| label="Text Output", |
| lines=8, |
| ) |
|
|
| status = gr.Markdown() |
|
|
| logs = gr.Textbox( |
| label="Logs", |
| lines=15, |
| ) |
|
|
| run_button.click( |
| fn=run_task, |
| inputs=[ |
| task, |
| prompt, |
| input_image, |
| input_video, |
| question, |
| height, |
| width, |
| num_frames, |
| seed, |
| resolution, |
| validation_num_timesteps, |
| validation_timestep_shift, |
| cfg_text_scale, |
| ], |
| outputs=[ |
| output_video, |
| output_image, |
| output_text, |
| status, |
| logs, |
| ], |
| ) |
|
|
| |
| |
| |
|
|
| demo.queue( |
| max_size=4, |
| default_concurrency_limit=1, |
| ).launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| ) |