multishot / multi_view /test.py
PencilHu's picture
Upload folder using huggingface_hub
85752bc verified
import torch
from PIL import Image
from einops import rearrange
import numpy as np
from typing import Optional, List, Tuple, Callable
import json
import math
from tqdm import tqdm
import os
import argparse
import torch.distributed as dist
import torch.nn.functional as F
from diffsynth.models import ModelManager
from diffsynth.models.utils import load_state_dict
import torch
from PIL import Image
from diffsynth import save_video
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
import yaml
import torch, os, json
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, wan_parser
from datasets.videodataset import MulltiShot_MultiView_Dataset
from PIL import Image, ImageOps
def test_video(args):
checkpoint_path = os.path.join(args.output_path, args.visual_log_project_name, f"checkpoint-step-{args.infer_step}-epoch-{args.epoch_id}", "weights.safetensors")
output_path = os.path.join("./output", args.visual_log_project_name)
os.makedirs(f"{output_path}/ref_images", exist_ok=True)
os.makedirs(f"{output_path}/video", exist_ok=True)
print(checkpoint_path)
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(path=os.path.join(args.local_model_path, "Wan2.2-TI2V-5B/models_t5_umt5-xxl-enc-bf16.pth"), offload_device="cuda"),
ModelConfig(path=os.path.join(args.local_model_path, "Wan2.2-TI2V-5B/Wan2.2_VAE.pth"), offload_device="cuda"),
ModelConfig(path=checkpoint_path, offload_device="cuda"),
],
redirect_common_files = False
)
pipe.enable_vram_management()
with open(args.train_yaml, "r", encoding="utf-8") as f:
conf_info = yaml.safe_load(f) # 用 safe_load 更安全
dataset = MulltiShot_MultiView_Dataset(
dataset_base_path=args.dataset_base_path,
resolution=(args.height, args.width),
ref_num=args.ref_num,
training=False
)
log_file_name = "output_log.txt"
import pdb; pdb.set_trace()
# v_indexs = [0, 10, 30, 50, 70, 100, 130, 150, 180, 200]
v_indexs = [0, 5, 15, 20]
with open(os.path.join(output_path, log_file_name), "w") as f:
for v_index in v_indexs:
metadata = dataset[v_index]
video, _ = pipe(
args = args,
prompt = [metadata["single_caption"]], #prompt, #"两只狗在擂台上打拳击", ### 手动变成batch = 1 的list
ref_images = [metadata["ref_images"]],
negative_prompt=["色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"],
seed=42, tiled=True,
height=args.height, width=args.width,
num_frames=args.num_frames,
cfg_scale_face = 5.,
num_ref_images = metadata["ref_num"]
)
for r_index, img in enumerate(metadata["ref_images"]):
img.save(f"{output_path}/ref_images/{v_index}-{r_index}.png")
save_video(video, f"{output_path}/video/{v_index}.mp4", fps=15, quality=10)
f.write(f"{metadata['single_caption']}\n")
def specify_video(args):
def process_ref_images(ref_images, height, width):
ref_images_new = []
for ref_image in ref_images:
h = height
w = width
ref_image = ref_image.convert("RGB")
# Calculate the required size to keep aspect ratio and fill the rest with padding.
img_ratio = ref_image.width / ref_image.height
target_ratio = w / h
if img_ratio > target_ratio: # Image is wider than target
new_width = w
new_height = int(new_width / img_ratio)
else: # Image is taller than target
new_height = h
new_width = int(new_height * img_ratio)
# img = img.resize((new_width, new_height), Image.ANTIALIAS)
ref_image = ref_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Create a new image with the target size and place the resized image in the center
delta_w = w - ref_image.size[0]
delta_h = h - ref_image.size[1]
padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
new_img = ImageOps.expand(ref_image, padding, fill=(255, 255, 255))
ref_images_new.append(new_img)
return ref_images_new
checkpoint_path = os.path.join(args.output_path, args.visual_log_project_name, f"checkpoint-step-{args.infer_step}-epoch-{args.epoch_id}", "weights.safetensors")
output_path = os.path.join("./output", args.visual_log_project_name)
os.makedirs(f"{output_path}/ref_images", exist_ok=True)
os.makedirs(f"{output_path}/video", exist_ok=True)
print(checkpoint_path)
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(path=os.path.join(args.local_model_path, "Wan2.2-TI2V-5B/models_t5_umt5-xxl-enc-bf16.pth"), offload_device="cuda"),
ModelConfig(path=os.path.join(args.local_model_path, "Wan2.2-TI2V-5B/Wan2.2_VAE.pth"), offload_device="cuda"),
ModelConfig(path=checkpoint_path, offload_device="cuda"),
],
redirect_common_files = False
)
pipe.enable_vram_management()
ref_images=[
Image.open("/root/paddlejob/workspace/qizipeng/baidu/personal-code/Multi-view/multi_view/datasets/cl_0.png"),
# Image.open("/root/paddlejob/workspace/qizipeng/baidu/personal-code/Multi-view/multi_view/datasets/cl_2.png"),
# Image.open("/root/paddlejob/workspace/qizipeng/baidu/personal-code/Multi-view/multi_view/datasets/cl_3.png"),
]
ref_images = process_ref_images(ref_images, args.height, args.width)
video, _ = pipe(
args = args,
prompt = ["An elderly man with short gray hair and glasses stands in a softly lit indoor hallway. The shot begins with a frontal view of his face, his expression calm and attentive as he looks straight ahead. Then, he turns his head to his right, responding to someone standing beside him. His gaze shifts fully toward the other person as his expression becomes more engaged. The movement continues until he reaches a complete side profile, fully turning his face toward the person he is interacting with. Smooth and natural head rotation, warm indoor lighting."], #prompt, #"两只狗在擂台上打拳击", ### 手动变成batch = 1 的list
ref_images = [ref_images],
negative_prompt=["色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"],
seed=42, tiled=True,
height=args.height, width=args.width,
num_frames=args.num_frames,
cfg_scale_face = 5.,
num_ref_images = len(ref_images)
)
save_video(video, f"{output_path}/video/cl.mp4", fps=15, quality=10)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="长视频分镜头连续生成脚本")
# --- 核心路径参数 ---
parser = wan_parser()
args = parser.parse_args()
args, unknown = parser.parse_known_args()
print("❗ Unknown arguments:", unknown)
### 执行过pip install -e . 的话diffsynth 里的东西修改后要重新安装
# import pdb; pdb.set_trace()
###下面是解析train.yaml里的内容
with open(args.train_yaml, "r", encoding="utf-8") as f:
conf_info = yaml.safe_load(f) # 用 safe_load 更安全
print(conf_info)
args.dataset_base_path = conf_info["dataset_args"]["base_path"]
args.max_checkpoints_to_keep = conf_info["train_args"]["max_checkpoints_to_keep"]
args.resume_from_checkpoint = conf_info["train_args"]["resume_from_checkpoint"]
args.visual_log_project_name = conf_info["train_args"]["visual_log_project_name"]
args.seed = conf_info["train_args"]["seed"]
args.output_path = conf_info["train_args"]["output_path"]
args.save_steps = conf_info["train_args"]["save_steps"]
args.save_epoches = conf_info["train_args"]["save_epoches"]
args.batch_size = conf_info["train_args"]["batch_size"]
args.local_model_path = conf_info["train_args"]["local_model_path"]
args.height = conf_info["dataset_args"]["height"]
args.width = conf_info["dataset_args"]["width"]
args.num_frames = conf_info["dataset_args"]["num_frames"]
args.ref_num = conf_info["dataset_args"]["ref_num"]
args.infer_step = conf_info["infer_args"]["infer_step"]
args.epoch_id = conf_info["infer_args"]["epoch_id"]
args.split_rope = conf_info["train_args"]["split_rope"]
args.split1 = conf_info["train_args"]["split1"]
args.split2 = conf_info["train_args"]["split2"]
args.split3 = conf_info["train_args"]["split3"]
test_video(args)
# specify_video(args)