notjulietxd commited on
Commit
f7db320
·
verified ·
1 Parent(s): 4a61760

Upload golf_ball_tracker.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. golf_ball_tracker.py +475 -0
golf_ball_tracker.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Golf Ball Tracker for Mobile Phone Camera
3
+ ===========================================
4
+ Real-time golf ball detection + tracking with:
5
+ - YOLO-based detection (exported to ONNX/TFLite for mobile)
6
+ - Kalman filter for smooth trajectory tracking
7
+ - Ballistic trajectory prediction for when ball is invisible
8
+
9
+ Usage:
10
+ # Load model and track from video file
11
+ tracker = GolfBallTracker("path/to/model.onnx")
12
+ tracker.track_video("input.mp4", "output.mp4")
13
+
14
+ # Or from camera (mobile)
15
+ tracker.track_camera(camera_id=0)
16
+
17
+ Mobile Deployment:
18
+ - Export YOLO to TFLite: model.export(format="tflite", int8=True)
19
+ - For iOS: model.export(format="coreml")
20
+ - Use ONNX Runtime for cross-platform inference
21
+ """
22
+
23
+ import numpy as np
24
+ import cv2
25
+ from dataclasses import dataclass
26
+ from typing import List, Tuple, Optional
27
+ from collections import deque
28
+ import time
29
+
30
+
31
+ @dataclass
32
+ class Detection:
33
+ """A detected golf ball."""
34
+ x: float # center x (pixels)
35
+ y: float # center y (pixels)
36
+ w: float # width (pixels)
37
+ h: float # height (pixels)
38
+ confidence: float
39
+ frame_id: int = 0
40
+
41
+
42
+ class KalmanTracker:
43
+ """
44
+ Kalman filter for 2D ball tracking.
45
+ State: [x, y, vx, vy, ax, ay]
46
+ Observation: [x, y]
47
+ """
48
+ def __init__(self, dt: float = 1.0/30.0):
49
+ self.dt = dt
50
+ n = 6 # state dimension
51
+ m = 2 # measurement dimension
52
+
53
+ # State transition matrix (constant acceleration model)
54
+ self.F = np.array([
55
+ [1, 0, dt, 0, 0.5*dt**2, 0],
56
+ [0, 1, 0, dt, 0, 0.5*dt**2],
57
+ [0, 0, 1, 0, dt, 0],
58
+ [0, 0, 0, 1, 0, dt],
59
+ [0, 0, 0, 0, 1, 0],
60
+ [0, 0, 0, 0, 0, 1]
61
+ ])
62
+
63
+ # Measurement matrix (observe x, y only)
64
+ self.H = np.array([
65
+ [1, 0, 0, 0, 0, 0],
66
+ [0, 1, 0, 0, 0, 0]
67
+ ])
68
+
69
+ # Process noise
70
+ q = 0.5 # process noise scaling
71
+ self.Q = q * np.eye(n)
72
+
73
+ # Measurement noise
74
+ r = 2.0 # measurement noise (pixels)
75
+ self.R = r * np.eye(m)
76
+
77
+ # Initial state and covariance
78
+ self.x = np.zeros((n, 1))
79
+ self.P = np.eye(n) * 100
80
+
81
+ self.initialized = False
82
+ self.missed_frames = 0
83
+ self.max_missed = 10 # max frames without detection before reset
84
+
85
+ def predict(self) -> Tuple[float, float]:
86
+ """Predict next state."""
87
+ self.x = self.F @ self.x
88
+ self.P = self.F @ self.P @ self.F.T + self.Q
89
+ return float(self.x[0, 0]), float(self.x[1, 0])
90
+
91
+ def update(self, z_x: float, z_y: float, confidence: float = 1.0):
92
+ """Update with new measurement."""
93
+ if not self.initialized:
94
+ self.x[0, 0] = z_x
95
+ self.x[1, 0] = z_y
96
+ self.initialized = True
97
+ self.missed_frames = 0
98
+ return
99
+
100
+ z = np.array([[z_x], [z_y]])
101
+
102
+ # Innovation
103
+ y = z - self.H @ self.x
104
+
105
+ # Innovation covariance
106
+ S = self.H @ self.P @ self.H.T + self.R
107
+
108
+ # Kalman gain
109
+ K = self.P @ self.H.T @ np.linalg.inv(S)
110
+
111
+ # Update
112
+ self.x = self.x + K @ y
113
+ self.P = (np.eye(6) - K @ self.H) @ self.P
114
+
115
+ self.missed_frames = 0
116
+
117
+ def predict_trajectory(self, n_steps: int = 30) -> List[Tuple[float, float]]:
118
+ """Predict future trajectory points using ballistic model."""
119
+ if not self.initialized:
120
+ return []
121
+
122
+ trajectory = []
123
+ x_pred = self.x.copy()
124
+ F_local = self.F.copy()
125
+ g = 9.81 # gravity (m/s^2, but we'll treat in pixel space)
126
+
127
+ for _ in range(n_steps):
128
+ # Apply gravity effect to vertical acceleration (approximate)
129
+ # In pixel space, this is a rough approximation
130
+ x_pred = F_local @ x_pred
131
+ # Add gravity to y-acceleration component (index 5)
132
+ # We don't have real-world scaling, so this is heuristic
133
+ x_pred[5, 0] += 0.5 # approximate pixel gravity per frame
134
+ trajectory.append((float(x_pred[0, 0]), float(x_pred[1, 0])))
135
+
136
+ return trajectory
137
+
138
+ def get_position(self) -> Tuple[float, float]:
139
+ return float(self.x[0, 0]), float(self.x[1, 0])
140
+
141
+ def get_velocity(self) -> Tuple[float, float]:
142
+ return float(self.x[2, 0]), float(self.x[3, 0])
143
+
144
+
145
+ class GolfBallTracker:
146
+ """
147
+ Golf ball detection + tracking pipeline.
148
+
149
+ Supports multiple backends:
150
+ - Ultralytics YOLO (Python)
151
+ - ONNX Runtime (cross-platform)
152
+ - TFLite (mobile)
153
+ """
154
+
155
+ def __init__(self, model_path: str, conf_threshold: float = 0.25,
156
+ iou_threshold: float = 0.45, use_kalman: bool = True,
157
+ fps: float = 30.0):
158
+ self.conf_threshold = conf_threshold
159
+ self.iou_threshold = iou_threshold
160
+ self.use_kalman = use_kalman
161
+ self.fps = fps
162
+ self.dt = 1.0 / fps
163
+
164
+ self.kalman = KalmanTracker(dt=self.dt) if use_kalman else None
165
+ self.trajectory_history = deque(maxlen=100) # store last 100 positions
166
+ self.predicted_trajectory = []
167
+ self.frame_count = 0
168
+
169
+ # Load model
170
+ self._load_model(model_path)
171
+
172
+ def _load_model(self, model_path: str):
173
+ """Load detection model. Auto-detects format."""
174
+ ext = model_path.lower().split('.')[-1]
175
+
176
+ if ext == 'pt':
177
+ # PyTorch / Ultralytics
178
+ try:
179
+ from ultralytics import YOLO
180
+ self.model = YOLO(model_path)
181
+ self.backend = 'ultralytics'
182
+ print(f"Loaded Ultralytics model: {model_path}")
183
+ except ImportError:
184
+ raise RuntimeError("ultralytics not installed. pip install ultralytics")
185
+
186
+ elif ext == 'onnx':
187
+ import onnxruntime as ort
188
+ self.session = ort.InferenceSession(model_path)
189
+ self.input_name = self.session.get_inputs()[0].name
190
+ self.backend = 'onnx'
191
+ print(f"Loaded ONNX model: {model_path}")
192
+
193
+ elif ext in ('tflite', 'lite'):
194
+ import tensorflow as tf
195
+ self.interpreter = tf.lite.Interpreter(model_path=model_path)
196
+ self.interpreter.allocate_tensors()
197
+ self.input_details = self.interpreter.get_input_details()
198
+ self.output_details = self.interpreter.get_output_details()
199
+ self.backend = 'tflite'
200
+ print(f"Loaded TFLite model: {model_path}")
201
+
202
+ else:
203
+ raise ValueError(f"Unsupported model format: {ext}")
204
+
205
+ def detect(self, frame: np.ndarray) -> List[Detection]:
206
+ """Run detection on a single frame."""
207
+ h, w = frame.shape[:2]
208
+ detections = []
209
+
210
+ if self.backend == 'ultralytics':
211
+ results = self.model(frame, conf=self.conf_threshold, iou=self.iou_threshold, verbose=False)
212
+ for r in results:
213
+ if r.boxes is None:
214
+ continue
215
+ for box in r.boxes:
216
+ x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
217
+ conf = float(box.conf[0])
218
+ cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
219
+ bw, bh = x2 - x1, y2 - y1
220
+ detections.append(Detection(cx, cy, bw, bh, conf, self.frame_count))
221
+
222
+ elif self.backend == 'onnx':
223
+ # Preprocess
224
+ img = cv2.resize(frame, (640, 640))
225
+ img = img.astype(np.float32) / 255.0
226
+ img = np.transpose(img, (2, 0, 1))
227
+ img = np.expand_dims(img, axis=0)
228
+
229
+ # Run inference
230
+ outputs = self.session.run(None, {self.input_name: img})
231
+
232
+ # Parse outputs (YOLOv8 ONNX format)
233
+ predictions = outputs[0][0] # shape: (84, 8400)
234
+
235
+ for pred in predictions.T:
236
+ conf = pred[4]
237
+ if conf < self.conf_threshold:
238
+ continue
239
+ # Extract bbox from first 4 values
240
+ cx, cy, bw, bh = pred[:4]
241
+ # Scale to original image
242
+ cx = cx * w / 640
243
+ cy = cy * h / 640
244
+ bw = bw * w / 640
245
+ bh = bh * h / 640
246
+ detections.append(Detection(cx, cy, bw, bh, conf, self.frame_count))
247
+
248
+ elif self.backend == 'tflite':
249
+ # Preprocess
250
+ input_shape = self.input_details[0]['shape']
251
+ _, inp_h, inp_w, _ = input_shape
252
+ img = cv2.resize(frame, (inp_w, inp_h))
253
+ img = img.astype(np.float32) / 255.0
254
+ img = np.expand_dims(img, axis=0)
255
+
256
+ self.interpreter.set_tensor(self.input_details[0]['index'], img)
257
+ self.interpreter.invoke()
258
+ outputs = self.interpreter.get_tensor(self.output_details[0]['index'])
259
+
260
+ # Parse (format varies by model)
261
+ for det in outputs[0]:
262
+ # Assuming [x, y, w, h, conf, class] format
263
+ if det[4] < self.conf_threshold:
264
+ continue
265
+ cx = det[0] * w / inp_w
266
+ cy = det[1] * h / inp_h
267
+ bw = det[2] * w / inp_w
268
+ bh = det[3] * h / inp_h
269
+ detections.append(Detection(cx, cy, bw, bh, det[4], self.frame_count))
270
+
271
+ # Non-maximum suppression (simple)
272
+ detections = self._nms(detections)
273
+ return detections
274
+
275
+ def _nms(self, detections: List[Detection]) -> List[Detection]:
276
+ """Simple NMS."""
277
+ if not detections:
278
+ return []
279
+
280
+ detections = sorted(detections, key=lambda d: d.confidence, reverse=True)
281
+ keep = []
282
+
283
+ while detections:
284
+ best = detections.pop(0)
285
+ keep.append(best)
286
+ detections = [d for d in detections
287
+ if self._iou(best, d) < self.iou_threshold]
288
+
289
+ return keep
290
+
291
+ def _iou(self, a: Detection, b: Detection) -> float:
292
+ """Compute IoU between two detections."""
293
+ ax1, ay1 = a.x - a.w/2, a.y - a.h/2
294
+ ax2, ay2 = a.x + a.w/2, a.y + a.h/2
295
+ bx1, by1 = b.x - b.w/2, b.y - b.h/2
296
+ bx2, by2 = b.x + b.w/2, b.y + b.h/2
297
+
298
+ inter_x1 = max(ax1, bx1)
299
+ inter_y1 = max(ay1, by1)
300
+ inter_x2 = min(ax2, bx2)
301
+ inter_y2 = min(ay2, by2)
302
+
303
+ inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
304
+ a_area = a.w * a.h
305
+ b_area = b.w * b.h
306
+ union = a_area + b_area - inter_area
307
+
308
+ return inter_area / union if union > 0 else 0
309
+
310
+ def update(self, frame: np.ndarray) -> Tuple[Optional[Detection], np.ndarray]:
311
+ """
312
+ Process one frame: detect ball, update tracker, predict trajectory.
313
+ Returns: (best_detection_or_none, annotated_frame)
314
+ """
315
+ self.frame_count += 1
316
+ h, w = frame.shape[:2]
317
+
318
+ # Detection
319
+ detections = self.detect(frame)
320
+
321
+ # Select best detection (highest confidence)
322
+ best = max(detections, key=lambda d: d.confidence) if detections else None
323
+
324
+ # Kalman update
325
+ if self.kalman:
326
+ if best:
327
+ self.kalman.update(best.x, best.y, best.confidence)
328
+ self.kalman.missed_frames = 0
329
+ else:
330
+ self.kalman.missed_frames += 1
331
+ # Predict anyway
332
+ px, py = self.kalman.predict()
333
+ # Create a predicted detection
334
+ best = Detection(px, py, 20, 20, 0.3, self.frame_count)
335
+
336
+ # Get smoothed position
337
+ kx, ky = self.kalman.get_position()
338
+ self.trajectory_history.append((kx, ky))
339
+ self.predicted_trajectory = self.kalman.predict_trajectory(n_steps=30)
340
+ else:
341
+ if best:
342
+ self.trajectory_history.append((best.x, best.y))
343
+
344
+ # Annotate frame
345
+ annotated = frame.copy()
346
+
347
+ # Draw trajectory history
348
+ if len(self.trajectory_history) > 1:
349
+ points = list(self.trajectory_history)
350
+ for i in range(1, len(points)):
351
+ p1 = (int(points[i-1][0]), int(points[i-1][1]))
352
+ p2 = (int(points[i][0]), int(points[i][1]))
353
+ alpha = int(255 * i / len(points))
354
+ cv2.line(annotated, p1, p2, (0, 255, 0), 2)
355
+
356
+ # Draw predicted trajectory
357
+ if self.predicted_trajectory:
358
+ for i, (px, py) in enumerate(self.predicted_trajectory):
359
+ if 0 <= px < w and 0 <= py < h:
360
+ alpha = int(255 * (1 - i / len(self.predicted_trajectory)))
361
+ color = (0, int(alpha), 255)
362
+ cv2.circle(annotated, (int(px), int(py)), 2, color, -1)
363
+
364
+ # Draw current detection
365
+ if best and best.confidence > 0.3:
366
+ x1 = int(best.x - best.w/2)
367
+ y1 = int(best.y - best.h/2)
368
+ x2 = int(best.x + best.w/2)
369
+ y2 = int(best.y + best.h/2)
370
+ color = (0, 255, 0) if best.confidence > 0.5 else (0, 165, 255)
371
+ cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 2)
372
+ cv2.putText(annotated, f"ball {best.confidence:.2f}",
373
+ (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
374
+
375
+ # FPS display
376
+ cv2.putText(annotated, f"Frame: {self.frame_count}", (10, 20),
377
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
378
+
379
+ return best, annotated
380
+
381
+ def track_video(self, input_path: str, output_path: str):
382
+ """Process a video file."""
383
+ cap = cv2.VideoCapture(input_path)
384
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
385
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
386
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
387
+
388
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
389
+ out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
390
+
391
+ while True:
392
+ ret, frame = cap.read()
393
+ if not ret:
394
+ break
395
+
396
+ det, annotated = self.update(frame)
397
+ out.write(annotated)
398
+
399
+ cap.release()
400
+ out.release()
401
+ print(f"Saved output to {output_path}")
402
+
403
+ def track_camera(self, camera_id: int = 0):
404
+ """Track from live camera (for mobile)."""
405
+ cap = cv2.VideoCapture(camera_id)
406
+
407
+ while True:
408
+ ret, frame = cap.read()
409
+ if not ret:
410
+ break
411
+
412
+ det, annotated = self.update(frame)
413
+ cv2.imshow("Golf Ball Tracker", annotated)
414
+
415
+ if cv2.waitKey(1) & 0xFF == ord('q'):
416
+ break
417
+
418
+ cap.release()
419
+ cv2.destroyAllWindows()
420
+
421
+ def get_trajectory(self) -> List[Tuple[float, float]]:
422
+ """Return tracked trajectory points."""
423
+ return list(self.trajectory_history)
424
+
425
+ def get_predicted_trajectory(self) -> List[Tuple[float, float]]:
426
+ """Return predicted future trajectory."""
427
+ return self.predicted_trajectory
428
+
429
+
430
+ def export_model_for_mobile():
431
+ """
432
+ Example script to export a trained YOLO model for mobile deployment.
433
+ """
434
+ from ultralytics import YOLO
435
+
436
+ model = YOLO("/app/golf_ball_runs/golf_ball_detector/weights/best.pt")
437
+
438
+ # ONNX - works on both Android and iOS
439
+ print("Exporting to ONNX...")
440
+ model.export(format="onnx", imgsz=640, simplify=True)
441
+
442
+ # TFLite - best for Android
443
+ print("Exporting to TFLite (INT8 for mobile)...")
444
+ model.export(format="tflite", imgsz=640, int8=True)
445
+
446
+ # CoreML - best for iOS
447
+ print("Exporting to CoreML...")
448
+ model.export(format="coreml", imgsz=640)
449
+
450
+ print("Export complete!")
451
+
452
+
453
+ if __name__ == "__main__":
454
+ import sys
455
+
456
+ if len(sys.argv) < 2:
457
+ print("Usage:")
458
+ print(" python golf_ball_tracker.py detect <model.pt> <video.mp4>")
459
+ print(" python golf_ball_tracker.py export")
460
+ sys.exit(1)
461
+
462
+ cmd = sys.argv[1]
463
+
464
+ if cmd == "detect":
465
+ if len(sys.argv) < 4:
466
+ print("Usage: python golf_ball_tracker.py detect <model> <video>")
467
+ sys.exit(1)
468
+ tracker = GolfBallTracker(sys.argv[2])
469
+ tracker.track_video(sys.argv[3], "output_tracked.mp4")
470
+
471
+ elif cmd == "export":
472
+ export_model_for_mobile()
473
+
474
+ else:
475
+ print(f"Unknown command: {cmd}")