Spaces:
Runtime error
Runtime error
Rawal Khirodkar
Initial sapiens2-pointmap Space (HF download at startup, all 4 sizes, 3D viewer)
bff20b3 | # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import contextlib | |
| import copy | |
| import os | |
| import sys | |
| from typing import List | |
| import cv2 | |
| import numpy as np | |
| from sapiens.engine.datasets import BaseDataset | |
| from sapiens.registry import DATASETS | |
| def suppress_stderr(): | |
| devnull_fd = os.open(os.devnull, os.O_WRONLY) | |
| stderr_fd = sys.stderr.fileno() | |
| sys.stderr.flush() | |
| saved_stderr_fd = os.dup(stderr_fd) | |
| os.dup2(devnull_fd, stderr_fd) | |
| os.close(devnull_fd) | |
| try: | |
| yield | |
| finally: | |
| os.dup2(saved_stderr_fd, stderr_fd) | |
| os.close(saved_stderr_fd) | |
| ##----------------------------------------------------------------------- | |
| class PointmapBaseDataset(BaseDataset): | |
| def __init__(self, num_samples=None, **kwargs) -> None: | |
| self.num_samples = num_samples | |
| super().__init__(**kwargs) | |
| return | |
| def load_data_list(self) -> List[dict]: | |
| data_list = [] | |
| self.rgb_dir = os.path.join(self.data_root, "rgb") | |
| self.mask_dir = os.path.join(self.data_root, "mask") | |
| self.depth_dir = os.path.join(self.data_root, "depth") | |
| self.K_dir = os.path.join(self.data_root, "camera_intrinsics") | |
| self.M_dir = os.path.join( | |
| self.data_root, "camera_extrinsics" | |
| ) ## cv camera extrinsics | |
| print(f"\033[92mLoading {self.__class__.__name__}!\033[0m") | |
| # Create a set of common file names from all three directories | |
| rgb_files = {x for x in os.listdir(self.rgb_dir) if x.endswith(".png")} | |
| mask_files = {x for x in os.listdir(self.mask_dir) if x.endswith(".png")} | |
| depth_files = { | |
| x.replace(".npy", ".png") | |
| for x in os.listdir(self.depth_dir) | |
| if x.endswith(".npy") | |
| } | |
| K_files = { | |
| x.replace(".txt", ".png") | |
| for x in os.listdir(self.K_dir) | |
| if x.endswith(".txt") | |
| } | |
| M_files = { | |
| x.replace(".txt", ".png") | |
| for x in os.listdir(self.M_dir) | |
| if x.endswith(".txt") | |
| } | |
| # Find the intersection of file names between images, masks, and normals | |
| common_names = rgb_files & mask_files & depth_files & K_files & M_files | |
| # Create data list using the common file names | |
| data_list = [ | |
| { | |
| "rgb_path": os.path.join(self.rgb_dir, name), | |
| "mask_path": os.path.join(self.mask_dir, name), | |
| "depth_path": os.path.join( | |
| self.depth_dir, name.replace(".png", ".npy") | |
| ), | |
| "K_path": os.path.join(self.K_dir, name.replace(".png", ".txt")), | |
| "M_path": os.path.join(self.M_dir, name.replace(".png", ".txt")), | |
| } | |
| for name in sorted(common_names) | |
| ] | |
| if self.num_samples is not None: | |
| data_list = data_list[: self.num_samples] | |
| print( | |
| "\033[92mDone! {}. Loaded total samples: {}. Test mode: {}\033[0m".format( | |
| self.__class__.__name__, len(data_list), self.test_mode | |
| ) | |
| ) | |
| return data_list | |
| def get_data_info(self, idx): | |
| data_info = copy.deepcopy(self.data_list[idx]) | |
| try: | |
| with suppress_stderr(): | |
| image = cv2.imread(data_info["rgb_path"]) ## bgr image is default | |
| mask = cv2.imread(data_info["mask_path"]) | |
| depth = np.load(data_info["depth_path"]) ## H x W, ## is not in 0 to 1 | |
| K = np.loadtxt(data_info["K_path"]) ## intrinsics, 3 x 3 | |
| M = np.loadtxt(data_info["M_path"]) ## extrinsics, 4 x 4 | |
| except Exception as e: | |
| return None | |
| mask = mask[:, :, 0] ## | |
| if image is None or mask is None or depth is None: | |
| return None | |
| ## remove any nan depth from valid pixels | |
| nan_depth = np.isnan(depth) | |
| if np.any(nan_depth): | |
| mask[nan_depth] = 0 | |
| if mask.sum() < 10: | |
| return None | |
| rows = np.any(mask, axis=1) | |
| cols = np.any(mask, axis=0) | |
| # Find the bounding box's bounds | |
| y1, y2 = np.where(rows)[0][[0, -1]] | |
| x1, x2 = np.where(cols)[0][[0, -1]] | |
| bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4) | |
| data_info = { | |
| "img": image, | |
| "id": idx, | |
| "orig_img_height": image.shape[0], | |
| "orig_img_width": image.shape[1], | |
| "img_id": os.path.basename(data_info["rgb_path"]), | |
| "img_path": data_info["rgb_path"], | |
| "gt_depth": depth, | |
| "K": K, | |
| "M": M, | |
| "mask": mask, | |
| "bbox": bbox, | |
| "bbox_score": np.ones(1, dtype=np.float32), | |
| } | |
| return data_info | |