Lance / appold2
Nayefleb's picture
Rename app.py to appold2
74a788c verified
# =========================================================
# ZERO GPU PATCHED VERSION FOR HUGGING FACE SPACES
# Based on original Lance app.py
# =========================================================
from __future__ import annotations
import concurrent.futures
import gc
import json
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import random
import threading
import time
import traceback
from collections import deque
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Optional
import gradio as gr
import spaces
import torch
from huggingface_hub import login
from safetensors.torch import load_file
from transformers import set_seed
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLVisionConfig,
)
from common.utils.logging import get_logger
from common.utils.misc import AutoEncoderParams, tuple_mul
from config.config_factory import (
DataArguments,
InferenceArguments,
ModelArguments,
)
from data.data_utils import add_special_tokens
from data.dataset_base import DataConfig, simple_custom_collate
from data.datasets_custom import ValidationDataset
from inference_lance import (
PROMPT_JSON_FILENAME,
apply_inference_defaults,
clean_memory,
init_from_model_path_if_needed,
save_prompt_results,
validate_on_fixed_batch,
)
from modeling.lance import Lance, LanceConfig, Qwen2ForCausalLM
from modeling.qwen2 import Qwen2Tokenizer
from modeling.qwen2.modeling_qwen2 import Qwen2Config
from modeling.vae.wan.model import WanVideoVAE
from modeling.vit.qwen2_5_vl_vit import (
Qwen2_5_VisionTransformerPretrainedModel,
)
# =========================================================
# HF TOKEN
# =========================================================
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
# =========================================================
# PERFORMANCE SETTINGS
# =========================================================
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# =========================================================
# PATHS
# =========================================================
# =========================================================
# PATHS
# =========================================================
REPO_ROOT = Path(__file__).resolve().parent
GRADIO_TMP_ROOT = REPO_ROOT / "tmps" / "gradio_t2v_v2t"
TMP_INPUT_DIR = GRADIO_TMP_ROOT / "inputs"
RESULTS_ROOT = GRADIO_TMP_ROOT / "results"
GLOBAL_RECORDS_FILE = GRADIO_TMP_ROOT / "generation_records.jsonl"
RUN_RECORD_FILENAME = "generation_record.json"
# =========================================================
# MODEL DOWNLOAD
# =========================================================
from huggingface_hub import snapshot_download
MODEL_REPO = "bytedance-research/Lance"
MODEL_CACHE_DIR = REPO_ROOT / "downloads"
snapshot_download(
repo_id=MODEL_REPO,
local_dir=str(MODEL_CACHE_DIR),
local_dir_use_symlinks=False,
token=HF_TOKEN,
resume_download=True,
)
DEFAULT_MODEL_PATH = str(
MODEL_CACHE_DIR / "Lance_3B_Video"
)
print("DEFAULT_MODEL_PATH =", DEFAULT_MODEL_PATH)
print("FILES =", os.listdir(DEFAULT_MODEL_PATH))
DEFAULT_VIT_TYPE = "qwen_2_5_vl_original"
DEFAULT_TASK = "t2v"
DEFAULT_TIMESTEPS = 30
DEFAULT_TIMESTEP_SHIFT = 3.5
DEFAULT_CFG_TEXT_SCALE = 4.0
DEFAULT_RESOLUTION = "video_480p"
DEFAULT_BASIC_SEED = -1
DEFAULT_HEIGHT = 480
DEFAULT_WIDTH = 848
DEFAULT_NUM_FRAMES = 50
DEFAULT_QUEUE_SIZE = 4
USE_KVCACHE = True
TEXT_TEMPLATE = True
RECORD_WRITE_LOCK = threading.Lock()
TASK_T2V = "t2v"
TASK_V2T = "v2t"
TASK_X2T_VIDEO = "x2t_video"
TASK_CHOICES = [TASK_T2V, TASK_V2T]
VIDEO_RESOLUTION_CHOICES = [
"video_192p",
"video_360p",
"video_480p",
]
V2T_SYSTEM_PROMPT = (
"Watch the video carefully and answer the question."
)
# =========================================================
# HELPERS
# =========================================================
def ensure_dirs():
TMP_INPUT_DIR.mkdir(parents=True, exist_ok=True)
RESULTS_ROOT.mkdir(parents=True, exist_ok=True)
def normalize_seed(seed: int):
if seed == -1:
return random.randint(0, 2**31 - 1)
return seed
def normalize_task(task: str):
task = (task or DEFAULT_TASK).strip().lower()
if task == TASK_V2T:
return TASK_X2T_VIDEO
return task
# =========================================================
# PIPELINE
# =========================================================
class LanceT2VV2TPipeline:
def __init__(self):
self.initialized = False
self.logger = get_logger("lance_zerogpu")
self.model = None
self.vae_model = None
self.vae_config = None
self.tokenizer = None
self.new_token_ids = None
self.image_token_id = None
self.base_model_args = None
self.base_data_args = None
self.base_inference_args = None
self.lock = threading.Lock()
def initialize(self):
with self.lock:
if self.initialized:
return
ensure_dirs()
if not torch.cuda.is_available():
raise RuntimeError("CUDA unavailable")
model_args = ModelArguments(
model_path=str(DEFAULT_MODEL_PATH),
vit_type=DEFAULT_VIT_TYPE,
llm_qk_norm=True,
llm_qk_norm_und=True,
llm_qk_norm_gen=True,
tie_word_embeddings=False,
max_num_frames=121,
max_latent_size=64,
latent_patch_size=[1,1,1],
)
data_args = DataArguments()
inference_args = InferenceArguments(
validation_num_timesteps=DEFAULT_TIMESTEPS,
validation_timestep_shift=DEFAULT_TIMESTEP_SHIFT,
copy_init_moe=True,
visual_und=True,
visual_gen=True,
vae_model_type="wan",
apply_qwen_2_5_vl_pos_emb=True,
apply_chat_template=False,
cfg_type=0,
validation_data_seed=42,
video_height=DEFAULT_HEIGHT,
video_width=DEFAULT_WIDTH,
num_frames=DEFAULT_NUM_FRAMES,
task=DEFAULT_TASK,
save_path_gen=str(RESULTS_ROOT),
resolution=DEFAULT_RESOLUTION,
text_template=TEXT_TEMPLATE,
use_KVcache=USE_KVCACHE,
)
apply_inference_defaults(
model_args,
data_args,
inference_args,
)
set_seed(inference_args.global_seed)
llm_config = Qwen2Config.from_json_file(
str(Path(model_args.model_path) / "llm_config.json")
)
language_model = Qwen2ForCausalLM(llm_config)
vit_config = Qwen2_5_VLVisionConfig.from_pretrained(
model_args.vit_path
)
vit_model = Qwen2_5_VisionTransformerPretrainedModel(
vit_config
)
vit_weights = load_file(
str(Path(model_args.vit_path) / "vit.safetensors")
)
vit_model.load_state_dict(vit_weights, strict=True)
clean_memory(vit_weights)
vae_model = WanVideoVAE()
vae_config = deepcopy(vae_model.vae_config)
config = LanceConfig(
visual_gen=True,
visual_und=True,
llm_config=llm_config,
vit_config=vit_config,
vae_config=vae_config,
latent_patch_size=model_args.latent_patch_size,
max_num_frames=model_args.max_num_frames,
max_latent_size=model_args.max_latent_size,
vit_max_num_patch_per_side=model_args.vit_max_num_patch_per_side,
connector_act=model_args.connector_act,
interpolate_pos=model_args.interpolate_pos,
timestep_shift=inference_args.timestep_shift,
)
model = Lance(
language_model=language_model,
vit_model=vit_model,
vit_type=model_args.vit_type,
config=config,
training_args=inference_args,
)
model = model.half().to("cuda")
tokenizer = Qwen2Tokenizer.from_pretrained(
model_args.model_path
)
tokenizer, new_token_ids, num_new_tokens = (
add_special_tokens(tokenizer)
)
init_from_model_path_if_needed(
model,
model_args,
)
image_token_id = model.language_model.config.video_token_id
model.eval()
self.model = model
self.vae_model = vae_model
self.vae_config = vae_config
self.tokenizer = tokenizer
self.new_token_ids = new_token_ids
self.image_token_id = image_token_id
self.base_model_args = model_args
self.base_data_args = data_args
self.base_inference_args = inference_args
self.initialized = True
print("Lance initialized successfully")
def generate(
self,
task,
prompt,
input_video,
question,
height,
width,
num_frames,
seed,
resolution,
validation_num_timesteps,
validation_timestep_shift,
cfg_text_scale,
):
self.initialize()
task = normalize_task(task)
actual_seed = normalize_seed(int(seed))
try:
save_dir = RESULTS_ROOT / str(time.time())
save_dir.mkdir(parents=True, exist_ok=True)
inference_args = deepcopy(
self.base_inference_args
)
inference_args.video_height = int(height)
inference_args.video_width = int(width)
inference_args.num_frames = int(num_frames)
inference_args.validation_num_timesteps = (
validation_num_timesteps
)
inference_args.validation_timestep_shift = (
validation_timestep_shift
)
inference_args.task = task
prompt_file = TMP_INPUT_DIR / "prompt.json"
if task == TASK_T2V:
payload = {
"000000.mp4": prompt
}
else:
payload = {
"000000": {
"interleave_array": [
input_video,
[
V2T_SYSTEM_PROMPT,
question,
""
]
],
"element_dtype_array": [
"video",
"text"
],
"istarget_in_interleave": [
0,
1
],
}
}
with open(prompt_file, "w") as f:
json.dump(payload, f)
dataset_config = DataConfig.from_yaml(
str(prompt_file)
)
val_dataset = ValidationDataset(
jsonl_path=str(prompt_file),
tokenizer=self.tokenizer,
data_args=self.base_data_args,
model_args=self.base_model_args,
training_args=inference_args,
new_token_ids=self.new_token_ids,
dataset_config=dataset_config,
local_rank=0,
world_size=1,
)
val_data_cpu = simple_custom_collate(
[val_dataset[0]]
)
validate_on_fixed_batch(
fsdp_model=self.model,
vae_model=self.vae_model,
tokenizer=self.tokenizer,
val_data_cpu=val_data_cpu,
training_args=inference_args,
model_args=self.base_model_args,
inference_args=inference_args,
new_token_ids=self.new_token_ids,
image_token_id=self.image_token_id,
device="cuda",
save_source_video=False,
save_path_gen=str(save_dir),
save_path_gt="",
)
clean_memory()
gc.collect()
torch.cuda.empty_cache()
videos = list(save_dir.glob("*.mp4"))
if task == TASK_T2V:
if len(videos) == 0:
return None, "", "No video generated", ""
return (
str(videos[0]),
"",
"Generation complete",
"Success",
)
return (
None,
"Understanding complete",
"Success",
"",
)
except Exception:
err = traceback.format_exc()
print(err)
return None, "", "Generation failed", err
# =========================================================
# GLOBAL PIPELINE
# =========================================================
PIPELINE = LanceT2VV2TPipeline()
# =========================================================
# SPACES GPU FUNCTION
# =========================================================
@spaces.GPU(duration=300)
def run_task(
task,
prompt,
input_video,
question,
height,
width,
num_frames,
seed,
resolution,
validation_num_timesteps,
validation_timestep_shift,
cfg_text_scale,
):
return PIPELINE.generate(
task=task,
prompt=prompt,
input_video=input_video,
question=question,
height=height,
width=width,
num_frames=num_frames,
seed=seed,
resolution=resolution,
validation_num_timesteps=validation_num_timesteps,
validation_timestep_shift=validation_timestep_shift,
cfg_text_scale=cfg_text_scale,
)
# =========================================================
# UI
# =========================================================
with gr.Blocks(title="Lance ZeroGPU") as demo:
gr.Markdown("# Lance T2V/V2T ZeroGPU")
with gr.Row():
with gr.Column():
task = gr.Dropdown(
label="Task",
choices=TASK_CHOICES,
value=DEFAULT_TASK,
)
prompt = gr.Textbox(
label="Prompt",
lines=6,
)
input_video = gr.Video(
label="Input Video",
)
question = gr.Textbox(
label="Question",
lines=3,
)
height = gr.Slider(
minimum=192,
maximum=1024,
step=16,
value=DEFAULT_HEIGHT,
label="Height",
)
width = gr.Slider(
minimum=192,
maximum=1024,
step=16,
value=DEFAULT_WIDTH,
label="Width",
)
num_frames = gr.Slider(
minimum=1,
maximum=121,
step=1,
value=DEFAULT_NUM_FRAMES,
label="Frames",
)
seed = gr.Number(
label="Seed",
value=-1,
precision=0,
)
resolution = gr.Dropdown(
label="Resolution",
choices=VIDEO_RESOLUTION_CHOICES,
value=DEFAULT_RESOLUTION,
)
validation_num_timesteps = gr.Slider(
minimum=1,
maximum=100,
step=1,
value=DEFAULT_TIMESTEPS,
label="Timesteps",
)
validation_timestep_shift = gr.Number(
label="Timestep Shift",
value=DEFAULT_TIMESTEP_SHIFT,
)
cfg_text_scale = gr.Number(
label="CFG Text Scale",
value=DEFAULT_CFG_TEXT_SCALE,
)
run_button = gr.Button(
"Run",
variant="primary",
)
with gr.Column():
output_video = gr.Video(
label="Generated Video"
)
output_text = gr.Textbox(
label="Text Output",
lines=8,
)
status = gr.Markdown()
logs = gr.Textbox(
label="Logs",
lines=20,
)
run_button.click(
fn=run_task,
inputs=[
task,
prompt,
input_video,
question,
height,
width,
num_frames,
seed,
resolution,
validation_num_timesteps,
validation_timestep_shift,
cfg_text_scale,
],
outputs=[
output_video,
output_text,
status,
logs,
],
)
# =========================================================
# LAUNCH
# =========================================================
demo.queue(
max_size=4,
default_concurrency_limit=1,
).launch(
server_name="0.0.0.0",
server_port=7860,
)