Spaces:
Running on Zero
Running on Zero
File size: 7,417 Bytes
bff20b3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import copy
import json
import os
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import cv2
import numpy as np
from sapiens.engine.datasets import BaseDataset
from sapiens.registry import DATASETS
from .utils import parse_pose_metainfo
@DATASETS.register_module()
class PoseBaseDataset(BaseDataset):
METAINFO: dict = dict(from_file="configs/_base_/keypoints308.py")
def __init__(
self,
ann_file: str = "",
num_samples: int = None,
bbox_file: Optional[str] = None,
**kwargs,
):
self.bbox_file = bbox_file
self.ann_file = ann_file
self.num_samples = num_samples
self.metainfo = parse_pose_metainfo(self.METAINFO)
super().__init__(**kwargs)
if self.num_samples is not None:
self.data_list = self.data_list[:num_samples]
print(
"\033[96mLoaded {} samples for {}, Test mode: {}\033[0m".format(
self.__len__(), self.__class__.__name__, self.test_mode
)
)
return
def prepare_data(self, idx) -> Any:
data_info = self.get_data_info(idx)
transformed_data_info = self.pipeline(data_info)
if transformed_data_info is None:
return None
## pipeline is set to empty when using concatenation of datasets.
if (
self.test_mode == False
and "data_samples" in transformed_data_info
and "gt_instance_labels" in transformed_data_info["data_samples"]
and "keypoints_visible"
in transformed_data_info["data_samples"].gt_instance_labels
):
num_transformed_keypoints = (
transformed_data_info["data_samples"]
.gt_instance_labels["keypoints_visible"]
.sum()
.item()
) ## after cropping
## minimum visible keypoints for coco_wholebody is 8
if self.metainfo["dataset_name"] == "coco_wholebody":
if num_transformed_keypoints < 8:
return None
## absolute minimum visible keypoints is 3
if num_transformed_keypoints < 3:
return None
return transformed_data_info
def get_data_info(self, idx: int) -> dict:
data_info = super().get_data_info(idx)
data_info["img"] = cv2.imread(data_info["img_path"])
# Add metainfo items that are required in the pipeline and the model
metainfo_keys = [
"upper_body_ids",
"lower_body_ids",
"flip_pairs",
"dataset_keypoint_weights",
"flip_indices",
"skeleton_links",
]
for key in metainfo_keys:
assert key not in data_info, (
f'"{key}" is a reserved key for `metainfo`, but already '
"exists in the `data_info`."
)
data_info[key] = deepcopy(self.metainfo[key])
return data_info
def load_data_list(self) -> List[dict]:
if self.bbox_file:
data_list = self._load_detection_results()
else:
instance_list, _ = self._load_annotations()
data_list = self._get_topdown_data_infos(instance_list)
return data_list
def _load_annotations(self) -> Tuple[List[dict], List[dict]]:
from xtcocotools.coco import COCO # lazy: only needed for COCO-format ann files
assert os.path.exists(self.ann_file), "Annotation file does not exist"
self.coco = COCO(self.ann_file)
self.metainfo["CLASSES"] = self.coco.loadCats(self.coco.getCatIds())
instance_list = []
image_list = []
for img_id in self.coco.getImgIds():
img = self.coco.loadImgs(img_id)[0]
img.update(
{
"img_id": img_id,
"img_path": os.path.join(self.data_root, img["file_name"]),
}
)
image_list.append(img)
ann_ids = self.coco.getAnnIds(imgIds=img_id)
for ann in self.coco.loadAnns(ann_ids):
instance_info = self.parse_data_info(
dict(raw_ann_info=ann, raw_img_info=img)
)
# skip invalid instance annotation.
if not instance_info:
continue
instance_list.append(instance_info)
return instance_list, image_list
def parse_data_info(self, raw_data_info: dict) -> Optional[dict]:
ann = raw_data_info["raw_ann_info"]
img = raw_data_info["raw_img_info"]
# filter invalid instance
if "bbox" not in ann or "keypoints" not in ann:
return None
img_w, img_h = img["width"], img["height"]
# get bbox in shape [1, 4], formatted as xywh
x, y, w, h = ann["bbox"]
x1 = np.clip(x, 0, img_w - 1)
y1 = np.clip(y, 0, img_h - 1)
x2 = np.clip(x + w, 0, img_w - 1)
y2 = np.clip(y + h, 0, img_h - 1)
bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4)
# keypoints in shape [1, K, 2] and keypoints_visible in [1, K]
_keypoints = np.array(ann["keypoints"], dtype=np.float32).reshape(1, -1, 3)
keypoints = _keypoints[..., :2]
keypoints_visible = np.minimum(1, _keypoints[..., 2])
if "num_keypoints" in ann:
num_keypoints = ann["num_keypoints"]
else:
num_keypoints = np.count_nonzero(keypoints.max(axis=2))
data_info = {
"img_id": ann["image_id"],
"img_path": img["img_path"],
"bbox": bbox,
"bbox_score": np.ones(1, dtype=np.float32),
"num_keypoints": num_keypoints,
"keypoints": keypoints,
"keypoints_visible": keypoints_visible,
"iscrowd": ann.get("iscrowd", 0),
"segmentation": ann.get("segmentation", None),
"id": ann["id"],
"category_id": ann["category_id"],
"raw_ann_info": copy.deepcopy(ann),
}
if "crowdIndex" in img:
data_info["crowd_index"] = img["crowdIndex"]
return data_info
@staticmethod
def _is_valid_instance(data_info: Dict) -> bool:
# crowd annotation
if "iscrowd" in data_info and data_info["iscrowd"]:
return False
# invalid keypoints
if "num_keypoints" in data_info and data_info["num_keypoints"] == 0:
return False
# invalid bbox
if "bbox" in data_info:
bbox = data_info["bbox"][0]
w, h = bbox[2:4] - bbox[:2]
if w <= 0 or h <= 0:
return False
# invalid keypoints
if "keypoints" in data_info:
if np.max(data_info["keypoints"]) <= 0:
return False
return True
def _get_topdown_data_infos(self, instance_list: List[Dict]) -> List[Dict]:
data_list_tp = list(filter(self._is_valid_instance, instance_list))
return data_list_tp
def _load_detection_results(self) -> List[dict]:
raise NotImplementedError
|