| import numpy as np
|
| import pandas as pd
|
| import json
|
| import tensorflow as tf
|
| import mediapipe as mp
|
| from skimage.transform import resize
|
| import matplotlib.pyplot as plt
|
| from mediapipe.framework.formats import landmark_pb2
|
| from PIL import Image
|
|
|
|
|
| with open("inference_args.json", "r") as f:
|
| SEL_COLS = json.load(f)["selected_columns"]
|
|
|
|
|
| interpreter = tf.lite.Interpreter(model_path="asl_model.tflite")
|
| interpreter.allocate_tensors()
|
| input_details = interpreter.get_input_details()
|
| output_details = interpreter.get_output_details()
|
|
|
|
|
| mp_drawing = mp.solutions.drawing_utils
|
| mp_drawing_styles = mp.solutions.drawing_styles
|
| mp_hands = mp.solutions.hands
|
|
|
| def load_relevant_data_subset(pq_path):
|
| return pd.read_parquet(pq_path, columns=SEL_COLS)
|
|
|
| def draw_hand_landmarks(seq_df):
|
| images = []
|
| for seq_idx in range(len(seq_df)):
|
| x_hand = seq_df.iloc[seq_idx].filter(regex="x_right_hand.*").values
|
| y_hand = seq_df.iloc[seq_idx].filter(regex="y_right_hand.*").values
|
| z_hand = seq_df.iloc[seq_idx].filter(regex="z_right_hand.*").values
|
|
|
| right_hand_image = np.zeros((600, 600, 3))
|
| right_hand_landmarks = landmark_pb2.NormalizedLandmarkList()
|
|
|
| for x, y, z in zip(x_hand, y_hand, z_hand):
|
| right_hand_landmarks.landmark.add(x=x, y=y, z=z)
|
|
|
| mp_drawing.draw_landmarks(
|
| right_hand_image,
|
| right_hand_landmarks,
|
| mp_hands.HAND_CONNECTIONS,
|
| landmark_drawing_spec=mp_drawing_styles.get_default_hand_landmarks_style()
|
| )
|
| images.append(right_hand_image)
|
| return images
|
|
|
| def preprocess_image(image):
|
| img = resize(image, (64, 64), preserve_range=True).astype(np.float32) / 255.0
|
| return np.expand_dims(img, axis=0)
|
|
|
| def predict_from_parquet(parquet_path):
|
| df = load_relevant_data_subset(parquet_path)
|
| image_seq = draw_hand_landmarks(df)
|
| if not image_seq:
|
| raise ValueError("No hand image generated.")
|
| img = preprocess_image(image_seq[len(image_seq) // 2])
|
| interpreter.set_tensor(input_details[0]['index'], img)
|
| interpreter.invoke()
|
| output = interpreter.get_tensor(output_details[0]['index'])
|
| prediction = np.argmax(output)
|
| return prediction
|
|
|
| if __name__ == "__main__":
|
| import sys
|
| if len(sys.argv) < 2:
|
| print("Usage: python tflite_inference.py <parquet_file_path>")
|
| else:
|
| parquet_file = sys.argv[1]
|
| pred = predict_from_parquet(parquet_file)
|
| print("Predicted class index:", pred)
|
|
|