| import torch |
| from PIL import Image |
| from einops import rearrange |
| import numpy as np |
| from typing import Optional, List, Tuple, Callable |
| import json |
| import math |
|
|
| from tqdm import tqdm |
| import os |
| import argparse |
| import torch.distributed as dist |
| import torch.nn.functional as F |
| from diffsynth.models import ModelManager |
| from diffsynth.models.utils import load_state_dict |
| import torch |
| from PIL import Image |
| from diffsynth import save_video |
| from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig |
| from modelscope import dataset_snapshot_download |
| import yaml |
| import torch, os, json |
| from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig |
| from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, wan_parser |
| from datasets.videodataset import MulltiShot_MultiView_Dataset |
| from PIL import Image, ImageOps |
|
|
|
|
| def test_video(args): |
| checkpoint_path = os.path.join(args.output_path, args.visual_log_project_name, f"checkpoint-step-{args.infer_step}-epoch-{args.epoch_id}", "weights.safetensors") |
| output_path = os.path.join("./output", args.visual_log_project_name) |
| os.makedirs(f"{output_path}/ref_images", exist_ok=True) |
| os.makedirs(f"{output_path}/video", exist_ok=True) |
| print(checkpoint_path) |
| pipe = WanVideoPipeline.from_pretrained( |
| torch_dtype=torch.bfloat16, |
| device="cuda", |
| model_configs=[ |
| ModelConfig(path=os.path.join(args.local_model_path, "Wan2.2-TI2V-5B/models_t5_umt5-xxl-enc-bf16.pth"), offload_device="cuda"), |
| ModelConfig(path=os.path.join(args.local_model_path, "Wan2.2-TI2V-5B/Wan2.2_VAE.pth"), offload_device="cuda"), |
| ModelConfig(path=checkpoint_path, offload_device="cuda"), |
| ], |
| redirect_common_files = False |
| ) |
| pipe.enable_vram_management() |
|
|
| with open(args.train_yaml, "r", encoding="utf-8") as f: |
| conf_info = yaml.safe_load(f) |
| dataset = MulltiShot_MultiView_Dataset( |
| dataset_base_path=args.dataset_base_path, |
| resolution=(args.height, args.width), |
| ref_num=args.ref_num, |
| training=False |
| ) |
|
|
| log_file_name = "output_log.txt" |
| import pdb; pdb.set_trace() |
| |
| v_indexs = [0, 5, 15, 20] |
| with open(os.path.join(output_path, log_file_name), "w") as f: |
| for v_index in v_indexs: |
| metadata = dataset[v_index] |
| video, _ = pipe( |
| args = args, |
| prompt = [metadata["single_caption"]], |
| ref_images = [metadata["ref_images"]], |
| negative_prompt=["色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"], |
| seed=42, tiled=True, |
| height=args.height, width=args.width, |
| num_frames=args.num_frames, |
| cfg_scale_face = 5., |
| num_ref_images = metadata["ref_num"] |
| ) |
| for r_index, img in enumerate(metadata["ref_images"]): |
| img.save(f"{output_path}/ref_images/{v_index}-{r_index}.png") |
| |
| save_video(video, f"{output_path}/video/{v_index}.mp4", fps=15, quality=10) |
| |
| f.write(f"{metadata['single_caption']}\n") |
| |
| def specify_video(args): |
| def process_ref_images(ref_images, height, width): |
| ref_images_new = [] |
| for ref_image in ref_images: |
| h = height |
| w = width |
| ref_image = ref_image.convert("RGB") |
| |
| img_ratio = ref_image.width / ref_image.height |
| target_ratio = w / h |
| |
| if img_ratio > target_ratio: |
| new_width = w |
| new_height = int(new_width / img_ratio) |
| else: |
| new_height = h |
| new_width = int(new_height * img_ratio) |
| |
| |
| ref_image = ref_image.resize((new_width, new_height), Image.Resampling.LANCZOS) |
|
|
| |
| delta_w = w - ref_image.size[0] |
| delta_h = h - ref_image.size[1] |
| padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2)) |
| new_img = ImageOps.expand(ref_image, padding, fill=(255, 255, 255)) |
| ref_images_new.append(new_img) |
| return ref_images_new |
|
|
| checkpoint_path = os.path.join(args.output_path, args.visual_log_project_name, f"checkpoint-step-{args.infer_step}-epoch-{args.epoch_id}", "weights.safetensors") |
| output_path = os.path.join("./output", args.visual_log_project_name) |
| os.makedirs(f"{output_path}/ref_images", exist_ok=True) |
| os.makedirs(f"{output_path}/video", exist_ok=True) |
| print(checkpoint_path) |
| pipe = WanVideoPipeline.from_pretrained( |
| torch_dtype=torch.bfloat16, |
| device="cuda", |
| model_configs=[ |
| ModelConfig(path=os.path.join(args.local_model_path, "Wan2.2-TI2V-5B/models_t5_umt5-xxl-enc-bf16.pth"), offload_device="cuda"), |
| ModelConfig(path=os.path.join(args.local_model_path, "Wan2.2-TI2V-5B/Wan2.2_VAE.pth"), offload_device="cuda"), |
| ModelConfig(path=checkpoint_path, offload_device="cuda"), |
| ], |
| redirect_common_files = False |
| ) |
| pipe.enable_vram_management() |
| ref_images=[ |
| Image.open("/root/paddlejob/workspace/qizipeng/baidu/personal-code/Multi-view/multi_view/datasets/cl_0.png"), |
| |
| |
| ] |
| |
| ref_images = process_ref_images(ref_images, args.height, args.width) |
|
|
|
|
| video, _ = pipe( |
| args = args, |
| prompt = ["An elderly man with short gray hair and glasses stands in a softly lit indoor hallway. The shot begins with a frontal view of his face, his expression calm and attentive as he looks straight ahead. Then, he turns his head to his right, responding to someone standing beside him. His gaze shifts fully toward the other person as his expression becomes more engaged. The movement continues until he reaches a complete side profile, fully turning his face toward the person he is interacting with. Smooth and natural head rotation, warm indoor lighting."], |
| ref_images = [ref_images], |
| negative_prompt=["色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"], |
| seed=42, tiled=True, |
| height=args.height, width=args.width, |
| num_frames=args.num_frames, |
| cfg_scale_face = 5., |
| num_ref_images = len(ref_images) |
| ) |
| save_video(video, f"{output_path}/video/cl.mp4", fps=15, quality=10) |
| |
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser(description="长视频分镜头连续生成脚本") |
| |
| |
| parser = wan_parser() |
| args = parser.parse_args() |
|
|
| args, unknown = parser.parse_known_args() |
| print("❗ Unknown arguments:", unknown) |
| |
| |
| |
| with open(args.train_yaml, "r", encoding="utf-8") as f: |
| conf_info = yaml.safe_load(f) |
| print(conf_info) |
| args.dataset_base_path = conf_info["dataset_args"]["base_path"] |
| args.max_checkpoints_to_keep = conf_info["train_args"]["max_checkpoints_to_keep"] |
| args.resume_from_checkpoint = conf_info["train_args"]["resume_from_checkpoint"] |
| args.visual_log_project_name = conf_info["train_args"]["visual_log_project_name"] |
| args.seed = conf_info["train_args"]["seed"] |
| args.output_path = conf_info["train_args"]["output_path"] |
| args.save_steps = conf_info["train_args"]["save_steps"] |
| args.save_epoches = conf_info["train_args"]["save_epoches"] |
| args.batch_size = conf_info["train_args"]["batch_size"] |
| args.local_model_path = conf_info["train_args"]["local_model_path"] |
| args.height = conf_info["dataset_args"]["height"] |
| args.width = conf_info["dataset_args"]["width"] |
| args.num_frames = conf_info["dataset_args"]["num_frames"] |
| args.ref_num = conf_info["dataset_args"]["ref_num"] |
| args.infer_step = conf_info["infer_args"]["infer_step"] |
| args.epoch_id = conf_info["infer_args"]["epoch_id"] |
| args.split_rope = conf_info["train_args"]["split_rope"] |
| args.split1 = conf_info["train_args"]["split1"] |
| args.split2 = conf_info["train_args"]["split2"] |
| args.split3 = conf_info["train_args"]["split3"] |
| |
| test_video(args) |
| |