import torch from PIL import Image from einops import rearrange import numpy as np from typing import Optional, List, Tuple, Callable import json import math from tqdm import tqdm import os import argparse import torch.distributed as dist import torch.nn.functional as F from diffsynth.models import ModelManager from diffsynth.models.utils import load_state_dict import torch from PIL import Image from diffsynth import save_video from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig from modelscope import dataset_snapshot_download import yaml import torch, os, json from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, wan_parser from datasets.videodataset import MulltiShot_MultiView_Dataset from PIL import Image, ImageOps def test_video(args): checkpoint_path = os.path.join(args.output_path, args.visual_log_project_name, f"checkpoint-step-{args.infer_step}-epoch-{args.epoch_id}", "weights.safetensors") output_path = os.path.join("./output", args.visual_log_project_name) os.makedirs(f"{output_path}/ref_images", exist_ok=True) os.makedirs(f"{output_path}/video", exist_ok=True) print(checkpoint_path) pipe = WanVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", model_configs=[ ModelConfig(path=os.path.join(args.local_model_path, "Wan2.2-TI2V-5B/models_t5_umt5-xxl-enc-bf16.pth"), offload_device="cuda"), ModelConfig(path=os.path.join(args.local_model_path, "Wan2.2-TI2V-5B/Wan2.2_VAE.pth"), offload_device="cuda"), ModelConfig(path=checkpoint_path, offload_device="cuda"), ], redirect_common_files = False ) pipe.enable_vram_management() with open(args.train_yaml, "r", encoding="utf-8") as f: conf_info = yaml.safe_load(f) # 用 safe_load 更安全 dataset = MulltiShot_MultiView_Dataset( dataset_base_path=args.dataset_base_path, resolution=(args.height, args.width), ref_num=args.ref_num, training=False ) log_file_name = "output_log.txt" import pdb; pdb.set_trace() # v_indexs = [0, 10, 30, 50, 70, 100, 130, 150, 180, 200] v_indexs = [0, 5, 15, 20] with open(os.path.join(output_path, log_file_name), "w") as f: for v_index in v_indexs: metadata = dataset[v_index] video, _ = pipe( args = args, prompt = [metadata["single_caption"]], #prompt, #"两只狗在擂台上打拳击", ### 手动变成batch = 1 的list ref_images = [metadata["ref_images"]], negative_prompt=["色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"], seed=42, tiled=True, height=args.height, width=args.width, num_frames=args.num_frames, cfg_scale_face = 5., num_ref_images = metadata["ref_num"] ) for r_index, img in enumerate(metadata["ref_images"]): img.save(f"{output_path}/ref_images/{v_index}-{r_index}.png") save_video(video, f"{output_path}/video/{v_index}.mp4", fps=15, quality=10) f.write(f"{metadata['single_caption']}\n") def specify_video(args): def process_ref_images(ref_images, height, width): ref_images_new = [] for ref_image in ref_images: h = height w = width ref_image = ref_image.convert("RGB") # Calculate the required size to keep aspect ratio and fill the rest with padding. img_ratio = ref_image.width / ref_image.height target_ratio = w / h if img_ratio > target_ratio: # Image is wider than target new_width = w new_height = int(new_width / img_ratio) else: # Image is taller than target new_height = h new_width = int(new_height * img_ratio) # img = img.resize((new_width, new_height), Image.ANTIALIAS) ref_image = ref_image.resize((new_width, new_height), Image.Resampling.LANCZOS) # Create a new image with the target size and place the resized image in the center delta_w = w - ref_image.size[0] delta_h = h - ref_image.size[1] padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2)) new_img = ImageOps.expand(ref_image, padding, fill=(255, 255, 255)) ref_images_new.append(new_img) return ref_images_new checkpoint_path = os.path.join(args.output_path, args.visual_log_project_name, f"checkpoint-step-{args.infer_step}-epoch-{args.epoch_id}", "weights.safetensors") output_path = os.path.join("./output", args.visual_log_project_name) os.makedirs(f"{output_path}/ref_images", exist_ok=True) os.makedirs(f"{output_path}/video", exist_ok=True) print(checkpoint_path) pipe = WanVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", model_configs=[ ModelConfig(path=os.path.join(args.local_model_path, "Wan2.2-TI2V-5B/models_t5_umt5-xxl-enc-bf16.pth"), offload_device="cuda"), ModelConfig(path=os.path.join(args.local_model_path, "Wan2.2-TI2V-5B/Wan2.2_VAE.pth"), offload_device="cuda"), ModelConfig(path=checkpoint_path, offload_device="cuda"), ], redirect_common_files = False ) pipe.enable_vram_management() ref_images=[ Image.open("/root/paddlejob/workspace/qizipeng/baidu/personal-code/Multi-view/multi_view/datasets/cl_0.png"), # Image.open("/root/paddlejob/workspace/qizipeng/baidu/personal-code/Multi-view/multi_view/datasets/cl_2.png"), # Image.open("/root/paddlejob/workspace/qizipeng/baidu/personal-code/Multi-view/multi_view/datasets/cl_3.png"), ] ref_images = process_ref_images(ref_images, args.height, args.width) video, _ = pipe( args = args, prompt = ["An elderly man with short gray hair and glasses stands in a softly lit indoor hallway. The shot begins with a frontal view of his face, his expression calm and attentive as he looks straight ahead. Then, he turns his head to his right, responding to someone standing beside him. His gaze shifts fully toward the other person as his expression becomes more engaged. The movement continues until he reaches a complete side profile, fully turning his face toward the person he is interacting with. Smooth and natural head rotation, warm indoor lighting."], #prompt, #"两只狗在擂台上打拳击", ### 手动变成batch = 1 的list ref_images = [ref_images], negative_prompt=["色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"], seed=42, tiled=True, height=args.height, width=args.width, num_frames=args.num_frames, cfg_scale_face = 5., num_ref_images = len(ref_images) ) save_video(video, f"{output_path}/video/cl.mp4", fps=15, quality=10) if __name__ == '__main__': parser = argparse.ArgumentParser(description="长视频分镜头连续生成脚本") # --- 核心路径参数 --- parser = wan_parser() args = parser.parse_args() args, unknown = parser.parse_known_args() print("❗ Unknown arguments:", unknown) ### 执行过pip install -e . 的话diffsynth 里的东西修改后要重新安装 # import pdb; pdb.set_trace() ###下面是解析train.yaml里的内容 with open(args.train_yaml, "r", encoding="utf-8") as f: conf_info = yaml.safe_load(f) # 用 safe_load 更安全 print(conf_info) args.dataset_base_path = conf_info["dataset_args"]["base_path"] args.max_checkpoints_to_keep = conf_info["train_args"]["max_checkpoints_to_keep"] args.resume_from_checkpoint = conf_info["train_args"]["resume_from_checkpoint"] args.visual_log_project_name = conf_info["train_args"]["visual_log_project_name"] args.seed = conf_info["train_args"]["seed"] args.output_path = conf_info["train_args"]["output_path"] args.save_steps = conf_info["train_args"]["save_steps"] args.save_epoches = conf_info["train_args"]["save_epoches"] args.batch_size = conf_info["train_args"]["batch_size"] args.local_model_path = conf_info["train_args"]["local_model_path"] args.height = conf_info["dataset_args"]["height"] args.width = conf_info["dataset_args"]["width"] args.num_frames = conf_info["dataset_args"]["num_frames"] args.ref_num = conf_info["dataset_args"]["ref_num"] args.infer_step = conf_info["infer_args"]["infer_step"] args.epoch_id = conf_info["infer_args"]["epoch_id"] args.split_rope = conf_info["train_args"]["split_rope"] args.split1 = conf_info["train_args"]["split1"] args.split2 = conf_info["train_args"]["split2"] args.split3 = conf_info["train_args"]["split3"] test_video(args) # specify_video(args)