| """ |
| Golf Ball Tracker for Mobile Phone Camera |
| =========================================== |
| Real-time golf ball detection + tracking with: |
| - YOLO-based detection (exported to ONNX/TFLite for mobile) |
| - Kalman filter for smooth trajectory tracking |
| - Ballistic trajectory prediction for when ball is invisible |
| |
| Usage: |
| # Load model and track from video file |
| tracker = GolfBallTracker("path/to/model.onnx") |
| tracker.track_video("input.mp4", "output.mp4") |
| |
| # Or from camera (mobile) |
| tracker.track_camera(camera_id=0) |
| |
| Mobile Deployment: |
| - Export YOLO to TFLite: model.export(format="tflite", int8=True) |
| - For iOS: model.export(format="coreml") |
| - Use ONNX Runtime for cross-platform inference |
| """ |
|
|
| import numpy as np |
| import cv2 |
| from dataclasses import dataclass |
| from typing import List, Tuple, Optional |
| from collections import deque |
| import time |
|
|
|
|
| @dataclass |
| class Detection: |
| """A detected golf ball.""" |
| x: float |
| y: float |
| w: float |
| h: float |
| confidence: float |
| frame_id: int = 0 |
|
|
|
|
| class KalmanTracker: |
| """ |
| Kalman filter for 2D ball tracking. |
| State: [x, y, vx, vy, ax, ay] |
| Observation: [x, y] |
| """ |
| def __init__(self, dt: float = 1.0/30.0): |
| self.dt = dt |
| n = 6 |
| m = 2 |
| |
| |
| self.F = np.array([ |
| [1, 0, dt, 0, 0.5*dt**2, 0], |
| [0, 1, 0, dt, 0, 0.5*dt**2], |
| [0, 0, 1, 0, dt, 0], |
| [0, 0, 0, 1, 0, dt], |
| [0, 0, 0, 0, 1, 0], |
| [0, 0, 0, 0, 0, 1] |
| ]) |
| |
| |
| self.H = np.array([ |
| [1, 0, 0, 0, 0, 0], |
| [0, 1, 0, 0, 0, 0] |
| ]) |
| |
| |
| q = 0.5 |
| self.Q = q * np.eye(n) |
| |
| |
| r = 2.0 |
| self.R = r * np.eye(m) |
| |
| |
| self.x = np.zeros((n, 1)) |
| self.P = np.eye(n) * 100 |
| |
| self.initialized = False |
| self.missed_frames = 0 |
| self.max_missed = 10 |
| |
| def predict(self) -> Tuple[float, float]: |
| """Predict next state.""" |
| self.x = self.F @ self.x |
| self.P = self.F @ self.P @ self.F.T + self.Q |
| return float(self.x[0, 0]), float(self.x[1, 0]) |
| |
| def update(self, z_x: float, z_y: float, confidence: float = 1.0): |
| """Update with new measurement.""" |
| if not self.initialized: |
| self.x[0, 0] = z_x |
| self.x[1, 0] = z_y |
| self.initialized = True |
| self.missed_frames = 0 |
| return |
| |
| z = np.array([[z_x], [z_y]]) |
| |
| |
| y = z - self.H @ self.x |
| |
| |
| S = self.H @ self.P @ self.H.T + self.R |
| |
| |
| K = self.P @ self.H.T @ np.linalg.inv(S) |
| |
| |
| self.x = self.x + K @ y |
| self.P = (np.eye(6) - K @ self.H) @ self.P |
| |
| self.missed_frames = 0 |
| |
| def predict_trajectory(self, n_steps: int = 30) -> List[Tuple[float, float]]: |
| """Predict future trajectory points using ballistic model.""" |
| if not self.initialized: |
| return [] |
| |
| trajectory = [] |
| x_pred = self.x.copy() |
| F_local = self.F.copy() |
| g = 9.81 |
| |
| for _ in range(n_steps): |
| |
| |
| x_pred = F_local @ x_pred |
| |
| |
| x_pred[5, 0] += 0.5 |
| trajectory.append((float(x_pred[0, 0]), float(x_pred[1, 0]))) |
| |
| return trajectory |
| |
| def get_position(self) -> Tuple[float, float]: |
| return float(self.x[0, 0]), float(self.x[1, 0]) |
| |
| def get_velocity(self) -> Tuple[float, float]: |
| return float(self.x[2, 0]), float(self.x[3, 0]) |
|
|
|
|
| class GolfBallTracker: |
| """ |
| Golf ball detection + tracking pipeline. |
| |
| Supports multiple backends: |
| - Ultralytics YOLO (Python) |
| - ONNX Runtime (cross-platform) |
| - TFLite (mobile) |
| """ |
| |
| def __init__(self, model_path: str, conf_threshold: float = 0.25, |
| iou_threshold: float = 0.45, use_kalman: bool = True, |
| fps: float = 30.0): |
| self.conf_threshold = conf_threshold |
| self.iou_threshold = iou_threshold |
| self.use_kalman = use_kalman |
| self.fps = fps |
| self.dt = 1.0 / fps |
| |
| self.kalman = KalmanTracker(dt=self.dt) if use_kalman else None |
| self.trajectory_history = deque(maxlen=100) |
| self.predicted_trajectory = [] |
| self.frame_count = 0 |
| |
| |
| self._load_model(model_path) |
| |
| def _load_model(self, model_path: str): |
| """Load detection model. Auto-detects format.""" |
| ext = model_path.lower().split('.')[-1] |
| |
| if ext == 'pt': |
| |
| try: |
| from ultralytics import YOLO |
| self.model = YOLO(model_path) |
| self.backend = 'ultralytics' |
| print(f"Loaded Ultralytics model: {model_path}") |
| except ImportError: |
| raise RuntimeError("ultralytics not installed. pip install ultralytics") |
| |
| elif ext == 'onnx': |
| import onnxruntime as ort |
| self.session = ort.InferenceSession(model_path) |
| self.input_name = self.session.get_inputs()[0].name |
| self.backend = 'onnx' |
| print(f"Loaded ONNX model: {model_path}") |
| |
| elif ext in ('tflite', 'lite'): |
| import tensorflow as tf |
| self.interpreter = tf.lite.Interpreter(model_path=model_path) |
| self.interpreter.allocate_tensors() |
| self.input_details = self.interpreter.get_input_details() |
| self.output_details = self.interpreter.get_output_details() |
| self.backend = 'tflite' |
| print(f"Loaded TFLite model: {model_path}") |
| |
| else: |
| raise ValueError(f"Unsupported model format: {ext}") |
| |
| def detect(self, frame: np.ndarray) -> List[Detection]: |
| """Run detection on a single frame.""" |
| h, w = frame.shape[:2] |
| detections = [] |
| |
| if self.backend == 'ultralytics': |
| results = self.model(frame, conf=self.conf_threshold, iou=self.iou_threshold, verbose=False) |
| for r in results: |
| if r.boxes is None: |
| continue |
| for box in r.boxes: |
| x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() |
| conf = float(box.conf[0]) |
| cx, cy = (x1 + x2) / 2, (y1 + y2) / 2 |
| bw, bh = x2 - x1, y2 - y1 |
| detections.append(Detection(cx, cy, bw, bh, conf, self.frame_count)) |
| |
| elif self.backend == 'onnx': |
| |
| img = cv2.resize(frame, (640, 640)) |
| img = img.astype(np.float32) / 255.0 |
| img = np.transpose(img, (2, 0, 1)) |
| img = np.expand_dims(img, axis=0) |
| |
| |
| outputs = self.session.run(None, {self.input_name: img}) |
| |
| |
| predictions = outputs[0][0] |
| |
| for pred in predictions.T: |
| conf = pred[4] |
| if conf < self.conf_threshold: |
| continue |
| |
| cx, cy, bw, bh = pred[:4] |
| |
| cx = cx * w / 640 |
| cy = cy * h / 640 |
| bw = bw * w / 640 |
| bh = bh * h / 640 |
| detections.append(Detection(cx, cy, bw, bh, conf, self.frame_count)) |
| |
| elif self.backend == 'tflite': |
| |
| input_shape = self.input_details[0]['shape'] |
| _, inp_h, inp_w, _ = input_shape |
| img = cv2.resize(frame, (inp_w, inp_h)) |
| img = img.astype(np.float32) / 255.0 |
| img = np.expand_dims(img, axis=0) |
| |
| self.interpreter.set_tensor(self.input_details[0]['index'], img) |
| self.interpreter.invoke() |
| outputs = self.interpreter.get_tensor(self.output_details[0]['index']) |
| |
| |
| for det in outputs[0]: |
| |
| if det[4] < self.conf_threshold: |
| continue |
| cx = det[0] * w / inp_w |
| cy = det[1] * h / inp_h |
| bw = det[2] * w / inp_w |
| bh = det[3] * h / inp_h |
| detections.append(Detection(cx, cy, bw, bh, det[4], self.frame_count)) |
| |
| |
| detections = self._nms(detections) |
| return detections |
| |
| def _nms(self, detections: List[Detection]) -> List[Detection]: |
| """Simple NMS.""" |
| if not detections: |
| return [] |
| |
| detections = sorted(detections, key=lambda d: d.confidence, reverse=True) |
| keep = [] |
| |
| while detections: |
| best = detections.pop(0) |
| keep.append(best) |
| detections = [d for d in detections |
| if self._iou(best, d) < self.iou_threshold] |
| |
| return keep |
| |
| def _iou(self, a: Detection, b: Detection) -> float: |
| """Compute IoU between two detections.""" |
| ax1, ay1 = a.x - a.w/2, a.y - a.h/2 |
| ax2, ay2 = a.x + a.w/2, a.y + a.h/2 |
| bx1, by1 = b.x - b.w/2, b.y - b.h/2 |
| bx2, by2 = b.x + b.w/2, b.y + b.h/2 |
| |
| inter_x1 = max(ax1, bx1) |
| inter_y1 = max(ay1, by1) |
| inter_x2 = min(ax2, bx2) |
| inter_y2 = min(ay2, by2) |
| |
| inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1) |
| a_area = a.w * a.h |
| b_area = b.w * b.h |
| union = a_area + b_area - inter_area |
| |
| return inter_area / union if union > 0 else 0 |
| |
| def update(self, frame: np.ndarray) -> Tuple[Optional[Detection], np.ndarray]: |
| """ |
| Process one frame: detect ball, update tracker, predict trajectory. |
| Returns: (best_detection_or_none, annotated_frame) |
| """ |
| self.frame_count += 1 |
| h, w = frame.shape[:2] |
| |
| |
| detections = self.detect(frame) |
| |
| |
| best = max(detections, key=lambda d: d.confidence) if detections else None |
| |
| |
| if self.kalman: |
| if best: |
| self.kalman.update(best.x, best.y, best.confidence) |
| self.kalman.missed_frames = 0 |
| else: |
| self.kalman.missed_frames += 1 |
| |
| px, py = self.kalman.predict() |
| |
| best = Detection(px, py, 20, 20, 0.3, self.frame_count) |
| |
| |
| kx, ky = self.kalman.get_position() |
| self.trajectory_history.append((kx, ky)) |
| self.predicted_trajectory = self.kalman.predict_trajectory(n_steps=30) |
| else: |
| if best: |
| self.trajectory_history.append((best.x, best.y)) |
| |
| |
| annotated = frame.copy() |
| |
| |
| if len(self.trajectory_history) > 1: |
| points = list(self.trajectory_history) |
| for i in range(1, len(points)): |
| p1 = (int(points[i-1][0]), int(points[i-1][1])) |
| p2 = (int(points[i][0]), int(points[i][1])) |
| alpha = int(255 * i / len(points)) |
| cv2.line(annotated, p1, p2, (0, 255, 0), 2) |
| |
| |
| if self.predicted_trajectory: |
| for i, (px, py) in enumerate(self.predicted_trajectory): |
| if 0 <= px < w and 0 <= py < h: |
| alpha = int(255 * (1 - i / len(self.predicted_trajectory))) |
| color = (0, int(alpha), 255) |
| cv2.circle(annotated, (int(px), int(py)), 2, color, -1) |
| |
| |
| if best and best.confidence > 0.3: |
| x1 = int(best.x - best.w/2) |
| y1 = int(best.y - best.h/2) |
| x2 = int(best.x + best.w/2) |
| y2 = int(best.y + best.h/2) |
| color = (0, 255, 0) if best.confidence > 0.5 else (0, 165, 255) |
| cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 2) |
| cv2.putText(annotated, f"ball {best.confidence:.2f}", |
| (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1) |
| |
| |
| cv2.putText(annotated, f"Frame: {self.frame_count}", (10, 20), |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) |
| |
| return best, annotated |
| |
| def track_video(self, input_path: str, output_path: str): |
| """Process a video file.""" |
| cap = cv2.VideoCapture(input_path) |
| fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 |
| w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
| out = cv2.VideoWriter(output_path, fourcc, fps, (w, h)) |
| |
| while True: |
| ret, frame = cap.read() |
| if not ret: |
| break |
| |
| det, annotated = self.update(frame) |
| out.write(annotated) |
| |
| cap.release() |
| out.release() |
| print(f"Saved output to {output_path}") |
| |
| def track_camera(self, camera_id: int = 0): |
| """Track from live camera (for mobile).""" |
| cap = cv2.VideoCapture(camera_id) |
| |
| while True: |
| ret, frame = cap.read() |
| if not ret: |
| break |
| |
| det, annotated = self.update(frame) |
| cv2.imshow("Golf Ball Tracker", annotated) |
| |
| if cv2.waitKey(1) & 0xFF == ord('q'): |
| break |
| |
| cap.release() |
| cv2.destroyAllWindows() |
| |
| def get_trajectory(self) -> List[Tuple[float, float]]: |
| """Return tracked trajectory points.""" |
| return list(self.trajectory_history) |
| |
| def get_predicted_trajectory(self) -> List[Tuple[float, float]]: |
| """Return predicted future trajectory.""" |
| return self.predicted_trajectory |
|
|
|
|
| def export_model_for_mobile(): |
| """ |
| Example script to export a trained YOLO model for mobile deployment. |
| """ |
| from ultralytics import YOLO |
| |
| model = YOLO("/app/golf_ball_runs/golf_ball_detector/weights/best.pt") |
| |
| |
| print("Exporting to ONNX...") |
| model.export(format="onnx", imgsz=640, simplify=True) |
| |
| |
| print("Exporting to TFLite (INT8 for mobile)...") |
| model.export(format="tflite", imgsz=640, int8=True) |
| |
| |
| print("Exporting to CoreML...") |
| model.export(format="coreml", imgsz=640) |
| |
| print("Export complete!") |
|
|
|
|
| if __name__ == "__main__": |
| import sys |
| |
| if len(sys.argv) < 2: |
| print("Usage:") |
| print(" python golf_ball_tracker.py detect <model.pt> <video.mp4>") |
| print(" python golf_ball_tracker.py export") |
| sys.exit(1) |
| |
| cmd = sys.argv[1] |
| |
| if cmd == "detect": |
| if len(sys.argv) < 4: |
| print("Usage: python golf_ball_tracker.py detect <model> <video>") |
| sys.exit(1) |
| tracker = GolfBallTracker(sys.argv[2]) |
| tracker.track_video(sys.argv[3], "output_tracked.mp4") |
| |
| elif cmd == "export": |
| export_model_for_mobile() |
| |
| else: |
| print(f"Unknown command: {cmd}") |
|
|