Spaces:
Sleeping
Sleeping
| import tqdm | |
| import cv2 | |
| import numpy as np | |
| import re | |
| import os | |
| from mediapipe.python.solutions import drawing_utils as mp_drawing | |
| import mediapipe as mp | |
| from PoseClassification.pose_embedding import FullBodyPoseEmbedding | |
| from PoseClassification.pose_classifier import PoseClassifier | |
| from PoseClassification.utils import EMADictSmoothing | |
| # from PoseClassification.utils import RepetitionCounter | |
| from PoseClassification.visualize import PoseClassificationVisualizer | |
| import argparse | |
| from PoseClassification.utils import show_image | |
| def check_major_current_position(positions_detected:dict, threshold_position) -> str: | |
| ''' | |
| return the major position between those detected in frame, or return none | |
| INPUTS | |
| positions_detected : | |
| dict of positions given by position classifier and pose_classification_filtered | |
| {'pose1':8.0, 'pose2':2.0} | |
| threshold_position : | |
| values strictly below are considered "none" position | |
| OUTPUT | |
| major_position : | |
| string with position (classes from classifier and "none") | |
| ''' | |
| if max(positions_detected.values())<float(threshold_position): | |
| major_position='none' | |
| else: | |
| major_position=max(positions_detected, key=positions_detected.get) | |
| return major_position | |
| def yoga_position_classifier(): | |
| #Load arguments | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("video_path", help="string video path in") | |
| args = parser.parse_args() | |
| video_path_in = args.video_path | |
| direct_video=False | |
| if video_path_in=="live": | |
| video_path_in='data/live.mp4' | |
| direct_video=True | |
| video_path_out = re.sub(r'.mp4', r'_classified_video.mp4', video_path_in) | |
| results_classification_path_out = re.sub(r'.mp4', r'_classified_results.csv', video_path_in) | |
| # Initialize tracker, classifier and current position. | |
| # Initialize tracker. | |
| mp_pose = mp.solutions.pose | |
| pose_tracker = mp_pose.Pose() | |
| # Folder with pose class CSVs. That should be the same folder you used while | |
| # building classifier to output CSVs. | |
| pose_samples_folder = 'data/yoga_poses_csvs_out' | |
| # Initialize embedder. | |
| pose_embedder = FullBodyPoseEmbedding() | |
| # Initialize classifier. | |
| # Check that you are using the same parameters as during bootstrapping. | |
| pose_classifier = PoseClassifier( | |
| pose_samples_folder=pose_samples_folder, | |
| pose_embedder=pose_embedder, | |
| top_n_by_max_distance=30, | |
| top_n_by_mean_distance=10) | |
| # Initialize EMA smoothing. | |
| pose_classification_filter = EMADictSmoothing( | |
| window_size=10, | |
| alpha=0.2) | |
| # Initialize list of results | |
| position_list=[] | |
| frame_list=[] | |
| # Instruction if direct flux video | |
| if direct_video : | |
| video_cap = cv2.VideoCapture(0) | |
| # Instruction if path video | |
| else : | |
| assert type(video_path_in)==str, "Error in video path format, not a string. Abort." | |
| # Open video and get video parameters and check if video is OK | |
| video_cap = cv2.VideoCapture(video_path_in) | |
| video_n_frames = video_cap.get(cv2.CAP_PROP_FRAME_COUNT) | |
| assert type(video_n_frames)==float, 'Error in input video frames type. Abort.' | |
| assert video_n_frames>0.0, 'Error in input video frames number : no frame. Abort.' | |
| video_fps = video_cap.get(cv2.CAP_PROP_FPS) | |
| video_width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| video_height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| class_names=['chair', 'cobra', 'dog', 'goddess', 'plank', 'tree', 'warrior', 'none'] | |
| position_threshold = 8.0 | |
| # Open output video. | |
| out_video = cv2.VideoWriter(video_path_out, cv2.VideoWriter_fourcc(*'mp4v'), video_fps, (video_width, video_height)) | |
| # Initialize results | |
| frame_idx = 0 | |
| current_position = {"none":10.0} | |
| output_frame = None | |
| position_timer = 0 | |
| previous_position_major = 'none' | |
| try: | |
| with tqdm.tqdm(position=0, leave=True) as pbar: | |
| while True: | |
| # Get current time from beggining of video | |
| time_sec = float(frame_idx*(1/video_fps)) | |
| # Get current major position (str) | |
| current_position_major = check_major_current_position(current_position, position_threshold) | |
| success, input_frame = video_cap.read() | |
| if not success: | |
| print("Unable to read input video frame, breaking!") | |
| break | |
| # Run pose tracker | |
| input_frame_rgb = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB) | |
| result = pose_tracker.process(image=input_frame_rgb) | |
| pose_landmarks = result.pose_landmarks | |
| # Prepare the output frame | |
| output_frame = input_frame.copy() | |
| # Add a white banner on top | |
| banner_height = int(video_height//10) | |
| output_frame[0:banner_height, :] = (255, 255, 255) # White color | |
| # Load the logo image | |
| logo = cv2.imread("src/logo_impredalam.jpg") | |
| logo_height, logo_width = logo.shape[:2] | |
| logo_height_rescaled = banner_height | |
| logo_width_rescaled = int((logo_width*logo_height_rescaled)// logo_height ) | |
| logo = cv2.resize(logo, (logo_width_rescaled, logo_height_rescaled)) # Resize to banner scale | |
| # Overlay the logo on the upper right corner | |
| output_frame[0 : logo.shape[0], output_frame.shape[1] - logo.shape[1] :] = (logo) | |
| # If landmarks are detected | |
| if pose_landmarks is not None: | |
| mp_drawing.draw_landmarks( | |
| image=output_frame, | |
| landmark_list=pose_landmarks, | |
| connections=mp_pose.POSE_CONNECTIONS,) | |
| # Get landmarks | |
| frame_height, frame_width = output_frame.shape[0], output_frame.shape[1] | |
| pose_landmarks = np.array( | |
| [ | |
| [lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width] | |
| for lmk in pose_landmarks.landmark | |
| ], | |
| dtype=np.float32,) | |
| assert pose_landmarks.shape == (33,3,), "Unexpected landmarks shape: {}".format(pose_landmarks.shape) | |
| # Classify the pose on the current frame | |
| pose_classification = pose_classifier(pose_landmarks) | |
| # Smooth classification using EMA | |
| pose_classification_filtered = pose_classification_filter(pose_classification) | |
| current_position=pose_classification_filtered | |
| current_position_major=check_major_current_position(current_position, position_threshold) | |
| # If no landmarks are detected | |
| else: | |
| current_position={'none':10.0} | |
| current_position_major=check_major_current_position(current_position, position_threshold) | |
| # If landmarks or no landmarks detected : | |
| # Compute position timer according to current and previous position | |
| if current_position_major==previous_position_major: | |
| #increase position_timer | |
| position_timer+=(1/video_fps) | |
| else: | |
| previous_position_major=current_position_major | |
| position_timer=0 | |
| # Display current position on frame | |
| cv2.putText( | |
| output_frame, | |
| f"Pose: {current_position_major}", | |
| (int(0+(1//50*video_width)), int(0+banner_height//3)), #coord | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| float(0.9*(video_height/video_width)), # Font size | |
| (0, 0, 0), #color | |
| 1, # Thinner line | |
| cv2.LINE_AA,) | |
| # Display current position timer on frame | |
| cv2.putText( | |
| output_frame, | |
| f"Duration: {int(position_timer)} seconds", | |
| (int(0+(1//50*video_width)), int(0+(2*banner_height)//3)), #coord | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| float(0.9*(video_height/video_width)), # Font size | |
| (0, 0, 0), #color | |
| 1, # Thinner line | |
| cv2.LINE_AA,) | |
| # Show output frame | |
| cv2.imshow("Yoga position", output_frame) | |
| # Add current_position (dict) and frame index to list (output file for debug) | |
| position_list.append(current_position) | |
| frame_list.append(frame_idx) | |
| # Output file for debug | |
| with open(results_classification_path_out, 'a') as f: | |
| f.write(f'{frame_idx},{current_position}\n') | |
| key = cv2.waitKey(1) & 0xFF | |
| if key == ord("q"): | |
| break | |
| elif key == ord("r"): | |
| current_position = {'none':10.0} | |
| print("Position reset !") | |
| frame_idx += 1 | |
| pbar.update() | |
| finally: | |
| pose_tracker.close() | |
| video_cap.release() | |
| cv2.destroyAllWindows() | |
| # Close output video. | |
| out_video.release() | |
| return frame_list, position_list | |
| if __name__ == "__main__": | |
| yoga_position_classifier() |