import argparse import json import random import time from pathlib import Path import imageio.v2 as imageio import numpy as np import yaml from multi_view.datasets.videodataset import MulltiShot_MultiView_Dataset def save_video(frames, path: Path, fps: int = 16) -> None: if not frames: raise ValueError("No frames to save.") writer = imageio.get_writer(str(path), fps=fps) try: for frame in frames: if hasattr(frame, "convert"): frame = np.asarray(frame.convert("RGB")) frame = np.asarray(frame) if frame.dtype != np.uint8: frame = np.clip(frame, 0, 255).astype(np.uint8) writer.append_data(frame) finally: writer.close() def ensure_dir(path: Path) -> None: path.mkdir(parents=True, exist_ok=True) def reseed(base_seed: int, idx: int) -> None: seed = base_seed + idx random.seed(seed) np.random.seed(seed) def main() -> int: parser = argparse.ArgumentParser(description="Inspect dataset samples and dump training inputs.") parser.add_argument("--train_yaml", type=str, required=True) parser.add_argument("--dataset_json", type=str, default="") parser.add_argument("--output_dir", type=str, default="") parser.add_argument("--indices", type=int, nargs="+", default=[]) parser.add_argument("--num_samples", type=int, default=4) parser.add_argument("--seed", type=int, default=1234) parser.add_argument("--split", choices=["train", "test", "all"], default="train") args = parser.parse_args() with open(args.train_yaml, "r", encoding="utf-8") as f: conf = yaml.safe_load(f) dataset_args = conf.get("dataset_args", {}) dataset_json = args.dataset_json or dataset_args.get("base_path", "") if not dataset_json: raise ValueError("dataset_json is required (or set dataset_args.base_path in YAML).") height = int(dataset_args.get("height", 480)) width = int(dataset_args.get("width", 832)) ref_num = int(dataset_args.get("ref_num", 3)) dataset = MulltiShot_MultiView_Dataset( dataset_base_path=dataset_json, resolution=(height, width), ref_num=ref_num, training=args.split != "test", ) if args.split == "all": dataset.data_train = dataset.data dataset.data_test = dataset.data dataset.training = True ts = time.strftime("%Y%m%d_%H%M%S") output_dir = Path(args.output_dir) if args.output_dir else Path(__file__).resolve().parent / "logs" / "dataset_check" / ts ensure_dir(output_dir) if args.indices: indices = args.indices else: sample_count = min(args.num_samples, len(dataset)) indices = list(range(sample_count)) manifest = { "train_yaml": args.train_yaml, "dataset_json": dataset_json, "split": args.split, "height": height, "width": width, "ref_num": ref_num, "indices": indices, "samples": [], } for idx in indices: reseed(args.seed, idx) sample = dataset[idx] video = sample.get("video", []) ref_images = sample.get("ref_images", []) shot_captions = sample.get("pre_shot_caption", []) sample_dir = output_dir / f"sample_{idx}" ensure_dir(sample_dir) video_path = sample_dir / "input.mp4" if video: save_video(video, video_path, fps=16) refs_dir = sample_dir / "refs" ensure_dir(refs_dir) for id_i, ref_group in enumerate(ref_images): for img_i, img in enumerate(ref_group): img.save(refs_dir / f"id{id_i}_img{img_i}.png") summary = { "index": idx, "video_path": sample.get("video_path"), "num_frames": len(video), "shot_num": sample.get("shot_num"), "pre_shot_caption": shot_captions, "ref_num": sample.get("ref_num"), "ID_num": sample.get("ID_num"), "saved_video": str(video_path), "saved_refs_dir": str(refs_dir), } with (sample_dir / "summary.json").open("w", encoding="utf-8") as f: json.dump(summary, f, ensure_ascii=False, indent=2) manifest["samples"].append(summary) with (output_dir / "manifest.json").open("w", encoding="utf-8") as f: json.dump(manifest, f, ensure_ascii=False, indent=2) print(f"Saved dataset check logs to: {output_dir}") return 0 if __name__ == "__main__": raise SystemExit(main())