Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import copy | |
| import json | |
| import os | |
| from copy import deepcopy | |
| from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union | |
| import cv2 | |
| import numpy as np | |
| from sapiens.engine.datasets import BaseDataset | |
| from sapiens.registry import DATASETS | |
| from .utils import parse_pose_metainfo | |
| class PoseBaseDataset(BaseDataset): | |
| METAINFO: dict = dict(from_file="configs/_base_/keypoints308.py") | |
| def __init__( | |
| self, | |
| ann_file: str = "", | |
| num_samples: int = None, | |
| bbox_file: Optional[str] = None, | |
| **kwargs, | |
| ): | |
| self.bbox_file = bbox_file | |
| self.ann_file = ann_file | |
| self.num_samples = num_samples | |
| self.metainfo = parse_pose_metainfo(self.METAINFO) | |
| super().__init__(**kwargs) | |
| if self.num_samples is not None: | |
| self.data_list = self.data_list[:num_samples] | |
| print( | |
| "\033[96mLoaded {} samples for {}, Test mode: {}\033[0m".format( | |
| self.__len__(), self.__class__.__name__, self.test_mode | |
| ) | |
| ) | |
| return | |
| def prepare_data(self, idx) -> Any: | |
| data_info = self.get_data_info(idx) | |
| transformed_data_info = self.pipeline(data_info) | |
| if transformed_data_info is None: | |
| return None | |
| ## pipeline is set to empty when using concatenation of datasets. | |
| if ( | |
| self.test_mode == False | |
| and "data_samples" in transformed_data_info | |
| and "gt_instance_labels" in transformed_data_info["data_samples"] | |
| and "keypoints_visible" | |
| in transformed_data_info["data_samples"].gt_instance_labels | |
| ): | |
| num_transformed_keypoints = ( | |
| transformed_data_info["data_samples"] | |
| .gt_instance_labels["keypoints_visible"] | |
| .sum() | |
| .item() | |
| ) ## after cropping | |
| ## minimum visible keypoints for coco_wholebody is 8 | |
| if self.metainfo["dataset_name"] == "coco_wholebody": | |
| if num_transformed_keypoints < 8: | |
| return None | |
| ## absolute minimum visible keypoints is 3 | |
| if num_transformed_keypoints < 3: | |
| return None | |
| return transformed_data_info | |
| def get_data_info(self, idx: int) -> dict: | |
| data_info = super().get_data_info(idx) | |
| data_info["img"] = cv2.imread(data_info["img_path"]) | |
| # Add metainfo items that are required in the pipeline and the model | |
| metainfo_keys = [ | |
| "upper_body_ids", | |
| "lower_body_ids", | |
| "flip_pairs", | |
| "dataset_keypoint_weights", | |
| "flip_indices", | |
| "skeleton_links", | |
| ] | |
| for key in metainfo_keys: | |
| assert key not in data_info, ( | |
| f'"{key}" is a reserved key for `metainfo`, but already ' | |
| "exists in the `data_info`." | |
| ) | |
| data_info[key] = deepcopy(self.metainfo[key]) | |
| return data_info | |
| def load_data_list(self) -> List[dict]: | |
| if self.bbox_file: | |
| data_list = self._load_detection_results() | |
| else: | |
| instance_list, _ = self._load_annotations() | |
| data_list = self._get_topdown_data_infos(instance_list) | |
| return data_list | |
| def _load_annotations(self) -> Tuple[List[dict], List[dict]]: | |
| from xtcocotools.coco import COCO # lazy: only needed for COCO-format ann files | |
| assert os.path.exists(self.ann_file), "Annotation file does not exist" | |
| self.coco = COCO(self.ann_file) | |
| self.metainfo["CLASSES"] = self.coco.loadCats(self.coco.getCatIds()) | |
| instance_list = [] | |
| image_list = [] | |
| for img_id in self.coco.getImgIds(): | |
| img = self.coco.loadImgs(img_id)[0] | |
| img.update( | |
| { | |
| "img_id": img_id, | |
| "img_path": os.path.join(self.data_root, img["file_name"]), | |
| } | |
| ) | |
| image_list.append(img) | |
| ann_ids = self.coco.getAnnIds(imgIds=img_id) | |
| for ann in self.coco.loadAnns(ann_ids): | |
| instance_info = self.parse_data_info( | |
| dict(raw_ann_info=ann, raw_img_info=img) | |
| ) | |
| # skip invalid instance annotation. | |
| if not instance_info: | |
| continue | |
| instance_list.append(instance_info) | |
| return instance_list, image_list | |
| def parse_data_info(self, raw_data_info: dict) -> Optional[dict]: | |
| ann = raw_data_info["raw_ann_info"] | |
| img = raw_data_info["raw_img_info"] | |
| # filter invalid instance | |
| if "bbox" not in ann or "keypoints" not in ann: | |
| return None | |
| img_w, img_h = img["width"], img["height"] | |
| # get bbox in shape [1, 4], formatted as xywh | |
| x, y, w, h = ann["bbox"] | |
| x1 = np.clip(x, 0, img_w - 1) | |
| y1 = np.clip(y, 0, img_h - 1) | |
| x2 = np.clip(x + w, 0, img_w - 1) | |
| y2 = np.clip(y + h, 0, img_h - 1) | |
| bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4) | |
| # keypoints in shape [1, K, 2] and keypoints_visible in [1, K] | |
| _keypoints = np.array(ann["keypoints"], dtype=np.float32).reshape(1, -1, 3) | |
| keypoints = _keypoints[..., :2] | |
| keypoints_visible = np.minimum(1, _keypoints[..., 2]) | |
| if "num_keypoints" in ann: | |
| num_keypoints = ann["num_keypoints"] | |
| else: | |
| num_keypoints = np.count_nonzero(keypoints.max(axis=2)) | |
| data_info = { | |
| "img_id": ann["image_id"], | |
| "img_path": img["img_path"], | |
| "bbox": bbox, | |
| "bbox_score": np.ones(1, dtype=np.float32), | |
| "num_keypoints": num_keypoints, | |
| "keypoints": keypoints, | |
| "keypoints_visible": keypoints_visible, | |
| "iscrowd": ann.get("iscrowd", 0), | |
| "segmentation": ann.get("segmentation", None), | |
| "id": ann["id"], | |
| "category_id": ann["category_id"], | |
| "raw_ann_info": copy.deepcopy(ann), | |
| } | |
| if "crowdIndex" in img: | |
| data_info["crowd_index"] = img["crowdIndex"] | |
| return data_info | |
| def _is_valid_instance(data_info: Dict) -> bool: | |
| # crowd annotation | |
| if "iscrowd" in data_info and data_info["iscrowd"]: | |
| return False | |
| # invalid keypoints | |
| if "num_keypoints" in data_info and data_info["num_keypoints"] == 0: | |
| return False | |
| # invalid bbox | |
| if "bbox" in data_info: | |
| bbox = data_info["bbox"][0] | |
| w, h = bbox[2:4] - bbox[:2] | |
| if w <= 0 or h <= 0: | |
| return False | |
| # invalid keypoints | |
| if "keypoints" in data_info: | |
| if np.max(data_info["keypoints"]) <= 0: | |
| return False | |
| return True | |
| def _get_topdown_data_infos(self, instance_list: List[Dict]) -> List[Dict]: | |
| data_list_tp = list(filter(self._is_valid_instance, instance_list)) | |
| return data_list_tp | |
| def _load_detection_results(self) -> List[dict]: | |
| raise NotImplementedError | |