import argparse import json import os import re import sys import time from pathlib import Path from types import SimpleNamespace import numpy as np import torch import yaml _THIS_DIR = Path(__file__).resolve().parent _DIFFSYNTH_ROOT = _THIS_DIR / "multi_view" / "DiffSynth-Studio-main" if _DIFFSYNTH_ROOT.exists(): sys.path.insert(0, str(_DIFFSYNTH_ROOT)) from diffsynth.pipelines.wan_video_new import WanVideoPipeline from diffsynth.utils import ModelConfig from multi_view.datasets.videodataset import MulltiShot_MultiView_Dataset def save_video(frames, path: str, fps: int = 16, **_kwargs) -> None: import imageio.v2 as imageio if not frames: raise ValueError("No frames to save.") writer = imageio.get_writer(path, fps=fps) try: for frame in frames: if isinstance(frame, torch.Tensor): frame = frame.detach().cpu().numpy() if hasattr(frame, "convert"): frame = np.asarray(frame.convert("RGB")) frame = np.asarray(frame) if frame.dtype != np.uint8: frame = np.clip(frame, 0, 255).astype(np.uint8) writer.append_data(frame) finally: writer.close() def load_config(train_yaml: str) -> dict: with open(train_yaml, "r", encoding="utf-8") as f: return yaml.safe_load(f) def build_runtime_args(conf_info: dict) -> SimpleNamespace: train_args = conf_info.get("train_args", {}) return SimpleNamespace( zero_face_ratio=float(train_args.get("zero_face_ratio", 0.0)), shot_rope=bool(train_args.get("shot_rope", False)), split_rope=bool(train_args.get("split_rope", False)), split1=bool(train_args.get("split1", False)), split2=bool(train_args.get("split2", False)), split3=bool(train_args.get("split3", False)), ) def resolve_checkpoint(checkpoint_path: str, output_root: Path) -> Path: if checkpoint_path: ckpt = Path(checkpoint_path) if ckpt.is_dir(): ckpt = ckpt / "weights.safetensors" if not ckpt.exists(): raise FileNotFoundError(f"Checkpoint not found: {ckpt}") return ckpt pattern = re.compile(r"checkpoint-step-(\d+)-epoch-(\d+)") latest_step = -1 latest_ckpt = None for path in output_root.glob("checkpoint-step-*"): match = pattern.search(path.name) if not match: continue step = int(match.group(1)) if step > latest_step: candidate = path / "weights.safetensors" if candidate.exists(): latest_step = step latest_ckpt = candidate if latest_ckpt is None: raise FileNotFoundError(f"No checkpoint found under {output_root}") return latest_ckpt def maybe_convert_checkpoint(checkpoint: Path, output_dir: Path) -> Path: if checkpoint.name != "model.safetensors": return checkpoint try: from safetensors.torch import load_file, save_file except ImportError as exc: raise ImportError("safetensors is required to convert final_model checkpoints.") from exc state_dict = load_file(str(checkpoint), device="cpu") if not state_dict: return checkpoint prefix = "pipe.dit." if not any(key.startswith(prefix) for key in state_dict.keys()): return checkpoint stripped = {key[len(prefix):]: value for key, value in state_dict.items() if key.startswith(prefix)} if not stripped: return checkpoint converted_path = output_dir / "converted_weights.safetensors" save_file(stripped, str(converted_path)) return converted_path def load_dataset_meta(dataset_json: str) -> dict: with open(dataset_json, "r", encoding="utf-8") as f: meta = json.load(f) return {value.get("disk_path"): value for value in meta.values() if isinstance(value, dict)} def ensure_dir(path: Path) -> None: path.mkdir(parents=True, exist_ok=True) def log_device_stats(log, label: str) -> None: if not torch.cuda.is_available(): return allocated = torch.cuda.memory_allocated() / (1024 ** 3) reserved = torch.cuda.memory_reserved() / (1024 ** 3) log(f"[GPU] {label} allocated={allocated:.2f}GB reserved={reserved:.2f}GB") def main() -> int: parser = argparse.ArgumentParser(description="Overfit inference debug script.") parser.add_argument("--train_yaml", type=str, required=True) parser.add_argument("--dataset_json", type=str, default="") parser.add_argument("--checkpoint_path", type=str, default="") parser.add_argument("--indices", type=int, nargs="+", default=[0]) parser.add_argument("--output_dir", type=str, default="") parser.add_argument("--num_inference_steps", type=int, default=50) parser.add_argument("--cfg_scale", type=float, default=5.0) parser.add_argument("--cfg_scale_face", type=float, default=5.0) parser.add_argument("--seed", type=int, default=42) parser.add_argument( "--split", choices=["train", "test", "all"], default="all", help="Which dataset split to use. Default is all for overfit checks.", ) parser.add_argument("--tiled", action="store_true") parser.add_argument("--use_input_video", action="store_true") parser.add_argument( "--no_input_video", action="store_true", help="Deprecated: input video is off by default.", ) parser.add_argument("--save_input_video", action="store_true") parser.add_argument("--negative_prompt", type=str, default="") args = parser.parse_args() conf_info = load_config(args.train_yaml) dataset_args = conf_info.get("dataset_args", {}) train_args = conf_info.get("train_args", {}) dataset_json = args.dataset_json or dataset_args.get("base_path", "") if not dataset_json: raise ValueError("dataset_json is required (or set dataset_args.base_path in YAML).") output_root = Path(train_args.get("output_path", "./ckpts")) / train_args.get("visual_log_project_name", "debug") output_dir = Path(args.output_dir) if args.output_dir else (output_root / "debug_infer") ensure_dir(output_dir) log_path = output_dir / "infer_debug.log" def log(message: str) -> None: print(message) with log_path.open("a", encoding="utf-8") as f: f.write(message + "\n") log(f"Train YAML: {args.train_yaml}") log(f"Dataset JSON: {dataset_json}") log(f"Output root: {output_root}") log(f"Output dir: {output_dir}") checkpoint = resolve_checkpoint(args.checkpoint_path, output_root) checkpoint = maybe_convert_checkpoint(checkpoint, output_dir) log(f"Checkpoint: {checkpoint}") runtime_args = build_runtime_args(conf_info) height = int(dataset_args.get("height", 480)) width = int(dataset_args.get("width", 832)) ref_num = int(dataset_args.get("ref_num", 3)) dataset = MulltiShot_MultiView_Dataset( dataset_base_path=dataset_json, resolution=(height, width), ref_num=ref_num, training=args.split != "test", ) if args.split == "all": dataset.data_train = dataset.data dataset.data_test = dataset.data dataset.training = True meta_map = load_dataset_meta(dataset_json) log(f"Dataset split: {args.split}") log(f"Dataset size: {len(dataset)}") local_model_path = train_args.get("local_model_path", "") model_id = "Wan2.2-TI2V-5B" model_configs = [ ModelConfig(path=os.path.join(local_model_path, model_id, "models_t5_umt5-xxl-enc-bf16.pth"), offload_device="cuda"), ModelConfig(path=os.path.join(local_model_path, model_id, "Wan2.2_VAE.pth"), offload_device="cuda"), ModelConfig(path=str(checkpoint), offload_device="cuda"), ] pipe = WanVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", model_configs=model_configs, redirect_common_files=False, ) pipe.enable_vram_management() use_input_video = bool(args.use_input_video) and not args.no_input_video for idx in args.indices: if idx < 0 or idx >= len(dataset): log(f"[Skip] index {idx} out of range.") continue log("=" * 80) log(f"[Sample {idx}]") sample = dataset[idx] video_path = sample.get("video_path") meta = meta_map.get(video_path, {}) text = meta.get("text", "").strip() shot_caption = [text] if text else sample.get("pre_shot_caption", ["xxx"]) log(f"video_path: {video_path}") log(f"text: {text if text else '(empty)'}") log(f"shot_caption: {shot_caption}") log(f"num_frames: {len(sample.get('video', []))}") log(f"ref_num: {sample.get('ref_num')}, ID_num: {sample.get('ID_num')}") ref_images = sample.get("ref_images", []) ref_dir = output_dir / f"ref_images_{idx}" ensure_dir(ref_dir) for id_index, image_group in enumerate(ref_images): for img_index, img in enumerate(image_group): img_path = ref_dir / f"id{id_index}_img{img_index}.png" img.save(img_path) log(f"saved ref images: {ref_dir}") input_video = sample.get("video", []) if input_video and (use_input_video or args.save_input_video): input_path = output_dir / f"input_{idx}.mp4" save_video(input_video, str(input_path), fps=16) log(f"saved input video: {input_path}") log_device_stats(log, "before_infer") start_time = time.time() video, _ = pipe( args=runtime_args, prompt=[shot_caption], negative_prompt=[args.negative_prompt], input_video=[input_video] if use_input_video else None, ref_images=[ref_images], seed=args.seed, tiled=args.tiled, height=height, width=width, num_frames=len(input_video), cfg_scale=args.cfg_scale, cfg_scale_face=args.cfg_scale_face, num_inference_steps=args.num_inference_steps, num_ref_images=sample.get("ref_num"), ) duration = time.time() - start_time log_device_stats(log, "after_infer") log(f"inference_time: {duration:.2f}s") output_video_path = output_dir / f"output_{idx}.mp4" save_video(video, str(output_video_path), fps=16, quality=8) log(f"saved output video: {output_video_path}") log("Done.") return 0 if __name__ == "__main__": raise SystemExit(main())