Lance / benchmarks /image_gen /DPG /sample_DPG.py
Nayefleb's picture
Upload folder using huggingface_hub
8b306b3 verified
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding: utf-8
import warnings
warnings.filterwarnings("ignore", message=".*pkg_resources is deprecated.*", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning, module="diffusers.models.transformers.transformer_2d")
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
import os.path as osp
from copy import deepcopy
from typing import Tuple, cast, Optional
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from transformers import HfArgumentParser, set_seed
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig
from safetensors.torch import load_file
from PIL import Image
from torchvision.utils import make_grid
import numpy as np
from tqdm import trange
from data.dataset_base import DataConfig, simple_custom_collate
from data.data_utils import add_special_tokens
from modeling.vae.wan.model import WanVideoVAE
from modeling.lance import LanceConfig, Lance, Qwen2ForCausalLM
from modeling.qwen2 import Qwen2Tokenizer
from modeling.qwen2.modeling_qwen2 import Qwen2Config
from modeling.vit.qwen2_5_vl_vit import Qwen2_5_VisionTransformerPretrainedModel
from common.utils.misc import tuple_mul, AutoEncoderParams
from common.utils.logging import get_logger
from common.val.utils import make_padded_latent
from data.datasets_custom import ValidationDataset
from config.config_factory import ModelArguments, DataArguments, TrainingArguments, EvaluationArguments, get_model_path
def init_from_vlm_if_needed(model: Qwen2ForCausalLM, model_args: ModelArguments, log_rank0):
# NOTE: 初始化加载VLM模型走这里
def load_safetensors_state_dict(folder_path):
# 只选取safetensors文件,按文件名排序保证顺序
safetensor_files = sorted(
f for f in os.listdir(folder_path) if f.endswith(".safetensors")
)
state_dict = {}
for filename in safetensor_files:
file_path = osp.join(folder_path, filename)
state_dict.update(load_file(file_path))
return state_dict
state_dict = load_safetensors_state_dict(model_args.llm_path)
# 参数名的更改以适配Lance的参数名
for k in list(state_dict.keys()):
if "visual" in k: # ViT and connector
state_dict[k.replace("visual", "vit_model")] = state_dict.pop(k)
else:
# 添加language_model前缀
state_dict["language_model." + k] = state_dict.pop(k)
result = model.load_state_dict(state_dict, strict=False)
clean_memory(state_dict)
def init_from_model_path_if_needed(model: Qwen2ForCausalLM, model_args: ModelArguments):
# 统一从 model_path 加载训练好的 Lance checkpoint。
path_dir = model_args.model_path
ema_path = osp.join(path_dir, "ema.safetensors")
model_path = osp.join(path_dir, "model.safetensors")
model_path_ft = None
if osp.exists(model_path):
model_path_ft = model_path
elif osp.exists(ema_path):
model_path_ft = ema_path
if model_path_ft:
model_state_dict = load_file(model_path_ft, device="cpu")
else:
raise FileNotFoundError(
f"Fine-tuning failed: No valid checkpoint ('ema.safetensors' or 'model.safetensors') found in {path_dir}"
)
# NOTE: position embeds are fixed sinusoidal embeddings, so we can just pop it off,
# which makes it easier to adapt to different resolutions.
if 'latent_pos_embed.pos_embed' in model_state_dict:
model_state_dict.pop('latent_pos_embed.pos_embed')
msg = model.load_state_dict(model_state_dict, strict=False)
clean_memory(model_state_dict)
return msg
def clean_memory(*objects):
"""清理内存并释放 GPU 缓存"""
for obj in objects:
del obj
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def decode_video_tensor_for_dpg(v_list):
"""
专门为 DPG 解码视频张量,保持原有的保存格式
"""
N_target = len(v_list)
if N_target != 1:
from einops import rearrange
padded_videos_latent = [v.permute(1, 0, 2, 3) for v in v_list]
v_tc_hw = rearrange(padded_videos_latent, "n t c h w -> t c h (n w)")
else:
v_tc_hw = v_list[0].permute(1, 0, 2, 3)
v_tc_hw = v_tc_hw.float().clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round().clamp(0, 255).to(torch.uint8)
return v_tc_hw
def resolve_dpg_paths(
model_args: ModelArguments,
data_args: DataArguments,
) -> None:
if not model_args.model_path:
raise ValueError("DPG requires --model_path to be provided explicitly.")
if not model_args.llm_path:
model_args.llm_path = model_args.model_path
if not model_args.vit_path:
model_args.vit_path = get_model_path("vit.qwen2_5_vl")
if not data_args.val_dataset_config_file:
data_args.val_dataset_config_file = get_model_path("dpg.data")
def validate_on_fixed_batch(
fsdp_model: Lance,
vae_model: Optional[WanVideoVAE],
tokenizer: Qwen2Tokenizer,
val_data_cpu: dict,
training_args: TrainingArguments,
model_args: ModelArguments,
data_args: DataArguments,
inference_args: EvaluationArguments,
curr_step: int,
logger,
new_token_ids,
image_token_id: int,
device: int,
save_source_video: bool = False,
save_path_gen: str = "",
save_path_gt: str = "",
sample_num_per_prompt: int = 1,
):
"""
验证逻辑,保持与原文件相同的保存格式
"""
# 检查是否初始化了分布式环境
if dist.is_initialized():
is_rank0 = (dist.get_rank() == 0)
else:
is_rank0 = True
log_rank0 = logger.info if is_rank0 else (lambda *_: None)
val_data = val_data_cpu.cuda(device).to_dict()
with torch.no_grad(), torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
# 计算 padded_latent
if "padded_videos" in val_data.keys():
val_data["padded_latent"] = make_padded_latent(val_data["padded_videos"], val_data["vae_data_mode"], vae_model)
# -------------------- GEN 分支 --------------------
tensor_list_for_grid = []
loop_iterator = trange(sample_num_per_prompt) if is_rank0 else range(sample_num_per_prompt)
# 支持断点重新生成
save_name = f"{save_path_gen}/{val_data['index']}.png"
if os.path.exists(save_name):
return None
for sample_num_per_prompt_index in loop_iterator:
# 采样生成(保持原参数)
params = {
"val_packed_text_ids": val_data["packed_text_ids"],
"val_packed_text_indexes": val_data["packed_text_indexes"],
"val_sample_lens": val_data["sample_lens"],
"val_packed_position_ids": val_data["packed_position_ids"],
"val_split_lens": val_data["split_lens"],
"val_attn_modes": val_data["attn_modes"],
"val_sample_N_target": val_data["sample_N_target"],
"val_packed_vae_token_indexes": val_data["packed_vae_token_indexes"],
"timestep_shift": training_args.validation_timestep_shift,
"num_timesteps": training_args.validation_num_timesteps,
"val_mse_loss_indexes": val_data.get("mse_loss_indexes", None),
"val_padded_latent": val_data["padded_latent"],
"video_sizes": val_data["video_sizes"],
"cfg_text_scale": model_args.cfg_text_scale,
"cfg_interval": training_args.cfg_interval,
"cfg_renorm_min": training_args.cfg_renorm_min,
"cfg_renorm_type": training_args.cfg_renorm_type,
"device": device,
"dtype": torch.bfloat16,
"new_token_ids": new_token_ids,
"max_samples": training_args.validation_max_samples,
"validation_noise_seed": training_args.validation_noise_seed + sample_num_per_prompt_index,
"apply_chat_template": training_args.apply_chat_template,
"apply_qwen_2_5_vl_pos_emb": training_args.apply_qwen_2_5_vl_pos_emb,
"image_token_id": image_token_id,
"val_packed_vit_token_indexes": val_data.get("packed_vit_token_indexes", None),
"val_packed_vit_tokens": val_data.get("packed_vit_tokens", None),
"vit_video_grid_thw": val_data.get("vit_video_grid_thw", None),
"vae_video_grid_thw": val_data["vae_video_grid_thw"],
"video_grid_thw": val_data.get("video_grid_thw", None),
"caption": val_data.get("caption", None),
"sample_task": val_data["sample_task"],
"sample_modality": val_data["sample_modality"],
"cfg_type": training_args.cfg_type,
"cfg_uncond_token_id": training_args.cfg_uncond_token_id,
"index": val_data["index"],
"val_padded_videos": val_data["padded_videos"] if save_source_video else None,
}
if training_args.use_KVcache:
denoise_latent, captions, padded_videos, index = fsdp_model.validation_gen_KVcache(**params)
else:
denoise_latent, captions, padded_videos, index = fsdp_model.validation_gen(**params)
# 解码 + 保存
for i_val, latent in enumerate(denoise_latent):
v_list = [vae_model.vae_decode([latent_])[0] for latent_ in latent]
# 保持与原文件相同的保存格式
v_thwc = decode_video_tensor_for_dpg(v_list)
# 直接取第0帧
if v_thwc.shape[0] == 1:
tensor_list_for_grid.append(v_thwc.squeeze(0).cpu())
else:
raise NotImplementedError("需要保存图像")
# 保持原有的保存格式
grid_tensor = make_grid(tensor_list_for_grid, nrow=int(np.sqrt(sample_num_per_prompt)), padding=0, pad_value=255)
grid_numpy = grid_tensor.permute(1, 2, 0).numpy()
Image.fromarray(grid_numpy).save(save_name)
def main():
# ========================= Env setup ==============================
assert torch.cuda.is_available()
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
dist.init_process_group("nccl")
GLOBAL_RANK = dist.get_rank()
WORLD_SIZE = dist.get_world_size()
else:
GLOBAL_RANK = 0
WORLD_SIZE = 1
LOCAL_RANK = GLOBAL_RANK % torch.cuda.device_count()
DEVICE = LOCAL_RANK
torch.cuda.set_device(DEVICE)
# ========================= Args and logger setup ==============================
parser = HfArgumentParser((ModelArguments, DataArguments, EvaluationArguments))
model_args, data_args, inference_args = cast(
Tuple[ModelArguments, DataArguments, EvaluationArguments],
parser.parse_args_into_dataclasses(),
)
training_args = inference_args
# ========================= DPG 路径解析 ==============================
resolve_dpg_paths(model_args, data_args)
# NOTE validation_noise_seed 与 validation_data_seed 相同
training_args.validation_noise_seed = inference_args.evaluation_seed
training_args.validation_data_seed = inference_args.evaluation_seed
logger = get_logger()
log_rank0 = print if GLOBAL_RANK == 0 else (lambda *_: None)
# Set seed:
seed = training_args.global_seed * WORLD_SIZE + GLOBAL_RANK
set_seed(seed)
# ========================= LLM model setup ==============================
llm_config: Qwen2Config = Qwen2Config.from_json_file(osp.join(model_args.model_path, "llm_config.json"))
llm_config.layer_module = model_args.layer_module
llm_config.qk_norm = model_args.llm_qk_norm
llm_config.qk_norm_und = model_args.llm_qk_norm_und
llm_config.qk_norm_gen = model_args.llm_qk_norm_gen
llm_config.tie_word_embeddings = model_args.tie_word_embeddings
llm_config.freeze_und = training_args.freeze_und
llm_config.apply_qwen_2_5_vl_pos_emb = training_args.apply_qwen_2_5_vl_pos_emb
language_model: Qwen2ForCausalLM = Qwen2ForCausalLM(llm_config)
if training_args.visual_und:
if model_args.vit_type in ("qwen2_5_vl", "qwen_2_5_vl_original"):
vit_config = Qwen2_5_VLVisionConfig.from_pretrained(model_args.vit_path)
vit_model = Qwen2_5_VisionTransformerPretrainedModel(vit_config)
vit_weights = load_file(osp.join(model_args.vit_path, "vit.safetensors"))
vit_model.load_state_dict(vit_weights, strict=True)
else:
raise ValueError(f"Unsupported vit_type: {model_args.vit_type}")
clean_memory(vit_weights)
if training_args.visual_gen:
vae_model = WanVideoVAE()
vae_config: AutoEncoderParams = deepcopy(vae_model.vae_config)
else:
vae_model = None
vae_config = None
# Lance的配置
config = LanceConfig(
visual_gen=training_args.visual_gen,
visual_und=training_args.visual_und,
llm_config=llm_config,
vit_config=vit_config if training_args.visual_und else None,
vae_config=vae_config if training_args.visual_gen else None,
latent_patch_size=model_args.latent_patch_size,
max_num_frames=model_args.max_num_frames,
max_latent_size=model_args.max_latent_size,
vit_max_num_patch_per_side=model_args.vit_max_num_patch_per_side,
connector_act=model_args.connector_act,
interpolate_pos=model_args.interpolate_pos,
timestep_shift=training_args.timestep_shift,
)
model: Lance = Lance(
language_model=language_model,
vit_model=vit_model if training_args.visual_und else None,
vit_type=model_args.vit_type,
config=config,
training_args=training_args,
)
model = model.to(DEVICE)
# Setup tokenizer for model:
tokenizer: Qwen2Tokenizer = Qwen2Tokenizer.from_pretrained(model_args.model_path)
tokenizer, new_token_ids, num_new_tokens = add_special_tokens(tokenizer)
# 在加载ckpt前,初始化moe
if training_args.copy_init_moe:
language_model.init_moe()
init_from_model_path_if_needed(model, model_args)
# 现在再 resize
if num_new_tokens > 0:
model.language_model.resize_token_embeddings(len(tokenizer))
model.config.llm_config.vocab_size = len(tokenizer)
model.language_model.config.vocab_size = len(tokenizer)
if model_args.vit_type.lower() == "qwen2_5_vl":
from common.model.hacks import hack_qwen2_5_vl_config
language_model = hack_qwen2_5_vl_config(language_model)
image_token_id = language_model.config.video_token_id
new_token_ids.update({"image_token_id": image_token_id})
model.update_tokenizer(tokenizer=tokenizer)
if model_args.tie_word_embeddings:
model.language_model.untie_lm_head()
model.language_model.copy_new_token_rows_to_lm_head(num_new_tokens)
model_args.tie_word_embeddings = False
llm_config.tie_word_embeddings = False
else:
assert model.language_model.get_input_embeddings().weight.data.data_ptr() != model.language_model.get_output_embeddings().weight.data.data_ptr(), 'tie_world_embeddings 冲突'
model = model.to(device=DEVICE, dtype=torch.bfloat16)
model.eval()
if vae_model is not None and hasattr(vae_model, "eval"):
vae_model.eval()
# Setup packed dataloader - 直接初始化简单的 DataConfig 对象
dataset_config = DataConfig(grouped_datasets={})
# 配置基本参数
dataset_config.num_frames = inference_args.num_frames
dataset_config.H = inference_args.video_height
dataset_config.W = inference_args.video_width
dataset_config.task = inference_args.task
dataset_config.resolution = inference_args.resolution
dataset_config.text_template = inference_args.text_template
# 配置 VIT 相关参数
if training_args.visual_und:
dataset_config.vit_patch_size = model_args.vit_patch_size
dataset_config.vit_patch_size_temporal = model_args.vit_patch_size_temporal
dataset_config.vit_max_num_patch_per_side = model_args.vit_max_num_patch_per_side
# 配置 VAE 相关参数
if training_args.visual_gen and vae_config:
assert len(model_args.latent_patch_size) == 3, "len(latent_patch_size) must be 3"
vae_downsample = tuple_mul(
model_args.latent_patch_size, (vae_config.downsample_temporal, vae_config.downsample_spatial, vae_config.downsample_spatial)
)
dataset_config.latent_patch_size = model_args.latent_patch_size
dataset_config.vae_downsample = vae_downsample
dataset_config.max_latent_size = model_args.max_latent_size
dataset_config.max_num_frames = model_args.max_num_frames
# fix: 共享dropout
dataset_config.text_cond_dropout_prob = model_args.text_cond_dropout_prob
dataset_config.vae_cond_dropout_prob = model_args.vae_cond_dropout_prob
dataset_config.vit_cond_dropout_prob = model_args.vit_cond_dropout_prob
# 创建数据集
val_dataset = ValidationDataset(
jsonl_path= data_args.val_dataset_config_file,
tokenizer=tokenizer,
data_args=data_args,
model_args=model_args,
training_args=training_args,
new_token_ids=new_token_ids,
dataset_config=dataset_config,
local_rank=GLOBAL_RANK,
world_size=WORLD_SIZE,
)
val_loader = DataLoader(
val_dataset,
batch_size=1,
num_workers=0,
pin_memory=True,
collate_fn=simple_custom_collate,
drop_last=True,
prefetch_factor=None,
persistent_workers=False,
multiprocessing_context=None,
)
val_loader_iter = iter(val_loader)
if not os.path.exists(inference_args.save_path_gen):
os.makedirs(inference_args.save_path_gen, exist_ok=True)
# 主循环
from tqdm import tqdm
import time
from datetime import datetime, timedelta
total_batches = len(val_loader)
pbar = tqdm(total=total_batches, desc="Validating", unit="batch", leave=True, ncols=120, disable=(GLOBAL_RANK != 0))
start_time = time.time()
for i in range(total_batches):
val_data_cpu = next(val_loader_iter)
validate_on_fixed_batch(
fsdp_model=model,
vae_model=vae_model,
tokenizer=tokenizer,
val_data_cpu=val_data_cpu,
training_args=training_args,
model_args=model_args,
data_args=data_args,
inference_args=inference_args,
curr_step=0,
logger=logger,
new_token_ids=new_token_ids,
image_token_id=image_token_id,
device=DEVICE,
save_source_video=False,
save_path_gen=inference_args.save_path_gen,
save_path_gt="",
sample_num_per_prompt=inference_args.sample_num_per_prompt,
)
if GLOBAL_RANK == 0:
elapsed = time.time() - start_time
avg_time = elapsed / (i + 1)
eta_seconds = avg_time * (total_batches - i - 1)
expected_finish = datetime.now() + timedelta(seconds=eta_seconds)
finish_str = expected_finish.strftime('%Y-%m-%d %H:%M:%S')
pbar.set_postfix_str(f"ETA: {timedelta(seconds=int(eta_seconds))} | Finish: {finish_str}")
pbar.update(1)
if GLOBAL_RANK == 0:
pbar.close()
if dist.is_initialized():
dist.destroy_process_group()
if __name__ == "__main__":
main()