Spaces:
Running on Zero
Running on Zero
| from .operators import * | |
| import torch, json | |
| def save_video_tensor_as_mp4(video_frames, out_path, fps=8): | |
| # (C,T,H,W) -> (T,H,W,C) | |
| video_np = [] | |
| for frame in video_frames: | |
| frame_np = np.array(frame) | |
| video_np.append(frame_np) | |
| video = np.stack(video_np, axis=0) | |
| imageio.mimwrite( | |
| out_path, | |
| video, | |
| fps=fps, | |
| codec="libx264", | |
| quality=8, | |
| ) | |
| class UnifiedDataset(torch.utils.data.Dataset): | |
| def __init__( | |
| self, | |
| base_path=None, | |
| repeat=1, | |
| data_file_keys=tuple(), | |
| main_data_operator=lambda x: x, | |
| ): | |
| self.base_path = base_path | |
| self.repeat = repeat | |
| self.data_file_keys = data_file_keys | |
| self.main_data_operator = main_data_operator | |
| self.data = [] | |
| self.load_metadata() | |
| def default_video_operator( | |
| base_path="", | |
| max_pixels=1920*1080, height=None, width=None, | |
| height_division_factor=16, width_division_factor=16, | |
| num_frames=81, time_division_factor=4, time_division_remainder=1, | |
| ): | |
| return RouteByType(operator_map=[ | |
| (str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[ | |
| (("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()), | |
| (("gif",), LoadGIF( | |
| num_frames, time_division_factor, time_division_remainder, | |
| frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor), | |
| )), | |
| (("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo( | |
| num_frames, time_division_factor, time_division_remainder, | |
| frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor), | |
| )), | |
| ])), | |
| ]) | |
| def load_metadata(self): | |
| src_dir = os.path.join(self.base_path, "point_video") | |
| tgt_dir = os.path.join(self.base_path, "videos/train") | |
| video_exts = (".mp4", ".avi", ".mov", ".mkv", ".webm") | |
| for fname in os.listdir(src_dir): | |
| if not fname.lower().endswith(video_exts): | |
| continue | |
| src_path = os.path.join(src_dir, fname) | |
| tgt_path = os.path.join(tgt_dir, fname) | |
| if not os.path.exists(tgt_path) or os.path.getsize(tgt_path) == 0: | |
| print(f"跳过无效文件:{tgt_path}") | |
| continue | |
| if not os.path.exists(src_path) or os.path.getsize(src_path) == 0: | |
| print(f"跳过无效文件:{src_path}") | |
| continue | |
| self.data.append({ | |
| "src_video": src_path, | |
| "tgt_video": tgt_path, | |
| "prompt": "Ensure the consistency of the video" | |
| }) | |
| print(f"Found {len(self.data)} video pairs") | |
| def __getitem__(self, data_id): | |
| try: | |
| data = self.data[data_id % len(self.data)].copy() | |
| for key in self.data_file_keys: | |
| if key in data: | |
| data[key] = self.main_data_operator(data[key]) | |
| return data | |
| except Exception: | |
| return self.__getitem__(data_id + 1) | |
| def __len__(self): | |
| return len(self.data) * self.repeat | |