File size: 4,090 Bytes
c139808
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import cv2
import numpy as np


def visualize_keypoints(
    image: np.ndarray,  # RGB uint8 H,W,3
    keypoints,  # list[(J,2)]
    keypoints_visible,  # list[(J,), {0/1}]
    keypoint_scores,  # list[(J,)]
    *,
    radius: int = 4,
    thickness: int = -1,
    color=(255, 0, 0),
    kpt_thr: float = 0.3,
    skeleton: list | None = None,  # [(i,j)]
    kpt_color: list | tuple | np.ndarray | None = None,
    link_color: list | tuple | np.ndarray | None = None,
    show_kpt_idx: bool = False,
) -> np.ndarray:
    img = image.copy()
    H, W = img.shape[:2]

    # defaults
    if skeleton is None:
        skeleton = []  # points only
    if kpt_color is None:
        kpt_color = color
    if link_color is None:
        link_color = (0, 255, 0)

    # robust color normalization: supports tuple, list-of-tuples, np.ndarray (N,3) or (3,)
    def _as_color_list(c, n):
        # torch -> numpy
        if hasattr(c, "detach"):
            c = c.detach().cpu().numpy()
        # numpy -> array
        if isinstance(c, np.ndarray):
            if c.ndim == 2 and c.shape[1] == 3:  # (N,3) palette
                return [tuple(int(v) for v in row) for row in c.tolist()]
            if c.size == 3:  # single (3,)
                return [tuple(int(v) for v in c.tolist())] * max(1, n)
        # python containers
        if isinstance(c, (list, tuple)):
            if n and len(c) == n and isinstance(c[0], (list, tuple, np.ndarray)):
                out = []
                for cc in c:
                    cc = np.asarray(cc).reshape(-1)
                    assert cc.size == 3, "Each color must be length-3"
                    out.append(tuple(int(v) for v in cc.tolist()))
                return out
            # single triplet
            c_arr = np.asarray(c).reshape(-1)
            if c_arr.size == 3:
                return [tuple(int(v) for v in c_arr.tolist())] * max(1, n)
        # fallback: red
        return [(255, 0, 0)] * max(1, n)

    J = keypoints[0].shape[0] if keypoints else 0
    kpt_colors = _as_color_list(kpt_color, J)
    link_colors = _as_color_list(link_color, len(skeleton))

    def in_bounds(x, y):
        return 0 <= x < W and 0 <= y < H

    for kpts, vis, score in zip(keypoints, keypoints_visible, keypoint_scores):
        kpts = np.asarray(kpts, float)
        vis = np.asarray(vis).reshape(-1).astype(bool)
        score = np.asarray(score).reshape(-1)

        # links (draw in RGB; NO channel flip)
        for lk, (i, j) in enumerate(skeleton):
            if i >= len(kpts) or j >= len(kpts):
                continue
            if not (vis[i] and vis[j]):
                continue
            if score[i] < kpt_thr or score[j] < kpt_thr:
                continue

            x1, y1 = map(int, np.round(kpts[i]))
            x2, y2 = map(int, np.round(kpts[j]))
            if not (in_bounds(x1, y1) and in_bounds(x2, y2)):
                continue

            cv2.line(
                img,
                (x1, y1),
                (x2, y2),
                link_colors[lk % len(link_colors)],
                thickness=max(1, thickness),
                lineType=cv2.LINE_AA,
            )

        # points
        for j_idx, (xy, v, s) in enumerate(zip(kpts, vis, score)):
            if not v or s < kpt_thr:
                continue
            x, y = map(int, np.round(xy))
            if not in_bounds(x, y):
                continue

            c = kpt_colors[min(j_idx, len(kpt_colors) - 1)]
            cv2.circle(img, (x, y), radius, c, thickness=-1, lineType=cv2.LINE_AA)
            if show_kpt_idx:
                cv2.putText(
                    img,
                    str(j_idx),
                    (x + radius, y - radius),
                    cv2.FONT_HERSHEY_SIMPLEX,
                    0.4,
                    c,
                    1,
                    cv2.LINE_AA,
                )
    return img