Lance / appold4
Nayefleb's picture
Rename app.py to appold4
bfec211 verified
# =========================================================
# ZERO GPU PATCHED + ALL TASKS ENABLED
# Hugging Face Spaces Compatible
# =========================================================
from __future__ import annotations
import gc
import json
import os
import random
import threading
import time
import traceback
from copy import deepcopy
from pathlib import Path
import gradio as gr
import spaces
import torch
# =========================================================
# ENV
# =========================================================
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["TRANSFORMERS_NO_FLASH_ATTN"] = "1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
# =========================================================
# LOGIN
# =========================================================
from huggingface_hub import login, snapshot_download
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
# =========================================================
# IMPORTS
# =========================================================
from safetensors.torch import load_file
from transformers import set_seed
from transformers.utils import is_flash_attn_2_available
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLVisionConfig,
)
from common.utils.logging import get_logger
from config.config_factory import (
DataArguments,
InferenceArguments,
ModelArguments,
)
from data.data_utils import add_special_tokens
from data.dataset_base import DataConfig, simple_custom_collate
from data.datasets_custom import ValidationDataset
from inference_lance import (
apply_inference_defaults,
clean_memory,
init_from_model_path_if_needed,
validate_on_fixed_batch,
)
from modeling.lance import (
Lance,
LanceConfig,
Qwen2ForCausalLM,
)
from modeling.qwen2 import Qwen2Tokenizer
from modeling.qwen2.modeling_qwen2 import Qwen2Config
from modeling.vae.wan.model import WanVideoVAE
from modeling.vit.qwen2_5_vl_vit import (
Qwen2_5_VisionTransformerPretrainedModel,
)
print("FlashAttention2 available:", is_flash_attn_2_available())
# =========================================================
# PERFORMANCE
# =========================================================
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# =========================================================
# PATHS
# =========================================================
REPO_ROOT = Path(__file__).resolve().parent
TMP_ROOT = REPO_ROOT / "tmps"
TMP_ROOT.mkdir(parents=True, exist_ok=True)
TMP_INPUT_DIR = TMP_ROOT / "inputs"
TMP_INPUT_DIR.mkdir(parents=True, exist_ok=True)
RESULTS_ROOT = TMP_ROOT / "results"
RESULTS_ROOT.mkdir(parents=True, exist_ok=True)
# =========================================================
# MODEL DOWNLOAD
# =========================================================
MODEL_REPO = "bytedance-research/Lance"
MODEL_CACHE_DIR = REPO_ROOT / "downloads"
snapshot_download(
repo_id=MODEL_REPO,
allow_patterns=[
"Lance_3B_Video/*",
],
local_dir=str(MODEL_CACHE_DIR),
local_dir_use_symlinks=False,
token=HF_TOKEN,
)
DEFAULT_MODEL_PATH = str(
MODEL_CACHE_DIR / "Lance_3B_Video"
)
print("DEFAULT_MODEL_PATH =", DEFAULT_MODEL_PATH)
# =========================================================
# DEFAULTS
# =========================================================
DEFAULT_TIMESTEPS = 30
DEFAULT_TIMESTEP_SHIFT = 3.5
DEFAULT_CFG_TEXT_SCALE = 4.0
DEFAULT_HEIGHT = 480
DEFAULT_WIDTH = 848
DEFAULT_NUM_FRAMES = 50
DEFAULT_RESOLUTION = "video_480p"
DEFAULT_TASK = "t2v"
DEFAULT_SEED = -1
DEFAULT_VIT_TYPE = "qwen_2_5_vl_original"
# =========================================================
# TASKS
# =========================================================
TASK_T2V = "t2v"
TASK_T2I = "t2i"
TASK_IMAGE_EDIT = "image_edit"
TASK_VIDEO_EDIT = "video_edit"
TASK_X2T_IMAGE = "x2t_image"
TASK_X2T_VIDEO = "x2t_video"
TASK_CHOICES = [
TASK_T2V,
TASK_T2I,
TASK_IMAGE_EDIT,
TASK_VIDEO_EDIT,
TASK_X2T_IMAGE,
TASK_X2T_VIDEO,
]
VIDEO_RESOLUTION_CHOICES = [
"video_192p",
"video_360p",
"video_480p",
]
# =========================================================
# HELPERS
# =========================================================
def normalize_seed(seed: int):
if seed == -1:
return random.randint(0, 2**31 - 1)
return seed
def normalize_task(task: str):
task = (task or DEFAULT_TASK).strip().lower()
valid = {
TASK_T2V,
TASK_T2I,
TASK_IMAGE_EDIT,
TASK_VIDEO_EDIT,
TASK_X2T_IMAGE,
TASK_X2T_VIDEO,
}
if task not in valid:
return DEFAULT_TASK
return task
# =========================================================
# PIPELINE
# =========================================================
class LancePipeline:
def __init__(self):
self.initialized = False
self.lock = threading.Lock()
self.logger = get_logger("lance")
def initialize(self):
with self.lock:
if self.initialized:
return
if not torch.cuda.is_available():
raise RuntimeError("CUDA unavailable")
model_args = ModelArguments(
model_path=DEFAULT_MODEL_PATH,
vit_type=DEFAULT_VIT_TYPE,
llm_qk_norm=True,
llm_qk_norm_und=True,
llm_qk_norm_gen=True,
tie_word_embeddings=False,
max_num_frames=121,
max_latent_size=64,
latent_patch_size=[1,1,1],
)
data_args = DataArguments()
inference_args = InferenceArguments(
validation_num_timesteps=DEFAULT_TIMESTEPS,
validation_timestep_shift=DEFAULT_TIMESTEP_SHIFT,
copy_init_moe=True,
visual_und=True,
visual_gen=True,
vae_model_type="wan",
apply_qwen_2_5_vl_pos_emb=True,
apply_chat_template=False,
cfg_type=0,
validation_data_seed=42,
video_height=DEFAULT_HEIGHT,
video_width=DEFAULT_WIDTH,
num_frames=DEFAULT_NUM_FRAMES,
task=DEFAULT_TASK,
save_path_gen=str(RESULTS_ROOT),
resolution=DEFAULT_RESOLUTION,
text_template=True,
use_KVcache=True,
)
apply_inference_defaults(
model_args,
data_args,
inference_args,
)
set_seed(42)
llm_config = Qwen2Config.from_json_file(
str(Path(model_args.model_path) / "llm_config.json")
)
language_model = Qwen2ForCausalLM(llm_config)
vit_config = Qwen2_5_VLVisionConfig.from_pretrained(
model_args.vit_path
)
vit_config._attn_implementation = "eager"
vit_model = Qwen2_5_VisionTransformerPretrainedModel(
vit_config
)
vit_weights = load_file(
str(Path(model_args.vit_path) / "vit.safetensors")
)
vit_model.load_state_dict(vit_weights, strict=True)
clean_memory(vit_weights)
vae_model = WanVideoVAE()
vae_config = deepcopy(vae_model.vae_config)
config = LanceConfig(
visual_gen=True,
visual_und=True,
llm_config=llm_config,
vit_config=vit_config,
vae_config=vae_config,
latent_patch_size=model_args.latent_patch_size,
max_num_frames=model_args.max_num_frames,
max_latent_size=model_args.max_latent_size,
vit_max_num_patch_per_side=model_args.vit_max_num_patch_per_side,
connector_act=model_args.connector_act,
interpolate_pos=model_args.interpolate_pos,
timestep_shift=inference_args.timestep_shift,
)
model = Lance(
language_model=language_model,
vit_model=vit_model,
vit_type=model_args.vit_type,
config=config,
training_args=inference_args,
)
model = model.to(
device="cuda",
dtype=torch.bfloat16,
)
tokenizer = Qwen2Tokenizer.from_pretrained(
model_args.model_path
)
tokenizer, new_token_ids, _ = add_special_tokens(
tokenizer
)
init_from_model_path_if_needed(
model,
model_args,
)
image_token_id = model.language_model.config.video_token_id
model.eval()
self.model = model
self.vae_model = vae_model
self.vae_config = vae_config
self.tokenizer = tokenizer
self.new_token_ids = new_token_ids
self.image_token_id = image_token_id
self.base_model_args = model_args
self.base_data_args = data_args
self.base_inference_args = inference_args
self.initialized = True
print("Lance initialized successfully")
def generate(
self,
task,
prompt,
input_image,
input_video,
question,
height,
width,
num_frames,
seed,
resolution,
validation_num_timesteps,
validation_timestep_shift,
cfg_text_scale,
):
task = normalize_task(task)
actual_seed = normalize_seed(int(seed))
set_seed(actual_seed)
save_dir = RESULTS_ROOT / str(time.time())
save_dir.mkdir(parents=True, exist_ok=True)
inference_args = deepcopy(
self.base_inference_args
)
inference_args.video_height = int(height)
inference_args.video_width = int(width)
inference_args.num_frames = int(num_frames)
inference_args.validation_num_timesteps = (
validation_num_timesteps
)
inference_args.validation_timestep_shift = (
validation_timestep_shift
)
inference_args.task = task
prompt_file = TMP_INPUT_DIR / "prompt.json"
# =====================================================
# PAYLOADS
# =====================================================
if task == TASK_T2V:
payload = {
"000000.mp4": prompt
}
elif task == TASK_T2I:
payload = {
"000000.png": prompt
}
elif task == TASK_IMAGE_EDIT:
payload = {
"000000": {
"interleave_array": [
input_image,
[prompt, ""]
],
"element_dtype_array": [
"image",
"text"
],
"istarget_in_interleave": [
0,
1
],
}
}
elif task == TASK_VIDEO_EDIT:
payload = {
"000000": {
"interleave_array": [
input_video,
[prompt, ""]
],
"element_dtype_array": [
"video",
"text"
],
"istarget_in_interleave": [
0,
1
],
}
}
elif task == TASK_X2T_IMAGE:
payload = {
"000000": {
"interleave_array": [
input_image,
[
"Describe the image",
question,
""
]
],
"element_dtype_array": [
"image",
"text"
],
"istarget_in_interleave": [
0,
1
],
}
}
elif task == TASK_X2T_VIDEO:
payload = {
"000000": {
"interleave_array": [
input_video,
[
"Describe the video",
question,
""
]
],
"element_dtype_array": [
"video",
"text"
],
"istarget_in_interleave": [
0,
1
],
}
}
else:
return (
None,
None,
"",
"Invalid task",
"",
)
with open(prompt_file, "w") as f:
json.dump(payload, f)
dataset_config = DataConfig.from_yaml(
str(prompt_file)
)
val_dataset = ValidationDataset(
jsonl_path=str(prompt_file),
tokenizer=self.tokenizer,
data_args=self.base_data_args,
model_args=self.base_model_args,
training_args=inference_args,
new_token_ids=self.new_token_ids,
dataset_config=dataset_config,
local_rank=0,
world_size=1,
)
val_data_cpu = simple_custom_collate(
[val_dataset[0]]
)
validate_on_fixed_batch(
fsdp_model=self.model,
vae_model=self.vae_model,
tokenizer=self.tokenizer,
val_data_cpu=val_data_cpu,
training_args=inference_args,
model_args=self.base_model_args,
inference_args=inference_args,
new_token_ids=self.new_token_ids,
image_token_id=self.image_token_id,
device="cuda",
save_source_video=False,
save_path_gen=str(save_dir),
save_path_gt="",
)
clean_memory()
gc.collect()
torch.cuda.empty_cache()
videos = list(save_dir.glob("*.mp4"))
images = list(save_dir.glob("*.png"))
if len(videos) > 0:
return (
str(videos[0]),
None,
"",
"Success",
"",
)
if len(images) > 0:
return (
None,
str(images[0]),
"",
"Success",
"",
)
if task in [TASK_X2T_IMAGE, TASK_X2T_VIDEO]:
return (
None,
None,
"Understanding complete",
"Success",
"",
)
return (
None,
None,
"",
"No output generated",
"",
)
# =========================================================
# GLOBAL
# =========================================================
PIPELINE = LancePipeline()
# =========================================================
# GPU RUN
# =========================================================
@spaces.GPU(duration=300)
def run_task(
task,
prompt,
input_image,
input_video,
question,
height,
width,
num_frames,
seed,
resolution,
validation_num_timesteps,
validation_timestep_shift,
cfg_text_scale,
):
PIPELINE.initialize()
return PIPELINE.generate(
task=task,
prompt=prompt,
input_image=input_image,
input_video=input_video,
question=question,
height=height,
width=width,
num_frames=num_frames,
seed=seed,
resolution=resolution,
validation_num_timesteps=validation_num_timesteps,
validation_timestep_shift=validation_timestep_shift,
cfg_text_scale=cfg_text_scale,
)
# =========================================================
# UI
# =========================================================
with gr.Blocks(title="Lance ZeroGPU") as demo:
gr.Markdown("# Lance Multi-Task ZeroGPU")
with gr.Row():
with gr.Column():
task = gr.Dropdown(
label="Task",
choices=TASK_CHOICES,
value=DEFAULT_TASK,
)
prompt = gr.Textbox(
label="Prompt",
lines=5,
)
input_image = gr.Image(
label="Input Image",
type="filepath",
)
input_video = gr.Video(
label="Input Video",
)
question = gr.Textbox(
label="Question",
lines=3,
)
height = gr.Slider(
minimum=192,
maximum=1024,
step=16,
value=DEFAULT_HEIGHT,
label="Height",
)
width = gr.Slider(
minimum=192,
maximum=1024,
step=16,
value=DEFAULT_WIDTH,
label="Width",
)
num_frames = gr.Slider(
minimum=1,
maximum=121,
step=1,
value=DEFAULT_NUM_FRAMES,
label="Frames",
)
seed = gr.Number(
label="Seed",
value=-1,
precision=0,
)
resolution = gr.Dropdown(
label="Resolution",
choices=VIDEO_RESOLUTION_CHOICES,
value=DEFAULT_RESOLUTION,
)
validation_num_timesteps = gr.Slider(
minimum=1,
maximum=100,
step=1,
value=DEFAULT_TIMESTEPS,
label="Timesteps",
)
validation_timestep_shift = gr.Number(
label="Timestep Shift",
value=DEFAULT_TIMESTEP_SHIFT,
)
cfg_text_scale = gr.Number(
label="CFG Text Scale",
value=DEFAULT_CFG_TEXT_SCALE,
)
run_button = gr.Button(
"Run",
variant="primary",
)
with gr.Column():
output_video = gr.Video(
label="Generated Video"
)
output_image = gr.Image(
label="Generated Image"
)
output_text = gr.Textbox(
label="Text Output",
lines=8,
)
status = gr.Markdown()
logs = gr.Textbox(
label="Logs",
lines=15,
)
run_button.click(
fn=run_task,
inputs=[
task,
prompt,
input_image,
input_video,
question,
height,
width,
num_frames,
seed,
resolution,
validation_num_timesteps,
validation_timestep_shift,
cfg_text_scale,
],
outputs=[
output_video,
output_image,
output_text,
status,
logs,
],
)
# =========================================================
# LAUNCH
# =========================================================
demo.queue(
max_size=4,
default_concurrency_limit=1,
).launch(
server_name="0.0.0.0",
server_port=7860,
)