multishot / check_dataset.py
PencilHu's picture
Upload folder using huggingface_hub
85752bc verified
import argparse
import json
import random
import time
from pathlib import Path
import imageio.v2 as imageio
import numpy as np
import yaml
from multi_view.datasets.videodataset import MulltiShot_MultiView_Dataset
def save_video(frames, path: Path, fps: int = 16) -> None:
if not frames:
raise ValueError("No frames to save.")
writer = imageio.get_writer(str(path), fps=fps)
try:
for frame in frames:
if hasattr(frame, "convert"):
frame = np.asarray(frame.convert("RGB"))
frame = np.asarray(frame)
if frame.dtype != np.uint8:
frame = np.clip(frame, 0, 255).astype(np.uint8)
writer.append_data(frame)
finally:
writer.close()
def ensure_dir(path: Path) -> None:
path.mkdir(parents=True, exist_ok=True)
def reseed(base_seed: int, idx: int) -> None:
seed = base_seed + idx
random.seed(seed)
np.random.seed(seed)
def main() -> int:
parser = argparse.ArgumentParser(description="Inspect dataset samples and dump training inputs.")
parser.add_argument("--train_yaml", type=str, required=True)
parser.add_argument("--dataset_json", type=str, default="")
parser.add_argument("--output_dir", type=str, default="")
parser.add_argument("--indices", type=int, nargs="+", default=[])
parser.add_argument("--num_samples", type=int, default=4)
parser.add_argument("--seed", type=int, default=1234)
parser.add_argument("--split", choices=["train", "test", "all"], default="train")
args = parser.parse_args()
with open(args.train_yaml, "r", encoding="utf-8") as f:
conf = yaml.safe_load(f)
dataset_args = conf.get("dataset_args", {})
dataset_json = args.dataset_json or dataset_args.get("base_path", "")
if not dataset_json:
raise ValueError("dataset_json is required (or set dataset_args.base_path in YAML).")
height = int(dataset_args.get("height", 480))
width = int(dataset_args.get("width", 832))
ref_num = int(dataset_args.get("ref_num", 3))
dataset = MulltiShot_MultiView_Dataset(
dataset_base_path=dataset_json,
resolution=(height, width),
ref_num=ref_num,
training=args.split != "test",
)
if args.split == "all":
dataset.data_train = dataset.data
dataset.data_test = dataset.data
dataset.training = True
ts = time.strftime("%Y%m%d_%H%M%S")
output_dir = Path(args.output_dir) if args.output_dir else Path(__file__).resolve().parent / "logs" / "dataset_check" / ts
ensure_dir(output_dir)
if args.indices:
indices = args.indices
else:
sample_count = min(args.num_samples, len(dataset))
indices = list(range(sample_count))
manifest = {
"train_yaml": args.train_yaml,
"dataset_json": dataset_json,
"split": args.split,
"height": height,
"width": width,
"ref_num": ref_num,
"indices": indices,
"samples": [],
}
for idx in indices:
reseed(args.seed, idx)
sample = dataset[idx]
video = sample.get("video", [])
ref_images = sample.get("ref_images", [])
shot_captions = sample.get("pre_shot_caption", [])
sample_dir = output_dir / f"sample_{idx}"
ensure_dir(sample_dir)
video_path = sample_dir / "input.mp4"
if video:
save_video(video, video_path, fps=16)
refs_dir = sample_dir / "refs"
ensure_dir(refs_dir)
for id_i, ref_group in enumerate(ref_images):
for img_i, img in enumerate(ref_group):
img.save(refs_dir / f"id{id_i}_img{img_i}.png")
summary = {
"index": idx,
"video_path": sample.get("video_path"),
"num_frames": len(video),
"shot_num": sample.get("shot_num"),
"pre_shot_caption": shot_captions,
"ref_num": sample.get("ref_num"),
"ID_num": sample.get("ID_num"),
"saved_video": str(video_path),
"saved_refs_dir": str(refs_dir),
}
with (sample_dir / "summary.json").open("w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
manifest["samples"].append(summary)
with (output_dir / "manifest.json").open("w", encoding="utf-8") as f:
json.dump(manifest, f, ensure_ascii=False, indent=2)
print(f"Saved dataset check logs to: {output_dir}")
return 0
if __name__ == "__main__":
raise SystemExit(main())