golf-ball-tracker / golf_ball_tracker.py
notjulietxd's picture
Upload golf_ball_tracker.py with huggingface_hub
f7db320 verified
"""
Golf Ball Tracker for Mobile Phone Camera
===========================================
Real-time golf ball detection + tracking with:
- YOLO-based detection (exported to ONNX/TFLite for mobile)
- Kalman filter for smooth trajectory tracking
- Ballistic trajectory prediction for when ball is invisible
Usage:
# Load model and track from video file
tracker = GolfBallTracker("path/to/model.onnx")
tracker.track_video("input.mp4", "output.mp4")
# Or from camera (mobile)
tracker.track_camera(camera_id=0)
Mobile Deployment:
- Export YOLO to TFLite: model.export(format="tflite", int8=True)
- For iOS: model.export(format="coreml")
- Use ONNX Runtime for cross-platform inference
"""
import numpy as np
import cv2
from dataclasses import dataclass
from typing import List, Tuple, Optional
from collections import deque
import time
@dataclass
class Detection:
"""A detected golf ball."""
x: float # center x (pixels)
y: float # center y (pixels)
w: float # width (pixels)
h: float # height (pixels)
confidence: float
frame_id: int = 0
class KalmanTracker:
"""
Kalman filter for 2D ball tracking.
State: [x, y, vx, vy, ax, ay]
Observation: [x, y]
"""
def __init__(self, dt: float = 1.0/30.0):
self.dt = dt
n = 6 # state dimension
m = 2 # measurement dimension
# State transition matrix (constant acceleration model)
self.F = np.array([
[1, 0, dt, 0, 0.5*dt**2, 0],
[0, 1, 0, dt, 0, 0.5*dt**2],
[0, 0, 1, 0, dt, 0],
[0, 0, 0, 1, 0, dt],
[0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 1]
])
# Measurement matrix (observe x, y only)
self.H = np.array([
[1, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0]
])
# Process noise
q = 0.5 # process noise scaling
self.Q = q * np.eye(n)
# Measurement noise
r = 2.0 # measurement noise (pixels)
self.R = r * np.eye(m)
# Initial state and covariance
self.x = np.zeros((n, 1))
self.P = np.eye(n) * 100
self.initialized = False
self.missed_frames = 0
self.max_missed = 10 # max frames without detection before reset
def predict(self) -> Tuple[float, float]:
"""Predict next state."""
self.x = self.F @ self.x
self.P = self.F @ self.P @ self.F.T + self.Q
return float(self.x[0, 0]), float(self.x[1, 0])
def update(self, z_x: float, z_y: float, confidence: float = 1.0):
"""Update with new measurement."""
if not self.initialized:
self.x[0, 0] = z_x
self.x[1, 0] = z_y
self.initialized = True
self.missed_frames = 0
return
z = np.array([[z_x], [z_y]])
# Innovation
y = z - self.H @ self.x
# Innovation covariance
S = self.H @ self.P @ self.H.T + self.R
# Kalman gain
K = self.P @ self.H.T @ np.linalg.inv(S)
# Update
self.x = self.x + K @ y
self.P = (np.eye(6) - K @ self.H) @ self.P
self.missed_frames = 0
def predict_trajectory(self, n_steps: int = 30) -> List[Tuple[float, float]]:
"""Predict future trajectory points using ballistic model."""
if not self.initialized:
return []
trajectory = []
x_pred = self.x.copy()
F_local = self.F.copy()
g = 9.81 # gravity (m/s^2, but we'll treat in pixel space)
for _ in range(n_steps):
# Apply gravity effect to vertical acceleration (approximate)
# In pixel space, this is a rough approximation
x_pred = F_local @ x_pred
# Add gravity to y-acceleration component (index 5)
# We don't have real-world scaling, so this is heuristic
x_pred[5, 0] += 0.5 # approximate pixel gravity per frame
trajectory.append((float(x_pred[0, 0]), float(x_pred[1, 0])))
return trajectory
def get_position(self) -> Tuple[float, float]:
return float(self.x[0, 0]), float(self.x[1, 0])
def get_velocity(self) -> Tuple[float, float]:
return float(self.x[2, 0]), float(self.x[3, 0])
class GolfBallTracker:
"""
Golf ball detection + tracking pipeline.
Supports multiple backends:
- Ultralytics YOLO (Python)
- ONNX Runtime (cross-platform)
- TFLite (mobile)
"""
def __init__(self, model_path: str, conf_threshold: float = 0.25,
iou_threshold: float = 0.45, use_kalman: bool = True,
fps: float = 30.0):
self.conf_threshold = conf_threshold
self.iou_threshold = iou_threshold
self.use_kalman = use_kalman
self.fps = fps
self.dt = 1.0 / fps
self.kalman = KalmanTracker(dt=self.dt) if use_kalman else None
self.trajectory_history = deque(maxlen=100) # store last 100 positions
self.predicted_trajectory = []
self.frame_count = 0
# Load model
self._load_model(model_path)
def _load_model(self, model_path: str):
"""Load detection model. Auto-detects format."""
ext = model_path.lower().split('.')[-1]
if ext == 'pt':
# PyTorch / Ultralytics
try:
from ultralytics import YOLO
self.model = YOLO(model_path)
self.backend = 'ultralytics'
print(f"Loaded Ultralytics model: {model_path}")
except ImportError:
raise RuntimeError("ultralytics not installed. pip install ultralytics")
elif ext == 'onnx':
import onnxruntime as ort
self.session = ort.InferenceSession(model_path)
self.input_name = self.session.get_inputs()[0].name
self.backend = 'onnx'
print(f"Loaded ONNX model: {model_path}")
elif ext in ('tflite', 'lite'):
import tensorflow as tf
self.interpreter = tf.lite.Interpreter(model_path=model_path)
self.interpreter.allocate_tensors()
self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()
self.backend = 'tflite'
print(f"Loaded TFLite model: {model_path}")
else:
raise ValueError(f"Unsupported model format: {ext}")
def detect(self, frame: np.ndarray) -> List[Detection]:
"""Run detection on a single frame."""
h, w = frame.shape[:2]
detections = []
if self.backend == 'ultralytics':
results = self.model(frame, conf=self.conf_threshold, iou=self.iou_threshold, verbose=False)
for r in results:
if r.boxes is None:
continue
for box in r.boxes:
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
conf = float(box.conf[0])
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
bw, bh = x2 - x1, y2 - y1
detections.append(Detection(cx, cy, bw, bh, conf, self.frame_count))
elif self.backend == 'onnx':
# Preprocess
img = cv2.resize(frame, (640, 640))
img = img.astype(np.float32) / 255.0
img = np.transpose(img, (2, 0, 1))
img = np.expand_dims(img, axis=0)
# Run inference
outputs = self.session.run(None, {self.input_name: img})
# Parse outputs (YOLOv8 ONNX format)
predictions = outputs[0][0] # shape: (84, 8400)
for pred in predictions.T:
conf = pred[4]
if conf < self.conf_threshold:
continue
# Extract bbox from first 4 values
cx, cy, bw, bh = pred[:4]
# Scale to original image
cx = cx * w / 640
cy = cy * h / 640
bw = bw * w / 640
bh = bh * h / 640
detections.append(Detection(cx, cy, bw, bh, conf, self.frame_count))
elif self.backend == 'tflite':
# Preprocess
input_shape = self.input_details[0]['shape']
_, inp_h, inp_w, _ = input_shape
img = cv2.resize(frame, (inp_w, inp_h))
img = img.astype(np.float32) / 255.0
img = np.expand_dims(img, axis=0)
self.interpreter.set_tensor(self.input_details[0]['index'], img)
self.interpreter.invoke()
outputs = self.interpreter.get_tensor(self.output_details[0]['index'])
# Parse (format varies by model)
for det in outputs[0]:
# Assuming [x, y, w, h, conf, class] format
if det[4] < self.conf_threshold:
continue
cx = det[0] * w / inp_w
cy = det[1] * h / inp_h
bw = det[2] * w / inp_w
bh = det[3] * h / inp_h
detections.append(Detection(cx, cy, bw, bh, det[4], self.frame_count))
# Non-maximum suppression (simple)
detections = self._nms(detections)
return detections
def _nms(self, detections: List[Detection]) -> List[Detection]:
"""Simple NMS."""
if not detections:
return []
detections = sorted(detections, key=lambda d: d.confidence, reverse=True)
keep = []
while detections:
best = detections.pop(0)
keep.append(best)
detections = [d for d in detections
if self._iou(best, d) < self.iou_threshold]
return keep
def _iou(self, a: Detection, b: Detection) -> float:
"""Compute IoU between two detections."""
ax1, ay1 = a.x - a.w/2, a.y - a.h/2
ax2, ay2 = a.x + a.w/2, a.y + a.h/2
bx1, by1 = b.x - b.w/2, b.y - b.h/2
bx2, by2 = b.x + b.w/2, b.y + b.h/2
inter_x1 = max(ax1, bx1)
inter_y1 = max(ay1, by1)
inter_x2 = min(ax2, bx2)
inter_y2 = min(ay2, by2)
inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
a_area = a.w * a.h
b_area = b.w * b.h
union = a_area + b_area - inter_area
return inter_area / union if union > 0 else 0
def update(self, frame: np.ndarray) -> Tuple[Optional[Detection], np.ndarray]:
"""
Process one frame: detect ball, update tracker, predict trajectory.
Returns: (best_detection_or_none, annotated_frame)
"""
self.frame_count += 1
h, w = frame.shape[:2]
# Detection
detections = self.detect(frame)
# Select best detection (highest confidence)
best = max(detections, key=lambda d: d.confidence) if detections else None
# Kalman update
if self.kalman:
if best:
self.kalman.update(best.x, best.y, best.confidence)
self.kalman.missed_frames = 0
else:
self.kalman.missed_frames += 1
# Predict anyway
px, py = self.kalman.predict()
# Create a predicted detection
best = Detection(px, py, 20, 20, 0.3, self.frame_count)
# Get smoothed position
kx, ky = self.kalman.get_position()
self.trajectory_history.append((kx, ky))
self.predicted_trajectory = self.kalman.predict_trajectory(n_steps=30)
else:
if best:
self.trajectory_history.append((best.x, best.y))
# Annotate frame
annotated = frame.copy()
# Draw trajectory history
if len(self.trajectory_history) > 1:
points = list(self.trajectory_history)
for i in range(1, len(points)):
p1 = (int(points[i-1][0]), int(points[i-1][1]))
p2 = (int(points[i][0]), int(points[i][1]))
alpha = int(255 * i / len(points))
cv2.line(annotated, p1, p2, (0, 255, 0), 2)
# Draw predicted trajectory
if self.predicted_trajectory:
for i, (px, py) in enumerate(self.predicted_trajectory):
if 0 <= px < w and 0 <= py < h:
alpha = int(255 * (1 - i / len(self.predicted_trajectory)))
color = (0, int(alpha), 255)
cv2.circle(annotated, (int(px), int(py)), 2, color, -1)
# Draw current detection
if best and best.confidence > 0.3:
x1 = int(best.x - best.w/2)
y1 = int(best.y - best.h/2)
x2 = int(best.x + best.w/2)
y2 = int(best.y + best.h/2)
color = (0, 255, 0) if best.confidence > 0.5 else (0, 165, 255)
cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 2)
cv2.putText(annotated, f"ball {best.confidence:.2f}",
(x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
# FPS display
cv2.putText(annotated, f"Frame: {self.frame_count}", (10, 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
return best, annotated
def track_video(self, input_path: str, output_path: str):
"""Process a video file."""
cap = cv2.VideoCapture(input_path)
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
while True:
ret, frame = cap.read()
if not ret:
break
det, annotated = self.update(frame)
out.write(annotated)
cap.release()
out.release()
print(f"Saved output to {output_path}")
def track_camera(self, camera_id: int = 0):
"""Track from live camera (for mobile)."""
cap = cv2.VideoCapture(camera_id)
while True:
ret, frame = cap.read()
if not ret:
break
det, annotated = self.update(frame)
cv2.imshow("Golf Ball Tracker", annotated)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
def get_trajectory(self) -> List[Tuple[float, float]]:
"""Return tracked trajectory points."""
return list(self.trajectory_history)
def get_predicted_trajectory(self) -> List[Tuple[float, float]]:
"""Return predicted future trajectory."""
return self.predicted_trajectory
def export_model_for_mobile():
"""
Example script to export a trained YOLO model for mobile deployment.
"""
from ultralytics import YOLO
model = YOLO("/app/golf_ball_runs/golf_ball_detector/weights/best.pt")
# ONNX - works on both Android and iOS
print("Exporting to ONNX...")
model.export(format="onnx", imgsz=640, simplify=True)
# TFLite - best for Android
print("Exporting to TFLite (INT8 for mobile)...")
model.export(format="tflite", imgsz=640, int8=True)
# CoreML - best for iOS
print("Exporting to CoreML...")
model.export(format="coreml", imgsz=640)
print("Export complete!")
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage:")
print(" python golf_ball_tracker.py detect <model.pt> <video.mp4>")
print(" python golf_ball_tracker.py export")
sys.exit(1)
cmd = sys.argv[1]
if cmd == "detect":
if len(sys.argv) < 4:
print("Usage: python golf_ball_tracker.py detect <model> <video>")
sys.exit(1)
tracker = GolfBallTracker(sys.argv[2])
tracker.track_video(sys.argv[3], "output_tracked.mp4")
elif cmd == "export":
export_model_for_mobile()
else:
print(f"Unknown command: {cmd}")