UniGeo / DiffSynth-Studio /diffsynth /core /data /unified_dataset.py
WensongSong
init UniGeo demo
abd08dc
from .operators import *
import torch, json
def save_video_tensor_as_mp4(video_frames, out_path, fps=8):
# (C,T,H,W) -> (T,H,W,C)
video_np = []
for frame in video_frames:
frame_np = np.array(frame)
video_np.append(frame_np)
video = np.stack(video_np, axis=0)
imageio.mimwrite(
out_path,
video,
fps=fps,
codec="libx264",
quality=8,
)
class UnifiedDataset(torch.utils.data.Dataset):
def __init__(
self,
base_path=None,
repeat=1,
data_file_keys=tuple(),
main_data_operator=lambda x: x,
):
self.base_path = base_path
self.repeat = repeat
self.data_file_keys = data_file_keys
self.main_data_operator = main_data_operator
self.data = []
self.load_metadata()
@staticmethod
def default_video_operator(
base_path="",
max_pixels=1920*1080, height=None, width=None,
height_division_factor=16, width_division_factor=16,
num_frames=81, time_division_factor=4, time_division_remainder=1,
):
return RouteByType(operator_map=[
(str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[
(("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()),
(("gif",), LoadGIF(
num_frames, time_division_factor, time_division_remainder,
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
)),
(("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo(
num_frames, time_division_factor, time_division_remainder,
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
)),
])),
])
def load_metadata(self):
src_dir = os.path.join(self.base_path, "point_video")
tgt_dir = os.path.join(self.base_path, "videos/train")
video_exts = (".mp4", ".avi", ".mov", ".mkv", ".webm")
for fname in os.listdir(src_dir):
if not fname.lower().endswith(video_exts):
continue
src_path = os.path.join(src_dir, fname)
tgt_path = os.path.join(tgt_dir, fname)
if not os.path.exists(tgt_path) or os.path.getsize(tgt_path) == 0:
print(f"跳过无效文件:{tgt_path}")
continue
if not os.path.exists(src_path) or os.path.getsize(src_path) == 0:
print(f"跳过无效文件:{src_path}")
continue
self.data.append({
"src_video": src_path,
"tgt_video": tgt_path,
"prompt": "Ensure the consistency of the video"
})
print(f"Found {len(self.data)} video pairs")
def __getitem__(self, data_id):
try:
data = self.data[data_id % len(self.data)].copy()
for key in self.data_file_keys:
if key in data:
data[key] = self.main_data_operator(data[key])
return data
except Exception:
return self.__getitem__(data_id + 1)
def __len__(self):
return len(self.data) * self.repeat