github_sync / A8 /pose_estimator.py
Amol Kaushik
A8: Add MoveNet pose estimator module (#33)
6678ad8
"""
MoveNet Pose Estimator Module
=============================
A Python module for human pose estimation using TensorFlow's MoveNet model.
This module provides functionality to:
- Load and run MoveNet pose estimation model
- Process images and videos
- Extract 17 COCO keypoints
- Visualize pose detection results
Issue #33 - A8: PoseNet/MoveNet Python Environment Setup
"""
import os
import time
from typing import Dict, List, Optional, Tuple, Union
import cv2
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
# COCO Keypoint definitions (17 keypoints)
KEYPOINT_NAMES = [
'nose',
'left_eye',
'right_eye',
'left_ear',
'right_ear',
'left_shoulder',
'right_shoulder',
'left_elbow',
'right_elbow',
'left_wrist',
'right_wrist',
'left_hip',
'right_hip',
'left_knee',
'right_knee',
'left_ankle',
'right_ankle'
]
# Skeleton connections for visualization
KEYPOINT_EDGES = {
(0, 1): 'face',
(0, 2): 'face',
(1, 3): 'face',
(2, 4): 'face',
(0, 5): 'torso',
(0, 6): 'torso',
(5, 7): 'left_arm',
(7, 9): 'left_arm',
(6, 8): 'right_arm',
(8, 10): 'right_arm',
(5, 6): 'torso',
(5, 11): 'torso',
(6, 12): 'torso',
(11, 12): 'torso',
(11, 13): 'left_leg',
(13, 15): 'left_leg',
(12, 14): 'right_leg',
(14, 16): 'right_leg',
}
# Colors for different body parts (BGR format for OpenCV)
EDGE_COLORS = {
'face': (255, 255, 0), # Cyan
'torso': (0, 255, 0), # Green
'left_arm': (255, 0, 0), # Blue
'right_arm': (0, 0, 255), # Red
'left_leg': (255, 165, 0), # Orange
'right_leg': (128, 0, 128), # Purple
}
class MoveNetPoseEstimator:
"""
MoveNet-based human pose estimator.
Supports two model variants:
- 'lightning': Faster, lower accuracy (default)
- 'thunder': Slower, higher accuracy
Example usage:
estimator = MoveNetPoseEstimator(model_name='lightning')
keypoints = estimator.detect_pose(image)
visualized = estimator.draw_keypoints(image, keypoints)
"""
# TensorFlow Hub model URLs
MODEL_URLS = {
'lightning': 'https://tfhub.dev/google/movenet/singlepose/lightning/4',
'thunder': 'https://tfhub.dev/google/movenet/singlepose/thunder/4',
}
# Input sizes for each model
INPUT_SIZES = {
'lightning': 192,
'thunder': 256,
}
def __init__(self, model_name: str = 'lightning'):
"""
Initialize the MoveNet pose estimator.
Args:
model_name: Model variant ('lightning' or 'thunder')
"""
if model_name not in self.MODEL_URLS:
raise ValueError(f"Model must be one of: {list(self.MODEL_URLS.keys())}")
self.model_name = model_name
self.input_size = self.INPUT_SIZES[model_name]
print(f"Loading MoveNet {model_name} model...")
self.model = hub.load(self.MODEL_URLS[model_name])
self.movenet = self.model.signatures['serving_default']
print(f"Model loaded successfully. Input size: {self.input_size}x{self.input_size}")
def preprocess_image(self, image: np.ndarray) -> tf.Tensor:
"""
Preprocess image for MoveNet inference.
Args:
image: Input image (BGR or RGB format, any size)
Returns:
Preprocessed tensor ready for inference
"""
# Convert BGR to RGB if needed (OpenCV loads as BGR)
if len(image.shape) == 3 and image.shape[2] == 3:
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
else:
image_rgb = image
# Resize to model input size
input_image = tf.image.resize_with_pad(
tf.expand_dims(image_rgb, axis=0),
self.input_size,
self.input_size
)
# Convert to int32 as required by MoveNet
input_image = tf.cast(input_image, dtype=tf.int32)
return input_image
def detect_pose(self, image: np.ndarray) -> Dict:
"""
Detect pose keypoints in an image.
Args:
image: Input image (BGR format from OpenCV)
Returns:
Dictionary with keypoint data:
{
'keypoints': {
'nose': {'x': float, 'y': float, 'confidence': float},
...
},
'inference_time_ms': float
}
"""
start_time = time.time()
# Preprocess
input_tensor = self.preprocess_image(image)
# Run inference
outputs = self.movenet(input_tensor)
keypoints_with_scores = outputs['output_0'].numpy()[0, 0, :, :]
inference_time = (time.time() - start_time) * 1000
# Parse keypoints
keypoints_dict = {}
for i, name in enumerate(KEYPOINT_NAMES):
y, x, confidence = keypoints_with_scores[i]
keypoints_dict[name] = {
'x': float(x),
'y': float(y),
'confidence': float(confidence)
}
return {
'keypoints': keypoints_dict,
'inference_time_ms': inference_time
}
def detect_pose_raw(self, image: np.ndarray) -> np.ndarray:
"""
Detect pose and return raw keypoints array.
Args:
image: Input image (BGR format)
Returns:
Array of shape (17, 3) with [y, x, confidence] for each keypoint
"""
input_tensor = self.preprocess_image(image)
outputs = self.movenet(input_tensor)
return outputs['output_0'].numpy()[0, 0, :, :]
def draw_keypoints(
self,
image: np.ndarray,
keypoints: Dict,
confidence_threshold: float = 0.3,
circle_radius: int = 5,
line_thickness: int = 2
) -> np.ndarray:
"""
Draw detected keypoints and skeleton on image.
Args:
image: Input image (will be copied, not modified)
keypoints: Keypoint dictionary from detect_pose()
confidence_threshold: Minimum confidence to draw keypoint
circle_radius: Radius of keypoint circles
line_thickness: Thickness of skeleton lines
Returns:
Image with keypoints and skeleton drawn
"""
output_image = image.copy()
height, width = image.shape[:2]
kps = keypoints['keypoints']
# Draw skeleton edges first (so keypoints appear on top)
for (start_idx, end_idx), body_part in KEYPOINT_EDGES.items():
start_name = KEYPOINT_NAMES[start_idx]
end_name = KEYPOINT_NAMES[end_idx]
start_kp = kps[start_name]
end_kp = kps[end_name]
if start_kp['confidence'] > confidence_threshold and end_kp['confidence'] > confidence_threshold:
start_point = (int(start_kp['x'] * width), int(start_kp['y'] * height))
end_point = (int(end_kp['x'] * width), int(end_kp['y'] * height))
color = EDGE_COLORS[body_part]
cv2.line(output_image, start_point, end_point, color, line_thickness)
# Draw keypoints
for name, kp in kps.items():
if kp['confidence'] > confidence_threshold:
x = int(kp['x'] * width)
y = int(kp['y'] * height)
cv2.circle(output_image, (x, y), circle_radius, (0, 255, 255), -1)
cv2.circle(output_image, (x, y), circle_radius, (0, 0, 0), 1)
return output_image
def process_video(
self,
video_path: str,
output_path: Optional[str] = None,
show_preview: bool = False,
confidence_threshold: float = 0.3
) -> List[Dict]:
"""
Process a video file and extract keypoints from each frame.
Args:
video_path: Path to input video file
output_path: Optional path to save annotated video
show_preview: Whether to show live preview (press 'q' to quit)
confidence_threshold: Minimum confidence for visualization
Returns:
List of keypoint dictionaries, one per frame
"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError(f"Could not open video: {video_path}")
# Get video properties
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(f"Video: {video_path}")
print(f"Resolution: {width}x{height}, FPS: {fps:.2f}, Frames: {total_frames}")
# Setup video writer if output path specified
writer = None
if output_path:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
all_keypoints = []
frame_idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
# Detect pose
result = self.detect_pose(frame)
result['frame_id'] = frame_idx
result['timestamp'] = frame_idx / fps if fps > 0 else 0
all_keypoints.append(result)
# Draw and optionally show/save
annotated_frame = self.draw_keypoints(frame, result, confidence_threshold)
if writer:
writer.write(annotated_frame)
if show_preview:
cv2.imshow('Pose Estimation', annotated_frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
frame_idx += 1
if frame_idx % 30 == 0:
print(f"Processed {frame_idx}/{total_frames} frames...")
cap.release()
if writer:
writer.release()
if show_preview:
cv2.destroyAllWindows()
print(f"Completed! Processed {frame_idx} frames.")
avg_inference = np.mean([r['inference_time_ms'] for r in all_keypoints])
print(f"Average inference time: {avg_inference:.2f} ms/frame")
return all_keypoints
def process_image_file(
self,
image_path: str,
output_path: Optional[str] = None,
confidence_threshold: float = 0.3
) -> Dict:
"""
Process a single image file.
Args:
image_path: Path to input image
output_path: Optional path to save annotated image
confidence_threshold: Minimum confidence for visualization
Returns:
Keypoint dictionary for the image
"""
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"Could not read image: {image_path}")
result = self.detect_pose(image)
if output_path:
annotated = self.draw_keypoints(image, result, confidence_threshold)
cv2.imwrite(output_path, annotated)
print(f"Saved annotated image to: {output_path}")
return result
def main():
"""Demo: Test the pose estimator on a sample image or webcam."""
import argparse
parser = argparse.ArgumentParser(description='MoveNet Pose Estimation Demo')
parser.add_argument('--model', choices=['lightning', 'thunder'], default='lightning',
help='Model variant (default: lightning)')
parser.add_argument('--image', type=str, help='Path to input image')
parser.add_argument('--video', type=str, help='Path to input video')
parser.add_argument('--webcam', action='store_true', help='Use webcam')
parser.add_argument('--output', type=str, help='Output path for annotated image/video')
args = parser.parse_args()
# Initialize estimator
estimator = MoveNetPoseEstimator(model_name=args.model)
if args.image:
# Process image
print(f"\nProcessing image: {args.image}")
result = estimator.process_image_file(
args.image,
output_path=args.output
)
print(f"Inference time: {result['inference_time_ms']:.2f} ms")
print("\nDetected keypoints:")
for name, kp in result['keypoints'].items():
if kp['confidence'] > 0.3:
print(f" {name}: ({kp['x']:.3f}, {kp['y']:.3f}) conf={kp['confidence']:.3f}")
elif args.video:
# Process video
print(f"\nProcessing video: {args.video}")
keypoints = estimator.process_video(
args.video,
output_path=args.output,
show_preview=True
)
print(f"\nExtracted keypoints from {len(keypoints)} frames")
elif args.webcam:
# Webcam demo
print("\nStarting webcam demo (press 'q' to quit)...")
cap = cv2.VideoCapture(0)
while True:
ret, frame = cap.read()
if not ret:
break
result = estimator.detect_pose(frame)
annotated = estimator.draw_keypoints(frame, result)
# Add FPS display
fps_text = f"Inference: {result['inference_time_ms']:.1f} ms"
cv2.putText(annotated, fps_text, (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
cv2.imshow('MoveNet Pose Estimation', annotated)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
else:
print("Please specify --image, --video, or --webcam")
print("Example: python pose_estimator.py --image test.jpg --output result.jpg")
if __name__ == '__main__':
main()