# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import copy import io import json import os from contextlib import redirect_stderr from copy import deepcopy from typing import List import cv2 import numpy as np from PIL import Image from sapiens.registry import DATASETS from .pose_base_dataset import PoseBaseDataset with open(os.devnull, "w") as f, redirect_stderr(f): try: from care.data.io import typed except Exception: pass @DATASETS.register_module() class Keypoints308GoliathDataset(PoseBaseDataset): METAINFO: dict = dict( from_file=os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "configs", "_base_", "keypoints308.py", ) ) def __init__(self, subsample_factor: int = 1, **kwargs) -> None: self.subsample_factor = subsample_factor super().__init__(**kwargs) self.remove_teeth = self.metainfo["remove_teeth"] if self.remove_teeth: self.teeth_ids = self.metainfo["teeth_keypoint_ids"] return def __len__(self) -> int: return len(self.data_list) // self.subsample_factor def load_data_list(self) -> List[dict]: """Load data list from 344 body points.""" self._register_airstore_handler() with open(self.ann_file, "rb") as f: raw = f.read() raw_data = json.loads(raw) # samples=5,267,269 data_list = [] for i, sample in enumerate(raw_data): if "sample_id" not in sample: sample["sample_id"] = sample["airstore_id"] dp = { "airstore_id": sample["sample_id"], "img_id": i, } if sample.get("box-default") is not None: dp["box"] = sample["box-default"] data_list.append(dp) return data_list def _register_airstore_handler(self) -> None: from typedio.file_system.airstore_client import register_airstore_in_fsspec register_airstore_in_fsspec() self.path_template = ( "airstoreds://rlr_detection_services_ml_datasets_no_user_data" ) self.airstore = True def _read_from_airstore(self, asset: str, sid: str) -> io.BytesIO: with typed.open(self.path_template + f"/{asset}?sampleId={sid}").open() as f: data = io.BytesIO(f.read()) return data def get_data_info(self, idx): if self.subsample_factor > 1: idx = idx * self.subsample_factor + np.random.randint( 0, self.subsample_factor ) idx = idx % len(self.data_list) data_info = copy.deepcopy(self.data_list[idx]) try: img = Image.open( self._read_from_airstore("image", data_info["airstore_id"]) ) ## pillow image keypoints_np = np.load( self._read_from_airstore("keypoint", data_info["airstore_id"]) ) # shape 3 x 344 except Exception as e: print(f"Error loading data: {e}") return None img = np.array(img) ## RGB image img = img[ :, :, ::-1 ] # Convert RGB to BGR, the model preprocessor will convert this to rgb again img_w, img_h = img.shape[1], img.shape[0] # process keypoints keypoints = keypoints_np[:2].T.reshape(1, -1, 2) # shape 1 x 344 x 2 keypoints_visible = np.where(keypoints_np[2].T > 0, 1, 0).reshape( 1, -1 ) # shape 1 x 344 # Identify keypoints that are out of bounds for x (width) and y (height) out_of_bounds_w = np.logical_or( keypoints[0, :, 0] <= 0, keypoints[0, :, 0] >= img_w ) out_of_bounds_h = np.logical_or( keypoints[0, :, 1] <= 0, keypoints[0, :, 1] >= img_h ) # Update keypoints_visible based on the out-of-bounds keypoints keypoints_visible[0, out_of_bounds_w | out_of_bounds_h] = 0 keypoints[keypoints_visible == 0] = 0 ## remove teeth keypoints if self.remove_teeth: # Use numpy's boolean indexing to remove keypoints mask = np.ones(keypoints.shape[1], dtype=bool) mask[self.teeth_ids] = False keypoints = keypoints[:, mask, :] keypoints_visible = keypoints_visible[:, mask] # Default bounding box to the full image size bbox = np.array([0, 0, img_w, img_h], dtype=np.float32).reshape(1, 4) if np.any(keypoints_visible): # If any keypoints are visible visible_keypoints = keypoints[0][ keypoints_visible[0] == 1 ] # Filter out the invisible keypoints # Get the bounding box encompassing the keypoints x_min, y_min = np.clip( np.min(visible_keypoints, axis=0), [0, 0], [img_w, img_h] ) x_max, y_max = np.clip( np.max(visible_keypoints, axis=0), [0, 0], [img_w, img_h] ) bbox = np.array([x_min, y_min, x_max, y_max], dtype=np.float32).reshape( 1, 4 ) num_keypoints = np.count_nonzero(keypoints_visible) ## atleast 8 vis keypoints if num_keypoints < self.metainfo["min_visible_keypoints"]: random_idx = np.random.randint(0, len(self.data_list)) return self.get_data_info(random_idx) ## check body keypoints additionally num_body_keypoints = np.count_nonzero(keypoints_visible[0, :21]) if num_body_keypoints < 6: return None ## ignore greyscale images for training B, G, R = cv2.split(img) if np.array_equal(B, G) and np.array_equal(B, R): random_idx = np.random.randint(0, len(self.data_list)) return self.get_data_info(random_idx) data_info = { "img": img, "img_id": data_info["img_id"], "img_path": "", "airstore_id": data_info["airstore_id"], "bbox": bbox, "bbox_score": np.ones(1, dtype=np.float32), "num_keypoints": num_keypoints, "keypoints": keypoints, ## 1 x 308 x 2 "keypoints_visible": keypoints_visible, ## 1 x 308 "iscrowd": 0, "segmentation": None, "id": idx, "category_id": 1, } if idx >= 0: data_info["sample_idx"] = idx else: data_info["sample_idx"] = len(self) + idx # Add metainfo items that are required in the pipeline and the model metainfo_keys = [ "upper_body_ids", "lower_body_ids", "flip_pairs", "dataset_keypoint_weights", "flip_indices", "skeleton_links", ] for key in metainfo_keys: assert key not in data_info, ( f'"{key}" is a reserved key for `metainfo`, but already ' "exists in the `data_info`." ) data_info[key] = deepcopy(self.metainfo[key]) return data_info