Spaces:
Running
Running
File size: 4,441 Bytes
ba23d94 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
from sapiens.engine import Config
def parse_pose_metainfo(metainfo: dict):
if "from_file" in metainfo:
cfg_file = metainfo["from_file"]
metainfo = Config.fromfile(cfg_file).dataset_info
# check data integrity
assert "dataset_name" in metainfo
assert "keypoint_info" in metainfo
assert "skeleton_info" in metainfo
# assert 'joint_weights' in metainfo
# assert 'sigmas' in metainfo
# parse metainfo
parsed = dict(
dataset_name=None,
num_keypoints=None,
keypoint_id2name={},
keypoint_name2id={},
upper_body_ids=[],
lower_body_ids=[],
flip_indices=[],
flip_pairs=[],
keypoint_colors=[],
num_skeleton_links=None,
skeleton_links=[],
skeleton_link_colors=[],
dataset_keypoint_weights=None,
sigmas=None,
)
parsed["dataset_name"] = metainfo["dataset_name"]
if "remove_teeth" in metainfo:
parsed["remove_teeth"] = metainfo["remove_teeth"]
if "min_visible_keypoints" in metainfo:
parsed["min_visible_keypoints"] = metainfo["min_visible_keypoints"]
if "teeth_keypoint_ids" in metainfo:
parsed["teeth_keypoint_ids"] = metainfo["teeth_keypoint_ids"]
if "coco_wholebody_to_goliath_mapping" in metainfo:
parsed["coco_wholebody_to_goliath_mapping"] = metainfo[
"coco_wholebody_to_goliath_mapping"
]
if "coco_wholebody_to_goliath_keypoint_info" in metainfo:
parsed["coco_wholebody_to_goliath_keypoint_info"] = metainfo[
"coco_wholebody_to_goliath_keypoint_info"
]
if "idx_to_original_idx_mapping" in metainfo:
parsed["idx_to_original_idx_mapping"] = metainfo["idx_to_original_idx_mapping"]
# parse keypoint information
parsed["num_keypoints"] = len(metainfo["keypoint_info"])
for kpt_id, kpt in metainfo["keypoint_info"].items():
kpt_name = kpt["name"]
parsed["keypoint_id2name"][kpt_id] = kpt_name
parsed["keypoint_name2id"][kpt_name] = kpt_id
parsed["keypoint_colors"].append(kpt.get("color", [255, 128, 0]))
kpt_type = kpt.get("type", "")
if kpt_type == "upper":
parsed["upper_body_ids"].append(kpt_id)
elif kpt_type == "lower":
parsed["lower_body_ids"].append(kpt_id)
swap_kpt = kpt.get("swap", "")
if swap_kpt == kpt_name or swap_kpt == "":
parsed["flip_indices"].append(kpt_name)
else:
parsed["flip_indices"].append(swap_kpt)
pair = (swap_kpt, kpt_name)
if pair not in parsed["flip_pairs"]:
parsed["flip_pairs"].append(pair)
# parse skeleton information
parsed["num_skeleton_links"] = len(metainfo["skeleton_info"])
for _, sk in metainfo["skeleton_info"].items():
parsed["skeleton_links"].append(sk["link"])
parsed["skeleton_link_colors"].append(sk.get("color", [96, 96, 255]))
# parse extra information
if "joint_weights" in metainfo:
parsed["dataset_keypoint_weights"] = np.array(
metainfo["joint_weights"], dtype=np.float32
)
if "sigmas" in metainfo:
parsed["sigmas"] = np.array(metainfo["sigmas"], dtype=np.float32)
if "stats_info" in metainfo:
parsed["stats_info"] = {}
for name, val in metainfo["stats_info"].items():
parsed["stats_info"][name] = np.array(val, dtype=np.float32)
# formatting
def _map(src, mapping: dict):
if isinstance(src, (list, tuple)):
cls = type(src)
return cls(_map(s, mapping) for s in src)
else:
return mapping[src]
parsed["flip_pairs"] = _map(
parsed["flip_pairs"], mapping=parsed["keypoint_name2id"]
)
parsed["flip_indices"] = _map(
parsed["flip_indices"], mapping=parsed["keypoint_name2id"]
)
parsed["skeleton_links"] = _map(
parsed["skeleton_links"], mapping=parsed["keypoint_name2id"]
)
parsed["keypoint_colors"] = np.array(parsed["keypoint_colors"], dtype=np.uint8)
parsed["skeleton_link_colors"] = np.array(
parsed["skeleton_link_colors"], dtype=np.uint8
)
return parsed
|