Jofax commited on
Commit
dfccaa2
·
verified ·
1 Parent(s): 003313d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+ from ultralytics import YOLO
5
+ from huggingface_hub import hf_hub_download
6
+ import tempfile
7
+ import os
8
+
9
+ # --- 1. SET UP MODELS ---
10
+ # Downloading specialized volleyball models from Davidsv/volley-ref-ai
11
+ try:
12
+ court_model_path = hf_hub_download(repo_id="Davidsv/volley-ref-ai", filename="yolo_court_keypoints.pt")
13
+ ball_model_path = hf_hub_download(repo_id="Davidsv/volley-ref-ai", filename="yolo_volleyball_ball.pt")
14
+
15
+ court_model = YOLO(court_model_path)
16
+ ball_model = YOLO(ball_model_path)
17
+ pose_model = YOLO("yolo11n-pose.pt") # General human pose model
18
+ except Exception as e:
19
+ print(f"Error loading models: {e}")
20
+
21
+ def process_volleyball_video(video_path):
22
+ if not video_path:
23
+ return None
24
+
25
+ cap = cv2.VideoCapture(video_path)
26
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
27
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
28
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
29
+
30
+ # Create a temporary file to save the processed video
31
+ temp_output = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
32
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
33
+ out = cv2.VideoWriter(temp_output.name, fourcc, fps, (width, height))
34
+
35
+ while cap.isOpened():
36
+ ret, frame = cap.read()
37
+ if not ret:
38
+ break
39
+
40
+ # Run Detections
41
+ court_res = court_model(frame, verbose=False)[0]
42
+ pose_res = pose_model(frame, verbose=False)[0]
43
+ ball_res = ball_model(frame, verbose=False)[0]
44
+
45
+ annotated_frame = frame.copy()
46
+
47
+ # Logic: Find the Net height (using court keypoints)
48
+ # Usually keypoints 6 and 7 in volleyball court models represent the net top
49
+ net_y = height // 2 # Default fallback
50
+ if court_res.keypoints is not None and len(court_res.keypoints.xy[0]) > 7:
51
+ net_y = int(court_res.keypoints.xy[0][6][1]) # Y-coord of net top
52
+
53
+ # Process Players
54
+ if pose_res.keypoints is not None:
55
+ for i, person in enumerate(pose_res.keypoints.xy):
56
+ if len(person) < 11: continue
57
+
58
+ # Get key joints (indices: 5=L_Shoulder, 6=R_Shoulder, 9=L_Wrist, 10=R_Wrist)
59
+ l_shoulder, r_shoulder = person[5], person[6]
60
+ l_wrist, r_wrist = person[9], person[10]
61
+
62
+ # ANALYSIS 1: Detection of a "Spike" (Hand above shoulder)
63
+ if (l_wrist[1] < l_shoulder[1] or r_wrist[1] < r_shoulder[1]) and l_wrist[1] > 0:
64
+ cv2.putText(annotated_frame, "SPIKE ATTACK", (int(l_shoulder[0]), int(l_shoulder[1]-20)),
65
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)
66
+
67
+ # ANALYSIS 2: Net Touch Mistake
68
+ # If wrist is near the net y-coordinate and moving forward
69
+ if abs(l_wrist[1] - net_y) < 10 or abs(r_wrist[1] - net_y) < 10:
70
+ cv2.putText(annotated_frame, "WARNING: NET TOUCH", (50, 50 + (i*30)),
71
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 3)
72
+
73
+ # Draw detections
74
+ annotated_frame = pose_res.plot(img=annotated_frame)
75
+ annotated_frame = court_res.plot(img=annotated_frame)
76
+
77
+ out.write(annotated_frame)
78
+
79
+ cap.release()
80
+ out.release()
81
+ return temp_output.name
82
+
83
+ # --- 3. GRADIO INTERFACE ---
84
+ interface = gr.Interface(
85
+ fn=process_volleyball_video,
86
+ inputs=gr.Video(label="Upload Volleyball Match"),
87
+ outputs=gr.Video(label="AI Analysis (Detections & Mistakes)"),
88
+ title="🏐 AI Volleyball Performance Lab",
89
+ description="This app uses YOLOv11 and specialized Volleyball-Ref-AI models to detect court lines, ball movement, and player form to identify mistakes.",
90
+ theme="soft"
91
+ )
92
+
93
+ if __name__ == "__main__":
94
+ interface.launch()