Lance / app.py
Nayefleb's picture
Update app.py
bc4cd2c verified
# =========================================================
# ZERO GPU PATCHED + ALL TASKS ENABLED
# Qwen2.5-VL FIXED VERSION
# Hugging Face Spaces Compatible
# =========================================================
from __future__ import annotations
import gc
import json
import os
import random
import threading
import time
import traceback
from copy import deepcopy
from pathlib import Path
import gradio as gr
import spaces
import torch
# =========================================================
# ENV
# =========================================================
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["TRANSFORMERS_NO_FLASH_ATTN"] = "1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
# =========================================================
# LOGIN
# =========================================================
from huggingface_hub import (
login,
snapshot_download,
hf_hub_download,
)
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
# =========================================================
# IMPORTS
# =========================================================
from safetensors.torch import load_file
from transformers import (
set_seed,
AutoConfig,
)
from transformers.utils import is_flash_attn_2_available
from common.utils.logging import get_logger
from config.config_factory import (
DataArguments,
InferenceArguments,
ModelArguments,
)
from data.data_utils import add_special_tokens
from data.dataset_base import DataConfig, simple_custom_collate
from data.datasets_custom import ValidationDataset
from inference_lance import (
apply_inference_defaults,
clean_memory,
init_from_model_path_if_needed,
validate_on_fixed_batch,
)
from modeling.lance import (
Lance,
LanceConfig,
Qwen2ForCausalLM,
)
from modeling.qwen2 import Qwen2Tokenizer
from modeling.qwen2.modeling_qwen2 import Qwen2Config
from modeling.vae.wan.model import WanVideoVAE
from modeling.vit.qwen2_5_vl_vit import (
Qwen2_5_VisionTransformerPretrainedModel,
)
print("FlashAttention2 available:", is_flash_attn_2_available())
# =========================================================
# PERFORMANCE
# =========================================================
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# =========================================================
# PATHS
# =========================================================
REPO_ROOT = Path(__file__).resolve().parent
TMP_ROOT = REPO_ROOT / "tmps"
TMP_ROOT.mkdir(parents=True, exist_ok=True)
TMP_INPUT_DIR = TMP_ROOT / "inputs"
TMP_INPUT_DIR.mkdir(parents=True, exist_ok=True)
RESULTS_ROOT = TMP_ROOT / "results"
RESULTS_ROOT.mkdir(parents=True, exist_ok=True)
# =========================================================
# MODEL DOWNLOAD
# =========================================================
MODEL_REPO = "bytedance-research/Lance"
MODEL_CACHE_DIR = REPO_ROOT / "downloads"
snapshot_download(
repo_id=MODEL_REPO,
allow_patterns=[
"Lance_3B_Video/*",
],
local_dir=str(MODEL_CACHE_DIR),
local_dir_use_symlinks=False,
token=HF_TOKEN,
)
DEFAULT_MODEL_PATH = str(
MODEL_CACHE_DIR / "Lance_3B_Video"
)
print("DEFAULT_MODEL_PATH =", DEFAULT_MODEL_PATH)
# =========================================================
# QWEN VL
# =========================================================
QWEN_VL_REPO = "Qwen/Qwen2.5-VL-7B-Instruct"
# =========================================================
# DEFAULTS
# =========================================================
DEFAULT_TIMESTEPS = 30
DEFAULT_TIMESTEP_SHIFT = 3.5
DEFAULT_CFG_TEXT_SCALE = 4.0
DEFAULT_HEIGHT = 480
DEFAULT_WIDTH = 848
DEFAULT_NUM_FRAMES = 50
DEFAULT_RESOLUTION = "video_480p"
DEFAULT_TASK = "t2v"
DEFAULT_SEED = -1
DEFAULT_VIT_TYPE = "qwen_2_5_vl_original"
# =========================================================
# TASKS
# =========================================================
TASK_T2V = "t2v"
TASK_T2I = "t2i"
TASK_IMAGE_EDIT = "image_edit"
TASK_VIDEO_EDIT = "video_edit"
TASK_X2T_IMAGE = "x2t_image"
TASK_X2T_VIDEO = "x2t_video"
TASK_CHOICES = [
TASK_T2V,
TASK_T2I,
TASK_IMAGE_EDIT,
TASK_VIDEO_EDIT,
TASK_X2T_IMAGE,
TASK_X2T_VIDEO,
]
VIDEO_RESOLUTION_CHOICES = [
"video_192p",
"video_360p",
"video_480p",
]
# =========================================================
# HELPERS
# =========================================================
def normalize_seed(seed: int):
if seed == -1:
return random.randint(0, 2**31 - 1)
return seed
def normalize_task(task: str):
task = (task or DEFAULT_TASK).strip().lower()
valid = {
TASK_T2V,
TASK_T2I,
TASK_IMAGE_EDIT,
TASK_VIDEO_EDIT,
TASK_X2T_IMAGE,
TASK_X2T_VIDEO,
}
if task not in valid:
return DEFAULT_TASK
return task
# =========================================================
# PIPELINE
# =========================================================
class LancePipeline:
def __init__(self):
self.initialized = False
self.lock = threading.Lock()
self.logger = get_logger("lance")
def initialize(self):
with self.lock:
if self.initialized:
return
if not torch.cuda.is_available():
raise RuntimeError("CUDA unavailable")
print("Initializing Lance...")
model_args = ModelArguments(
model_path=DEFAULT_MODEL_PATH,
vit_type=DEFAULT_VIT_TYPE,
llm_qk_norm=True,
llm_qk_norm_und=True,
llm_qk_norm_gen=True,
tie_word_embeddings=False,
max_num_frames=121,
max_latent_size=64,
latent_patch_size=[1,1,1],
)
# =====================================================
# IMPORTANT FIX
# =====================================================
model_args.vit_path = QWEN_VL_REPO
data_args = DataArguments()
inference_args = InferenceArguments(
validation_num_timesteps=DEFAULT_TIMESTEPS,
validation_timestep_shift=DEFAULT_TIMESTEP_SHIFT,
copy_init_moe=True,
visual_und=True,
visual_gen=True,
vae_model_type="wan",
apply_qwen_2_5_vl_pos_emb=True,
apply_chat_template=False,
cfg_type=0,
validation_data_seed=42,
video_height=DEFAULT_HEIGHT,
video_width=DEFAULT_WIDTH,
num_frames=DEFAULT_NUM_FRAMES,
task=DEFAULT_TASK,
save_path_gen=str(RESULTS_ROOT),
resolution=DEFAULT_RESOLUTION,
text_template=True,
use_KVcache=True,
)
apply_inference_defaults(
model_args,
data_args,
inference_args,
)
set_seed(42)
# =====================================================
# LLM
# =====================================================
llm_config = Qwen2Config.from_json_file(
str(Path(model_args.model_path) / "llm_config.json")
)
language_model = Qwen2ForCausalLM(llm_config)
# =====================================================
# FIXED QWEN2.5-VL LOADING
# =====================================================
print("Loading Qwen2.5-VL config...")
full_qwen_config = AutoConfig.from_pretrained(
QWEN_VL_REPO,
token=HF_TOKEN,
trust_remote_code=True,
)
vit_config = full_qwen_config.vision_config
vit_config._attn_implementation = "eager"
print("Creating vision transformer...")
vit_model = Qwen2_5_VisionTransformerPretrainedModel(
vit_config
)
# =====================================================
# LOAD WEIGHTS
# =====================================================
print("Downloading Qwen weights...")
vit_weights_path = hf_hub_download(
repo_id=QWEN_VL_REPO,
filename="model.safetensors",
token=HF_TOKEN,
)
print("Loading VIT weights...")
vit_weights = load_file(vit_weights_path)
missing, unexpected = vit_model.load_state_dict(
vit_weights,
strict=False,
)
print("Missing keys:", len(missing))
print("Unexpected keys:", len(unexpected))
clean_memory(vit_weights)
# =====================================================
# VAE
# =====================================================
vae_model = WanVideoVAE()
vae_config = deepcopy(vae_model.vae_config)
# =====================================================
# CONFIG
# =====================================================
config = LanceConfig(
visual_gen=True,
visual_und=True,
llm_config=llm_config,
vit_config=vit_config,
vae_config=vae_config,
latent_patch_size=model_args.latent_patch_size,
max_num_frames=model_args.max_num_frames,
max_latent_size=model_args.max_latent_size,
vit_max_num_patch_per_side=model_args.vit_max_num_patch_per_side,
connector_act=model_args.connector_act,
interpolate_pos=model_args.interpolate_pos,
timestep_shift=inference_args.timestep_shift,
)
model = Lance(
language_model=language_model,
vit_model=vit_model,
vit_type=model_args.vit_type,
config=config,
training_args=inference_args,
)
print("Moving model to CUDA...")
model = model.to(
device="cuda",
dtype=torch.bfloat16,
)
tokenizer = Qwen2Tokenizer.from_pretrained(
model_args.model_path
)
tokenizer, new_token_ids, _ = add_special_tokens(
tokenizer
)
init_from_model_path_if_needed(
model,
model_args,
)
image_token_id = model.language_model.config.video_token_id
model.eval()
self.model = model
self.vae_model = vae_model
self.vae_config = vae_config
self.tokenizer = tokenizer
self.new_token_ids = new_token_ids
self.image_token_id = image_token_id
self.base_model_args = model_args
self.base_data_args = data_args
self.base_inference_args = inference_args
self.initialized = True
print("Lance initialized successfully")
# =========================================================
# GENERATE
# =========================================================
def generate(
self,
task,
prompt,
input_image,
input_video,
question,
height,
width,
num_frames,
seed,
resolution,
validation_num_timesteps,
validation_timestep_shift,
cfg_text_scale,
):
try:
task = normalize_task(task)
actual_seed = normalize_seed(int(seed))
set_seed(actual_seed)
save_dir = RESULTS_ROOT / str(time.time())
save_dir.mkdir(parents=True, exist_ok=True)
inference_args = deepcopy(
self.base_inference_args
)
inference_args.video_height = int(height)
inference_args.video_width = int(width)
inference_args.num_frames = int(num_frames)
inference_args.validation_num_timesteps = (
validation_num_timesteps
)
inference_args.validation_timestep_shift = (
validation_timestep_shift
)
inference_args.task = task
prompt_file = TMP_INPUT_DIR / "prompt.json"
# =====================================================
# PAYLOADS
# =====================================================
if task == TASK_T2V:
payload = {
"000000.mp4": prompt
}
elif task == TASK_T2I:
payload = {
"000000.png": prompt
}
elif task == TASK_IMAGE_EDIT:
payload = {
"000000": {
"interleave_array": [
input_image,
[prompt, ""]
],
"element_dtype_array": [
"image",
"text"
],
"istarget_in_interleave": [
0,
1
],
}
}
elif task == TASK_VIDEO_EDIT:
payload = {
"000000": {
"interleave_array": [
input_video,
[prompt, ""]
],
"element_dtype_array": [
"video",
"text"
],
"istarget_in_interleave": [
0,
1
],
}
}
elif task == TASK_X2T_IMAGE:
payload = {
"000000": {
"interleave_array": [
input_image,
[
"Describe the image",
question,
""
]
],
"element_dtype_array": [
"image",
"text"
],
"istarget_in_interleave": [
0,
1
],
}
}
elif task == TASK_X2T_VIDEO:
payload = {
"000000": {
"interleave_array": [
input_video,
[
"Describe the video",
question,
""
]
],
"element_dtype_array": [
"video",
"text"
],
"istarget_in_interleave": [
0,
1
],
}
}
else:
return (
None,
None,
"",
"Invalid task",
"",
)
with open(prompt_file, "w") as f:
json.dump(payload, f)
dataset_config = DataConfig.from_yaml(
str(prompt_file)
)
val_dataset = ValidationDataset(
jsonl_path=str(prompt_file),
tokenizer=self.tokenizer,
data_args=self.base_data_args,
model_args=self.base_model_args,
training_args=inference_args,
new_token_ids=self.new_token_ids,
dataset_config=dataset_config,
local_rank=0,
world_size=1,
)
val_data_cpu = simple_custom_collate(
[val_dataset[0]]
)
validate_on_fixed_batch(
fsdp_model=self.model,
vae_model=self.vae_model,
tokenizer=self.tokenizer,
val_data_cpu=val_data_cpu,
training_args=inference_args,
model_args=self.base_model_args,
inference_args=inference_args,
new_token_ids=self.new_token_ids,
image_token_id=self.image_token_id,
device="cuda",
save_source_video=False,
save_path_gen=str(save_dir),
save_path_gt="",
)
clean_memory()
gc.collect()
torch.cuda.empty_cache()
videos = list(save_dir.glob("*.mp4"))
images = list(save_dir.glob("*.png"))
if len(videos) > 0:
return (
str(videos[0]),
None,
"",
"Success",
"",
)
if len(images) > 0:
return (
None,
str(images[0]),
"",
"Success",
"",
)
if task in [TASK_X2T_IMAGE, TASK_X2T_VIDEO]:
return (
None,
None,
"Understanding complete",
"Success",
"",
)
return (
None,
None,
"",
"No output generated",
"",
)
except Exception as e:
traceback.print_exc()
return (
None,
None,
"",
f"ERROR: {str(e)}",
traceback.format_exc(),
)
# =========================================================
# GLOBAL
# =========================================================
PIPELINE = LancePipeline()
# =========================================================
# GPU RUN
# =========================================================
@spaces.GPU(duration=300)
def run_task(
task,
prompt,
input_image,
input_video,
question,
height,
width,
num_frames,
seed,
resolution,
validation_num_timesteps,
validation_timestep_shift,
cfg_text_scale,
):
PIPELINE.initialize()
return PIPELINE.generate(
task=task,
prompt=prompt,
input_image=input_image,
input_video=input_video,
question=question,
height=height,
width=width,
num_frames=num_frames,
seed=seed,
resolution=resolution,
validation_num_timesteps=validation_num_timesteps,
validation_timestep_shift=validation_timestep_shift,
cfg_text_scale=cfg_text_scale,
)
# =========================================================
# UI
# =========================================================
with gr.Blocks(title="Lance ZeroGPU") as demo:
gr.Markdown("# Lance Multi-Task ZeroGPU")
with gr.Row():
with gr.Column():
task = gr.Dropdown(
label="Task",
choices=TASK_CHOICES,
value=DEFAULT_TASK,
)
prompt = gr.Textbox(
label="Prompt",
lines=5,
)
input_image = gr.Image(
label="Input Image",
type="filepath",
)
input_video = gr.Video(
label="Input Video",
)
question = gr.Textbox(
label="Question",
lines=3,
)
height = gr.Slider(
minimum=192,
maximum=1024,
step=16,
value=DEFAULT_HEIGHT,
label="Height",
)
width = gr.Slider(
minimum=192,
maximum=1024,
step=16,
value=DEFAULT_WIDTH,
label="Width",
)
num_frames = gr.Slider(
minimum=1,
maximum=121,
step=1,
value=DEFAULT_NUM_FRAMES,
label="Frames",
)
seed = gr.Number(
label="Seed",
value=-1,
precision=0,
)
resolution = gr.Dropdown(
label="Resolution",
choices=VIDEO_RESOLUTION_CHOICES,
value=DEFAULT_RESOLUTION,
)
validation_num_timesteps = gr.Slider(
minimum=1,
maximum=100,
step=1,
value=DEFAULT_TIMESTEPS,
label="Timesteps",
)
validation_timestep_shift = gr.Number(
label="Timestep Shift",
value=DEFAULT_TIMESTEP_SHIFT,
)
cfg_text_scale = gr.Number(
label="CFG Text Scale",
value=DEFAULT_CFG_TEXT_SCALE,
)
run_button = gr.Button(
"Run",
variant="primary",
)
with gr.Column():
output_video = gr.Video(
label="Generated Video"
)
output_image = gr.Image(
label="Generated Image"
)
output_text = gr.Textbox(
label="Text Output",
lines=8,
)
status = gr.Markdown()
logs = gr.Textbox(
label="Logs",
lines=15,
)
run_button.click(
fn=run_task,
inputs=[
task,
prompt,
input_image,
input_video,
question,
height,
width,
num_frames,
seed,
resolution,
validation_num_timesteps,
validation_timestep_shift,
cfg_text_scale,
],
outputs=[
output_video,
output_image,
output_text,
status,
logs,
],
)
# =========================================================
# LAUNCH
# =========================================================
demo.queue(
max_size=4,
default_concurrency_limit=1,
).launch(
server_name="0.0.0.0",
server_port=7860,
)