sapiens2-pointmap / sapiens /pose /src /datasets /pose_base_dataset.py
Rawal Khirodkar
Initial sapiens2-pointmap Space (HF download at startup, all 4 sizes, 3D viewer)
bff20b3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import copy
import json
import os
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import cv2
import numpy as np
from sapiens.engine.datasets import BaseDataset
from sapiens.registry import DATASETS
from .utils import parse_pose_metainfo
@DATASETS.register_module()
class PoseBaseDataset(BaseDataset):
METAINFO: dict = dict(from_file="configs/_base_/keypoints308.py")
def __init__(
self,
ann_file: str = "",
num_samples: int = None,
bbox_file: Optional[str] = None,
**kwargs,
):
self.bbox_file = bbox_file
self.ann_file = ann_file
self.num_samples = num_samples
self.metainfo = parse_pose_metainfo(self.METAINFO)
super().__init__(**kwargs)
if self.num_samples is not None:
self.data_list = self.data_list[:num_samples]
print(
"\033[96mLoaded {} samples for {}, Test mode: {}\033[0m".format(
self.__len__(), self.__class__.__name__, self.test_mode
)
)
return
def prepare_data(self, idx) -> Any:
data_info = self.get_data_info(idx)
transformed_data_info = self.pipeline(data_info)
if transformed_data_info is None:
return None
## pipeline is set to empty when using concatenation of datasets.
if (
self.test_mode == False
and "data_samples" in transformed_data_info
and "gt_instance_labels" in transformed_data_info["data_samples"]
and "keypoints_visible"
in transformed_data_info["data_samples"].gt_instance_labels
):
num_transformed_keypoints = (
transformed_data_info["data_samples"]
.gt_instance_labels["keypoints_visible"]
.sum()
.item()
) ## after cropping
## minimum visible keypoints for coco_wholebody is 8
if self.metainfo["dataset_name"] == "coco_wholebody":
if num_transformed_keypoints < 8:
return None
## absolute minimum visible keypoints is 3
if num_transformed_keypoints < 3:
return None
return transformed_data_info
def get_data_info(self, idx: int) -> dict:
data_info = super().get_data_info(idx)
data_info["img"] = cv2.imread(data_info["img_path"])
# Add metainfo items that are required in the pipeline and the model
metainfo_keys = [
"upper_body_ids",
"lower_body_ids",
"flip_pairs",
"dataset_keypoint_weights",
"flip_indices",
"skeleton_links",
]
for key in metainfo_keys:
assert key not in data_info, (
f'"{key}" is a reserved key for `metainfo`, but already '
"exists in the `data_info`."
)
data_info[key] = deepcopy(self.metainfo[key])
return data_info
def load_data_list(self) -> List[dict]:
if self.bbox_file:
data_list = self._load_detection_results()
else:
instance_list, _ = self._load_annotations()
data_list = self._get_topdown_data_infos(instance_list)
return data_list
def _load_annotations(self) -> Tuple[List[dict], List[dict]]:
from xtcocotools.coco import COCO # lazy: only needed for COCO-format ann files
assert os.path.exists(self.ann_file), "Annotation file does not exist"
self.coco = COCO(self.ann_file)
self.metainfo["CLASSES"] = self.coco.loadCats(self.coco.getCatIds())
instance_list = []
image_list = []
for img_id in self.coco.getImgIds():
img = self.coco.loadImgs(img_id)[0]
img.update(
{
"img_id": img_id,
"img_path": os.path.join(self.data_root, img["file_name"]),
}
)
image_list.append(img)
ann_ids = self.coco.getAnnIds(imgIds=img_id)
for ann in self.coco.loadAnns(ann_ids):
instance_info = self.parse_data_info(
dict(raw_ann_info=ann, raw_img_info=img)
)
# skip invalid instance annotation.
if not instance_info:
continue
instance_list.append(instance_info)
return instance_list, image_list
def parse_data_info(self, raw_data_info: dict) -> Optional[dict]:
ann = raw_data_info["raw_ann_info"]
img = raw_data_info["raw_img_info"]
# filter invalid instance
if "bbox" not in ann or "keypoints" not in ann:
return None
img_w, img_h = img["width"], img["height"]
# get bbox in shape [1, 4], formatted as xywh
x, y, w, h = ann["bbox"]
x1 = np.clip(x, 0, img_w - 1)
y1 = np.clip(y, 0, img_h - 1)
x2 = np.clip(x + w, 0, img_w - 1)
y2 = np.clip(y + h, 0, img_h - 1)
bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4)
# keypoints in shape [1, K, 2] and keypoints_visible in [1, K]
_keypoints = np.array(ann["keypoints"], dtype=np.float32).reshape(1, -1, 3)
keypoints = _keypoints[..., :2]
keypoints_visible = np.minimum(1, _keypoints[..., 2])
if "num_keypoints" in ann:
num_keypoints = ann["num_keypoints"]
else:
num_keypoints = np.count_nonzero(keypoints.max(axis=2))
data_info = {
"img_id": ann["image_id"],
"img_path": img["img_path"],
"bbox": bbox,
"bbox_score": np.ones(1, dtype=np.float32),
"num_keypoints": num_keypoints,
"keypoints": keypoints,
"keypoints_visible": keypoints_visible,
"iscrowd": ann.get("iscrowd", 0),
"segmentation": ann.get("segmentation", None),
"id": ann["id"],
"category_id": ann["category_id"],
"raw_ann_info": copy.deepcopy(ann),
}
if "crowdIndex" in img:
data_info["crowd_index"] = img["crowdIndex"]
return data_info
@staticmethod
def _is_valid_instance(data_info: Dict) -> bool:
# crowd annotation
if "iscrowd" in data_info and data_info["iscrowd"]:
return False
# invalid keypoints
if "num_keypoints" in data_info and data_info["num_keypoints"] == 0:
return False
# invalid bbox
if "bbox" in data_info:
bbox = data_info["bbox"][0]
w, h = bbox[2:4] - bbox[:2]
if w <= 0 or h <= 0:
return False
# invalid keypoints
if "keypoints" in data_info:
if np.max(data_info["keypoints"]) <= 0:
return False
return True
def _get_topdown_data_infos(self, instance_list: List[Dict]) -> List[Dict]:
data_list_tp = list(filter(self._is_valid_instance, instance_list))
return data_list_tp
def _load_detection_results(self) -> List[dict]:
raise NotImplementedError