| import argparse |
| import json |
| import random |
| import time |
| from pathlib import Path |
|
|
| import imageio.v2 as imageio |
| import numpy as np |
| import yaml |
|
|
| from multi_view.datasets.videodataset import MulltiShot_MultiView_Dataset |
|
|
|
|
| def save_video(frames, path: Path, fps: int = 16) -> None: |
| if not frames: |
| raise ValueError("No frames to save.") |
| writer = imageio.get_writer(str(path), fps=fps) |
| try: |
| for frame in frames: |
| if hasattr(frame, "convert"): |
| frame = np.asarray(frame.convert("RGB")) |
| frame = np.asarray(frame) |
| if frame.dtype != np.uint8: |
| frame = np.clip(frame, 0, 255).astype(np.uint8) |
| writer.append_data(frame) |
| finally: |
| writer.close() |
|
|
|
|
| def ensure_dir(path: Path) -> None: |
| path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
| def reseed(base_seed: int, idx: int) -> None: |
| seed = base_seed + idx |
| random.seed(seed) |
| np.random.seed(seed) |
|
|
|
|
| def main() -> int: |
| parser = argparse.ArgumentParser(description="Inspect dataset samples and dump training inputs.") |
| parser.add_argument("--train_yaml", type=str, required=True) |
| parser.add_argument("--dataset_json", type=str, default="") |
| parser.add_argument("--output_dir", type=str, default="") |
| parser.add_argument("--indices", type=int, nargs="+", default=[]) |
| parser.add_argument("--num_samples", type=int, default=4) |
| parser.add_argument("--seed", type=int, default=1234) |
| parser.add_argument("--split", choices=["train", "test", "all"], default="train") |
| args = parser.parse_args() |
|
|
| with open(args.train_yaml, "r", encoding="utf-8") as f: |
| conf = yaml.safe_load(f) |
|
|
| dataset_args = conf.get("dataset_args", {}) |
| dataset_json = args.dataset_json or dataset_args.get("base_path", "") |
| if not dataset_json: |
| raise ValueError("dataset_json is required (or set dataset_args.base_path in YAML).") |
|
|
| height = int(dataset_args.get("height", 480)) |
| width = int(dataset_args.get("width", 832)) |
| ref_num = int(dataset_args.get("ref_num", 3)) |
|
|
| dataset = MulltiShot_MultiView_Dataset( |
| dataset_base_path=dataset_json, |
| resolution=(height, width), |
| ref_num=ref_num, |
| training=args.split != "test", |
| ) |
| if args.split == "all": |
| dataset.data_train = dataset.data |
| dataset.data_test = dataset.data |
| dataset.training = True |
|
|
| ts = time.strftime("%Y%m%d_%H%M%S") |
| output_dir = Path(args.output_dir) if args.output_dir else Path(__file__).resolve().parent / "logs" / "dataset_check" / ts |
| ensure_dir(output_dir) |
|
|
| if args.indices: |
| indices = args.indices |
| else: |
| sample_count = min(args.num_samples, len(dataset)) |
| indices = list(range(sample_count)) |
|
|
| manifest = { |
| "train_yaml": args.train_yaml, |
| "dataset_json": dataset_json, |
| "split": args.split, |
| "height": height, |
| "width": width, |
| "ref_num": ref_num, |
| "indices": indices, |
| "samples": [], |
| } |
|
|
| for idx in indices: |
| reseed(args.seed, idx) |
| sample = dataset[idx] |
| video = sample.get("video", []) |
| ref_images = sample.get("ref_images", []) |
| shot_captions = sample.get("pre_shot_caption", []) |
|
|
| sample_dir = output_dir / f"sample_{idx}" |
| ensure_dir(sample_dir) |
|
|
| video_path = sample_dir / "input.mp4" |
| if video: |
| save_video(video, video_path, fps=16) |
|
|
| refs_dir = sample_dir / "refs" |
| ensure_dir(refs_dir) |
| for id_i, ref_group in enumerate(ref_images): |
| for img_i, img in enumerate(ref_group): |
| img.save(refs_dir / f"id{id_i}_img{img_i}.png") |
|
|
| summary = { |
| "index": idx, |
| "video_path": sample.get("video_path"), |
| "num_frames": len(video), |
| "shot_num": sample.get("shot_num"), |
| "pre_shot_caption": shot_captions, |
| "ref_num": sample.get("ref_num"), |
| "ID_num": sample.get("ID_num"), |
| "saved_video": str(video_path), |
| "saved_refs_dir": str(refs_dir), |
| } |
| with (sample_dir / "summary.json").open("w", encoding="utf-8") as f: |
| json.dump(summary, f, ensure_ascii=False, indent=2) |
| manifest["samples"].append(summary) |
|
|
| with (output_dir / "manifest.json").open("w", encoding="utf-8") as f: |
| json.dump(manifest, f, ensure_ascii=False, indent=2) |
|
|
| print(f"Saved dataset check logs to: {output_dir}") |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|