# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os from pathlib import Path import cv2 import numpy as np import torch import torchvision from sapiens.registry import VISUALIZERS from torch import nn from ..datasets.utils import parse_pose_metainfo @VISUALIZERS.register_module() class PoseVisualizer(nn.Module): def __init__( self, output_dir: str, vis_interval: int = 100, vis_max_samples: int = 4, vis_image_width: int = 384, vis_image_height: int = 512, num_keypoints: int = 308, scale: int = 4, line_width: int = 4, radius: int = 4, ): super().__init__() self.output_dir = Path(output_dir) self.output_dir.mkdir(parents=True, exist_ok=True) self.vis_max_samples = vis_max_samples self.vis_interval = vis_interval self.vis_image_width = vis_image_width self.vis_image_height = vis_image_height self.num_keypoints = num_keypoints self.scale = scale self.line_width = line_width self.radius = radius if self.num_keypoints == 308: self.dataset_meta = parse_pose_metainfo( dict(from_file="configs/_base_/keypoints308.py") ) self.bbox_color = self.dataset_meta.get("bbox_colors", "green") self.kpt_color = self.dataset_meta.get("keypoint_colors") self.link_color = self.dataset_meta.get("skeleton_link_colors") self.skeleton = self.dataset_meta.get("skeleton_links") def add_batch(self, data_batch: dict, logs: dict, step: int): pred_heatmaps = logs["outputs"] pred_heatmaps = pred_heatmaps.detach().cpu() # B x K x H x W gt_heatmaps = ( data_batch["data_samples"]["heatmaps"].detach().cpu() ) # B x K x H x W inputs = data_batch["inputs"].detach().cpu() # B x 3 x H x W if pred_heatmaps.dtype == torch.bfloat16: inputs = inputs.float() pred_heatmaps = pred_heatmaps.float() pred_heatmaps = pred_heatmaps.cpu().detach().numpy() ## B x K x H x W gt_heatmaps = gt_heatmaps.cpu().detach().numpy() ## B x K x H x W target_weights = ( data_batch["data_samples"]["keypoint_weights"].squeeze(dim=1).cpu().numpy() ) ## B x K batch_size = min(len(inputs), self.vis_max_samples) inputs = inputs[:batch_size] pred_heatmaps = pred_heatmaps[:batch_size] ## B x K x H x W gt_heatmaps = gt_heatmaps[:batch_size] ## B x K x H x W target_weights = target_weights[:batch_size] ## B x K kps_vis_dir = os.path.join(self.output_dir, "kps") heatmap_vis_dir = os.path.join(self.output_dir, "heatmap") if not os.path.exists(kps_vis_dir): os.makedirs(kps_vis_dir, exist_ok=True) if not os.path.exists(heatmap_vis_dir): os.makedirs(heatmap_vis_dir, exist_ok=True) kps_prefix = os.path.join(kps_vis_dir, "train") heatmap_prefix = os.path.join(heatmap_vis_dir, "train") suffix = str(step).zfill(6) original_image = inputs / 255.0 ## B x 3 x H x W ## heatmap vis for only first 17 kps self.save_batch_heatmaps( original_image, gt_heatmaps[:, :17], "{}_{}_hm_gt.jpg".format(heatmap_prefix, suffix), normalize=False, scale=self.scale, is_rgb=False, ) self.save_batch_heatmaps( original_image, pred_heatmaps[:, :17], "{}_{}_hm_pred.jpg".format(heatmap_prefix, suffix), normalize=False, scale=self.scale, is_rgb=False, ) self.save_batch_image_with_joints( 255 * original_image, gt_heatmaps, target_weights, "{}_{}_gt.jpg".format(kps_prefix, suffix), scale=self.scale, is_rgb=False, ) self.save_batch_image_with_joints( 255 * original_image, pred_heatmaps, np.ones_like(target_weights), "{}_{}_pred.jpg".format(kps_prefix, suffix), scale=self.scale, is_rgb=False, ) return def save_batch_heatmaps( self, batch_image, batch_heatmaps, file_name, normalize=True, scale=4, is_rgb=True, max_num_joints=17, ): """ batch_image: [batch_size, channel, height, width] batch_heatmaps: ['batch_size, num_joints, height, width] file_name: saved file name """ ## normalize image if normalize: batch_image = batch_image.clone() min_val = float(batch_image.min()) max_val = float(batch_image.max()) batch_image.add_(-min_val).div_(max_val - min_val + 1e-5) ## check if type of batch_heatmaps is numpy.ndarray if isinstance(batch_heatmaps, np.ndarray): preds, maxvals = get_max_preds(batch_heatmaps) batch_heatmaps = torch.from_numpy(batch_heatmaps) else: preds, maxvals = get_max_preds(batch_heatmaps.detach().cpu().numpy()) preds = preds * scale ## scale to original image size batch_size = batch_heatmaps.size(0) num_joints = batch_heatmaps.size(1) heatmap_height = int(batch_heatmaps.size(2) * scale) heatmap_width = int(batch_heatmaps.size(3) * scale) num_joints = min(max_num_joints, num_joints) grid_image = np.zeros( (batch_size * heatmap_height, (num_joints + 1) * heatmap_width, 3), dtype=np.uint8, ) body_joint_order = range(max_num_joints) for i in range(batch_size): image = ( batch_image[i] .mul(255) .clamp(0, 255) .byte() .permute(1, 2, 0) .cpu() .numpy() ) heatmaps = batch_heatmaps[i].mul(255).clamp(0, 255).byte().cpu().numpy() if is_rgb == True: image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) resized_image = cv2.resize(image, (int(heatmap_width), int(heatmap_height))) height_begin = heatmap_height * i height_end = heatmap_height * (i + 1) for j in range(num_joints): joint_index = body_joint_order[j] cv2.circle( resized_image, (int(preds[i][joint_index][0]), int(preds[i][joint_index][1])), 1, [0, 0, 255], 1, ) heatmap = heatmaps[joint_index, :, :] colored_heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) colored_heatmap = cv2.resize( colored_heatmap, (int(heatmap_width), int(heatmap_height)) ) masked_image = colored_heatmap * 0.7 + resized_image * 0.3 cv2.circle( masked_image, (int(preds[i][joint_index][0]), int(preds[i][joint_index][1])), 1, [0, 0, 255], 1, ) width_begin = heatmap_width * (j + 1) width_end = heatmap_width * (j + 2) grid_image[height_begin:height_end, width_begin:width_end, :] = ( masked_image ) grid_image[height_begin:height_end, 0:heatmap_width, :] = resized_image ## resize target_height = batch_size * self.vis_image_height target_width = (num_joints + 1) * self.vis_image_width grid_image = cv2.resize(grid_image, (target_width, target_height)) cv2.imwrite(file_name, grid_image) return def save_batch_image_with_joints( self, batch_image, batch_heatmaps, batch_target_weight, file_name, is_rgb=True, scale=4, nrow=8, padding=2, ): """ batch_image: [batch_size, channel, height, width] batch_joints: [batch_size, num_joints, 3], batch_joints_vis: [batch_size, num_joints, 1], } """ B, C, H, W = batch_image.size() num_joints = batch_heatmaps.shape[1] ## check if type of batch_heatmaps is numpy.ndarray if isinstance(batch_heatmaps, np.ndarray): batch_joints, batch_scores = get_max_preds(batch_heatmaps) else: batch_joints, batch_scores = get_max_preds( batch_heatmaps.detach().cpu().numpy() ) batch_joints = ( batch_joints * scale ) ## 4 is the ratio of output heatmap and input image if isinstance(batch_joints, torch.Tensor): batch_joints = batch_joints.cpu().numpy() if isinstance(batch_target_weight, torch.Tensor): batch_target_weight = batch_target_weight.cpu().numpy() batch_target_weight = batch_target_weight.reshape(B, num_joints) ## B x 17 grid = [] for i in range(B): image = ( batch_image[i].permute(1, 2, 0).cpu().numpy() ) # image_size x image_size x BGR. if is_rgb is False. image = image.copy() kps = batch_joints[i] ## 17 x 2 kps_vis = batch_target_weight[i] kps_score = batch_scores[i].reshape(-1) if is_rgb == False: image = cv2.cvtColor( image, cv2.COLOR_BGR2RGB ) # convert bgr to rgb image kp_vis_image = self.draw_instance_kpts( image, keypoints=[kps], keypoints_visible=[kps_vis], keypoint_scores=[kps_score], radius=self.radius, thickness=self.line_width, kpt_thr=0.3, skeleton=self.skeleton, kpt_color=self.kpt_color, link_color=self.link_color, ) ## H, W, C, rgb image kp_vis_image = cv2.cvtColor( kp_vis_image, cv2.COLOR_RGB2BGR ) ## convert rgb to bgr image kp_vis_image = kp_vis_image.transpose((2, 0, 1)).astype(np.float32) kp_vis_image = torch.from_numpy(kp_vis_image.copy()) grid.append(kp_vis_image) grid = torchvision.utils.make_grid(grid, nrow, padding) ndarr = grid.byte().permute(1, 2, 0).cpu().numpy() ## resize target_height = self.vis_image_height target_width = ndarr.shape[1] * target_height // ndarr.shape[0] ndarr = cv2.resize(ndarr, (target_width, target_height)) cv2.imwrite(file_name, ndarr) return def draw_instance_kpts( self, image: np.ndarray, # RGB uint8 H,W,3 keypoints, # list[(J,2)] keypoints_visible, # list[(J,), {0/1}] keypoint_scores, # list[(J,)] *, radius: int = 4, thickness: int = -1, color=(255, 0, 0), kpt_thr: float = 0.3, skeleton: list | None = None, # [(i,j)] kpt_color: list | tuple | np.ndarray | None = None, link_color: list | tuple | np.ndarray | None = None, show_kpt_idx: bool = False, ) -> np.ndarray: img = image.copy() H, W = img.shape[:2] # defaults if skeleton is None: skeleton = [] # points only if kpt_color is None: kpt_color = color if link_color is None: link_color = (0, 255, 0) # robust color normalization: supports tuple, list-of-tuples, np.ndarray (N,3) or (3,) def _as_color_list(c, n): # torch -> numpy if hasattr(c, "detach"): c = c.detach().cpu().numpy() # numpy -> array if isinstance(c, np.ndarray): if c.ndim == 2 and c.shape[1] == 3: # (N,3) palette return [tuple(int(v) for v in row) for row in c.tolist()] if c.size == 3: # single (3,) return [tuple(int(v) for v in c.tolist())] * max(1, n) # python containers if isinstance(c, (list, tuple)): if n and len(c) == n and isinstance(c[0], (list, tuple, np.ndarray)): out = [] for cc in c: cc = np.asarray(cc).reshape(-1) assert cc.size == 3, "Each color must be length-3" out.append(tuple(int(v) for v in cc.tolist())) return out # single triplet c_arr = np.asarray(c).reshape(-1) if c_arr.size == 3: return [tuple(int(v) for v in c_arr.tolist())] * max(1, n) # fallback: red return [(255, 0, 0)] * max(1, n) J = keypoints[0].shape[0] if keypoints else 0 kpt_colors = _as_color_list(kpt_color, J) link_colors = _as_color_list(link_color, len(skeleton)) def in_bounds(x, y): return 0 <= x < W and 0 <= y < H for kpts, vis, score in zip(keypoints, keypoints_visible, keypoint_scores): kpts = np.asarray(kpts, float) vis = np.asarray(vis).reshape(-1).astype(bool) score = np.asarray(score).reshape(-1) # links (draw in RGB; NO channel flip) for lk, (i, j) in enumerate(skeleton): if i >= len(kpts) or j >= len(kpts): continue if not (vis[i] and vis[j]): continue if score[i] < kpt_thr or score[j] < kpt_thr: continue x1, y1 = map(int, np.round(kpts[i])) x2, y2 = map(int, np.round(kpts[j])) if not (in_bounds(x1, y1) and in_bounds(x2, y2)): continue cv2.line( img, (x1, y1), (x2, y2), link_colors[lk % len(link_colors)], thickness=max(1, self.line_width), lineType=cv2.LINE_AA, ) # points for j_idx, (xy, v, s) in enumerate(zip(kpts, vis, score)): if not v or s < kpt_thr: continue x, y = map(int, np.round(xy)) if not in_bounds(x, y): continue c = kpt_colors[min(j_idx, len(kpt_colors) - 1)] cv2.circle( img, (x, y), radius, c, thickness=thickness, lineType=cv2.LINE_AA ) if show_kpt_idx: cv2.putText( img, str(j_idx), (x + radius, y - radius), cv2.FONT_HERSHEY_SIMPLEX, 0.4, c, 1, cv2.LINE_AA, ) return img ###------------------helpers----------------------- def batch_unnormalize_image( images, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375] ): normalize = transforms.Normalize(mean=mean, std=std) images[:, 0, :, :] = (images[:, 0, :, :] * normalize.std[0]) + normalize.mean[0] images[:, 1, :, :] = (images[:, 1, :, :] * normalize.std[1]) + normalize.mean[1] images[:, 2, :, :] = (images[:, 2, :, :] * normalize.std[2]) + normalize.mean[2] return images def get_max_preds(batch_heatmaps): """ get predictions from score maps heatmaps: numpy.ndarray([batch_size, num_joints, height, width]) """ assert isinstance(batch_heatmaps, np.ndarray), ( "batch_heatmaps should be numpy.ndarray" ) assert batch_heatmaps.ndim == 4, "batch_images should be 4-ndim" batch_size = batch_heatmaps.shape[0] num_joints = batch_heatmaps.shape[1] width = batch_heatmaps.shape[3] heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1)) idx = np.argmax(heatmaps_reshaped, 2) ## B x 17 maxvals = np.amax(heatmaps_reshaped, 2) ## B x 17 maxvals = maxvals.reshape((batch_size, num_joints, 1)) ## B x 17 x 1 idx = idx.reshape((batch_size, num_joints, 1)) ## B x 17 x 1 preds = np.tile(idx, (1, 1, 2)).astype( np.float32 ) ## B x 17 x 2, like repeat in pytorch preds[:, :, 0] = (preds[:, :, 0]) % width preds[:, :, 1] = np.floor((preds[:, :, 1]) / width) pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2)) pred_mask = pred_mask.astype(np.float32) preds *= pred_mask return preds, maxvals