Spaces:
Sleeping
Sleeping
File size: 4,090 Bytes
c139808 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import cv2
import numpy as np
def visualize_keypoints(
image: np.ndarray, # RGB uint8 H,W,3
keypoints, # list[(J,2)]
keypoints_visible, # list[(J,), {0/1}]
keypoint_scores, # list[(J,)]
*,
radius: int = 4,
thickness: int = -1,
color=(255, 0, 0),
kpt_thr: float = 0.3,
skeleton: list | None = None, # [(i,j)]
kpt_color: list | tuple | np.ndarray | None = None,
link_color: list | tuple | np.ndarray | None = None,
show_kpt_idx: bool = False,
) -> np.ndarray:
img = image.copy()
H, W = img.shape[:2]
# defaults
if skeleton is None:
skeleton = [] # points only
if kpt_color is None:
kpt_color = color
if link_color is None:
link_color = (0, 255, 0)
# robust color normalization: supports tuple, list-of-tuples, np.ndarray (N,3) or (3,)
def _as_color_list(c, n):
# torch -> numpy
if hasattr(c, "detach"):
c = c.detach().cpu().numpy()
# numpy -> array
if isinstance(c, np.ndarray):
if c.ndim == 2 and c.shape[1] == 3: # (N,3) palette
return [tuple(int(v) for v in row) for row in c.tolist()]
if c.size == 3: # single (3,)
return [tuple(int(v) for v in c.tolist())] * max(1, n)
# python containers
if isinstance(c, (list, tuple)):
if n and len(c) == n and isinstance(c[0], (list, tuple, np.ndarray)):
out = []
for cc in c:
cc = np.asarray(cc).reshape(-1)
assert cc.size == 3, "Each color must be length-3"
out.append(tuple(int(v) for v in cc.tolist()))
return out
# single triplet
c_arr = np.asarray(c).reshape(-1)
if c_arr.size == 3:
return [tuple(int(v) for v in c_arr.tolist())] * max(1, n)
# fallback: red
return [(255, 0, 0)] * max(1, n)
J = keypoints[0].shape[0] if keypoints else 0
kpt_colors = _as_color_list(kpt_color, J)
link_colors = _as_color_list(link_color, len(skeleton))
def in_bounds(x, y):
return 0 <= x < W and 0 <= y < H
for kpts, vis, score in zip(keypoints, keypoints_visible, keypoint_scores):
kpts = np.asarray(kpts, float)
vis = np.asarray(vis).reshape(-1).astype(bool)
score = np.asarray(score).reshape(-1)
# links (draw in RGB; NO channel flip)
for lk, (i, j) in enumerate(skeleton):
if i >= len(kpts) or j >= len(kpts):
continue
if not (vis[i] and vis[j]):
continue
if score[i] < kpt_thr or score[j] < kpt_thr:
continue
x1, y1 = map(int, np.round(kpts[i]))
x2, y2 = map(int, np.round(kpts[j]))
if not (in_bounds(x1, y1) and in_bounds(x2, y2)):
continue
cv2.line(
img,
(x1, y1),
(x2, y2),
link_colors[lk % len(link_colors)],
thickness=max(1, thickness),
lineType=cv2.LINE_AA,
)
# points
for j_idx, (xy, v, s) in enumerate(zip(kpts, vis, score)):
if not v or s < kpt_thr:
continue
x, y = map(int, np.round(xy))
if not in_bounds(x, y):
continue
c = kpt_colors[min(j_idx, len(kpt_colors) - 1)]
cv2.circle(img, (x, y), radius, c, thickness=-1, lineType=cv2.LINE_AA)
if show_kpt_idx:
cv2.putText(
img,
str(j_idx),
(x + radius, y - radius),
cv2.FONT_HERSHEY_SIMPLEX,
0.4,
c,
1,
cv2.LINE_AA,
)
return img
|