| import torch, os, json, sys |
| _THIS_DIR = os.path.dirname(__file__) |
| if _THIS_DIR not in sys.path: |
| sys.path.insert(0, _THIS_DIR) |
| _DIFFSYNTH_ROOT = os.path.join(_THIS_DIR, "DiffSynth-Studio-main") |
| if _DIFFSYNTH_ROOT not in sys.path: |
| sys.path.insert(0, _DIFFSYNTH_ROOT) |
| from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig |
| from diffsynth.trainers.utils import DiffusionTrainingModule, launch_training_task, wan_parser |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| import yaml |
| import torch |
| import imageio, os, torch, warnings, torchvision, argparse, json |
| from peft import LoraConfig, inject_adapter_in_model |
| from PIL import Image |
| import pandas as pd |
| from tqdm import tqdm |
| from accelerate import Accelerator |
| from accelerate.utils import DistributedDataParallelKwargs |
| import matplotlib.pyplot as plt |
| import os |
| import re |
|
|
| from multi_view.datasets.videodataset import MulltiShot_MultiView_Dataset |
| |
|
|
|
|
| class WanTrainingModule(DiffusionTrainingModule): |
| def __init__( |
| self, |
| model_paths=None, model_id_with_origin_paths=None, |
| trainable_models=None, |
| lora_base_model=None, lora_target_modules="q,k,v,o,ffn.0,ffn.2", lora_rank=32, |
| use_gradient_checkpointing=True, |
| use_gradient_checkpointing_offload=False, |
| extra_inputs=None, |
| max_timestep_boundary=1.0, |
| min_timestep_boundary=0.0, |
| local_model_path=None, |
| ): |
| super().__init__() |
| |
| model_configs = [] |
| if model_paths is not None: |
| model_paths = json.loads(model_paths) |
| model_configs += [ModelConfig(path=path) for path in model_paths] |
| if model_id_with_origin_paths is not None: |
| model_id_with_origin_paths = model_id_with_origin_paths.split(",") |
| model_configs += [ModelConfig(local_model_path = local_model_path, model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths] |
| self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, redirect_common_files=False) |
| |
| |
| self.pipe.scheduler.set_timesteps(1000, training=True) |
| |
| |
| self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(",")) |
| |
| |
| if lora_base_model is not None: |
| model = self.add_lora_to_model( |
| getattr(self.pipe, lora_base_model), |
| target_modules=lora_target_modules.split(","), |
| lora_rank=lora_rank |
| ) |
| setattr(self.pipe, lora_base_model, model) |
| |
| |
| self.use_gradient_checkpointing = use_gradient_checkpointing |
| self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload |
|
|
| |
| self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] |
| self.max_timestep_boundary = max_timestep_boundary |
| self.min_timestep_boundary = min_timestep_boundary |
| |
| |
| def forward_preprocess(self, data): |
| |
| inputs_posi = {"prompt": [d["pre_shot_caption"] for d in data], "global_caption": None} |
| inputs_nega = {} |
| |
| |
| inputs_shared = { |
| |
| |
| "input_video": [d["video"] for d in data], |
| "height": data[0]["video"][0].size[1], |
| "width": data[0]["video"][0].size[0], |
| "num_frames": len(data[0]["video"]), |
| "ref_images": [d["ref_images"] for d in data], |
| |
| |
| "cfg_scale": 1, |
| "tiled": False, |
| "rand_device": self.pipe.device, |
| "use_gradient_checkpointing": self.use_gradient_checkpointing, |
| "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, |
| "cfg_merge": False, |
| "vace_scale": 1, |
| "max_timestep_boundary": self.max_timestep_boundary, |
| "min_timestep_boundary": self.min_timestep_boundary, |
| "num_ref_images": data[0]["ref_num"], |
| "batch_size": len(data), |
| } |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| for unit in self.pipe.units: |
| inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) |
| return {**inputs_shared, **inputs_posi} |
| |
| |
| def forward(self, data, args, inputs=None): |
| if inputs is None: inputs = self.forward_preprocess(data) |
| models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models} |
| loss = self.pipe.training_loss(args = args, **models, **inputs) |
| return loss |
|
|
|
|
| if __name__ == "__main__": |
| parser = wan_parser() |
| args = parser.parse_args() |
|
|
| args, unknown = parser.parse_known_args() |
| print("❗ Unknown arguments:", unknown) |
| |
| |
| |
| with open(args.train_yaml, "r", encoding="utf-8") as f: |
| conf_info = yaml.safe_load(f) |
| print(conf_info) |
| args.dataset_base_path = conf_info["dataset_args"]["base_path"] |
| args.max_checkpoints_to_keep = conf_info["train_args"]["max_checkpoints_to_keep"] |
| args.resume_from_checkpoint = conf_info["train_args"]["resume_from_checkpoint"] |
| args.visual_log_project_name = conf_info["train_args"]["visual_log_project_name"] |
| args.seed = conf_info["train_args"]["seed"] |
| args.output_path = conf_info["train_args"]["output_path"] |
| args.save_steps = conf_info["train_args"]["save_steps"] |
| args.save_epoches = conf_info["train_args"]["save_epoches"] |
| print("outpath:", args.output_path) |
| print("visual_log_project_name:", args.visual_log_project_name) |
| |
| if args.output_path is None or args.visual_log_project_name is None: |
| raise ValueError(f"output_path或visual_log_project_name为None: output_path={args.output_path}, visual_log_project_name={args.visual_log_project_name}") |
| args.output_path = args.output_path + "/" + args.visual_log_project_name |
| args.batch_size = conf_info["train_args"]["batch_size"] |
| args.local_model_path = conf_info["train_args"]["local_model_path"] |
| if "model_id_with_origin_paths" in conf_info["train_args"]: |
| args.model_id_with_origin_paths = conf_info["train_args"]["model_id_with_origin_paths"] |
| if "trainable_models" in conf_info["train_args"]: |
| args.trainable_models = conf_info["train_args"]["trainable_models"] |
| if "learning_rate" in conf_info["train_args"]: |
| args.learning_rate = float(conf_info["train_args"]["learning_rate"]) |
| args.debug_infer = bool(conf_info["train_args"].get("debug_infer", False)) |
| args.debug_infer_interval = int(conf_info["train_args"].get("debug_infer_interval", 1)) |
| args.debug_infer_steps = int(conf_info["train_args"].get("debug_infer_steps", 8)) |
| args.debug_infer_cfg_scale = float(conf_info["train_args"].get("debug_infer_cfg_scale", 5.0)) |
| args.debug_infer_cfg_scale_face = float(conf_info["train_args"].get("debug_infer_cfg_scale_face", 5.0)) |
| args.debug_infer_seed = int(conf_info["train_args"].get("debug_infer_seed", args.seed)) |
| args.debug_infer_tiled = bool(conf_info["train_args"].get("debug_infer_tiled", True)) |
| args.debug_infer_use_input_video = bool(conf_info["train_args"].get("debug_infer_use_input_video", True)) |
| args.debug_infer_negative_prompt = conf_info["train_args"].get("debug_infer_negative_prompt", "") |
| args.debug_infer_indices = conf_info["train_args"].get("debug_infer_indices", [0]) |
| args.zero_face_ratio = conf_info["train_args"]["zero_face_ratio"] |
| args.split_rope = conf_info["train_args"]["split_rope"] |
| args.split1 = conf_info["train_args"]["split1"] |
| args.split2 = conf_info["train_args"]["split2"] |
| args.split3 = conf_info["train_args"]["split3"] |
| if args.batch_size != 1: |
| args.learning_rate = min(args.learning_rate * ((args.batch_size * 1 / 2) * 1.5), args.learning_rate * 10) |
| args.height = conf_info["dataset_args"]["height"] |
| args.width = conf_info["dataset_args"]["width"] |
| args.num_frames = conf_info["dataset_args"]["num_frames"] |
| args.ref_num = conf_info["dataset_args"]["ref_num"] |
| |
|
|
| dataset = MulltiShot_MultiView_Dataset( |
| dataset_base_path=args.dataset_base_path, |
| resolution=(args.height, args.width), |
| ref_num=args.ref_num, |
| training=True |
| ) |
| model = WanTrainingModule( |
| model_paths=args.model_paths, |
| model_id_with_origin_paths=args.model_id_with_origin_paths, |
| trainable_models=args.trainable_models, |
| lora_base_model=args.lora_base_model, |
| lora_target_modules=args.lora_target_modules, |
| lora_rank=args.lora_rank, |
| use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, |
| extra_inputs=args.extra_inputs, |
| max_timestep_boundary=args.max_timestep_boundary, |
| min_timestep_boundary=args.min_timestep_boundary, |
| local_model_path = args.local_model_path |
| ) |
| optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate) |
| scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0) |
| launch_training_task( |
| args, |
| dataset, model, optimizer, scheduler, |
| num_epochs =args.num_epochs, |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| output_path = args.output_path, |
| save_steps = args.save_steps, |
| save_epoches = args.save_epoches, |
| max_checkpoints_to_keep = args.max_checkpoints_to_keep, |
| resume_from_checkpoint = args.resume_from_checkpoint, |
| seed = args.seed, |
| visual_log_project_name = args.visual_log_project_name, |
| ) |
|
|