File size: 7,417 Bytes
bff20b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import copy
import json
import os
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import cv2
import numpy as np
from sapiens.engine.datasets import BaseDataset
from sapiens.registry import DATASETS

from .utils import parse_pose_metainfo


@DATASETS.register_module()
class PoseBaseDataset(BaseDataset):
    METAINFO: dict = dict(from_file="configs/_base_/keypoints308.py")

    def __init__(
        self,
        ann_file: str = "",
        num_samples: int = None,
        bbox_file: Optional[str] = None,
        **kwargs,
    ):
        self.bbox_file = bbox_file
        self.ann_file = ann_file
        self.num_samples = num_samples

        self.metainfo = parse_pose_metainfo(self.METAINFO)
        super().__init__(**kwargs)

        if self.num_samples is not None:
            self.data_list = self.data_list[:num_samples]

        print(
            "\033[96mLoaded {} samples for {}, Test mode: {}\033[0m".format(
                self.__len__(), self.__class__.__name__, self.test_mode
            )
        )
        return

    def prepare_data(self, idx) -> Any:
        data_info = self.get_data_info(idx)
        transformed_data_info = self.pipeline(data_info)

        if transformed_data_info is None:
            return None

        ## pipeline is set to empty when using concatenation of datasets.
        if (
            self.test_mode == False
            and "data_samples" in transformed_data_info
            and "gt_instance_labels" in transformed_data_info["data_samples"]
            and "keypoints_visible"
            in transformed_data_info["data_samples"].gt_instance_labels
        ):
            num_transformed_keypoints = (
                transformed_data_info["data_samples"]
                .gt_instance_labels["keypoints_visible"]
                .sum()
                .item()
            )  ## after cropping

            ## minimum visible keypoints for coco_wholebody is 8
            if self.metainfo["dataset_name"] == "coco_wholebody":
                if num_transformed_keypoints < 8:
                    return None

            ## absolute minimum visible keypoints is 3
            if num_transformed_keypoints < 3:
                return None

        return transformed_data_info

    def get_data_info(self, idx: int) -> dict:
        data_info = super().get_data_info(idx)
        data_info["img"] = cv2.imread(data_info["img_path"])

        # Add metainfo items that are required in the pipeline and the model
        metainfo_keys = [
            "upper_body_ids",
            "lower_body_ids",
            "flip_pairs",
            "dataset_keypoint_weights",
            "flip_indices",
            "skeleton_links",
        ]

        for key in metainfo_keys:
            assert key not in data_info, (
                f'"{key}" is a reserved key for `metainfo`, but already '
                "exists in the `data_info`."
            )

            data_info[key] = deepcopy(self.metainfo[key])

        return data_info

    def load_data_list(self) -> List[dict]:
        if self.bbox_file:
            data_list = self._load_detection_results()
        else:
            instance_list, _ = self._load_annotations()
            data_list = self._get_topdown_data_infos(instance_list)

        return data_list

    def _load_annotations(self) -> Tuple[List[dict], List[dict]]:
        from xtcocotools.coco import COCO  # lazy: only needed for COCO-format ann files

        assert os.path.exists(self.ann_file), "Annotation file does not exist"
        self.coco = COCO(self.ann_file)
        self.metainfo["CLASSES"] = self.coco.loadCats(self.coco.getCatIds())

        instance_list = []
        image_list = []

        for img_id in self.coco.getImgIds():
            img = self.coco.loadImgs(img_id)[0]
            img.update(
                {
                    "img_id": img_id,
                    "img_path": os.path.join(self.data_root, img["file_name"]),
                }
            )
            image_list.append(img)
            ann_ids = self.coco.getAnnIds(imgIds=img_id)
            for ann in self.coco.loadAnns(ann_ids):
                instance_info = self.parse_data_info(
                    dict(raw_ann_info=ann, raw_img_info=img)
                )

                # skip invalid instance annotation.
                if not instance_info:
                    continue

                instance_list.append(instance_info)
        return instance_list, image_list

    def parse_data_info(self, raw_data_info: dict) -> Optional[dict]:
        ann = raw_data_info["raw_ann_info"]
        img = raw_data_info["raw_img_info"]

        # filter invalid instance
        if "bbox" not in ann or "keypoints" not in ann:
            return None

        img_w, img_h = img["width"], img["height"]

        # get bbox in shape [1, 4], formatted as xywh
        x, y, w, h = ann["bbox"]
        x1 = np.clip(x, 0, img_w - 1)
        y1 = np.clip(y, 0, img_h - 1)
        x2 = np.clip(x + w, 0, img_w - 1)
        y2 = np.clip(y + h, 0, img_h - 1)

        bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4)

        # keypoints in shape [1, K, 2] and keypoints_visible in [1, K]
        _keypoints = np.array(ann["keypoints"], dtype=np.float32).reshape(1, -1, 3)
        keypoints = _keypoints[..., :2]
        keypoints_visible = np.minimum(1, _keypoints[..., 2])

        if "num_keypoints" in ann:
            num_keypoints = ann["num_keypoints"]
        else:
            num_keypoints = np.count_nonzero(keypoints.max(axis=2))

        data_info = {
            "img_id": ann["image_id"],
            "img_path": img["img_path"],
            "bbox": bbox,
            "bbox_score": np.ones(1, dtype=np.float32),
            "num_keypoints": num_keypoints,
            "keypoints": keypoints,
            "keypoints_visible": keypoints_visible,
            "iscrowd": ann.get("iscrowd", 0),
            "segmentation": ann.get("segmentation", None),
            "id": ann["id"],
            "category_id": ann["category_id"],
            "raw_ann_info": copy.deepcopy(ann),
        }

        if "crowdIndex" in img:
            data_info["crowd_index"] = img["crowdIndex"]

        return data_info

    @staticmethod
    def _is_valid_instance(data_info: Dict) -> bool:
        # crowd annotation
        if "iscrowd" in data_info and data_info["iscrowd"]:
            return False
        # invalid keypoints
        if "num_keypoints" in data_info and data_info["num_keypoints"] == 0:
            return False
        # invalid bbox
        if "bbox" in data_info:
            bbox = data_info["bbox"][0]
            w, h = bbox[2:4] - bbox[:2]
            if w <= 0 or h <= 0:
                return False
        # invalid keypoints
        if "keypoints" in data_info:
            if np.max(data_info["keypoints"]) <= 0:
                return False
        return True

    def _get_topdown_data_infos(self, instance_list: List[Dict]) -> List[Dict]:
        data_list_tp = list(filter(self._is_valid_instance, instance_list))
        return data_list_tp

    def _load_detection_results(self) -> List[dict]:
        raise NotImplementedError