| import numpy as np |
| import tensorflow as tf |
| import matplotlib.pyplot as plt |
| import xml.etree.ElementTree as ET |
| import cv2 |
| import glob |
| import os |
| from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score |
|
|
| |
| test_images_dir = "test_images/" |
| test_annotations_dir = "test_annotations/" |
| tflite_model_path = "efficientdet_lite0.tflite" |
|
|
| |
| interpreter = tf.lite.Interpreter(model_path=tflite_model_path) |
| interpreter.allocate_tensors() |
|
|
| |
| def run_inference(interpreter, image): |
| input_details = interpreter.get_input_details() |
| output_details = interpreter.get_output_details() |
|
|
| |
| input_shape = input_details[0]['shape'] |
| image = cv2.resize(image, (input_shape[1], input_shape[2])) |
| image = image.astype(np.float32) / 255.0 |
| image = np.expand_dims(image, axis=0) |
|
|
| |
| interpreter.set_tensor(input_details[0]['index'], image) |
| interpreter.invoke() |
|
|
| |
| output_data = interpreter.get_tensor(output_details[0]['index']) |
| return output_data |
|
|
| |
| def parse_voc_annotation(xml_file): |
| tree = ET.parse(xml_file) |
| root = tree.getroot() |
|
|
| objects = root.findall("object") |
| return 1 if objects else 0 |
|
|
| |
| image_files = glob.glob(os.path.join(test_images_dir, "*.jpg")) |
| y_scores = [] |
| y_true = [] |
|
|
| for image_file in image_files: |
| |
| image = cv2.imread(image_file) |
|
|
| |
| xml_file = os.path.join(test_annotations_dir, os.path.splitext(os.path.basename(image_file))[0] + ".xml") |
| if not os.path.exists(xml_file): |
| continue |
|
|
| |
| true_label = parse_voc_annotation(xml_file) |
|
|
| |
| scores = run_inference(interpreter, image) |
| max_score = np.max(scores) |
|
|
| |
| y_scores.append(max_score) |
| y_true.append(true_label) |
|
|
| |
| y_scores = np.array(y_scores) |
| y_true = np.array(y_true) |
|
|
| |
| fpr, tpr, _ = roc_curve(y_true, y_scores) |
| roc_auc = auc(fpr, tpr) |
|
|
| |
| precision, recall, _ = precision_recall_curve(y_true, y_scores) |
| average_precision = average_precision_score(y_true, y_scores) |
|
|
| |
| plt.figure(figsize=(8, 6)) |
| plt.plot(fpr, tpr, color='blue', lw=2, label=f'ROC Curve (AUC = {roc_auc:.2f})') |
| plt.plot([0, 1], [0, 1], color='gray', linestyle='--') |
| plt.xlabel("False Positive Rate") |
| plt.ylabel("True Positive Rate") |
| plt.title("ROC Curve") |
| plt.legend() |
| plt.show() |
|
|
| |
| plt.figure(figsize=(8, 6)) |
| plt.plot(recall, precision, color='green', lw=2, label=f'PR Curve (AP = {average_precision:.2f})') |
| plt.xlabel("Recall") |
| plt.ylabel("Precision") |
| plt.title("Precision-Recall Curve") |
| plt.legend() |
| plt.show() |
|
|