Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import copy | |
| import io | |
| import json | |
| import os | |
| from contextlib import redirect_stderr | |
| from copy import deepcopy | |
| from typing import List | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from sapiens.registry import DATASETS | |
| from .pose_base_dataset import PoseBaseDataset | |
| with open(os.devnull, "w") as f, redirect_stderr(f): | |
| try: | |
| from care.data.io import typed | |
| except Exception: | |
| pass | |
| class Keypoints308GoliathDataset(PoseBaseDataset): | |
| METAINFO: dict = dict( | |
| from_file=os.path.join( | |
| os.path.dirname(os.path.dirname(os.path.dirname(__file__))), | |
| "configs", | |
| "_base_", | |
| "keypoints308.py", | |
| ) | |
| ) | |
| def __init__(self, subsample_factor: int = 1, **kwargs) -> None: | |
| self.subsample_factor = subsample_factor | |
| super().__init__(**kwargs) | |
| self.remove_teeth = self.metainfo["remove_teeth"] | |
| if self.remove_teeth: | |
| self.teeth_ids = self.metainfo["teeth_keypoint_ids"] | |
| return | |
| def __len__(self) -> int: | |
| return len(self.data_list) // self.subsample_factor | |
| def load_data_list(self) -> List[dict]: | |
| """Load data list from 344 body points.""" | |
| self._register_airstore_handler() | |
| with open(self.ann_file, "rb") as f: | |
| raw = f.read() | |
| raw_data = json.loads(raw) # samples=5,267,269 | |
| data_list = [] | |
| for i, sample in enumerate(raw_data): | |
| if "sample_id" not in sample: | |
| sample["sample_id"] = sample["airstore_id"] | |
| dp = { | |
| "airstore_id": sample["sample_id"], | |
| "img_id": i, | |
| } | |
| if sample.get("box-default") is not None: | |
| dp["box"] = sample["box-default"] | |
| data_list.append(dp) | |
| return data_list | |
| def _register_airstore_handler(self) -> None: | |
| from typedio.file_system.airstore_client import register_airstore_in_fsspec | |
| register_airstore_in_fsspec() | |
| self.path_template = ( | |
| "airstoreds://rlr_detection_services_ml_datasets_no_user_data" | |
| ) | |
| self.airstore = True | |
| def _read_from_airstore(self, asset: str, sid: str) -> io.BytesIO: | |
| with typed.open(self.path_template + f"/{asset}?sampleId={sid}").open() as f: | |
| data = io.BytesIO(f.read()) | |
| return data | |
| def get_data_info(self, idx): | |
| if self.subsample_factor > 1: | |
| idx = idx * self.subsample_factor + np.random.randint( | |
| 0, self.subsample_factor | |
| ) | |
| idx = idx % len(self.data_list) | |
| data_info = copy.deepcopy(self.data_list[idx]) | |
| try: | |
| img = Image.open( | |
| self._read_from_airstore("image", data_info["airstore_id"]) | |
| ) ## pillow image | |
| keypoints_np = np.load( | |
| self._read_from_airstore("keypoint", data_info["airstore_id"]) | |
| ) # shape 3 x 344 | |
| except Exception as e: | |
| print(f"Error loading data: {e}") | |
| return None | |
| img = np.array(img) ## RGB image | |
| img = img[ | |
| :, :, ::-1 | |
| ] # Convert RGB to BGR, the model preprocessor will convert this to rgb again | |
| img_w, img_h = img.shape[1], img.shape[0] | |
| # process keypoints | |
| keypoints = keypoints_np[:2].T.reshape(1, -1, 2) # shape 1 x 344 x 2 | |
| keypoints_visible = np.where(keypoints_np[2].T > 0, 1, 0).reshape( | |
| 1, -1 | |
| ) # shape 1 x 344 | |
| # Identify keypoints that are out of bounds for x (width) and y (height) | |
| out_of_bounds_w = np.logical_or( | |
| keypoints[0, :, 0] <= 0, keypoints[0, :, 0] >= img_w | |
| ) | |
| out_of_bounds_h = np.logical_or( | |
| keypoints[0, :, 1] <= 0, keypoints[0, :, 1] >= img_h | |
| ) | |
| # Update keypoints_visible based on the out-of-bounds keypoints | |
| keypoints_visible[0, out_of_bounds_w | out_of_bounds_h] = 0 | |
| keypoints[keypoints_visible == 0] = 0 | |
| ## remove teeth keypoints | |
| if self.remove_teeth: | |
| # Use numpy's boolean indexing to remove keypoints | |
| mask = np.ones(keypoints.shape[1], dtype=bool) | |
| mask[self.teeth_ids] = False | |
| keypoints = keypoints[:, mask, :] | |
| keypoints_visible = keypoints_visible[:, mask] | |
| # Default bounding box to the full image size | |
| bbox = np.array([0, 0, img_w, img_h], dtype=np.float32).reshape(1, 4) | |
| if np.any(keypoints_visible): # If any keypoints are visible | |
| visible_keypoints = keypoints[0][ | |
| keypoints_visible[0] == 1 | |
| ] # Filter out the invisible keypoints | |
| # Get the bounding box encompassing the keypoints | |
| x_min, y_min = np.clip( | |
| np.min(visible_keypoints, axis=0), [0, 0], [img_w, img_h] | |
| ) | |
| x_max, y_max = np.clip( | |
| np.max(visible_keypoints, axis=0), [0, 0], [img_w, img_h] | |
| ) | |
| bbox = np.array([x_min, y_min, x_max, y_max], dtype=np.float32).reshape( | |
| 1, 4 | |
| ) | |
| num_keypoints = np.count_nonzero(keypoints_visible) | |
| ## atleast 8 vis keypoints | |
| if num_keypoints < self.metainfo["min_visible_keypoints"]: | |
| random_idx = np.random.randint(0, len(self.data_list)) | |
| return self.get_data_info(random_idx) | |
| ## check body keypoints additionally | |
| num_body_keypoints = np.count_nonzero(keypoints_visible[0, :21]) | |
| if num_body_keypoints < 6: | |
| return None | |
| ## ignore greyscale images for training | |
| B, G, R = cv2.split(img) | |
| if np.array_equal(B, G) and np.array_equal(B, R): | |
| random_idx = np.random.randint(0, len(self.data_list)) | |
| return self.get_data_info(random_idx) | |
| data_info = { | |
| "img": img, | |
| "img_id": data_info["img_id"], | |
| "img_path": "", | |
| "airstore_id": data_info["airstore_id"], | |
| "bbox": bbox, | |
| "bbox_score": np.ones(1, dtype=np.float32), | |
| "num_keypoints": num_keypoints, | |
| "keypoints": keypoints, ## 1 x 308 x 2 | |
| "keypoints_visible": keypoints_visible, ## 1 x 308 | |
| "iscrowd": 0, | |
| "segmentation": None, | |
| "id": idx, | |
| "category_id": 1, | |
| } | |
| if idx >= 0: | |
| data_info["sample_idx"] = idx | |
| else: | |
| data_info["sample_idx"] = len(self) + idx | |
| # Add metainfo items that are required in the pipeline and the model | |
| metainfo_keys = [ | |
| "upper_body_ids", | |
| "lower_body_ids", | |
| "flip_pairs", | |
| "dataset_keypoint_weights", | |
| "flip_indices", | |
| "skeleton_links", | |
| ] | |
| for key in metainfo_keys: | |
| assert key not in data_info, ( | |
| f'"{key}" is a reserved key for `metainfo`, but already ' | |
| "exists in the `data_info`." | |
| ) | |
| data_info[key] = deepcopy(self.metainfo[key]) | |
| return data_info | |