| import os |
| os.environ['ERPC'] = '1' |
|
|
| import torch |
| import cv2 |
| import time |
| import numpy as np |
| import trimesh |
|
|
| import arg_parser |
|
|
| from model import TEHNetWrapper |
| from settings import OUTPUT_HEIGHT, OUTPUT_WIDTH, MAIN_CAMERA, REAL_TEST_DATA_PATH |
|
|
|
|
| def pc_normalize(pc): |
| pc[:, 0] /= OUTPUT_WIDTH |
| pc[:, 1] /= OUTPUT_HEIGHT |
| pc[:, :2] = 2 * pc[:, :2] - 1 |
| |
| ts = pc[:, 2:] |
| |
| t_max = ts.max(0).values |
| t_min = ts.min(0).values |
|
|
| ts = (2 * ((ts - t_min) / (t_max - t_min))) - 1 |
|
|
| pc[:, 2:] = ts |
|
|
| return pc |
|
|
|
|
|
|
| def process_events(events): |
| n_events = 2048 |
|
|
| events[:, 2] -= events[0, 2] |
|
|
| event_grid = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH, 3), dtype=np.float32) |
| count_grid = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH), dtype=np.float32) |
|
|
| x, y, t, p = events.T |
| x, y = x.astype(dtype=np.int32), y.astype(dtype=np.int32) |
|
|
| np.add.at(event_grid, (y, x, 0), t) |
| np.add.at(event_grid, (y, x, 1), p == 1) |
| np.add.at(event_grid, (y, x, 2), p != 1) |
|
|
| np.add.at(count_grid, (y, x), 1) |
|
|
|
|
| yi, xi = np.nonzero(count_grid) |
| t_avg = event_grid[yi, xi, 0] / count_grid[yi, xi] |
| p_evn = event_grid[yi, xi, 1] |
| n_evn = event_grid[yi, xi, 2] |
|
|
| events = np.hstack([xi[:, None], yi[:, None], t_avg[:, None], p_evn[:, None], n_evn[:, None]]) |
|
|
| sampled_indices = np.random.choice(events.shape[0], n_events) |
| events = events[sampled_indices] |
|
|
| events = torch.tensor(events, dtype=torch.float32) |
|
|
| coordinates = np.zeros((events.shape[0], 2)) |
| event_frame = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH, 3), dtype=np.uint8) |
| for idx, (x, y, t_avg, p_evn, n_evn) in enumerate(events): |
| y, x = y.int(), x.int() |
| |
| coordinates[idx] = (y, x) |
| event_frame[y, x, 0] = (p_evn / (p_evn + n_evn)) * 255 |
| event_frame[y, x, -1] = (n_evn / (p_evn + n_evn)) * 255 |
|
|
|
|
| events[:, :3] = pc_normalize(events[:, :3]) |
|
|
| hand_data = { |
| 'event_frame': torch.tensor(event_frame, dtype=torch.uint8), |
| 'events': events.permute(1, 0).unsqueeze(0), |
| 'coordinates': torch.tensor(coordinates, dtype=torch.float32) |
| } |
| |
| return hand_data |
|
|
|
|
|
|
| def demo(net, device, data): |
| net.eval() |
|
|
| events = data['events'] |
| events = events.to(device=device, dtype=torch.float32) |
|
|
| start_time = time.time() |
| with torch.no_grad(): |
| outputs = net(events) |
|
|
| end_time = time.time() |
|
|
| N = events.shape[0] |
| print(end_time - start_time) |
|
|
| outputs['class_logits'] = outputs['class_logits'].softmax(1).argmax(1).int().cpu() |
| |
| frames = list() |
| for idx in range(N): |
| hands = dict() |
|
|
| hands['left'] = { |
| 'vertices': outputs['left']['vertices'][idx].cpu(), |
| 'j3d': outputs['left']['j3d'][idx].cpu(), |
| } |
|
|
| hands['right'] = { |
| 'vertices': outputs['right']['vertices'][idx].cpu(), |
| 'j3d': outputs['right']['j3d'][idx].cpu(), |
| } |
|
|
| coordinates = data['coordinates'] |
|
|
| seg_mask = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH, 3), dtype=np.uint8) |
| for edx, (y, x) in enumerate(coordinates): |
| y, x = y.int(), x.int() |
|
|
| cid = outputs['class_logits'][idx][edx] |
|
|
| if cid == 3: |
| seg_mask[y, x] = 255 |
| else: |
| seg_mask[y, x, cid] = 255 |
|
|
| hands['seg_mask'] = seg_mask |
|
|
| frames.append(hands) |
|
|
| return frames |
|
|
|
|
| class Ev2Hands: |
| def __init__(self) -> None: |
| arg_parser.demo() |
| device = torch.device('cpu') |
| net = TEHNetWrapper(device=device) |
|
|
| save_path = os.environ['CHECKPOINT_PATH'] |
|
|
| checkpoint = torch.load(save_path, map_location=device) |
| net.load_state_dict(checkpoint['state_dict'], strict=True) |
|
|
| rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0]) |
| |
| mano_hands = net.hands |
|
|
| self.net = net |
| self.device = device |
| self.mano_hands = mano_hands |
| self.rot = rot |
| |
| def __call__(self, data): |
| net = self.net |
| device = self.device |
| mano_hands = self.mano_hands |
| rot = self.rot |
|
|
| frame = demo(net=net, device=device, data=data)[0] |
| seg_mask = frame['seg_mask'] |
|
|
| pred_meshes = list() |
| for hand_type in ['left', 'right']: |
| faces = mano_hands[hand_type].faces |
|
|
| pred_mesh = trimesh.Trimesh(frame[hand_type]['vertices'].cpu().numpy() * 1000, faces) |
| pred_mesh.visual.vertex_colors = [255, 0, 0] |
| pred_meshes.append(pred_mesh) |
|
|
| pred_meshes = trimesh.util.concatenate(pred_meshes) |
| pred_meshes.apply_transform(rot) |
| |
| return pred_meshes |
|
|